aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--Pipfile2
-rw-r--r--Pipfile.lock25
-rw-r--r--bot/__main__.py16
-rw-r--r--bot/decorators.py99
-rw-r--r--bot/errors.py20
-rw-r--r--bot/exts/backend/alias.py87
-rw-r--r--bot/exts/backend/error_handler.py3
-rw-r--r--bot/exts/backend/sync/_syncers.py98
-rw-r--r--bot/exts/fun/duck_pond.py13
-rw-r--r--bot/exts/fun/off_topic_names.py5
-rw-r--r--bot/exts/help_channels.py2
-rw-r--r--bot/exts/info/doc.py2
-rw-r--r--bot/exts/info/help.py2
-rw-r--r--bot/exts/info/information.py51
-rw-r--r--bot/exts/info/reddit.py8
-rw-r--r--bot/exts/info/site.py21
-rw-r--r--bot/exts/info/source.py24
-rw-r--r--bot/exts/info/stats.py37
-rw-r--r--bot/exts/info/tags.py2
-rw-r--r--bot/exts/moderation/dm_relay.py6
-rw-r--r--bot/exts/moderation/infraction/infractions.py32
-rw-r--r--bot/exts/moderation/infraction/management.py4
-rw-r--r--bot/exts/moderation/infraction/superstarify.py5
-rw-r--r--bot/exts/moderation/modlog.py2
-rw-r--r--bot/exts/moderation/verification.py47
-rw-r--r--bot/exts/utils/bot.py2
-rw-r--r--bot/exts/utils/internal.py (renamed from bot/exts/utils/eval.py)42
-rw-r--r--bot/exts/utils/ping.py2
-rw-r--r--bot/exts/utils/reminders.py61
-rw-r--r--bot/exts/utils/snekbox.py4
-rw-r--r--bot/exts/utils/utils.py4
-rw-r--r--bot/patches/__init__.py6
-rw-r--r--bot/patches/message_edited_at.py32
-rw-r--r--bot/utils/function.py75
-rw-r--r--bot/utils/lock.py114
-rw-r--r--bot/utils/messages.py37
-rw-r--r--tests/bot/exts/backend/sync/test_cog.py4
-rw-r--r--tests/bot/exts/backend/sync/test_users.py120
-rw-r--r--tests/bot/exts/info/test_information.py175
-rw-r--r--tests/bot/exts/moderation/test_silence.py4
-rw-r--r--tests/bot/exts/utils/test_snekbox.py14
-rw-r--r--tests/bot/patches/__init__.py0
43 files changed, 695 insertions, 615 deletions
diff --git a/.gitignore b/.gitignore
index fb3156ab1..2074887ad 100644
--- a/.gitignore
+++ b/.gitignore
@@ -110,6 +110,7 @@ ENV/
# Logfiles
log.*
+*.log.*
# Custom user configuration
config.yml
diff --git a/Pipfile b/Pipfile
index e6f84d911..99fc70b46 100644
--- a/Pipfile
+++ b/Pipfile
@@ -14,7 +14,7 @@ beautifulsoup4 = "~=4.9"
colorama = {version = "~=0.4.3",sys_platform = "== 'win32'"}
coloredlogs = "~=14.0"
deepdiff = "~=4.0"
-discord.py = "~=1.4.0"
+"discord.py" = "~=1.5.0"
feedparser = "~=5.2"
fuzzywuzzy = "~=0.17"
lxml = "~=4.4"
diff --git a/Pipfile.lock b/Pipfile.lock
index 4c63277de..becd85c55 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
- "sha256": "644012a1c3fa3e3a30f8b8f8e672c468dfaa155d9e43d26e2be8713c8dc5ebb3"
+ "sha256": "073fd0c51749aafa188fdbe96c5b90dd157cb1d23bdd144801fb0d0a369ffa88"
},
"pipfile-spec": 6,
"requires": {
@@ -18,11 +18,11 @@
"default": {
"aio-pika": {
"hashes": [
- "sha256:4a20d4d941e1f113a950ea529a90bd9159c8d7aafaa1c71e9c707c8c2b526ea6",
- "sha256:7bf3f183df1eb348d007210a0c1a3c5c755f1b3def1a9a395e93f30b91da1daf"
+ "sha256:9773440a89840941ac3099a7720bf9d51e8764a484066b82ede4d395660ff430",
+ "sha256:a8065be3c722eb8f9fff8c0e7590729e7782202cdb9363d9830d7d5d47b45c7c"
],
"index": "pypi",
- "version": "==6.7.0"
+ "version": "==6.7.1"
},
"aiodns": {
"hashes": [
@@ -205,22 +205,13 @@
"index": "pypi",
"version": "==4.3.2"
},
- "discord": {
- "hashes": [
- "sha256:9d4debb4a37845543bd4b92cb195bc53a302797333e768e70344222857ff1559",
- "sha256:ff6653655e342e7721dfb3f10421345fd852c2a33f2cca912b1c39b3778a9429"
- ],
- "index": "pypi",
- "py": "~=1.4.0",
- "version": "==1.0.1"
- },
"discord.py": {
"hashes": [
- "sha256:98ea3096a3585c9c379209926f530808f5fcf4930928d8cfb579d2562d119570",
- "sha256:f9decb3bfa94613d922376288617e6a6f969260923643e2897f4540c34793442"
+ "sha256:3acb61fde0d862ed346a191d69c46021e6063673f63963bc984ae09a685ab211",
+ "sha256:e71089886aa157341644bdecad63a72ff56b44406b1a6467b66db31c8e5a5a15"
],
- "markers": "python_full_version >= '3.5.3'",
- "version": "==1.4.1"
+ "index": "pypi",
+ "version": "==1.5.0"
},
"docutils": {
"hashes": [
diff --git a/bot/__main__.py b/bot/__main__.py
index a07bc21d6..da042a5ed 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -9,7 +9,7 @@ from sentry_sdk.integrations.aiohttp import AioHttpIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
from sentry_sdk.integrations.redis import RedisIntegration
-from bot import constants, patches
+from bot import constants
from bot.bot import Bot
from bot.utils.extensions import EXTENSIONS
@@ -47,6 +47,13 @@ loop.run_until_complete(redis_session.connect())
# Instantiate the bot.
allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES]
+intents = discord.Intents().all()
+intents.presences = False
+intents.dm_typing = False
+intents.dm_reactions = False
+intents.invites = False
+intents.webhooks = False
+intents.integrations = False
bot = Bot(
redis_session=redis_session,
loop=loop,
@@ -54,7 +61,8 @@ bot = Bot(
activity=discord.Game(name="Commands: !help"),
case_insensitive=True,
max_messages=10_000,
- allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles)
+ allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles),
+ intents=intents,
)
# Load extensions.
@@ -65,8 +73,4 @@ if not constants.HelpChannels.enable:
for extension in extensions:
bot.load_extension(extension)
-# Apply `message_edited_at` patch if discord.py did not yet release a bug fix.
-if not hasattr(discord.message.Message, '_handle_edited_timestamp'):
- patches.message_edited_at.apply_patch()
-
bot.run(constants.Bot.token)
diff --git a/bot/decorators.py b/bot/decorators.py
index 2518124da..063c8f878 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,16 +1,15 @@
+import asyncio
import logging
-import random
-from asyncio import Lock, create_task, sleep
+import typing as t
from contextlib import suppress
from functools import wraps
-from typing import Callable, Container, Optional, Union
-from weakref import WeakValueDictionary
-from discord import Colour, Embed, Member, NotFound
+from discord import Member, NotFound
from discord.ext import commands
from discord.ext.commands import Cog, Context
-from bot.constants import Channels, ERROR_REPLIES, RedirectOutput
+from bot.constants import Channels, RedirectOutput
+from bot.utils import function
from bot.utils.checks import in_whitelist_check
log = logging.getLogger(__name__)
@@ -18,12 +17,12 @@ log = logging.getLogger(__name__)
def in_whitelist(
*,
- channels: Container[int] = (),
- categories: Container[int] = (),
- roles: Container[int] = (),
- redirect: Optional[int] = Channels.bot_commands,
+ channels: t.Container[int] = (),
+ categories: t.Container[int] = (),
+ roles: t.Container[int] = (),
+ redirect: t.Optional[int] = Channels.bot_commands,
fail_silently: bool = False,
-) -> Callable:
+) -> t.Callable:
"""
Check if a command was issued in a whitelisted context.
@@ -31,7 +30,7 @@ def in_whitelist(
- `channels`: a container with channel ids for whitelisted channels
- `categories`: a container with category ids for whitelisted categories
- - `roles`: a container with with role ids for whitelisted roles
+ - `roles`: a container with role ids for whitelisted roles
If the command was invoked in a context that was not whitelisted, the member is either
redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
@@ -44,7 +43,7 @@ def in_whitelist(
return commands.check(predicate)
-def has_no_roles(*roles: Union[str, int]) -> Callable:
+def has_no_roles(*roles: t.Union[str, int]) -> t.Callable:
"""
Returns True if the user does not have any of the roles specified.
@@ -63,39 +62,7 @@ def has_no_roles(*roles: Union[str, int]) -> Callable:
return commands.check(predicate)
-def locked() -> Callable:
- """
- Allows the user to only run one instance of the decorated command at a time.
-
- Subsequent calls to the command from the same author are ignored until the command has completed invocation.
-
- This decorator must go before (below) the `command` decorator.
- """
- def wrap(func: Callable) -> Callable:
- func.__locks = WeakValueDictionary()
-
- @wraps(func)
- async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
- lock = func.__locks.setdefault(ctx.author.id, Lock())
- if lock.locked():
- embed = Embed()
- embed.colour = Colour.red()
-
- log.debug("User tried to invoke a locked command.")
- embed.description = (
- "You're already using this command. Please wait until it is done before you use it again."
- )
- embed.title = random.choice(ERROR_REPLIES)
- await ctx.send(embed=embed)
- return
-
- async with func.__locks.setdefault(ctx.author.id, Lock()):
- await func(self, ctx, *args, **kwargs)
- return inner
- return wrap
-
-
-def redirect_output(destination_channel: int, bypass_roles: Container[int] = None) -> Callable:
+def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable:
"""
Changes the channel in the context of the command to redirect the output to a certain channel.
@@ -103,7 +70,7 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
This decorator must go before (below) the `command` decorator.
"""
- def wrap(func: Callable) -> Callable:
+ def wrap(func: t.Callable) -> t.Callable:
@wraps(func)
async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
if ctx.channel.id == destination_channel:
@@ -122,14 +89,14 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
log.trace(f"Redirecting output of {ctx.author}'s command '{ctx.command.name}' to {redirect_channel.name}")
ctx.channel = redirect_channel
await ctx.channel.send(f"Here's the output of your command, {ctx.author.mention}")
- create_task(func(self, ctx, *args, **kwargs))
+ asyncio.create_task(func(self, ctx, *args, **kwargs))
message = await old_channel.send(
f"Hey, {ctx.author.mention}, you can find the output of your command here: "
f"{redirect_channel.mention}"
)
if RedirectOutput.delete_invocation:
- await sleep(RedirectOutput.delete_delay)
+ await asyncio.sleep(RedirectOutput.delete_delay)
with suppress(NotFound):
await message.delete()
@@ -143,38 +110,35 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
return wrap
-def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable:
+def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable:
"""
Ensure the highest role of the invoking member is greater than that of the target member.
If the condition fails, a warning is sent to the invoking context. A target which is not an
instance of discord.Member will always pass.
- A value of 0 (i.e. position 0) for `target_arg` corresponds to the argument which comes after
- `ctx`. If the target argument is a kwarg, its name can instead be given.
+ `member_arg` is the keyword name or position index of the parameter of the decorated command
+ whose value is the target member.
This decorator must go before (below) the `command` decorator.
"""
- def wrap(func: Callable) -> Callable:
+ def decorator(func: t.Callable) -> t.Callable:
@wraps(func)
- async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
- try:
- target = kwargs[target_arg]
- except KeyError:
- try:
- target = args[target_arg]
- except IndexError:
- raise ValueError(f"Could not find target argument at position {target_arg}")
- except TypeError:
- raise ValueError(f"Could not find target kwarg with key {target_arg!r}")
+ async def wrapper(*args, **kwargs) -> None:
+ log.trace(f"{func.__name__}: respect role hierarchy decorator called")
+
+ bound_args = function.get_bound_args(func, args, kwargs)
+ target = function.get_arg_value(member_arg, bound_args)
if not isinstance(target, Member):
log.trace("The target is not a discord.Member; skipping role hierarchy check.")
- await func(self, ctx, *args, **kwargs)
+ await func(*args, **kwargs)
return
+ ctx = function.get_arg_value(1, bound_args)
cmd = ctx.command.name
actor = ctx.author
+
if target.top_role >= actor.top_role:
log.info(
f"{actor} ({actor.id}) attempted to {cmd} "
@@ -185,6 +149,7 @@ def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable:
"someone with an equal or higher top role."
)
else:
- await func(self, ctx, *args, **kwargs)
- return inner
- return wrap
+ log.trace(f"{func.__name__}: {target.top_role=} < {actor.top_role=}; calling func")
+ await func(*args, **kwargs)
+ return wrapper
+ return decorator
diff --git a/bot/errors.py b/bot/errors.py
new file mode 100644
index 000000000..65d715203
--- /dev/null
+++ b/bot/errors.py
@@ -0,0 +1,20 @@
+from typing import Hashable
+
+
+class LockedResourceError(RuntimeError):
+ """
+ Exception raised when an operation is attempted on a locked resource.
+
+ Attributes:
+ `type` -- name of the locked resource's type
+ `id` -- ID of the locked resource
+ """
+
+ def __init__(self, resource_type: str, resource_id: Hashable):
+ self.type = resource_type
+ self.id = resource_id
+
+ super().__init__(
+ f"Cannot operate on {self.type.lower()} `{self.id}`; "
+ "it is currently locked and in use by another operation."
+ )
diff --git a/bot/exts/backend/alias.py b/bot/exts/backend/alias.py
deleted file mode 100644
index c6ba8d6f3..000000000
--- a/bot/exts/backend/alias.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import inspect
-import logging
-
-from discord import Colour, Embed
-from discord.ext.commands import (
- Cog, Command, Context,
- clean_content, command, group,
-)
-
-from bot.bot import Bot
-from bot.converters import TagNameConverter
-from bot.pagination import LinePaginator
-
-log = logging.getLogger(__name__)
-
-
-class Alias (Cog):
- """Aliases for commonly used commands."""
-
- def __init__(self, bot: Bot):
- self.bot = bot
-
- async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None:
- """Invokes a command with args and kwargs."""
- log.debug(f"{cmd_name} was invoked through an alias")
- cmd = self.bot.get_command(cmd_name)
- if not cmd:
- return log.info(f'Did not find command "{cmd_name}" to invoke.')
- elif not await cmd.can_run(ctx):
- return log.info(
- f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.'
- )
-
- await ctx.invoke(cmd, *args, **kwargs)
-
- @command(name='aliases')
- async def aliases_command(self, ctx: Context) -> None:
- """Show configured aliases on the bot."""
- embed = Embed(
- title='Configured aliases',
- colour=Colour.blue()
- )
- await LinePaginator.paginate(
- (
- f"• `{ctx.prefix}{value.name}` "
- f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`"
- for name, value in inspect.getmembers(self)
- if isinstance(value, Command) and name.endswith('_alias')
- ),
- ctx, embed, empty=False, max_lines=20
- )
-
- @command(name="exception", hidden=True)
- async def tags_get_traceback_alias(self, ctx: Context) -> None:
- """Alias for invoking <prefix>tags get traceback."""
- await self.invoke(ctx, "tags get", tag_name="traceback")
-
- @group(name="get",
- aliases=("show", "g"),
- hidden=True,
- invoke_without_command=True)
- async def get_group_alias(self, ctx: Context) -> None:
- """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`."""
- pass
-
- @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True)
- async def tags_get_alias(
- self, ctx: Context, *, tag_name: TagNameConverter = None
- ) -> None:
- """
- Alias for invoking <prefix>tags get [tag_name].
-
- tag_name: str - tag to be viewed.
- """
- await self.invoke(ctx, "tags get", tag_name=tag_name)
-
- @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True)
- async def docs_get_alias(
- self, ctx: Context, symbol: clean_content = None
- ) -> None:
- """Alias for invoking <prefix>docs get [symbol]."""
- await self.invoke(ctx, "docs get", symbol)
-
-
-def setup(bot: Bot) -> None:
- """Load the Alias cog."""
- bot.add_cog(Alias(bot))
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index f9d4de638..c643d346e 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -10,6 +10,7 @@ from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.constants import Channels, Colours
from bot.converters import TagNameConverter
+from bot.errors import LockedResourceError
from bot.utils.checks import InWhitelistCheckFailure
log = logging.getLogger(__name__)
@@ -75,6 +76,8 @@ class ErrorHandler(Cog):
elif isinstance(e, errors.CommandInvokeError):
if isinstance(e.original, ResponseCodeError):
await self.handle_api_error(ctx, e.original)
+ elif isinstance(e.original, LockedResourceError):
+ await ctx.send(f"{e.original} Please wait for it to finish and try again later.")
else:
await self.handle_unexpected_error(ctx, e.original)
return # Exit early to avoid logging.
diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py
index 3d4a09df3..38468c2b1 100644
--- a/bot/exts/backend/sync/_syncers.py
+++ b/bot/exts/backend/sync/_syncers.py
@@ -14,7 +14,6 @@ log = logging.getLogger(__name__)
# These objects are declared as namedtuples because tuples are hashable,
# something that we make use of when diffing site roles against guild roles.
_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))
-_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild'))
_Diff = namedtuple('Diff', ('created', 'updated', 'deleted'))
@@ -134,61 +133,76 @@ class UserSyncer(Syncer):
async def _get_diff(self, guild: Guild) -> _Diff:
"""Return the difference of users between the cache of `guild` and the database."""
log.trace("Getting the diff for users.")
- users = await self.bot.api_client.get('bot/users')
- # Pack DB roles and guild roles into one common, hashable format.
- # They're hashable so that they're easily comparable with sets later.
- db_users = {
- user_dict['id']: _User(
- roles=tuple(sorted(user_dict.pop('roles'))),
- **user_dict
- )
- for user_dict in users
- }
- guild_users = {
- member.id: _User(
- id=member.id,
- name=member.name,
- discriminator=int(member.discriminator),
- roles=tuple(sorted(role.id for role in member.roles)),
- in_guild=True
- )
- for member in guild.members
- }
+ users_to_create = []
+ users_to_update = []
+ seen_guild_users = set()
+
+ async for db_user in self._get_users():
+ # Store user fields which are to be updated.
+ updated_fields = {}
- users_to_create = set()
- users_to_update = set()
+ def maybe_update(db_field: str, guild_value: t.Union[str, int]) -> None:
+ # Equalize DB user and guild user attributes.
+ if db_user[db_field] != guild_value:
+ updated_fields[db_field] = guild_value
- for db_user in db_users.values():
- guild_user = guild_users.get(db_user.id)
- if guild_user is not None:
- if db_user != guild_user:
- users_to_update.add(guild_user)
+ if guild_user := guild.get_member(db_user["id"]):
+ seen_guild_users.add(guild_user.id)
- elif db_user.in_guild:
+ maybe_update("name", guild_user.name)
+ maybe_update("discriminator", int(guild_user.discriminator))
+ maybe_update("in_guild", True)
+
+ guild_roles = [role.id for role in guild_user.roles]
+ if set(db_user["roles"]) != set(guild_roles):
+ updated_fields["roles"] = guild_roles
+
+ elif db_user["in_guild"]:
# The user is known in the DB but not the guild, and the
# DB currently specifies that the user is a member of the guild.
# This means that the user has left since the last sync.
# Update the `in_guild` attribute of the user on the site
# to signify that the user left.
- new_api_user = db_user._replace(in_guild=False)
- users_to_update.add(new_api_user)
-
- new_user_ids = set(guild_users.keys()) - set(db_users.keys())
- for user_id in new_user_ids:
- # The user is known on the guild but not on the API. This means
- # that the user has joined since the last sync. Create it.
- new_user = guild_users[user_id]
- users_to_create.add(new_user)
+ updated_fields["in_guild"] = False
+
+ if updated_fields:
+ updated_fields["id"] = db_user["id"]
+ users_to_update.append(updated_fields)
+
+ for member in guild.members:
+ if member.id not in seen_guild_users:
+ # The user is known on the guild but not on the API. This means
+ # that the user has joined since the last sync. Create it.
+ new_user = {
+ "id": member.id,
+ "name": member.name,
+ "discriminator": int(member.discriminator),
+ "roles": [role.id for role in member.roles],
+ "in_guild": True
+ }
+ users_to_create.append(new_user)
return _Diff(users_to_create, users_to_update, None)
+ async def _get_users(self) -> t.AsyncIterable:
+ """GET users from database."""
+ query_params = {
+ "page": 1
+ }
+ while query_params["page"]:
+ res = await self.bot.api_client.get("bot/users", params=query_params)
+ for user in res["results"]:
+ yield user
+
+ query_params["page"] = res["next_page_no"]
+
async def _sync(self, diff: _Diff) -> None:
"""Synchronise the database with the user cache of `guild`."""
log.trace("Syncing created users...")
- for user in diff.created:
- await self.bot.api_client.post('bot/users', json=user._asdict())
+ if diff.created:
+ await self.bot.api_client.post("bot/users", json=diff.created)
log.trace("Syncing updated users...")
- for user in diff.updated:
- await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict())
+ if diff.updated:
+ await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated)
diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py
index 6c2d22b9c..82084ea88 100644
--- a/bot/exts/fun/duck_pond.py
+++ b/bot/exts/fun/duck_pond.py
@@ -145,6 +145,10 @@ class DuckPond(Cog):
amount of ducks specified in the config under duck_pond/threshold, it will
send the message off to the duck pond.
"""
+ # Ignore other guilds and DMs.
+ if payload.guild_id != constants.Guild.id:
+ return
+
# Was this reaction issued in a blacklisted channel?
if payload.channel_id in constants.DuckPond.channel_blacklist:
return
@@ -154,6 +158,9 @@ class DuckPond(Cog):
return
channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id)
+ if channel is None:
+ return
+
message = await channel.fetch_message(payload.message_id)
member = discord.utils.get(message.guild.members, id=payload.user_id)
@@ -175,7 +182,13 @@ class DuckPond(Cog):
@Cog.listener()
async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None:
"""Ensure that people don't remove the green checkmark from duck ponded messages."""
+ # Ignore other guilds and DMs.
+ if payload.guild_id != constants.Guild.id:
+ return
+
channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id)
+ if channel is None:
+ return
# Prevent the green checkmark from being removed
if payload.emoji.name == "✅":
diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py
index b9d235fa2..7fc93b88c 100644
--- a/bot/exts/fun/off_topic_names.py
+++ b/bot/exts/fun/off_topic_names.py
@@ -1,10 +1,10 @@
-import asyncio
import difflib
import logging
from datetime import datetime, timedelta
from discord import Colour, Embed
from discord.ext.commands import Cog, Context, group, has_any_role
+from discord.utils import sleep_until
from bot.api import ResponseCodeError
from bot.bot import Bot
@@ -23,8 +23,7 @@ async def update_names(bot: Bot) -> None:
# we go past midnight in the `seconds_to_sleep` set below.
today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0)
next_midnight = today_at_midnight + timedelta(days=1)
- seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1
- await asyncio.sleep(seconds_to_sleep)
+ await sleep_until(next_midnight)
try:
channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get(
diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py
index 9e33a6aba..f5c9a5dd0 100644
--- a/bot/exts/help_channels.py
+++ b/bot/exts/help_channels.py
@@ -494,7 +494,7 @@ class HelpChannels(commands.Cog):
If `options` are provided, the channel will be edited after the move is completed. This is the
same order of operations that `discord.TextChannel.edit` uses. For information on available
- options, see the documention on `discord.TextChannel.edit`. While possible, position-related
+ options, see the documentation on `discord.TextChannel.edit`. While possible, position-related
options should be avoided, as it may interfere with the category move we perform.
"""
# Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had.
diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py
index e50b9b32b..c16a99225 100644
--- a/bot/exts/info/doc.py
+++ b/bot/exts/info/doc.py
@@ -345,7 +345,7 @@ class Doc(commands.Cog):
@commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True)
async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None:
"""Lookup documentation for Python symbols."""
- await ctx.invoke(self.get_command, symbol)
+ await self.get_command(ctx, symbol)
@docs_group.command(name='get', aliases=('g',))
async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None:
diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py
index 99d503f5c..599c5d5c0 100644
--- a/bot/exts/info/help.py
+++ b/bot/exts/info/help.py
@@ -229,7 +229,7 @@ class CustomHelpCommand(HelpCommand):
async def send_cog_help(self, cog: Cog) -> None:
"""Send help for a cog."""
- # sort commands by name, and remove any the user cant run or are hidden.
+ # sort commands by name, and remove any the user can't run or are hidden.
commands_ = await self.filter_commands(cog.get_commands(), sort=True)
embed = Embed()
diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py
index 0f074c45d..1f5c513f9 100644
--- a/bot/exts/info/information.py
+++ b/bot/exts/info/information.py
@@ -7,10 +7,9 @@ from string import Template
from typing import Any, Mapping, Optional, Tuple, Union
import fuzzywuzzy
-from discord import ChannelType, Colour, CustomActivity, Embed, Guild, Member, Message, Role, Status
+from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils
from discord.abc import GuildChannel
from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role
-from discord.utils import escape_markdown
from bot import constants
from bot.bot import Bot
@@ -160,7 +159,9 @@ class Information(Cog):
channel_counts = self.get_channel_type_counts(ctx.guild)
# How many of each user status?
- statuses = Counter(member.status for member in ctx.guild.members)
+ py_invite = await self.bot.fetch_invite(constants.Guild.invite)
+ online_presences = py_invite.approximate_presence_count
+ offline_presences = py_invite.approximate_member_count - online_presences
embed = Embed(colour=Colour.blurple())
# How many staff members and staff channels do we have?
@@ -168,9 +169,9 @@ class Information(Cog):
staff_channel_count = self.get_staff_channel_count(ctx.guild)
# Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the
- # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting
- # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts
- # after the dedent is made.
+ # f-string. While this is correctly formatted by Discord, it makes unit testing difficult. To keep the
+ # formatting without joining a tuple of strings we can use a Template string to insert the already-formatted
+ # channel_counts after the dedent is made.
embed.description = Template(
textwrap.dedent(f"""
**Server information**
@@ -188,10 +189,8 @@ class Information(Cog):
Roles: {roles}
**Member statuses**
- {constants.Emojis.status_online} {statuses[Status.online]:,}
- {constants.Emojis.status_idle} {statuses[Status.idle]:,}
- {constants.Emojis.status_dnd} {statuses[Status.dnd]:,}
- {constants.Emojis.status_offline} {statuses[Status.offline]:,}
+ {constants.Emojis.status_online} {online_presences:,}
+ {constants.Emojis.status_offline} {offline_presences:,}
""")
).substitute({"channel_counts": channel_counts})
embed.set_thumbnail(url=ctx.guild.icon_url)
@@ -218,25 +217,6 @@ class Information(Cog):
"""Creates an embed containing information on the `user`."""
created = time_since(user.created_at, max_units=3)
- # Custom status
- custom_status = ''
- for activity in user.activities:
- if isinstance(activity, CustomActivity):
- state = ""
-
- if activity.name:
- state = escape_markdown(activity.name)
-
- emoji = ""
- if activity.emoji:
- # If an emoji is unicode use the emoji, else write the emote like :abc:
- if not activity.emoji.id:
- emoji += activity.emoji.name + " "
- else:
- emoji += f"`:{activity.emoji.name}:` "
-
- custom_status = f'Status: {emoji}{state}\n'
-
name = str(user)
if user.nick:
name = f"{user.nick} ({name})"
@@ -250,10 +230,6 @@ class Information(Cog):
joined = time_since(user.joined_at, max_units=3)
roles = ", ".join(role.mention for role in user.roles[1:])
- desktop_status = STATUS_EMOTES.get(user.desktop_status, constants.Emojis.status_online)
- web_status = STATUS_EMOTES.get(user.web_status, constants.Emojis.status_online)
- mobile_status = STATUS_EMOTES.get(user.mobile_status, constants.Emojis.status_online)
-
fields = [
(
"User information",
@@ -261,7 +237,6 @@ class Information(Cog):
Created: {created}
Profile: {user.mention}
ID: {user.id}
- {custom_status}
""").strip()
),
(
@@ -271,14 +246,6 @@ class Information(Cog):
Roles: {roles or None}
""").strip()
),
- (
- "Status",
- textwrap.dedent(f"""
- {desktop_status} Desktop
- {web_status} Web
- {mobile_status} Mobile
- """).strip()
- )
]
# Use getattr to future-proof for commands invoked via DMs.
diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py
index 635162308..debe40c82 100644
--- a/bot/exts/info/reddit.py
+++ b/bot/exts/info/reddit.py
@@ -10,7 +10,7 @@ from aiohttp import BasicAuth, ClientError
from discord import Colour, Embed, TextChannel
from discord.ext.commands import Cog, Context, group, has_any_role
from discord.ext.tasks import loop
-from discord.utils import escape_markdown
+from discord.utils import escape_markdown, sleep_until
from bot.bot import Bot
from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks
@@ -203,13 +203,13 @@ class Reddit(Cog):
@loop()
async def auto_poster_loop(self) -> None:
"""Post the top 5 posts daily, and the top 5 posts weekly."""
- # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter
+ # once d.py get support for `time` parameter in loop decorator,
+ # this can be removed and the loop can use the `time=datetime.time.min` parameter
now = datetime.utcnow()
tomorrow = now + timedelta(days=1)
midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0)
- seconds_until = (midnight_tomorrow - now).total_seconds()
- await asyncio.sleep(seconds_until)
+ await sleep_until(midnight_tomorrow)
await self.bot.wait_until_guild_available()
if not self.webhook:
diff --git a/bot/exts/info/site.py b/bot/exts/info/site.py
index 2d3a3d9f3..fb5b99086 100644
--- a/bot/exts/info/site.py
+++ b/bot/exts/info/site.py
@@ -1,7 +1,7 @@
import logging
from discord import Colour, Embed
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, Greedy, group
from bot.bot import Bot
from bot.constants import URLs
@@ -105,10 +105,9 @@ class Site(Cog):
await ctx.send(embed=embed)
@site_group.command(name="rules", aliases=("r", "rule"), root_aliases=("rules", "rule"))
- async def site_rules(self, ctx: Context, *rules: int) -> None:
+ async def site_rules(self, ctx: Context, rules: Greedy[int]) -> None:
"""Provides a link to all rules or, if specified, displays specific rule(s)."""
- rules_embed = Embed(title='Rules', color=Colour.blurple())
- rules_embed.url = f"{PAGES_URL}/rules"
+ rules_embed = Embed(title='Rules', color=Colour.blurple(), url=f'{PAGES_URL}/rules')
if not rules:
# Rules were not submitted. Return the default description.
@@ -122,15 +121,13 @@ class Site(Cog):
return
full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'})
- invalid_indices = tuple(
- pick
- for pick in rules
- if pick < 1 or pick > len(full_rules)
- )
- if invalid_indices:
- indices = ', '.join(map(str, invalid_indices))
- await ctx.send(f":x: Invalid rule indices: {indices}")
+ # Remove duplicates and sort the rule indices
+ rules = sorted(set(rules))
+ invalid = ', '.join(str(index) for index in rules if index < 1 or index > len(full_rules))
+
+ if invalid:
+ await ctx.send(f":x: Invalid rule indices: {invalid}")
return
for rule in rules:
diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py
index 205e0ba81..7b41352d4 100644
--- a/bot/exts/info/source.py
+++ b/bot/exts/info/source.py
@@ -2,7 +2,7 @@ import inspect
from pathlib import Path
from typing import Optional, Tuple, Union
-from discord import Embed
+from discord import Embed, utils
from discord.ext import commands
from bot.bot import Bot
@@ -35,8 +35,10 @@ class SourceConverter(commands.Converter):
elif argument.lower() in tags_cog._cache:
return argument.lower()
+ escaped_arg = utils.escape_markdown(argument)
+
raise commands.BadArgument(
- f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog."
+ f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog."
)
@@ -66,14 +68,8 @@ class BotSource(commands.Cog):
Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval).
"""
if isinstance(source_item, commands.Command):
- if source_item.cog_name == "Alias":
- cmd_name = source_item.callback.__name__.replace("_alias", "")
- cmd = self.bot.get_command(cmd_name.replace("_", " "))
- src = cmd.callback.__code__
- filename = src.co_filename
- else:
- src = source_item.callback.__code__
- filename = src.co_filename
+ src = source_item.callback.__code__
+ filename = src.co_filename
elif isinstance(source_item, str):
tags_cog = self.bot.get_cog("Tags")
filename = tags_cog._cache[source_item]["location"]
@@ -113,13 +109,7 @@ class BotSource(commands.Cog):
title = "Help Command"
description = source_object.__doc__.splitlines()[1]
elif isinstance(source_object, commands.Command):
- if source_object.cog_name == "Alias":
- cmd_name = source_object.callback.__name__.replace("_alias", "")
- cmd = self.bot.get_command(cmd_name.replace("_", " "))
- description = cmd.short_doc
- else:
- description = source_object.short_doc
-
+ description = source_object.short_doc
title = f"Command: {source_object.qualified_name}"
elif isinstance(source_object, str):
title = f"Tag: {source_object}"
diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py
index d42f55466..21aa91873 100644
--- a/bot/exts/info/stats.py
+++ b/bot/exts/info/stats.py
@@ -1,12 +1,11 @@
import string
-from datetime import datetime
-from discord import Member, Message, Status
+from discord import Member, Message
from discord.ext.commands import Cog, Context
from discord.ext.tasks import loop
from bot.bot import Bot
-from bot.constants import Categories, Channels, Guild, Stats as StatConf
+from bot.constants import Categories, Channels, Guild
CHANNEL_NAME_OVERRIDES = {
@@ -79,38 +78,6 @@ class Stats(Cog):
self.bot.stats.gauge("guild.total_members", len(member.guild.members))
- @Cog.listener()
- async def on_member_update(self, _before: Member, after: Member) -> None:
- """Update presence estimates on member update."""
- if after.guild.id != Guild.id:
- return
-
- if self.last_presence_update:
- if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout:
- return
-
- self.last_presence_update = datetime.now()
-
- online = 0
- idle = 0
- dnd = 0
- offline = 0
-
- for member in after.guild.members:
- if member.status is Status.online:
- online += 1
- elif member.status is Status.dnd:
- dnd += 1
- elif member.status is Status.idle:
- idle += 1
- elif member.status is Status.offline:
- offline += 1
-
- self.bot.stats.gauge("guild.status.online", online)
- self.bot.stats.gauge("guild.status.idle", idle)
- self.bot.stats.gauge("guild.status.do_not_disturb", dnd)
- self.bot.stats.gauge("guild.status.offline", offline)
-
@loop(hours=1)
async def update_guild_boost(self) -> None:
"""Post the server boost level and tier every hour."""
diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py
index d01647312..ae95ac1ef 100644
--- a/bot/exts/info/tags.py
+++ b/bot/exts/info/tags.py
@@ -160,7 +160,7 @@ class Tags(Cog):
@group(name='tags', aliases=('tag', 't'), invoke_without_command=True)
async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None:
"""Show all known tags, a single tag, or run a subcommand."""
- await ctx.invoke(self.get_command, tag_name=tag_name)
+ await self.get_command(ctx, tag_name=tag_name)
@tags_group.group(name='search', invoke_without_command=True)
async def search_tag_content(self, ctx: Context, *, keywords: str) -> None:
diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py
index 14263e004..4d5142b55 100644
--- a/bot/exts/moderation/dm_relay.py
+++ b/bot/exts/moderation/dm_relay.py
@@ -90,7 +90,11 @@ class DMRelay(Cog):
# Handle any attachments
if message.attachments:
try:
- await send_attachments(message, self.webhook)
+ await send_attachments(
+ message,
+ self.webhook,
+ username=f"{message.author.display_name} ({message.author.id})"
+ )
except (discord.errors.Forbidden, discord.errors.NotFound):
e = discord.Embed(
description=":x: **This message contained an attachment, but it could not be retrieved**",
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index ef6f6e3c6..7cf7075e6 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -71,6 +71,23 @@ class Infractions(InfractionScheduler, commands.Cog):
"""Permanently ban a user for the given reason and stop watching them with Big Brother."""
await self.apply_ban(ctx, user, reason)
+ @command(aliases=('pban',))
+ async def purgeban(
+ self,
+ ctx: Context,
+ user: FetchedMember,
+ purge_days: t.Optional[int] = 1,
+ *,
+ reason: t.Optional[str] = None
+ ) -> None:
+ """
+ Same as ban but removes all their messages for the given number of days, default being 1.
+
+ `purge_days` can only be values between 0 and 7.
+ Anything outside these bounds are automatically adjusted to their respective limits.
+ """
+ await self.apply_ban(ctx, user, reason, max(min(purge_days, 7), 0))
+
# endregion
# region: Temporary infractions
@@ -230,7 +247,7 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_infraction(ctx, infraction, user, action())
- @respect_role_hierarchy()
+ @respect_role_hierarchy(member_arg=2)
async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None:
"""Apply a kick infraction with kwargs passed to `post_infraction`."""
infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs)
@@ -245,8 +262,15 @@ class Infractions(InfractionScheduler, commands.Cog):
action = user.kick(reason=reason)
await self.apply_infraction(ctx, infraction, user, action)
- @respect_role_hierarchy()
- async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None:
+ @respect_role_hierarchy(member_arg=2)
+ async def apply_ban(
+ self,
+ ctx: Context,
+ user: UserSnowflake,
+ reason: t.Optional[str],
+ purge_days: t.Optional[int] = 0,
+ **kwargs
+ ) -> None:
"""
Apply a ban infraction with kwargs passed to `post_infraction`.
@@ -278,7 +302,7 @@ class Infractions(InfractionScheduler, commands.Cog):
if reason:
reason = textwrap.shorten(reason, width=512, placeholder="...")
- action = ctx.guild.ban(user, reason=reason, delete_message_days=0)
+ action = ctx.guild.ban(user, reason=reason, delete_message_days=purge_days)
await self.apply_infraction(ctx, infraction, user, action)
if infraction.get('expires_at') is not None:
diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py
index 856a4e1a2..cdab1a6c7 100644
--- a/bot/exts/moderation/infraction/management.py
+++ b/bot/exts/moderation/infraction/management.py
@@ -179,9 +179,9 @@ class ModManagement(commands.Cog):
async def infraction_search_group(self, ctx: Context, query: t.Union[UserMention, Snowflake, str]) -> None:
"""Searches for infractions in the database."""
if isinstance(query, int):
- await ctx.invoke(self.search_user, discord.Object(query))
+ await self.search_user(ctx, discord.Object(query))
else:
- await ctx.invoke(self.search_reason, query)
+ await self.search_reason(ctx, query)
@infraction_search_group.command(name="user", aliases=("member", "id"))
async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None:
diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py
index eec63f5b3..adfe42fcd 100644
--- a/bot/exts/moderation/infraction/superstarify.py
+++ b/bot/exts/moderation/infraction/superstarify.py
@@ -135,7 +135,8 @@ class Superstarify(InfractionScheduler, Cog):
return
# Post the infraction to the API
- reason = reason or f"old nick: {member.display_name}"
+ old_nick = member.display_name
+ reason = reason or f"old nick: {old_nick}"
infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True)
id_ = infraction["id"]
@@ -148,7 +149,7 @@ class Superstarify(InfractionScheduler, Cog):
await member.edit(nick=forced_nick, reason=reason)
self.schedule_expiration(infraction)
- old_nick = escape_markdown(member.display_name)
+ old_nick = escape_markdown(old_nick)
forced_nick = escape_markdown(forced_nick)
# Send a DM to the user to notify them of their new infraction.
diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py
index 41ed46b69..b01de0ee3 100644
--- a/bot/exts/moderation/modlog.py
+++ b/bot/exts/moderation/modlog.py
@@ -63,7 +63,7 @@ class ModLog(Cog, name="ModLog"):
'id': message.id,
'author': message.author.id,
'channel_id': message.channel.id,
- 'content': message.content,
+ 'content': message.content.replace("\0", ""), # Null chars cause 400.
'embeds': [embed.to_dict() for embed in message.embeds],
'attachments': attachment,
}
diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py
index 206556483..c3ad8687e 100644
--- a/bot/exts/moderation/verification.py
+++ b/bot/exts/moderation/verification.py
@@ -53,6 +53,23 @@ If you'd like to unsubscribe from the announcement notifications, simply send `!
<#{constants.Channels.bot_commands}>.
"""
+ALTERNATE_VERIFIED_MESSAGE = f"""
+Thanks for accepting our rules!
+
+You can find a copy of our rules for reference at <https://pythondiscord.com/pages/rules>.
+
+Additionally, if you'd like to receive notifications for the announcements \
+we post in <#{constants.Channels.announcements}>
+from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \
+to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement.
+
+If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \
+<#{constants.Channels.bot_commands}>.
+
+To introduce you to our community, we've made the following video:
+https://youtu.be/ZH26PuX3re0
+"""
+
# Sent via DMs to users kicked for failing to verify
KICKED_MESSAGE = f"""
Hi! You have been automatically kicked from Python Discord as you have failed to accept our rules \
@@ -156,6 +173,9 @@ class Verification(Cog):
# ]
task_cache = RedisCache()
+ # Create a cache for storing recipients of the alternate welcome DM.
+ member_gating_cache = RedisCache()
+
def __init__(self, bot: Bot) -> None:
"""Start internal tasks."""
self.bot = bot
@@ -519,6 +539,16 @@ class Verification(Cog):
if member.guild.id != constants.Guild.id:
return # Only listen for PyDis events
+ raw_member = await self.bot.http.get_member(member.guild.id, member.id)
+
+ # If the user has the is_pending flag set, they will be using the alternate
+ # gate and will not need a welcome DM with verification instructions.
+ # We will send them an alternate DM once they verify with the welcome
+ # video.
+ if raw_member.get("is_pending"):
+ await self.member_gating_cache.set(member.id, True)
+ return
+
log.trace(f"Sending on join message to new member: {member.id}")
try:
await safe_dm(member.send(ON_JOIN_MESSAGE))
@@ -526,6 +556,23 @@ class Verification(Cog):
log.exception("DM dispatch failed on unexpected error code")
@Cog.listener()
+ async def on_member_update(self, before: discord.Member, after: discord.Member) -> None:
+ """Check if we need to send a verification DM to a gated user."""
+ before_roles = [role.id for role in before.roles]
+ after_roles = [role.id for role in after.roles]
+
+ if constants.Roles.verified not in before_roles and constants.Roles.verified in after_roles:
+ if await self.member_gating_cache.pop(after.id):
+ try:
+ # If the member has not received a DM from our !accept command
+ # and has gone through the alternate gating system we should send
+ # our alternate welcome DM which includes info such as our welcome
+ # video.
+ await safe_dm(after.send(ALTERNATE_VERIFIED_MESSAGE))
+ except discord.HTTPException:
+ log.exception("DM dispatch failed on unexpected error code")
+
+ @Cog.listener()
async def on_message(self, message: discord.Message) -> None:
"""Check new message event for messages to the checkpoint channel & process."""
if message.channel.id != constants.Channels.verification:
diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py
index 7ed487d47..ba1fd2a5c 100644
--- a/bot/exts/utils/bot.py
+++ b/bot/exts/utils/bot.py
@@ -130,7 +130,7 @@ class BotCog(Cog, name="Bot"):
else:
content = "".join(content[1:])
- # Strip it again to remove any leading whitespace. This is neccessary
+ # Strip it again to remove any leading whitespace. This is necessary
# if the first line of the message looked like ```python <code>
old = content.strip()
diff --git a/bot/exts/utils/eval.py b/bot/exts/utils/internal.py
index 6419b320e..1b4900f42 100644
--- a/bot/exts/utils/eval.py
+++ b/bot/exts/utils/internal.py
@@ -5,6 +5,8 @@ import pprint
import re
import textwrap
import traceback
+from collections import Counter
+from datetime import datetime
from io import StringIO
from typing import Any, Optional, Tuple
@@ -19,8 +21,8 @@ from bot.utils import find_nth_occurrence, send_to_paste_service
log = logging.getLogger(__name__)
-class CodeEval(Cog):
- """Owner and admin feature that evaluates code and returns the result to the channel."""
+class Internal(Cog):
+ """Administrator and Core Developer commands."""
def __init__(self, bot: Bot):
self.bot = bot
@@ -30,6 +32,17 @@ class CodeEval(Cog):
self.interpreter = Interpreter(bot)
+ self.socket_since = datetime.utcnow()
+ self.socket_event_total = 0
+ self.socket_events = Counter()
+
+ @Cog.listener()
+ async def on_socket_response(self, msg: dict) -> None:
+ """When a websocket event is received, increase our counters."""
+ if event_type := msg.get("t"):
+ self.socket_event_total += 1
+ self.socket_events[event_type] += 1
+
def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]:
"""Format the eval output into a string & attempt to format it into an Embed."""
self._ = out
@@ -198,7 +211,7 @@ async def func(): # (None,) -> Any
await ctx.send(f"```py\n{out}```", embed=embed)
@group(name='internal', aliases=('int',))
- @has_any_role(Roles.owners, Roles.admins)
+ @has_any_role(Roles.owners, Roles.admins, Roles.core_developers)
async def internal_group(self, ctx: Context) -> None:
"""Internal commands. Top secret!"""
if not ctx.invoked_subcommand:
@@ -220,7 +233,26 @@ async def func(): # (None,) -> Any
await self._eval(ctx, code)
+ @internal_group.command(name='socketstats', aliases=('socket', 'stats'))
+ @has_any_role(Roles.admins, Roles.owners, Roles.core_developers)
+ async def socketstats(self, ctx: Context) -> None:
+ """Fetch information on the socket events received from Discord."""
+ running_s = (datetime.utcnow() - self.socket_since).total_seconds()
+
+ per_s = self.socket_event_total / running_s
+
+ stats_embed = discord.Embed(
+ title="WebSocket statistics",
+ description=f"Receiving {per_s:0.2f} event per second.",
+ color=discord.Color.blurple()
+ )
+
+ for event_type, count in self.socket_events.most_common(25):
+ stats_embed.add_field(name=event_type, value=count, inline=False)
+
+ await ctx.send(embed=stats_embed)
+
def setup(bot: Bot) -> None:
- """Load the CodeEval cog."""
- bot.add_cog(CodeEval(bot))
+ """Load the Internal cog."""
+ bot.add_cog(Internal(bot))
diff --git a/bot/exts/utils/ping.py b/bot/exts/utils/ping.py
index a9ca3dbeb..572fc934b 100644
--- a/bot/exts/utils/ping.py
+++ b/bot/exts/utils/ping.py
@@ -33,7 +33,7 @@ class Latency(commands.Cog):
"""
# datetime.datetime objects do not have the "milliseconds" attribute.
# It must be converted to seconds before converting to milliseconds.
- bot_ping = (datetime.utcnow() - ctx.message.created_at).total_seconds() / 1000
+ bot_ping = (datetime.utcnow() - ctx.message.created_at).total_seconds() * 1000
bot_ping = f"{bot_ping:.{ROUND_LATENCY}f} ms"
try:
diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py
index 6806f2889..bf4e24661 100644
--- a/bot/exts/utils/reminders.py
+++ b/bot/exts/utils/reminders.py
@@ -16,12 +16,14 @@ from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Role
from bot.converters import Duration
from bot.pagination import LinePaginator
from bot.utils.checks import has_any_role_check, has_no_roles_check
+from bot.utils.lock import lock_arg
from bot.utils.messages import send_denial
from bot.utils.scheduling import Scheduler
from bot.utils.time import humanize_delta
log = logging.getLogger(__name__)
+NAMESPACE = "reminder" # Used for the mutually_exclusive decorator; constant to prevent typos
WHITELISTED_CHANNELS = Guild.reminder_whitelist
MAXIMUM_REMINDERS = 5
@@ -52,7 +54,7 @@ class Reminders(Cog):
now = datetime.utcnow()
for reminder in response:
- is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False)
+ is_valid, *_ = self.ensure_valid_reminder(reminder)
if not is_valid:
continue
@@ -65,11 +67,7 @@ class Reminders(Cog):
else:
self.schedule_reminder(reminder)
- def ensure_valid_reminder(
- self,
- reminder: dict,
- cancel_task: bool = True
- ) -> t.Tuple[bool, discord.User, discord.TextChannel]:
+ def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, discord.User, discord.TextChannel]:
"""Ensure reminder author and channel can be fetched otherwise delete the reminder."""
user = self.bot.get_user(reminder['author'])
channel = self.bot.get_channel(reminder['channel_id'])
@@ -80,7 +78,7 @@ class Reminders(Cog):
f"Reminder {reminder['id']} invalid: "
f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}."
)
- asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task))
+ asyncio.create_task(self.bot.api_client.delete(f"bot/reminders/{reminder['id']}"))
return is_valid, user, channel
@@ -88,7 +86,7 @@ class Reminders(Cog):
async def _send_confirmation(
ctx: Context,
on_success: str,
- reminder_id: str,
+ reminder_id: t.Union[str, int],
delivery_dt: t.Optional[datetime],
) -> None:
"""Send an embed confirming the reminder change was made successfully."""
@@ -148,24 +146,8 @@ class Reminders(Cog):
def schedule_reminder(self, reminder: dict) -> None:
"""A coroutine which sends the reminder once the time is reached, and cancels the running task."""
- reminder_id = reminder["id"]
reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None)
-
- async def _remind() -> None:
- await self.send_reminder(reminder)
-
- log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).")
- await self._delete_reminder(reminder_id)
-
- self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind())
-
- async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None:
- """Delete a reminder from the database, given its ID, and cancel the running task."""
- await self.bot.api_client.delete('bot/reminders/' + str(reminder_id))
-
- if cancel_task:
- # Now we can remove it from the schedule list
- self.scheduler.cancel(reminder_id)
+ self.scheduler.schedule_at(reminder_datetime, reminder["id"], self.send_reminder(reminder))
async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict:
"""
@@ -188,10 +170,12 @@ class Reminders(Cog):
log.trace(f"Scheduling new task #{reminder['id']}")
self.schedule_reminder(reminder)
+ @lock_arg(NAMESPACE, "reminder", itemgetter("id"), raise_error=True)
async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:
"""Send the reminder."""
is_valid, user, channel = self.ensure_valid_reminder(reminder)
if not is_valid:
+ # No need to cancel the task too; it'll simply be done once this coroutine returns.
return
embed = discord.Embed()
@@ -217,18 +201,17 @@ class Reminders(Cog):
mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"])
)
- await channel.send(
- content=f"{user.mention} {additional_mentions}",
- embed=embed
- )
- await self._delete_reminder(reminder["id"])
+ await channel.send(content=f"{user.mention} {additional_mentions}", embed=embed)
+
+ log.debug(f"Deleting reminder #{reminder['id']} (the user has been reminded).")
+ await self.bot.api_client.delete(f"bot/reminders/{reminder['id']}")
@group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True)
async def remind_group(
self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str
) -> None:
"""Commands for managing your reminders."""
- await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content)
+ await self.new_reminder(ctx, mentions=mentions, expiration=expiration, content=content)
@remind_group.command(name="new", aliases=("add", "create"))
async def new_reminder(
@@ -286,10 +269,11 @@ class Reminders(Cog):
now = datetime.utcnow() - timedelta(seconds=1)
humanized_delta = humanize_delta(relativedelta(expiration, now))
- mention_string = (
- f"Your reminder will arrive in {humanized_delta} "
- f"and will mention {len(mentions)} other(s)!"
- )
+ mention_string = f"Your reminder will arrive in {humanized_delta}"
+
+ if mentions:
+ mention_string += f" and will mention {len(mentions)} other(s)"
+ mention_string += "!"
# Confirm to the user that it worked.
await self._send_confirmation(
@@ -394,6 +378,7 @@ class Reminders(Cog):
mention_ids = [mention.id for mention in mentions]
await self.edit_reminder(ctx, id_, {"mentions": mention_ids})
+ @lock_arg(NAMESPACE, "id_", raise_error=True)
async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None:
"""Edits a reminder with the given payload, then sends a confirmation message."""
if not await self._can_modify(ctx, id_):
@@ -413,11 +398,15 @@ class Reminders(Cog):
await self._reschedule_reminder(reminder)
@remind_group.command("delete", aliases=("remove", "cancel"))
+ @lock_arg(NAMESPACE, "id_", raise_error=True)
async def delete_reminder(self, ctx: Context, id_: int) -> None:
"""Delete one of your active reminders."""
if not await self._can_modify(ctx, id_):
return
- await self._delete_reminder(id_)
+
+ await self.bot.api_client.delete(f"bot/reminders/{id_}")
+ self.scheduler.cancel(id_)
+
await self._send_confirmation(
ctx,
on_success="That reminder has been deleted successfully!",
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 18b9a5014..ca6fbf5cb 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -241,12 +241,12 @@ class Snekbox(Cog):
)
code = await self.get_code(new_message)
- await ctx.message.clear_reactions()
+ await ctx.message.clear_reaction(REEVAL_EMOJI)
with contextlib.suppress(HTTPException):
await response.delete()
except asyncio.TimeoutError:
- await ctx.message.clear_reactions()
+ await ctx.message.clear_reaction(REEVAL_EMOJI)
return None
return code
diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py
index 6b6941064..3e9230414 100644
--- a/bot/exts/utils/utils.py
+++ b/bot/exts/utils/utils.py
@@ -84,7 +84,7 @@ class Utils(Cog):
# Assemble the embed
pep_embed = Embed(
title=f"**PEP {pep_number} - {pep_header['Title']}**",
- description=f"[Link]({self.base_pep_url}{pep_number:04})",
+ url=f"{self.base_pep_url}{pep_number:04}"
)
pep_embed.set_thumbnail(url=ICON_URL)
@@ -250,7 +250,7 @@ class Utils(Cog):
"""Send information about PEP 0."""
pep_embed = Embed(
title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**",
- description="[Link](https://www.python.org/dev/peps/)"
+ url="https://www.python.org/dev/peps/"
)
pep_embed.set_thumbnail(url=ICON_URL)
pep_embed.add_field(name="Status", value="Active")
diff --git a/bot/patches/__init__.py b/bot/patches/__init__.py
deleted file mode 100644
index 60f6becaa..000000000
--- a/bot/patches/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""Subpackage that contains patches for discord.py."""
-from . import message_edited_at
-
-__all__ = [
- message_edited_at,
-]
diff --git a/bot/patches/message_edited_at.py b/bot/patches/message_edited_at.py
deleted file mode 100644
index a0154f12d..000000000
--- a/bot/patches/message_edited_at.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""
-# message_edited_at patch.
-
-Date: 2019-09-16
-Author: Scragly
-Added by: Ves Zappa
-
-Due to a bug in our current version of discord.py (1.2.3), the edited_at timestamp of
-`discord.Messages` are not being handled correctly. This patch fixes that until a new
-release of discord.py is released (and we've updated to it).
-"""
-import logging
-
-from discord import message, utils
-
-log = logging.getLogger(__name__)
-
-
-def _handle_edited_timestamp(self: message.Message, value: str) -> None:
- """Helper function that takes care of parsing the edited timestamp."""
- self._edited_timestamp = utils.parse_time(value)
-
-
-def apply_patch() -> None:
- """Applies the `edited_at` patch to the `discord.message.Message` class."""
- message.Message._handle_edited_timestamp = _handle_edited_timestamp
- message.Message._HANDLERS['edited_timestamp'] = message.Message._handle_edited_timestamp
- log.info("Patch applied: message_edited_at")
-
-
-if __name__ == "__main__":
- apply_patch()
diff --git a/bot/utils/function.py b/bot/utils/function.py
new file mode 100644
index 000000000..3ab32fe3c
--- /dev/null
+++ b/bot/utils/function.py
@@ -0,0 +1,75 @@
+"""Utilities for interaction with functions."""
+
+import inspect
+import typing as t
+
+Argument = t.Union[int, str]
+BoundArgs = t.OrderedDict[str, t.Any]
+Decorator = t.Callable[[t.Callable], t.Callable]
+ArgValGetter = t.Callable[[BoundArgs], t.Any]
+
+
+def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any:
+ """
+ Return a value from `arguments` based on a name or position.
+
+ `arguments` is an ordered mapping of parameter names to argument values.
+
+ Raise TypeError if `name_or_pos` isn't a str or int.
+ Raise ValueError if `name_or_pos` does not match any argument.
+ """
+ if isinstance(name_or_pos, int):
+ # Convert arguments to a tuple to make them indexable.
+ arg_values = tuple(arguments.items())
+ arg_pos = name_or_pos
+
+ try:
+ name, value = arg_values[arg_pos]
+ return value
+ except IndexError:
+ raise ValueError(f"Argument position {arg_pos} is out of bounds.")
+ elif isinstance(name_or_pos, str):
+ arg_name = name_or_pos
+ try:
+ return arguments[arg_name]
+ except KeyError:
+ raise ValueError(f"Argument {arg_name!r} doesn't exist.")
+ else:
+ raise TypeError("'arg' must either be an int (positional index) or a str (keyword).")
+
+
+def get_arg_value_wrapper(
+ decorator_func: t.Callable[[ArgValGetter], Decorator],
+ name_or_pos: Argument,
+ func: t.Callable[[t.Any], t.Any] = None,
+) -> Decorator:
+ """
+ Call `decorator_func` with the value of the arg at the given name/position.
+
+ `decorator_func` must accept a callable as a parameter to which it will pass a mapping of
+ parameter names to argument values of the function it's decorating.
+
+ `func` is an optional callable which will return a new value given the argument's value.
+
+ Return the decorator returned by `decorator_func`.
+ """
+ def wrapper(args: BoundArgs) -> t.Any:
+ value = get_arg_value(name_or_pos, args)
+ if func:
+ value = func(value)
+ return value
+
+ return decorator_func(wrapper)
+
+
+def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any]) -> BoundArgs:
+ """
+ Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values.
+
+ Default parameter values are also set.
+ """
+ sig = inspect.signature(func)
+ bound_args = sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+
+ return bound_args.arguments
diff --git a/bot/utils/lock.py b/bot/utils/lock.py
new file mode 100644
index 000000000..7aaafbc88
--- /dev/null
+++ b/bot/utils/lock.py
@@ -0,0 +1,114 @@
+import inspect
+import logging
+from collections import defaultdict
+from functools import partial, wraps
+from typing import Any, Awaitable, Callable, Hashable, Union
+from weakref import WeakValueDictionary
+
+from bot.errors import LockedResourceError
+from bot.utils import function
+
+log = logging.getLogger(__name__)
+__lock_dicts = defaultdict(WeakValueDictionary)
+
+_IdCallableReturn = Union[Hashable, Awaitable[Hashable]]
+_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
+ResourceId = Union[Hashable, _IdCallable]
+
+
+class LockGuard:
+ """
+ A context manager which acquires and releases a lock (mutex).
+
+ Raise RuntimeError if trying to acquire a locked lock.
+ """
+
+ def __init__(self):
+ self._locked = False
+
+ @property
+ def locked(self) -> bool:
+ """Return True if currently locked or False if unlocked."""
+ return self._locked
+
+ def __enter__(self):
+ if self._locked:
+ raise RuntimeError("Cannot acquire a locked lock.")
+
+ self._locked = True
+
+ def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: ANN001
+ self._locked = False
+ return False # Indicate any raised exception shouldn't be suppressed.
+
+
+def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> Callable:
+ """
+ Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.
+
+ If any other mutually exclusive function currently holds the lock for a resource, do not run the
+ decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if
+ the lock cannot be acquired.
+
+ `namespace` is an identifier used to prevent collisions among resource IDs.
+
+ `resource_id` identifies a resource on which to perform a mutually exclusive operation.
+ It may also be a callable or awaitable which will return the resource ID given an ordered
+ mapping of the parameters' names to arguments' values.
+
+ If decorating a command, this decorator must go before (below) the `command` decorator.
+ """
+ def decorator(func: Callable) -> Callable:
+ name = func.__name__
+
+ @wraps(func)
+ async def wrapper(*args, **kwargs) -> Any:
+ log.trace(f"{name}: mutually exclusive decorator called")
+
+ if callable(resource_id):
+ log.trace(f"{name}: binding args to signature")
+ bound_args = function.get_bound_args(func, args, kwargs)
+
+ log.trace(f"{name}: calling the given callable to get the resource ID")
+ id_ = resource_id(bound_args)
+
+ if inspect.isawaitable(id_):
+ log.trace(f"{name}: awaiting to get resource ID")
+ id_ = await id_
+ else:
+ id_ = resource_id
+
+ log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}")
+
+ # Get the lock for the ID. Create a lock if one doesn't exist yet.
+ locks = __lock_dicts[namespace]
+ lock_guard = locks.setdefault(id_, LockGuard())
+
+ if not lock_guard.locked:
+ log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...")
+ with lock_guard:
+ return await func(*args, **kwargs)
+ else:
+ log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
+ if raise_error:
+ raise LockedResourceError(str(namespace), id_)
+
+ return wrapper
+ return decorator
+
+
+def lock_arg(
+ namespace: Hashable,
+ name_or_pos: function.Argument,
+ func: Callable[[Any], _IdCallableReturn] = None,
+ *,
+ raise_error: bool = False,
+) -> Callable:
+ """
+ Apply the `lock` decorator using the value of the arg at the given name/position as the ID.
+
+ `func` is an optional callable or awaitable which will return the ID given the argument value.
+ See `lock` docs for more information.
+ """
+ decorator_func = partial(lock, namespace, raise_error=raise_error)
+ return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index 9cc0d8a34..b6c7cab50 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -34,7 +34,11 @@ async def wait_for_deletion(
if attach_emojis:
for emoji in deletion_emojis:
- await message.add_reaction(emoji)
+ try:
+ await message.add_reaction(emoji)
+ except discord.NotFound:
+ log.trace(f"Aborting wait_for_deletion: message {message.id} deleted prematurely.")
+ return
def check(reaction: discord.Reaction, user: discord.Member) -> bool:
"""Check that the deletion emoji is reacted by the appropriate user."""
@@ -52,15 +56,24 @@ async def wait_for_deletion(
async def send_attachments(
message: discord.Message,
destination: Union[discord.TextChannel, discord.Webhook],
- link_large: bool = True
+ link_large: bool = True,
+ use_cached: bool = False,
+ **kwargs
) -> List[str]:
"""
Re-upload the message's attachments to the destination and return a list of their new URLs.
Each attachment is sent as a separate message to more easily comply with the request/file size
limit. If link_large is True, attachments which are too large are instead grouped into a single
- embed which links to them.
+ embed which links to them. Extra kwargs will be passed to send() when sending the attachment.
"""
+ webhook_send_kwargs = {
+ 'username': message.author.display_name,
+ 'avatar_url': message.author.avatar_url,
+ }
+ webhook_send_kwargs.update(kwargs)
+ webhook_send_kwargs['username'] = sub_clyde(webhook_send_kwargs['username'])
+
large = []
urls = []
for attachment in message.attachments:
@@ -74,18 +87,14 @@ async def send_attachments(
# but some may get through hence the try-catch.
if attachment.size <= destination.guild.filesize_limit - 512:
with BytesIO() as file:
- await attachment.save(file, use_cached=True)
+ await attachment.save(file, use_cached=use_cached)
attachment_file = discord.File(file, filename=attachment.filename)
if isinstance(destination, discord.TextChannel):
- msg = await destination.send(file=attachment_file)
+ msg = await destination.send(file=attachment_file, **kwargs)
urls.append(msg.attachments[0].url)
else:
- await destination.send(
- file=attachment_file,
- username=sub_clyde(message.author.display_name),
- avatar_url=message.author.avatar_url
- )
+ await destination.send(file=attachment_file, **webhook_send_kwargs)
elif link_large:
large.append(attachment)
else:
@@ -102,13 +111,9 @@ async def send_attachments(
embed.set_footer(text="Attachments exceed upload size limit.")
if isinstance(destination, discord.TextChannel):
- await destination.send(embed=embed)
+ await destination.send(embed=embed, **kwargs)
else:
- await destination.send(
- embed=embed,
- username=sub_clyde(message.author.display_name),
- avatar_url=message.author.avatar_url
- )
+ await destination.send(embed=embed, **webhook_send_kwargs)
return urls
diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py
index 1b89564f2..063a82754 100644
--- a/tests/bot/exts/backend/sync/test_cog.py
+++ b/tests/bot/exts/backend/sync/test_cog.py
@@ -392,14 +392,14 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
async def test_sync_roles_command(self):
"""sync() should be called on the RoleSyncer."""
ctx = helpers.MockContext()
- await self.cog.sync_roles_command.callback(self.cog, ctx)
+ await self.cog.sync_roles_command(self.cog, ctx)
self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
async def test_sync_users_command(self):
"""sync() should be called on the UserSyncer."""
ctx = helpers.MockContext()
- await self.cog.sync_users_command.callback(self.cog, ctx)
+ await self.cog.sync_users_command(self.cog, ctx)
self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py
index c0a1da35c..9f380a15d 100644
--- a/tests/bot/exts/backend/sync/test_users.py
+++ b/tests/bot/exts/backend/sync/test_users.py
@@ -1,7 +1,6 @@
import unittest
-from unittest import mock
-from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User
+from bot.exts.backend.sync._syncers import UserSyncer, _Diff
from tests import helpers
@@ -10,7 +9,7 @@ def fake_user(**kwargs):
kwargs.setdefault("id", 43)
kwargs.setdefault("name", "bob the test man")
kwargs.setdefault("discriminator", 1337)
- kwargs.setdefault("roles", (666,))
+ kwargs.setdefault("roles", [666])
kwargs.setdefault("in_guild", True)
return kwargs
@@ -40,22 +39,42 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
return guild
+ @staticmethod
+ def get_mock_member(member: dict):
+ member = member.copy()
+ del member["in_guild"]
+ mock_member = helpers.MockMember(**member)
+ mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]]
+ return mock_member
+
async def test_empty_diff_for_no_users(self):
"""When no users are given, an empty diff should be returned."""
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": []
+ }
guild = self.get_guild()
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
async def test_empty_diff_for_identical_users(self):
"""No differences should be found if the users in the guild and DB are identical."""
- self.bot.api_client.get.return_value = [fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user()]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.return_value = self.get_mock_member(fake_user())
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -63,59 +82,102 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Only updated users should be added to the 'updated' set of the diff."""
updated_user = fake_user(id=99, name="new")
- self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(id=99, name="old"), fake_user()]
+ }
guild = self.get_guild(updated_user, fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(updated_user),
+ self.get_mock_member(fake_user())
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**updated_user)}, None)
+ expected_diff = ([], [{"id": 99, "name": "new"}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_new_users(self):
- """Only new users should be added to the 'created' set of the diff."""
+ """Only new users should be added to the 'created' list of the diff."""
new_user = fake_user(id=99, name="new")
- self.bot.api_client.get.return_value = [fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user()]
+ }
guild = self.get_guild(fake_user(), new_user)
-
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ self.get_mock_member(new_user)
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = ({_User(**new_user)}, set(), None)
+ expected_diff = ([new_user], [], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- leaving_user = fake_user(id=63, in_guild=False)
-
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=63)]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**leaving_user)}, None)
+ expected_diff = ([], [{"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_new_updated_and_leaving_users(self):
"""When users are added, updated, and removed, all of them are returned properly."""
new_user = fake_user(id=99, name="new")
+
updated_user = fake_user(id=55, name="updated")
- leaving_user = fake_user(id=63, in_guild=False)
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=55), fake_user(id=63)]
+ }
guild = self.get_guild(fake_user(), new_user, updated_user)
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ self.get_mock_member(updated_user),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None)
+ expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_empty_diff_for_db_users_not_in_guild(self):
- """When the DB knows a user the guild doesn't, no difference is found."""
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
+ """When the DB knows a user, but the guild doesn't, no difference is found."""
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=63, in_guild=False)]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -131,13 +193,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Only POST requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
- user_tuples = {_User(**user) for user in users}
- diff = _Diff(user_tuples, set(), None)
+ diff = _Diff(users, [], None)
await self.syncer._sync(diff)
- calls = [mock.call("bot/users", json=user) for user in users]
- self.bot.api_client.post.assert_has_calls(calls, any_order=True)
- self.assertEqual(self.bot.api_client.post.call_count, len(users))
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created)
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
@@ -146,13 +205,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Only PUT requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
- user_tuples = {_User(**user) for user in users}
- diff = _Diff(set(), user_tuples, None)
+ diff = _Diff([], users, None)
await self.syncer._sync(diff)
- calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users]
- self.bot.api_client.put.assert_has_calls(calls, any_order=True)
- self.assertEqual(self.bot.api_client.put.call_count, len(users))
+ self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated)
self.bot.api_client.post.assert_not_called()
self.bot.api_client.delete.assert_not_called()
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index 7bc7dbb5d..c149f5745 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -1,4 +1,3 @@
-import asyncio
import textwrap
import unittest
import unittest.mock
@@ -13,7 +12,7 @@ from tests import helpers
COG_PATH = "bot.exts.info.information.Information"
-class InformationCogTests(unittest.TestCase):
+class InformationCogTests(unittest.IsolatedAsyncioTestCase):
"""Tests the Information cog."""
@classmethod
@@ -29,16 +28,14 @@ class InformationCogTests(unittest.TestCase):
self.ctx = helpers.MockContext()
self.ctx.author.roles.append(self.moderator_role)
- def test_roles_command_command(self):
+ async def test_roles_command_command(self):
"""Test if the `role_info` command correctly returns the `moderator_role`."""
self.ctx.guild.roles.append(self.moderator_role)
self.cog.roles_info.can_run = unittest.mock.AsyncMock()
self.cog.roles_info.can_run.return_value = True
- coroutine = self.cog.roles_info.callback(self.cog, self.ctx)
-
- self.assertIsNone(asyncio.run(coroutine))
+ self.assertIsNone(await self.cog.roles_info(self.cog, self.ctx))
self.ctx.send.assert_called_once()
_, kwargs = self.ctx.send.call_args
@@ -48,7 +45,7 @@ class InformationCogTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour.blurple())
self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
- def test_role_info_command(self):
+ async def test_role_info_command(self):
"""Tests the `role info` command."""
dummy_role = helpers.MockRole(
name="Dummy",
@@ -73,9 +70,7 @@ class InformationCogTests(unittest.TestCase):
self.cog.role_info.can_run = unittest.mock.AsyncMock()
self.cog.role_info.can_run.return_value = True
- coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role)
-
- self.assertIsNone(asyncio.run(coroutine))
+ self.assertIsNone(await self.cog.role_info(self.cog, self.ctx, dummy_role, admin_role))
self.assertEqual(self.ctx.send.call_count, 2)
@@ -97,80 +92,8 @@ class InformationCogTests(unittest.TestCase):
self.assertEqual(admin_embed.title, "Admins info")
self.assertEqual(admin_embed.colour, discord.Colour.red())
- @unittest.mock.patch('bot.exts.info.information.time_since')
- def test_server_info_command(self, time_since_patch):
- time_since_patch.return_value = '2 days ago'
-
- self.ctx.guild = helpers.MockGuild(
- features=('lemons', 'apples'),
- region="The Moon",
- roles=[self.moderator_role],
- channels=[
- discord.TextChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'}
- ),
- discord.CategoryChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'}
- ),
- discord.VoiceChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'}
- )
- ],
- members=[
- *(helpers.MockMember(status=discord.Status.online) for _ in range(2)),
- *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)),
- *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)),
- *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)),
- ],
- member_count=1_234,
- icon_url='a-lemon.jpg',
- )
-
- coroutine = self.cog.server_info.callback(self.cog, self.ctx)
- self.assertIsNone(asyncio.run(coroutine))
-
- time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days')
- _, kwargs = self.ctx.send.call_args
- embed = kwargs.pop('embed')
- self.assertEqual(embed.colour, discord.Colour.blurple())
- self.assertEqual(
- embed.description,
- textwrap.dedent(
- f"""
- **Server information**
- Created: {time_since_patch.return_value}
- Voice region: {self.ctx.guild.region}
- Features: {', '.join(self.ctx.guild.features)}
-
- **Channel counts**
- Category channels: 1
- Text channels: 1
- Voice channels: 1
- Staff channels: 0
-
- **Member counts**
- Members: {self.ctx.guild.member_count:,}
- Staff members: 0
- Roles: {len(self.ctx.guild.roles)}
-
- **Member statuses**
- {constants.Emojis.status_online} 2
- {constants.Emojis.status_idle} 1
- {constants.Emojis.status_dnd} 4
- {constants.Emojis.status_offline} 3
- """
- )
- )
- self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg')
-
-class UserInfractionHelperMethodTests(unittest.TestCase):
+class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the helper methods of the `!user` command."""
def setUp(self):
@@ -180,7 +103,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
self.cog = information.Information(self.bot)
self.member = helpers.MockMember(id=1234)
- def test_user_command_helper_method_get_requests(self):
+ async def test_user_command_helper_method_get_requests(self):
"""The helper methods should form the correct get requests."""
test_values = (
{
@@ -202,11 +125,11 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
endpoint, params = test_value["expected_args"]
with self.subTest(method=helper_method, endpoint=endpoint, params=params):
- asyncio.run(helper_method(self.member))
+ await helper_method(self.member)
self.bot.api_client.get.assert_called_once_with(endpoint, params=params)
self.bot.api_client.get.reset_mock()
- def _method_subtests(self, method, test_values, default_header):
+ async def _method_subtests(self, method, test_values, default_header):
"""Helper method that runs the subtests for the different helper methods."""
for test_value in test_values:
api_response = test_value["api response"]
@@ -216,11 +139,11 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
self.bot.api_client.get.return_value = api_response
expected_output = "\n".join(expected_lines)
- actual_output = asyncio.run(method(self.member))
+ actual_output = await method(self.member)
self.assertEqual((default_header, expected_output), actual_output)
- def test_basic_user_infraction_counts_returns_correct_strings(self):
+ async def test_basic_user_infraction_counts_returns_correct_strings(self):
"""The method should correctly list both the total and active number of non-hidden infractions."""
test_values = (
# No infractions means zero counts
@@ -251,9 +174,9 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
header = "Infractions"
- self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header)
+ await self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header)
- def test_expanded_user_infraction_counts_returns_correct_strings(self):
+ async def test_expanded_user_infraction_counts_returns_correct_strings(self):
"""The method should correctly list the total and active number of all infractions split by infraction type."""
test_values = (
{
@@ -306,9 +229,9 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
header = "Infractions"
- self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header)
+ await self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header)
- def test_user_nomination_counts_returns_correct_strings(self):
+ async def test_user_nomination_counts_returns_correct_strings(self):
"""The method should list the number of active and historical nominations for the user."""
test_values = (
{
@@ -336,12 +259,12 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
header = "Nominations"
- self._method_subtests(self.cog.user_nomination_counts, test_values, header)
+ await self._method_subtests(self.cog.user_nomination_counts, test_values, header)
@unittest.mock.patch("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago"))
@unittest.mock.patch("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50])
-class UserEmbedTests(unittest.TestCase):
+class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the creation of the `!user` embed."""
def setUp(self):
@@ -354,14 +277,14 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
+ async def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
user = helpers.MockMember()
user.nick = None
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertEqual(embed.title, "Mr. Hemlock")
@@ -369,14 +292,14 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_nick_in_title_if_available(self):
+ async def test_create_user_embed_uses_nick_in_title_if_available(self):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
user = helpers.MockMember()
user.nick = "Cat lover"
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)")
@@ -384,7 +307,7 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_ignores_everyone_role(self):
+ async def test_create_user_embed_ignores_everyone_role(self):
"""Created `!user` embeds should not contain mention of the @everyone-role."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
admins_role = helpers.MockRole(name='Admins')
@@ -393,14 +316,18 @@ class UserEmbedTests(unittest.TestCase):
# A `MockMember` has the @Everyone role by default; we add the Admins to that.
user = helpers.MockMember(roles=[admins_role], top_role=admins_role)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertIn("&Admins", embed.fields[1].value)
self.assertNotIn("&Everyone", embed.fields[1].value)
@unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
@unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)
- def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):
+ async def test_create_user_embed_expanded_information_in_moderation_channels(
+ self,
+ nomination_counts,
+ infraction_counts
+ ):
"""The embed should contain expanded infractions and nomination info in mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50))
@@ -411,7 +338,7 @@ class UserEmbedTests(unittest.TestCase):
nomination_counts.return_value = ("Nominations", "nomination info")
user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
infraction_counts.assert_called_once_with(user)
nomination_counts.assert_called_once_with(user)
@@ -434,7 +361,7 @@ class UserEmbedTests(unittest.TestCase):
)
@unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
- def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
+ async def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
"""The embed should contain only basic infraction data outside of mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100))
@@ -444,7 +371,7 @@ class UserEmbedTests(unittest.TestCase):
infraction_counts.return_value = ("Infractions", "basic infractions info")
user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
infraction_counts.assert_called_once_with(user)
@@ -467,14 +394,14 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(
"basic infractions info",
- embed.fields[3].value
+ embed.fields[2].value
)
@unittest.mock.patch(
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
+ async def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
"""The embed should be created with the colour of the top role, if a top role is available."""
ctx = helpers.MockContext()
@@ -482,7 +409,7 @@ class UserEmbedTests(unittest.TestCase):
moderators_role.colour = 100
user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertEqual(embed.colour, discord.Colour(moderators_role.colour))
@@ -490,12 +417,12 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
+ async def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
"""The embed should be created with a blurple colour if the user has no assigned roles."""
ctx = helpers.MockContext()
user = helpers.MockMember(id=217)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertEqual(embed.colour, discord.Colour.blurple())
@@ -503,20 +430,20 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
+ async def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
"""The embed thumbnail should be set to the user's avatar in `png` format."""
ctx = helpers.MockContext()
user = helpers.MockMember(id=217)
user.avatar_url_as.return_value = "avatar url"
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
user.avatar_url_as.assert_called_once_with(static_format="png")
self.assertEqual(embed.thumbnail.url, "avatar url")
@unittest.mock.patch("bot.exts.info.information.constants")
-class UserCommandTests(unittest.TestCase):
+class UserCommandTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the `!user` command."""
def setUp(self):
@@ -536,16 +463,16 @@ class UserCommandTests(unittest.TestCase):
# used as a default value for a parameter, which gets defined upon import.
self.bot_command_channel = helpers.MockTextChannel(id=constants.Channels.bot_commands)
- def test_regular_member_cannot_target_another_member(self, constants):
+ async def test_regular_member_cannot_target_another_member(self, constants):
"""A regular user should not be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.author)
- asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target))
+ await self.cog.user_info(self.cog, ctx, self.target)
ctx.send.assert_called_once_with("You may not use this command on users other than yourself.")
- def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants):
+ async def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants):
"""A regular user should not be able to use this command outside of bot-commands."""
constants.MODERATION_ROLES = [self.moderator_role.id]
constants.STAFF_ROLES = [self.moderator_role.id]
@@ -553,49 +480,49 @@ class UserCommandTests(unittest.TestCase):
msg = "Sorry, but you may only use this command within <#50>."
with self.assertRaises(InWhitelistCheckFailure, msg=msg):
- asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+ await self.cog.user_info(self.cog, ctx)
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
- def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
+ async def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
"""A regular user should be allowed to use `!user` targeting themselves in bot-commands."""
constants.STAFF_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.author, channel=self.bot_command_channel)
- asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+ await self.cog.user_info(self.cog, ctx)
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
- def test_regular_user_can_explicitly_target_themselves(self, create_embed, _):
+ async def test_regular_user_can_explicitly_target_themselves(self, create_embed, _):
"""A user should target itself with `!user` when a `user` argument was not provided."""
constants.STAFF_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.author, channel=self.bot_command_channel)
- asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author))
+ await self.cog.user_info(self.cog, ctx, self.author)
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
- def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
+ async def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
"""Staff members should be able to bypass the bot-commands channel restriction."""
constants.STAFF_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200))
- asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+ await self.cog.user_info(self.cog, ctx)
create_embed.assert_called_once_with(ctx, self.moderator)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
- def test_moderators_can_target_another_member(self, create_embed, constants):
+ async def test_moderators_can_target_another_member(self, create_embed, constants):
"""A moderator should be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
constants.STAFF_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50))
- asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target))
+ await self.cog.user_info(self.cog, ctx, self.target)
create_embed.assert_called_once_with(ctx, self.target)
ctx.send.assert_called_once()
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index e2d44c637..3c2d52ae0 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -122,7 +122,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
starting_unsilenced_state=_silence_patch_return
):
with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return):
- await self.cog.silence.callback(self.cog, self.ctx, duration)
+ await self.cog.silence(self.cog, self.ctx, duration)
self.ctx.send.assert_called_once_with(result_message)
self.ctx.reset_mock()
@@ -138,7 +138,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
result_message=result_message
):
with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return):
- await self.cog.unsilence.callback(self.cog, self.ctx)
+ await self.cog.unsilence(self.cog, self.ctx)
self.ctx.send.assert_called_once_with(result_message)
self.ctx.reset_mock()
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 40b2202aa..6601fad2c 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -154,7 +154,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.send_eval = AsyncMock(return_value=response)
self.cog.continue_eval = AsyncMock(return_value=None)
- await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')
self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')
self.cog.continue_eval.assert_called_once_with(ctx, response)
@@ -168,7 +168,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.continue_eval = AsyncMock()
self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None)
- await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))
self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')
self.cog.continue_eval.assert_called_with(ctx, response)
@@ -180,7 +180,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.author.mention = '@LemonLemonishBeard#0042'
ctx.send = AsyncMock()
self.cog.jobs = (42,)
- await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
ctx.send.assert_called_once_with(
"@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
)
@@ -188,8 +188,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
async def test_eval_command_call_help(self):
"""Test if the eval command call the help command if no code is provided."""
ctx = MockContext(command="sentinel")
- await self.cog.eval_command.callback(self.cog, ctx=ctx, code='')
- ctx.send_help.assert_called_once_with("sentinel")
+ await self.cog.eval_command(self.cog, ctx=ctx, code='')
+ ctx.send_help.assert_called_once_with(ctx.command)
async def test_send_eval(self):
"""Test the send_eval function."""
@@ -290,7 +290,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
)
ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
- ctx.message.clear_reactions.assert_called_once()
+ ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
response.delete.assert_called_once()
async def test_continue_eval_does_not_continue(self):
@@ -299,7 +299,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
actual = await self.cog.continue_eval(ctx, MockMessage())
self.assertEqual(actual, None)
- ctx.message.clear_reactions.assert_called_once()
+ ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
async def test_get_code(self):
"""Should return 1st arg (or None) if eval cmd in message, otherwise return full content."""
diff --git a/tests/bot/patches/__init__.py b/tests/bot/patches/__init__.py
deleted file mode 100644
index e69de29bb..000000000
--- a/tests/bot/patches/__init__.py
+++ /dev/null