aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/api.py45
-rw-r--r--bot/bot.py54
-rw-r--r--bot/cogs/antispam.py2
-rw-r--r--bot/cogs/defcon.py2
-rw-r--r--bot/cogs/doc.py2
-rw-r--r--bot/cogs/duck_pond.py2
-rw-r--r--bot/cogs/logging.py2
-rw-r--r--bot/cogs/moderation/scheduler.py2
-rw-r--r--bot/cogs/off_topic_names.py2
-rw-r--r--bot/cogs/reddit.py4
-rw-r--r--bot/cogs/reminders.py2
-rw-r--r--bot/cogs/sync/cog.py100
-rw-r--r--bot/cogs/sync/syncers.py551
-rw-r--r--bot/cogs/verification.py2
-rw-r--r--bot/cogs/watchchannels/watchchannel.py2
-rw-r--r--bot/constants.py10
-rw-r--r--config-default.yml7
-rw-r--r--tests/base.py34
-rw-r--r--tests/bot/cogs/sync/test_base.py412
-rw-r--r--tests/bot/cogs/sync/test_cog.py395
-rw-r--r--tests/bot/cogs/sync/test_roles.py287
-rw-r--r--tests/bot/cogs/sync/test_users.py241
-rw-r--r--tests/bot/cogs/test_duck_pond.py4
-rw-r--r--tests/helpers.py35
24 files changed, 1658 insertions, 541 deletions
diff --git a/bot/api.py b/bot/api.py
index fb126b384..e59916114 100644
--- a/bot/api.py
+++ b/bot/api.py
@@ -32,6 +32,11 @@ class ResponseCodeError(ValueError):
class APIClient:
"""Django Site API wrapper."""
+ # These are class attributes so they can be seen when being mocked for tests.
+ # See commit 22a55534ef13990815a6f69d361e2a12693075d5 for details.
+ session: Optional[aiohttp.ClientSession] = None
+ loop: asyncio.AbstractEventLoop = None
+
def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs):
auth_headers = {
'Authorization': f"Token {Keys.site_api}"
@@ -42,7 +47,7 @@ class APIClient:
else:
kwargs['headers'] = auth_headers
- self.session: Optional[aiohttp.ClientSession] = None
+ self.session = None
self.loop = loop
self._ready = asyncio.Event(loop=loop)
@@ -85,43 +90,35 @@ class APIClient:
response_text = await response.text()
raise ResponseCodeError(response=response, response_text=response_text)
- async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
- """Site API GET."""
+ async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Send an HTTP request to the site API and return the JSON response."""
await self._ready.wait()
- async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp:
+ async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp:
await self.maybe_raise_for_status(resp, raise_for_status)
return await resp.json()
- async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
- """Site API PATCH."""
- await self._ready.wait()
+ async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Site API GET."""
+ return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs)
- async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp:
- await self.maybe_raise_for_status(resp, raise_for_status)
- return await resp.json()
+ async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Site API PATCH."""
+ return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs)
- async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
+ async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
"""Site API POST."""
- await self._ready.wait()
-
- async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp:
- await self.maybe_raise_for_status(resp, raise_for_status)
- return await resp.json()
+ return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs)
- async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
+ async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
"""Site API PUT."""
- await self._ready.wait()
-
- async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp:
- await self.maybe_raise_for_status(resp, raise_for_status)
- return await resp.json()
+ return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs)
- async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]:
+ async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]:
"""Site API DELETE."""
await self._ready.wait()
- async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp:
+ async with self.session.delete(self._url_for(endpoint), **kwargs) as resp:
if resp.status == 204:
return None
diff --git a/bot/bot.py b/bot/bot.py
index cecee7b68..19b9035c4 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -1,11 +1,14 @@
+import asyncio
import logging
import socket
from typing import Optional
import aiohttp
+import discord
from discord.ext import commands
from bot import api
+from bot import constants
log = logging.getLogger('bot')
@@ -17,15 +20,17 @@ class Bot(commands.Bot):
# Use asyncio for DNS resolution instead of threads so threads aren't spammed.
# Use AF_INET as its socket family to prevent HTTPS related problems both locally
# and in production.
- self.connector = aiohttp.TCPConnector(
+ self._connector = aiohttp.TCPConnector(
resolver=aiohttp.AsyncResolver(),
family=socket.AF_INET,
)
- super().__init__(*args, connector=self.connector, **kwargs)
+ super().__init__(*args, connector=self._connector, **kwargs)
+
+ self._guild_available = asyncio.Event()
self.http_session: Optional[aiohttp.ClientSession] = None
- self.api_client = api.APIClient(loop=self.loop, connector=self.connector)
+ self.api_client = api.APIClient(loop=self.loop, connector=self._connector)
def add_cog(self, cog: commands.Cog) -> None:
"""Adds a "cog" to the bot and logs the operation."""
@@ -46,6 +51,47 @@ class Bot(commands.Bot):
async def start(self, *args, **kwargs) -> None:
"""Open an aiohttp session before logging in and connecting to Discord."""
- self.http_session = aiohttp.ClientSession(connector=self.connector)
+ self.http_session = aiohttp.ClientSession(connector=self._connector)
await super().start(*args, **kwargs)
+
+ async def on_guild_available(self, guild: discord.Guild) -> None:
+ """
+ Set the internal guild available event when constants.Guild.id becomes available.
+
+ If the cache appears to still be empty (no members, no channels, or no roles), the event
+ will not be set.
+ """
+ if guild.id != constants.Guild.id:
+ return
+
+ if not guild.roles or not guild.members or not guild.channels:
+ msg = "Guild available event was dispatched but the cache appears to still be empty!"
+ log.warning(msg)
+
+ try:
+ webhook = await self.fetch_webhook(constants.Webhooks.dev_log)
+ except discord.HTTPException as e:
+ log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}")
+ else:
+ await webhook.send(f"<@&{constants.Roles.admin}> {msg}")
+
+ return
+
+ self._guild_available.set()
+
+ async def on_guild_unavailable(self, guild: discord.Guild) -> None:
+ """Clear the internal guild available event when constants.Guild.id becomes unavailable."""
+ if guild.id != constants.Guild.id:
+ return
+
+ self._guild_available.clear()
+
+ async def wait_until_guild_available(self) -> None:
+ """
+ Wait until the constants.Guild.id guild is available (and the cache is ready).
+
+ The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE
+ gateway event before giving up and thus not populating the cache for unavailable guilds.
+ """
+ await self._guild_available.wait()
diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py
index f67ef6f05..baa6b9459 100644
--- a/bot/cogs/antispam.py
+++ b/bot/cogs/antispam.py
@@ -123,7 +123,7 @@ class AntiSpam(Cog):
async def alert_on_validation_error(self) -> None:
"""Unloads the cog and alerts admins if configuration validation failed."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
if self.validation_errors:
body = "**The following errors were encountered:**\n"
body += "\n".join(f"- {error}" for error in self.validation_errors.values())
diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py
index a0d8fedd5..20961e0a2 100644
--- a/bot/cogs/defcon.py
+++ b/bot/cogs/defcon.py
@@ -59,7 +59,7 @@ class Defcon(Cog):
async def sync_settings(self) -> None:
"""On cog load, try to synchronize DEFCON settings to the API."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
self.channel = await self.bot.fetch_channel(Channels.defcon)
try:
diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py
index 6e7c00b6a..204cffb37 100644
--- a/bot/cogs/doc.py
+++ b/bot/cogs/doc.py
@@ -157,7 +157,7 @@ class Doc(commands.Cog):
async def init_refresh_inventory(self) -> None:
"""Refresh documentation inventory on cog initialization."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
await self.refresh_inventory()
async def update_single(
diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py
index 345d2856c..1f84a0609 100644
--- a/bot/cogs/duck_pond.py
+++ b/bot/cogs/duck_pond.py
@@ -22,7 +22,7 @@ class DuckPond(Cog):
async def fetch_webhook(self) -> None:
"""Fetches the webhook object, so we can post to it."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
try:
self.webhook = await self.bot.fetch_webhook(self.webhook_id)
diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py
index d1b7dcab3..dbd76672f 100644
--- a/bot/cogs/logging.py
+++ b/bot/cogs/logging.py
@@ -20,7 +20,7 @@ class Logging(Cog):
async def startup_greeting(self) -> None:
"""Announce our presence to the configured devlog channel."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
log.info("Bot connected!")
embed = Embed(description="Connected!")
diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py
index c0de0e4da..3c5185468 100644
--- a/bot/cogs/moderation/scheduler.py
+++ b/bot/cogs/moderation/scheduler.py
@@ -38,7 +38,7 @@ class InfractionScheduler(Scheduler):
async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None:
"""Schedule expiration for previous infractions."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
log.trace(f"Rescheduling infractions for {self.__class__.__name__}.")
diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py
index bf777ea5a..81511f99d 100644
--- a/bot/cogs/off_topic_names.py
+++ b/bot/cogs/off_topic_names.py
@@ -88,7 +88,7 @@ class OffTopicNames(Cog):
async def init_offtopic_updater(self) -> None:
"""Start off-topic channel updating event loop if it hasn't already started."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
if self.updater_task is None:
coro = update_names(self.bot)
self.updater_task = self.bot.loop.create_task(coro)
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index aa487f18e..4f6584aba 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -48,7 +48,7 @@ class Reddit(Cog):
async def init_reddit_ready(self) -> None:
"""Sets the reddit webhook when the cog is loaded."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
if not self.webhook:
self.webhook = await self.bot.fetch_webhook(Webhooks.reddit)
@@ -208,7 +208,7 @@ class Reddit(Cog):
await asyncio.sleep(seconds_until)
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
if not self.webhook:
await self.bot.fetch_webhook(Webhooks.reddit)
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py
index 42229123b..a642cbfdb 100644
--- a/bot/cogs/reminders.py
+++ b/bot/cogs/reminders.py
@@ -36,7 +36,7 @@ class Reminders(Scheduler, Cog):
async def reschedule_reminders(self) -> None:
"""Get all current reminders from the API and reschedule them."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
response = await self.bot.api_client.get(
'bot/reminders',
params={'active': 'true'}
diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py
index 4e6ed156b..5708be3f4 100644
--- a/bot/cogs/sync/cog.py
+++ b/bot/cogs/sync/cog.py
@@ -1,7 +1,7 @@
import logging
-from typing import Callable, Dict, Iterable, Union
+from typing import Any, Dict
-from discord import Guild, Member, Role, User
+from discord import Member, Role, User
from discord.ext import commands
from discord.ext.commands import Cog, Context
@@ -16,45 +16,28 @@ log = logging.getLogger(__name__)
class Sync(Cog):
"""Captures relevant events and sends them to the site."""
- # The server to synchronize events on.
- # Note that setting this wrongly will result in things getting deleted
- # that possibly shouldn't be.
- SYNC_SERVER_ID = constants.Guild.id
-
- # An iterable of callables that are called when the bot is ready.
- ON_READY_SYNCERS: Iterable[Callable[[Bot, Guild], None]] = (
- syncers.sync_roles,
- syncers.sync_users
- )
-
def __init__(self, bot: Bot) -> None:
self.bot = bot
+ self.role_syncer = syncers.RoleSyncer(self.bot)
+ self.user_syncer = syncers.UserSyncer(self.bot)
self.bot.loop.create_task(self.sync_guild())
async def sync_guild(self) -> None:
"""Syncs the roles/users of the guild with the database."""
- await self.bot.wait_until_ready()
- guild = self.bot.get_guild(self.SYNC_SERVER_ID)
- if guild is not None:
- for syncer in self.ON_READY_SYNCERS:
- syncer_name = syncer.__name__[5:] # drop off `sync_`
- log.info("Starting `%s` syncer.", syncer_name)
- total_created, total_updated, total_deleted = await syncer(self.bot, guild)
- if total_deleted is None:
- log.info(
- f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`."
- )
- else:
- log.info(
- f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`, "
- f"deleted `{total_deleted}`."
- )
-
- async def patch_user(self, user_id: int, updated_information: Dict[str, Union[str, int]]) -> None:
+ await self.bot.wait_until_guild_available()
+
+ guild = self.bot.get_guild(constants.Guild.id)
+ if guild is None:
+ return
+
+ for syncer in (self.role_syncer, self.user_syncer):
+ await syncer.sync(guild)
+
+ async def patch_user(self, user_id: int, updated_information: Dict[str, Any]) -> None:
"""Send a PATCH request to partially update a user in the database."""
try:
- await self.bot.api_client.patch("bot/users/" + str(user_id), json=updated_information)
+ await self.bot.api_client.patch(f"bot/users/{user_id}", json=updated_information)
except ResponseCodeError as e:
if e.response.status != 404:
raise
@@ -82,12 +65,14 @@ class Sync(Cog):
@Cog.listener()
async def on_guild_role_update(self, before: Role, after: Role) -> None:
"""Syncs role with the database if any of the stored attributes were updated."""
- if (
- before.name != after.name
- or before.colour != after.colour
- or before.permissions != after.permissions
- or before.position != after.position
- ):
+ was_updated = (
+ before.name != after.name
+ or before.colour != after.colour
+ or before.permissions != after.permissions
+ or before.position != after.position
+ )
+
+ if was_updated:
await self.bot.api_client.put(
f'bot/roles/{after.id}',
json={
@@ -137,18 +122,8 @@ class Sync(Cog):
@Cog.listener()
async def on_member_remove(self, member: Member) -> None:
- """Updates the user information when a member leaves the guild."""
- await self.bot.api_client.put(
- f'bot/users/{member.id}',
- json={
- 'avatar_hash': member.avatar,
- 'discriminator': int(member.discriminator),
- 'id': member.id,
- 'in_guild': False,
- 'name': member.name,
- 'roles': sorted(role.id for role in member.roles)
- }
- )
+ """Set the in_guild field to False when a member leaves the guild."""
+ await self.patch_user(member.id, updated_information={"in_guild": False})
@Cog.listener()
async def on_member_update(self, before: Member, after: Member) -> None:
@@ -160,7 +135,8 @@ class Sync(Cog):
@Cog.listener()
async def on_user_update(self, before: User, after: User) -> None:
"""Update the user information in the database if a relevant change is detected."""
- if any(getattr(before, attr) != getattr(after, attr) for attr in ("name", "discriminator", "avatar")):
+ attrs = ("name", "discriminator", "avatar")
+ if any(getattr(before, attr) != getattr(after, attr) for attr in attrs):
updated_information = {
"name": after.name,
"discriminator": int(after.discriminator),
@@ -176,25 +152,11 @@ class Sync(Cog):
@sync_group.command(name='roles')
@commands.has_permissions(administrator=True)
async def sync_roles_command(self, ctx: Context) -> None:
- """Manually synchronize the guild's roles with the roles on the site."""
- initial_response = await ctx.send("📊 Synchronizing roles.")
- total_created, total_updated, total_deleted = await syncers.sync_roles(self.bot, ctx.guild)
- await initial_response.edit(
- content=(
- f"👌 Role synchronization complete, created **{total_created}** "
- f", updated **{total_created}** roles, and deleted **{total_deleted}** roles."
- )
- )
+ """Manually synchronise the guild's roles with the roles on the site."""
+ await self.role_syncer.sync(ctx.guild, ctx)
@sync_group.command(name='users')
@commands.has_permissions(administrator=True)
async def sync_users_command(self, ctx: Context) -> None:
- """Manually synchronize the guild's users with the users on the site."""
- initial_response = await ctx.send("📊 Synchronizing users.")
- total_created, total_updated, total_deleted = await syncers.sync_users(self.bot, ctx.guild)
- await initial_response.edit(
- content=(
- f"👌 User synchronization complete, created **{total_created}** "
- f"and updated **{total_created}** users."
- )
- )
+ """Manually synchronise the guild's users with the users on the site."""
+ await self.user_syncer.sync(ctx.guild, ctx)
diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py
index 14cf51383..6715ad6fb 100644
--- a/bot/cogs/sync/syncers.py
+++ b/bot/cogs/sync/syncers.py
@@ -1,235 +1,342 @@
+import abc
+import logging
+import typing as t
from collections import namedtuple
-from typing import Dict, Set, Tuple
+from functools import partial
-from discord import Guild
+from discord import Guild, HTTPException, Member, Message, Reaction, User
+from discord.ext.commands import Context
+from bot import constants
+from bot.api import ResponseCodeError
from bot.bot import Bot
+log = logging.getLogger(__name__)
+
# These objects are declared as namedtuples because tuples are hashable,
# something that we make use of when diffing site roles against guild roles.
-Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))
-User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild'))
-
-
-def get_roles_for_sync(
- guild_roles: Set[Role], api_roles: Set[Role]
-) -> Tuple[Set[Role], Set[Role], Set[Role]]:
- """
- Determine which roles should be created or updated on the site.
-
- Arguments:
- guild_roles (Set[Role]):
- Roles that were found on the guild at startup.
-
- api_roles (Set[Role]):
- Roles that were retrieved from the API at startup.
-
- Returns:
- Tuple[Set[Role], Set[Role]. Set[Role]]:
- A tuple with three elements. The first element represents
- roles to be created on the site, meaning that they were
- present on the cached guild but not on the API. The second
- element represents roles to be updated, meaning they were
- present on both the cached guild and the API but non-ID
- fields have changed inbetween. The third represents roles
- to be deleted on the site, meaning the roles are present on
- the API but not in the cached guild.
- """
- guild_role_ids = {role.id for role in guild_roles}
- api_role_ids = {role.id for role in api_roles}
- new_role_ids = guild_role_ids - api_role_ids
- deleted_role_ids = api_role_ids - guild_role_ids
-
- # New roles are those which are on the cached guild but not on the
- # API guild, going by the role ID. We need to send them in for creation.
- roles_to_create = {role for role in guild_roles if role.id in new_role_ids}
- roles_to_update = guild_roles - api_roles - roles_to_create
- roles_to_delete = {role for role in api_roles if role.id in deleted_role_ids}
- return roles_to_create, roles_to_update, roles_to_delete
-
-
-async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]:
- """
- Synchronize roles found on the given `guild` with the ones on the API.
-
- Arguments:
- bot (bot.bot.Bot):
- The bot instance that we're running with.
-
- guild (discord.Guild):
- The guild instance from the bot's cache
- to synchronize roles with.
-
- Returns:
- Tuple[int, int, int]:
- A tuple with three integers representing how many roles were created
- (element `0`) , how many roles were updated (element `1`), and how many
- roles were deleted (element `2`) on the API.
- """
- roles = await bot.api_client.get('bot/roles')
-
- # Pack API roles and guild roles into one common format,
- # which is also hashable. We need hashability to be able
- # to compare these easily later using sets.
- api_roles = {Role(**role_dict) for role_dict in roles}
- guild_roles = {
- Role(
- id=role.id, name=role.name,
- colour=role.colour.value, permissions=role.permissions.value,
- position=role.position,
- )
- for role in guild.roles
- }
- roles_to_create, roles_to_update, roles_to_delete = get_roles_for_sync(guild_roles, api_roles)
-
- for role in roles_to_create:
- await bot.api_client.post(
- 'bot/roles',
- json={
- 'id': role.id,
- 'name': role.name,
- 'colour': role.colour,
- 'permissions': role.permissions,
- 'position': role.position,
- }
- )
+_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))
+_User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild'))
+_Diff = namedtuple('Diff', ('created', 'updated', 'deleted'))
- for role in roles_to_update:
- await bot.api_client.put(
- f'bot/roles/{role.id}',
- json={
- 'id': role.id,
- 'name': role.name,
- 'colour': role.colour,
- 'permissions': role.permissions,
- 'position': role.position,
- }
- )
- for role in roles_to_delete:
- await bot.api_client.delete(f'bot/roles/{role.id}')
-
- return len(roles_to_create), len(roles_to_update), len(roles_to_delete)
-
-
-def get_users_for_sync(
- guild_users: Dict[int, User], api_users: Dict[int, User]
-) -> Tuple[Set[User], Set[User]]:
- """
- Determine which users should be created or updated on the website.
-
- Arguments:
- guild_users (Dict[int, User]):
- A mapping of user IDs to user data, populated from the
- guild cached on the running bot instance.
-
- api_users (Dict[int, User]):
- A mapping of user IDs to user data, populated from the API's
- current inventory of all users.
-
- Returns:
- Tuple[Set[User], Set[User]]:
- Two user sets as a tuple. The first element represents users
- to be created on the website, these are users that are present
- in the cached guild data but not in the API at all, going by
- their ID. The second element represents users to update. It is
- populated by users which are present on both the API and the
- guild, but where the attribute of a user on the API is not
- equal to the attribute of the user on the guild.
- """
- users_to_create = set()
- users_to_update = set()
-
- for api_user in api_users.values():
- guild_user = guild_users.get(api_user.id)
- if guild_user is not None:
- if api_user != guild_user:
- users_to_update.add(guild_user)
-
- elif api_user.in_guild:
- # The user is known on the API but not the guild, and the
- # API currently specifies that the user is a member of the guild.
- # This means that the user has left since the last sync.
- # Update the `in_guild` attribute of the user on the site
- # to signify that the user left.
- new_api_user = api_user._replace(in_guild=False)
- users_to_update.add(new_api_user)
-
- new_user_ids = set(guild_users.keys()) - set(api_users.keys())
- for user_id in new_user_ids:
- # The user is known on the guild but not on the API. This means
- # that the user has joined since the last sync. Create it.
- new_user = guild_users[user_id]
- users_to_create.add(new_user)
-
- return users_to_create, users_to_update
-
-
-async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]:
- """
- Synchronize users found in the given `guild` with the ones in the API.
-
- Arguments:
- bot (bot.bot.Bot):
- The bot instance that we're running with.
-
- guild (discord.Guild):
- The guild instance from the bot's cache
- to synchronize roles with.
-
- Returns:
- Tuple[int, int, None]:
- A tuple with two integers, representing how many users were created
- (element `0`) and how many users were updated (element `1`), and `None`
- to indicate that a user sync never deletes entries from the API.
- """
- current_users = await bot.api_client.get('bot/users')
-
- # Pack API users and guild users into one common format,
- # which is also hashable. We need hashability to be able
- # to compare these easily later using sets.
- api_users = {
- user_dict['id']: User(
- roles=tuple(sorted(user_dict.pop('roles'))),
- **user_dict
- )
- for user_dict in current_users
- }
- guild_users = {
- member.id: User(
- id=member.id, name=member.name,
- discriminator=int(member.discriminator), avatar_hash=member.avatar,
- roles=tuple(sorted(role.id for role in member.roles)), in_guild=True
- )
- for member in guild.members
- }
-
- users_to_create, users_to_update = get_users_for_sync(guild_users, api_users)
-
- for user in users_to_create:
- await bot.api_client.post(
- 'bot/users',
- json={
- 'avatar_hash': user.avatar_hash,
- 'discriminator': user.discriminator,
- 'id': user.id,
- 'in_guild': user.in_guild,
- 'name': user.name,
- 'roles': list(user.roles)
- }
+class Syncer(abc.ABC):
+ """Base class for synchronising the database with objects in the Discord cache."""
+
+ _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developer}> "
+ _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark)
+
+ def __init__(self, bot: Bot) -> None:
+ self.bot = bot
+
+ @property
+ @abc.abstractmethod
+ def name(self) -> str:
+ """The name of the syncer; used in output messages and logging."""
+ raise NotImplementedError # pragma: no cover
+
+ async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]:
+ """
+ Send a prompt to confirm or abort a sync using reactions and return the sent message.
+
+ If a message is given, it is edited to display the prompt and reactions. Otherwise, a new
+ message is sent to the dev-core channel and mentions the core developers role. If the
+ channel cannot be retrieved, return None.
+ """
+ log.trace(f"Sending {self.name} sync confirmation prompt.")
+
+ msg_content = (
+ f'Possible cache issue while syncing {self.name}s. '
+ f'More than {constants.Sync.max_diff} {self.name}s were changed. '
+ f'React to confirm or abort the sync.'
)
- for user in users_to_update:
- await bot.api_client.put(
- f'bot/users/{user.id}',
- json={
- 'avatar_hash': user.avatar_hash,
- 'discriminator': user.discriminator,
- 'id': user.id,
- 'in_guild': user.in_guild,
- 'name': user.name,
- 'roles': list(user.roles)
- }
+ # Send to core developers if it's an automatic sync.
+ if not message:
+ log.trace("Message not provided for confirmation; creating a new one in dev-core.")
+ channel = self.bot.get_channel(constants.Channels.devcore)
+
+ if not channel:
+ log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.")
+ try:
+ channel = await self.bot.fetch_channel(constants.Channels.devcore)
+ except HTTPException:
+ log.exception(
+ f"Failed to fetch channel for sending sync confirmation prompt; "
+ f"aborting {self.name} sync."
+ )
+ return None
+
+ message = await channel.send(f"{self._CORE_DEV_MENTION}{msg_content}")
+ else:
+ await message.edit(content=msg_content)
+
+ # Add the initial reactions.
+ log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.")
+ for emoji in self._REACTION_EMOJIS:
+ await message.add_reaction(emoji)
+
+ return message
+
+ def _reaction_check(
+ self,
+ author: Member,
+ message: Message,
+ reaction: Reaction,
+ user: t.Union[Member, User]
+ ) -> bool:
+ """
+ Return True if the `reaction` is a valid confirmation or abort reaction on `message`.
+
+ If the `author` of the prompt is a bot, then a reaction by any core developer will be
+ considered valid. Otherwise, the author of the reaction (`user`) will have to be the
+ `author` of the prompt.
+ """
+ # For automatic syncs, check for the core dev role instead of an exact author
+ has_role = any(constants.Roles.core_developer == role.id for role in user.roles)
+ return (
+ reaction.message.id == message.id
+ and not user.bot
+ and (has_role if author.bot else user == author)
+ and str(reaction.emoji) in self._REACTION_EMOJIS
)
- return len(users_to_create), len(users_to_update), None
+ async def _wait_for_confirmation(self, author: Member, message: Message) -> bool:
+ """
+ Wait for a confirmation reaction by `author` on `message` and return True if confirmed.
+
+ Uses the `_reaction_check` function to determine if a reaction is valid.
+
+ If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False.
+ To acknowledge the reaction (or lack thereof), `message` will be edited.
+ """
+ # Preserve the core-dev role mention in the message edits so users aren't confused about
+ # where notifications came from.
+ mention = self._CORE_DEV_MENTION if author.bot else ""
+
+ reaction = None
+ try:
+ log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.")
+ reaction, _ = await self.bot.wait_for(
+ 'reaction_add',
+ check=partial(self._reaction_check, author, message),
+ timeout=constants.Sync.confirm_timeout
+ )
+ except TimeoutError:
+ # reaction will remain none thus sync will be aborted in the finally block below.
+ log.debug(f"The {self.name} syncer confirmation prompt timed out.")
+ finally:
+ if str(reaction) == constants.Emojis.check_mark:
+ log.trace(f"The {self.name} syncer was confirmed.")
+ await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.')
+ return True
+ else:
+ log.warning(f"The {self.name} syncer was aborted or timed out!")
+ await message.edit(
+ content=f':warning: {mention}{self.name} sync aborted or timed out!'
+ )
+ return False
+
+ @abc.abstractmethod
+ async def _get_diff(self, guild: Guild) -> _Diff:
+ """Return the difference between the cache of `guild` and the database."""
+ raise NotImplementedError # pragma: no cover
+
+ @abc.abstractmethod
+ async def _sync(self, diff: _Diff) -> None:
+ """Perform the API calls for synchronisation."""
+ raise NotImplementedError # pragma: no cover
+
+ async def _get_confirmation_result(
+ self,
+ diff_size: int,
+ author: Member,
+ message: t.Optional[Message] = None
+ ) -> t.Tuple[bool, t.Optional[Message]]:
+ """
+ Prompt for confirmation and return a tuple of the result and the prompt message.
+
+ `diff_size` is the size of the diff of the sync. If it is greater than
+ `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the
+ sync and the `message` is an extant message to edit to display the prompt.
+
+ If confirmed or no confirmation was needed, the result is True. The returned message will
+ either be the given `message` or a new one which was created when sending the prompt.
+ """
+ log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.")
+ if diff_size > constants.Sync.max_diff:
+ message = await self._send_prompt(message)
+ if not message:
+ return False, None # Couldn't get channel.
+
+ confirmed = await self._wait_for_confirmation(author, message)
+ if not confirmed:
+ return False, message # Sync aborted.
+
+ return True, message
+
+ async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None:
+ """
+ Synchronise the database with the cache of `guild`.
+
+ If the differences between the cache and the database are greater than
+ `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core
+ channel. The confirmation can be optionally redirect to `ctx` instead.
+ """
+ log.info(f"Starting {self.name} syncer.")
+
+ message = None
+ author = self.bot.user
+ if ctx:
+ message = await ctx.send(f"📊 Synchronising {self.name}s.")
+ author = ctx.author
+
+ diff = await self._get_diff(guild)
+ diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict
+ totals = {k: len(v) for k, v in diff_dict.items() if v is not None}
+ diff_size = sum(totals.values())
+
+ confirmed, message = await self._get_confirmation_result(diff_size, author, message)
+ if not confirmed:
+ return
+
+ # Preserve the core-dev role mention in the message edits so users aren't confused about
+ # where notifications came from.
+ mention = self._CORE_DEV_MENTION if author.bot else ""
+
+ try:
+ await self._sync(diff)
+ except ResponseCodeError as e:
+ log.exception(f"{self.name} syncer failed!")
+
+ # Don't show response text because it's probably some really long HTML.
+ results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```"
+ content = f":x: {mention}Synchronisation of {self.name}s failed: {results}"
+ else:
+ results = ", ".join(f"{name} `{total}`" for name, total in totals.items())
+ log.info(f"{self.name} syncer finished: {results}.")
+ content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}"
+
+ if message:
+ await message.edit(content=content)
+
+
+class RoleSyncer(Syncer):
+ """Synchronise the database with roles in the cache."""
+
+ name = "role"
+
+ async def _get_diff(self, guild: Guild) -> _Diff:
+ """Return the difference of roles between the cache of `guild` and the database."""
+ log.trace("Getting the diff for roles.")
+ roles = await self.bot.api_client.get('bot/roles')
+
+ # Pack DB roles and guild roles into one common, hashable format.
+ # They're hashable so that they're easily comparable with sets later.
+ db_roles = {_Role(**role_dict) for role_dict in roles}
+ guild_roles = {
+ _Role(
+ id=role.id,
+ name=role.name,
+ colour=role.colour.value,
+ permissions=role.permissions.value,
+ position=role.position,
+ )
+ for role in guild.roles
+ }
+
+ guild_role_ids = {role.id for role in guild_roles}
+ api_role_ids = {role.id for role in db_roles}
+ new_role_ids = guild_role_ids - api_role_ids
+ deleted_role_ids = api_role_ids - guild_role_ids
+
+ # New roles are those which are on the cached guild but not on the
+ # DB guild, going by the role ID. We need to send them in for creation.
+ roles_to_create = {role for role in guild_roles if role.id in new_role_ids}
+ roles_to_update = guild_roles - db_roles - roles_to_create
+ roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids}
+
+ return _Diff(roles_to_create, roles_to_update, roles_to_delete)
+
+ async def _sync(self, diff: _Diff) -> None:
+ """Synchronise the database with the role cache of `guild`."""
+ log.trace("Syncing created roles...")
+ for role in diff.created:
+ await self.bot.api_client.post('bot/roles', json=role._asdict())
+
+ log.trace("Syncing updated roles...")
+ for role in diff.updated:
+ await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict())
+
+ log.trace("Syncing deleted roles...")
+ for role in diff.deleted:
+ await self.bot.api_client.delete(f'bot/roles/{role.id}')
+
+
+class UserSyncer(Syncer):
+ """Synchronise the database with users in the cache."""
+
+ name = "user"
+
+ async def _get_diff(self, guild: Guild) -> _Diff:
+ """Return the difference of users between the cache of `guild` and the database."""
+ log.trace("Getting the diff for users.")
+ users = await self.bot.api_client.get('bot/users')
+
+ # Pack DB roles and guild roles into one common, hashable format.
+ # They're hashable so that they're easily comparable with sets later.
+ db_users = {
+ user_dict['id']: _User(
+ roles=tuple(sorted(user_dict.pop('roles'))),
+ **user_dict
+ )
+ for user_dict in users
+ }
+ guild_users = {
+ member.id: _User(
+ id=member.id,
+ name=member.name,
+ discriminator=int(member.discriminator),
+ avatar_hash=member.avatar,
+ roles=tuple(sorted(role.id for role in member.roles)),
+ in_guild=True
+ )
+ for member in guild.members
+ }
+
+ users_to_create = set()
+ users_to_update = set()
+
+ for db_user in db_users.values():
+ guild_user = guild_users.get(db_user.id)
+ if guild_user is not None:
+ if db_user != guild_user:
+ users_to_update.add(guild_user)
+
+ elif db_user.in_guild:
+ # The user is known in the DB but not the guild, and the
+ # DB currently specifies that the user is a member of the guild.
+ # This means that the user has left since the last sync.
+ # Update the `in_guild` attribute of the user on the site
+ # to signify that the user left.
+ new_api_user = db_user._replace(in_guild=False)
+ users_to_update.add(new_api_user)
+
+ new_user_ids = set(guild_users.keys()) - set(db_users.keys())
+ for user_id in new_user_ids:
+ # The user is known on the guild but not on the API. This means
+ # that the user has joined since the last sync. Create it.
+ new_user = guild_users[user_id]
+ users_to_create.add(new_user)
+
+ return _Diff(users_to_create, users_to_update, None)
+
+ async def _sync(self, diff: _Diff) -> None:
+ """Synchronise the database with the user cache of `guild`."""
+ log.trace("Syncing created users...")
+ for user in diff.created:
+ await self.bot.api_client.post('bot/users', json=user._asdict())
+
+ log.trace("Syncing updated users...")
+ for user in diff.updated:
+ await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict())
diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py
index 582237374..e3c396863 100644
--- a/bot/cogs/verification.py
+++ b/bot/cogs/verification.py
@@ -219,7 +219,7 @@ class Verification(Cog):
@periodic_ping.before_loop
async def before_ping(self) -> None:
"""Only start the loop when the bot is ready."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
def cog_unload(self) -> None:
"""Cancel the periodic ping task when the cog is unloaded."""
diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py
index eb787b083..3667a80e8 100644
--- a/bot/cogs/watchchannels/watchchannel.py
+++ b/bot/cogs/watchchannels/watchchannel.py
@@ -91,7 +91,7 @@ class WatchChannel(metaclass=CogABCMeta):
async def start_watchchannel(self) -> None:
"""Starts the watch channel by getting the channel, webhook, and user cache ready."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
try:
self.channel = await self.bot.fetch_channel(self.destination)
diff --git a/bot/constants.py b/bot/constants.py
index 681d8da49..9bc331dc4 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -263,6 +263,7 @@ class Emojis(metaclass=YAMLGetter):
new: str
pencil: str
cross_mark: str
+ check_mark: str
ducky_yellow: int
ducky_blurple: int
@@ -366,6 +367,7 @@ class Channels(metaclass=YAMLGetter):
checkpoint_test: int
defcon: int
devcontrib: int
+ devcore: int
devlog: int
devtest: int
esoteric: int
@@ -405,6 +407,7 @@ class Webhooks(metaclass=YAMLGetter):
big_brother: int
reddit: int
duck_pond: int
+ dev_log: int
class Roles(metaclass=YAMLGetter):
@@ -538,6 +541,13 @@ class RedirectOutput(metaclass=YAMLGetter):
delete_delay: int
+class Sync(metaclass=YAMLGetter):
+ section = 'sync'
+
+ confirm_timeout: int
+ max_diff: int
+
+
class Event(Enum):
"""
Event names. This does not include every event (for example, raw
diff --git a/config-default.yml b/config-default.yml
index 379475907..f70fe3c34 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -35,6 +35,7 @@ style:
pencil: "\u270F"
new: "\U0001F195"
cross_mark: "\u274C"
+ check_mark: "\u2705"
ducky_yellow: &DUCKY_YELLOW 574951975574175744
ducky_blurple: &DUCKY_BLURPLE 574951975310065675
@@ -123,6 +124,7 @@ guild:
checkpoint_test: 422077681434099723
defcon: &DEFCON 464469101889454091
devcontrib: &DEV_CONTRIB 635950537262759947
+ devcore: 411200599653351425
devlog: &DEVLOG 622895325144940554
devtest: &DEVTEST 414574275865870337
esoteric: 470884583684964352
@@ -180,6 +182,7 @@ guild:
big_brother: 569133704568373283
reddit: 635408384794951680
duck_pond: 637821475327311927
+ dev_log: 680501655111729222
filter:
@@ -432,6 +435,10 @@ redirect_output:
delete_invocation: true
delete_delay: 15
+sync:
+ confirm_timeout: 300
+ max_diff: 10
+
duck_pond:
threshold: 5
custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE, *DUCKY_HUNT, *DUCKY_WIZARD, *DUCKY_PARTY, *DUCKY_ANGEL, *DUCKY_MAUL, *DUCKY_SANTA]
diff --git a/tests/base.py b/tests/base.py
index 21a57716a..21613110e 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -1,5 +1,11 @@
import logging
from contextlib import contextmanager
+from typing import Dict
+
+import discord
+from discord.ext import commands
+
+from tests import helpers
class _CaptureLogHandler(logging.Handler):
@@ -69,3 +75,31 @@ class LoggingTestsMixin:
standard_message = self._truncateMessage(base_message, record_message)
msg = self._formatMessage(msg, standard_message)
self.fail(msg)
+
+
+class CommandTestCase(unittest.TestCase):
+ """TestCase with additional assertions that are useful for testing Discord commands."""
+
+ @helpers.async_test
+ async def assertHasPermissionsCheck(
+ self,
+ cmd: commands.Command,
+ permissions: Dict[str, bool],
+ ) -> None:
+ """
+ Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`.
+
+ Every permission in `permissions` is expected to be reported as missing. In other words, do
+ not include permissions which should not raise an exception along with those which should.
+ """
+ # Invert permission values because it's more intuitive to pass to this assertion the same
+ # permissions as those given to the check decorator.
+ permissions = {k: not v for k, v in permissions.items()}
+
+ ctx = helpers.MockContext()
+ ctx.channel.permissions_for.return_value = discord.Permissions(**permissions)
+
+ with self.assertRaises(commands.MissingPermissions) as cm:
+ await cmd.can_run(ctx)
+
+ self.assertCountEqual(permissions.keys(), cm.exception.missing_perms)
diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py
new file mode 100644
index 000000000..e6a6f9688
--- /dev/null
+++ b/tests/bot/cogs/sync/test_base.py
@@ -0,0 +1,412 @@
+import unittest
+from unittest import mock
+
+import discord
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.cogs.sync.syncers import Syncer, _Diff
+from tests import helpers
+
+
+class TestSyncer(Syncer):
+ """Syncer subclass with mocks for abstract methods for testing purposes."""
+
+ name = "test"
+ _get_diff = helpers.AsyncMock()
+ _sync = helpers.AsyncMock()
+
+
+class SyncerBaseTests(unittest.TestCase):
+ """Tests for the syncer base class."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+
+ def test_instantiation_fails_without_abstract_methods(self):
+ """The class must have abstract methods implemented."""
+ with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"):
+ Syncer(self.bot)
+
+
+class SyncerSendPromptTests(unittest.TestCase):
+ """Tests for sending the sync confirmation prompt."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = TestSyncer(self.bot)
+
+ def mock_get_channel(self):
+ """Fixture to return a mock channel and message for when `get_channel` is used."""
+ self.bot.reset_mock()
+
+ mock_channel = helpers.MockTextChannel()
+ mock_message = helpers.MockMessage()
+
+ mock_channel.send.return_value = mock_message
+ self.bot.get_channel.return_value = mock_channel
+
+ return mock_channel, mock_message
+
+ def mock_fetch_channel(self):
+ """Fixture to return a mock channel and message for when `fetch_channel` is used."""
+ self.bot.reset_mock()
+
+ mock_channel = helpers.MockTextChannel()
+ mock_message = helpers.MockMessage()
+
+ self.bot.get_channel.return_value = None
+ mock_channel.send.return_value = mock_message
+ self.bot.fetch_channel.return_value = mock_channel
+
+ return mock_channel, mock_message
+
+ @helpers.async_test
+ async def test_send_prompt_edits_and_returns_message(self):
+ """The given message should be edited to display the prompt and then should be returned."""
+ msg = helpers.MockMessage()
+ ret_val = await self.syncer._send_prompt(msg)
+
+ msg.edit.assert_called_once()
+ self.assertIn("content", msg.edit.call_args[1])
+ self.assertEqual(ret_val, msg)
+
+ @helpers.async_test
+ async def test_send_prompt_gets_dev_core_channel(self):
+ """The dev-core channel should be retrieved if an extant message isn't given."""
+ subtests = (
+ (self.bot.get_channel, self.mock_get_channel),
+ (self.bot.fetch_channel, self.mock_fetch_channel),
+ )
+
+ for method, mock_ in subtests:
+ with self.subTest(method=method, msg=mock_.__name__):
+ mock_()
+ await self.syncer._send_prompt()
+
+ method.assert_called_once_with(constants.Channels.devcore)
+
+ @helpers.async_test
+ async def test_send_prompt_returns_None_if_channel_fetch_fails(self):
+ """None should be returned if there's an HTTPException when fetching the channel."""
+ self.bot.get_channel.return_value = None
+ self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!")
+
+ ret_val = await self.syncer._send_prompt()
+
+ self.assertIsNone(ret_val)
+
+ @helpers.async_test
+ async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):
+ """A new message mentioning core devs should be sent and returned if message isn't given."""
+ for mock_ in (self.mock_get_channel, self.mock_fetch_channel):
+ with self.subTest(msg=mock_.__name__):
+ mock_channel, mock_message = mock_()
+ ret_val = await self.syncer._send_prompt()
+
+ mock_channel.send.assert_called_once()
+ self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])
+ self.assertEqual(ret_val, mock_message)
+
+ @helpers.async_test
+ async def test_send_prompt_adds_reactions(self):
+ """The message should have reactions for confirmation added."""
+ extant_message = helpers.MockMessage()
+ subtests = (
+ (extant_message, lambda: (None, extant_message)),
+ (None, self.mock_get_channel),
+ (None, self.mock_fetch_channel),
+ )
+
+ for message_arg, mock_ in subtests:
+ subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__
+
+ with self.subTest(msg=subtest_msg):
+ _, mock_message = mock_()
+ await self.syncer._send_prompt(message_arg)
+
+ calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS]
+ mock_message.add_reaction.assert_has_calls(calls)
+
+
+class SyncerConfirmationTests(unittest.TestCase):
+ """Tests for waiting for a sync confirmation reaction on the prompt."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = TestSyncer(self.bot)
+ self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developer)
+
+ @staticmethod
+ def get_message_reaction(emoji):
+ """Fixture to return a mock message an reaction from the given `emoji`."""
+ message = helpers.MockMessage()
+ reaction = helpers.MockReaction(emoji=emoji, message=message)
+
+ return message, reaction
+
+ def test_reaction_check_for_valid_emoji_and_authors(self):
+ """Should return True if authors are identical or are a bot and a core dev, respectively."""
+ user_subtests = (
+ (
+ helpers.MockMember(id=77),
+ helpers.MockMember(id=77),
+ "identical users",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True),
+ helpers.MockMember(id=43, roles=[self.core_dev_role]),
+ "bot author and core-dev reactor",
+ ),
+ )
+
+ for emoji in self.syncer._REACTION_EMOJIS:
+ for author, user, msg in user_subtests:
+ with self.subTest(author=author, user=user, emoji=emoji, msg=msg):
+ message, reaction = self.get_message_reaction(emoji)
+ ret_val = self.syncer._reaction_check(author, message, reaction, user)
+
+ self.assertTrue(ret_val)
+
+ def test_reaction_check_for_invalid_reactions(self):
+ """Should return False for invalid reaction events."""
+ valid_emoji = self.syncer._REACTION_EMOJIS[0]
+ subtests = (
+ (
+ helpers.MockMember(id=77),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=43, roles=[self.core_dev_role]),
+ "users are not identical",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=43),
+ "reactor lacks the core-dev role",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
+ "reactor is a bot",
+ ),
+ (
+ helpers.MockMember(id=77),
+ helpers.MockMessage(id=95),
+ helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)),
+ helpers.MockMember(id=77),
+ "messages are not identical",
+ ),
+ (
+ helpers.MockMember(id=77),
+ *self.get_message_reaction("InVaLiD"),
+ helpers.MockMember(id=77),
+ "emoji is invalid",
+ ),
+ )
+
+ for *args, msg in subtests:
+ kwargs = dict(zip(("author", "message", "reaction", "user"), args))
+ with self.subTest(**kwargs, msg=msg):
+ ret_val = self.syncer._reaction_check(*args)
+ self.assertFalse(ret_val)
+
+ @helpers.async_test
+ async def test_wait_for_confirmation(self):
+ """The message should always be edited and only return True if the emoji is a check mark."""
+ subtests = (
+ (constants.Emojis.check_mark, True, None),
+ ("InVaLiD", False, None),
+ (None, False, TimeoutError),
+ )
+
+ for emoji, ret_val, side_effect in subtests:
+ for bot in (True, False):
+ with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot):
+ # Set up mocks
+ message = helpers.MockMessage()
+ member = helpers.MockMember(bot=bot)
+
+ self.bot.wait_for.reset_mock()
+ self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None)
+ self.bot.wait_for.side_effect = side_effect
+
+ # Call the function
+ actual_return = await self.syncer._wait_for_confirmation(member, message)
+
+ # Perform assertions
+ self.bot.wait_for.assert_called_once()
+ self.assertIn("reaction_add", self.bot.wait_for.call_args[0])
+
+ message.edit.assert_called_once()
+ kwargs = message.edit.call_args[1]
+ self.assertIn("content", kwargs)
+
+ # Core devs should only be mentioned if the author is a bot.
+ if bot:
+ self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
+ else:
+ self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
+
+ self.assertIs(actual_return, ret_val)
+
+
+class SyncerSyncTests(unittest.TestCase):
+ """Tests for main function orchestrating the sync."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))
+ self.syncer = TestSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_respects_confirmation_result(self):
+ """The sync should abort if confirmation fails and continue if confirmed."""
+ mock_message = helpers.MockMessage()
+ subtests = (
+ (True, mock_message),
+ (False, None),
+ )
+
+ for confirmed, message in subtests:
+ with self.subTest(confirmed=confirmed):
+ self.syncer._sync.reset_mock()
+ self.syncer._get_diff.reset_mock()
+
+ diff = _Diff({1, 2, 3}, {4, 5}, None)
+ self.syncer._get_diff.return_value = diff
+ self.syncer._get_confirmation_result = helpers.AsyncMock(
+ return_value=(confirmed, message)
+ )
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ self.syncer._get_diff.assert_called_once_with(guild)
+ self.syncer._get_confirmation_result.assert_called_once()
+
+ if confirmed:
+ self.syncer._sync.assert_called_once_with(diff)
+ else:
+ self.syncer._sync.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_diff_size(self):
+ """The diff size should be correctly calculated."""
+ subtests = (
+ (6, _Diff({1, 2}, {3, 4}, {5, 6})),
+ (5, _Diff({1, 2, 3}, None, {4, 5})),
+ (0, _Diff(None, None, None)),
+ (0, _Diff(set(), set(), set())),
+ )
+
+ for size, diff in subtests:
+ with self.subTest(size=size, diff=diff):
+ self.syncer._get_diff.reset_mock()
+ self.syncer._get_diff.return_value = diff
+ self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ self.syncer._get_diff.assert_called_once_with(guild)
+ self.syncer._get_confirmation_result.assert_called_once()
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size)
+
+ @helpers.async_test
+ async def test_sync_message_edited(self):
+ """The message should be edited if one was sent, even if the sync has an API error."""
+ subtests = (
+ (None, None, False),
+ (helpers.MockMessage(), None, True),
+ (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True),
+ )
+
+ for message, side_effect, should_edit in subtests:
+ with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):
+ self.syncer._sync.side_effect = side_effect
+ self.syncer._get_confirmation_result = helpers.AsyncMock(
+ return_value=(True, message)
+ )
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ if should_edit:
+ message.edit.assert_called_once()
+ self.assertIn("content", message.edit.call_args[1])
+
+ @helpers.async_test
+ async def test_sync_confirmation_context_redirect(self):
+ """If ctx is given, a new message should be sent and author should be ctx's author."""
+ mock_member = helpers.MockMember()
+ subtests = (
+ (None, self.bot.user, None),
+ (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()),
+ )
+
+ for ctx, author, message in subtests:
+ with self.subTest(ctx=ctx, author=author, message=message):
+ if ctx is not None:
+ ctx.send.return_value = message
+
+ self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild, ctx)
+
+ if ctx is not None:
+ ctx.send.assert_called_once()
+
+ self.syncer._get_confirmation_result.assert_called_once()
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author)
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)
+
+ @mock.patch.object(constants.Sync, "max_diff", new=3)
+ @helpers.async_test
+ async def test_confirmation_result_small_diff(self):
+ """Should always return True and the given message if the diff size is too small."""
+ author = helpers.MockMember()
+ expected_message = helpers.MockMessage()
+
+ for size in (3, 2):
+ with self.subTest(size=size):
+ self.syncer._send_prompt = helpers.AsyncMock()
+ self.syncer._wait_for_confirmation = helpers.AsyncMock()
+
+ coro = self.syncer._get_confirmation_result(size, author, expected_message)
+ result, actual_message = await coro
+
+ self.assertTrue(result)
+ self.assertEqual(actual_message, expected_message)
+ self.syncer._send_prompt.assert_not_called()
+ self.syncer._wait_for_confirmation.assert_not_called()
+
+ @mock.patch.object(constants.Sync, "max_diff", new=3)
+ @helpers.async_test
+ async def test_confirmation_result_large_diff(self):
+ """Should return True if confirmed and False if _send_prompt fails or aborted."""
+ author = helpers.MockMember()
+ mock_message = helpers.MockMessage()
+
+ subtests = (
+ (True, mock_message, True, "confirmed"),
+ (False, None, False, "_send_prompt failed"),
+ (False, mock_message, False, "aborted"),
+ )
+
+ for expected_result, expected_message, confirmed, msg in subtests:
+ with self.subTest(msg=msg):
+ self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message)
+ self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed)
+
+ coro = self.syncer._get_confirmation_result(4, author)
+ actual_result, actual_message = await coro
+
+ self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None
+ self.assertIs(actual_result, expected_result)
+ self.assertEqual(actual_message, expected_message)
+
+ if expected_message:
+ self.syncer._wait_for_confirmation.assert_called_once_with(
+ author, expected_message
+ )
diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py
new file mode 100644
index 000000000..98c9afc0d
--- /dev/null
+++ b/tests/bot/cogs/sync/test_cog.py
@@ -0,0 +1,395 @@
+import unittest
+from unittest import mock
+
+import discord
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.cogs import sync
+from bot.cogs.sync.syncers import Syncer
+from tests import helpers
+from tests.base import CommandTestCase
+
+
+class MockSyncer(helpers.CustomMockMixin, mock.MagicMock):
+ """
+ A MagicMock subclass to mock Syncer objects.
+
+ Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer`
+ instances. For more information, see the `MockGuild` docstring.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=Syncer, **kwargs)
+
+
+class SyncExtensionTests(unittest.TestCase):
+ """Tests for the sync extension."""
+
+ @staticmethod
+ def test_extension_setup():
+ """The Sync cog should be added."""
+ bot = helpers.MockBot()
+ sync.setup(bot)
+ bot.add_cog.assert_called_once()
+
+
+class SyncCogTestCase(unittest.TestCase):
+ """Base class for Sync cog tests. Sets up patches for syncers."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+
+ # These patch the type. When the type is called, a MockSyncer instanced is returned.
+ # MockSyncer is needed so that our custom AsyncMock is used.
+ # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed.
+ self.role_syncer_patcher = mock.patch(
+ "bot.cogs.sync.syncers.RoleSyncer",
+ new=mock.MagicMock(return_value=MockSyncer())
+ )
+ self.user_syncer_patcher = mock.patch(
+ "bot.cogs.sync.syncers.UserSyncer",
+ new=mock.MagicMock(return_value=MockSyncer())
+ )
+ self.RoleSyncer = self.role_syncer_patcher.start()
+ self.UserSyncer = self.user_syncer_patcher.start()
+
+ self.cog = sync.Sync(self.bot)
+
+ def tearDown(self):
+ self.role_syncer_patcher.stop()
+ self.user_syncer_patcher.stop()
+
+ @staticmethod
+ def response_error(status: int) -> ResponseCodeError:
+ """Fixture to return a ResponseCodeError with the given status code."""
+ response = mock.MagicMock()
+ response.status = status
+
+ return ResponseCodeError(response)
+
+
+class SyncCogTests(SyncCogTestCase):
+ """Tests for the Sync cog."""
+
+ @mock.patch.object(sync.Sync, "sync_guild")
+ def test_sync_cog_init(self, sync_guild):
+ """Should instantiate syncers and run a sync for the guild."""
+ # Reset because a Sync cog was already instantiated in setUp.
+ self.RoleSyncer.reset_mock()
+ self.UserSyncer.reset_mock()
+ self.bot.loop.create_task.reset_mock()
+
+ mock_sync_guild_coro = mock.MagicMock()
+ sync_guild.return_value = mock_sync_guild_coro
+
+ sync.Sync(self.bot)
+
+ self.RoleSyncer.assert_called_once_with(self.bot)
+ self.UserSyncer.assert_called_once_with(self.bot)
+ sync_guild.assert_called_once_with()
+ self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro)
+
+ @helpers.async_test
+ async def test_sync_cog_sync_guild(self):
+ """Roles and users should be synced only if a guild is successfully retrieved."""
+ for guild in (helpers.MockGuild(), None):
+ with self.subTest(guild=guild):
+ self.bot.reset_mock()
+ self.cog.role_syncer.reset_mock()
+ self.cog.user_syncer.reset_mock()
+
+ self.bot.get_guild = mock.MagicMock(return_value=guild)
+
+ await self.cog.sync_guild()
+
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.get_guild.assert_called_once_with(constants.Guild.id)
+
+ if guild is None:
+ self.cog.role_syncer.sync.assert_not_called()
+ self.cog.user_syncer.sync.assert_not_called()
+ else:
+ self.cog.role_syncer.sync.assert_called_once_with(guild)
+ self.cog.user_syncer.sync.assert_called_once_with(guild)
+
+ async def patch_user_helper(self, side_effect: BaseException) -> None:
+ """Helper to set a side effect for bot.api_client.patch and then assert it is called."""
+ self.bot.api_client.patch.reset_mock(side_effect=True)
+ self.bot.api_client.patch.side_effect = side_effect
+
+ user_id, updated_information = 5, {"key": 123}
+ await self.cog.patch_user(user_id, updated_information)
+
+ self.bot.api_client.patch.assert_called_once_with(
+ f"bot/users/{user_id}",
+ json=updated_information,
+ )
+
+ @helpers.async_test
+ async def test_sync_cog_patch_user(self):
+ """A PATCH request should be sent and 404 errors ignored."""
+ for side_effect in (None, self.response_error(404)):
+ with self.subTest(side_effect=side_effect):
+ await self.patch_user_helper(side_effect)
+
+ @helpers.async_test
+ async def test_sync_cog_patch_user_non_404(self):
+ """A PATCH request should be sent and the error raised if it's not a 404."""
+ with self.assertRaises(ResponseCodeError):
+ await self.patch_user_helper(self.response_error(500))
+
+
+class SyncCogListenerTests(SyncCogTestCase):
+ """Tests for the listeners of the Sync cog."""
+
+ def setUp(self):
+ super().setUp()
+ self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user)
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_create(self):
+ """A POST request should be sent with the new role's data."""
+ self.assertTrue(self.cog.on_guild_role_create.__cog_listener__)
+
+ role_data = {
+ "colour": 49,
+ "id": 777,
+ "name": "rolename",
+ "permissions": 8,
+ "position": 23,
+ }
+ role = helpers.MockRole(**role_data)
+ await self.cog.on_guild_role_create(role)
+
+ self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data)
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_delete(self):
+ """A DELETE request should be sent."""
+ self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__)
+
+ role = helpers.MockRole(id=99)
+ await self.cog.on_guild_role_delete(role)
+
+ self.bot.api_client.delete.assert_called_once_with("bot/roles/99")
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_update(self):
+ """A PUT request should be sent if the colour, name, permissions, or position changes."""
+ self.assertTrue(self.cog.on_guild_role_update.__cog_listener__)
+
+ role_data = {
+ "colour": 49,
+ "id": 777,
+ "name": "rolename",
+ "permissions": 8,
+ "position": 23,
+ }
+ subtests = (
+ (True, ("colour", "name", "permissions", "position")),
+ (False, ("hoist", "mentionable")),
+ )
+
+ for should_put, attributes in subtests:
+ for attribute in attributes:
+ with self.subTest(should_put=should_put, changed_attribute=attribute):
+ self.bot.api_client.put.reset_mock()
+
+ after_role_data = role_data.copy()
+ after_role_data[attribute] = 876
+
+ before_role = helpers.MockRole(**role_data)
+ after_role = helpers.MockRole(**after_role_data)
+
+ await self.cog.on_guild_role_update(before_role, after_role)
+
+ if should_put:
+ self.bot.api_client.put.assert_called_once_with(
+ f"bot/roles/{after_role.id}",
+ json=after_role_data
+ )
+ else:
+ self.bot.api_client.put.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_remove(self):
+ """Member should patched to set in_guild as False."""
+ self.assertTrue(self.cog.on_member_remove.__cog_listener__)
+
+ member = helpers.MockMember()
+ await self.cog.on_member_remove(member)
+
+ self.cog.patch_user.assert_called_once_with(
+ member.id,
+ updated_information={"in_guild": False}
+ )
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_update_roles(self):
+ """Members should be patched if their roles have changed."""
+ self.assertTrue(self.cog.on_member_update.__cog_listener__)
+
+ # Roles are intentionally unsorted.
+ before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)]
+ before_member = helpers.MockMember(roles=before_roles)
+ after_member = helpers.MockMember(roles=before_roles[1:])
+
+ await self.cog.on_member_update(before_member, after_member)
+
+ data = {"roles": sorted(role.id for role in after_member.roles)}
+ self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data)
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_update_other(self):
+ """Members should not be patched if other attributes have changed."""
+ self.assertTrue(self.cog.on_member_update.__cog_listener__)
+
+ subtests = (
+ ("activities", discord.Game("Pong"), discord.Game("Frogger")),
+ ("nick", "old nick", "new nick"),
+ ("status", discord.Status.online, discord.Status.offline),
+ )
+
+ for attribute, old_value, new_value in subtests:
+ with self.subTest(attribute=attribute):
+ self.cog.patch_user.reset_mock()
+
+ before_member = helpers.MockMember(**{attribute: old_value})
+ after_member = helpers.MockMember(**{attribute: new_value})
+
+ await self.cog.on_member_update(before_member, after_member)
+
+ self.cog.patch_user.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_user_update(self):
+ """A user should be patched only if the name, discriminator, or avatar changes."""
+ self.assertTrue(self.cog.on_user_update.__cog_listener__)
+
+ before_data = {
+ "name": "old name",
+ "discriminator": "1234",
+ "avatar": "old avatar",
+ "bot": False,
+ }
+
+ subtests = (
+ (True, "name", "name", "new name", "new name"),
+ (True, "discriminator", "discriminator", "8765", 8765),
+ (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),
+ (False, "bot", "bot", True, True),
+ )
+
+ for should_patch, attribute, api_field, value, api_value in subtests:
+ with self.subTest(attribute=attribute):
+ self.cog.patch_user.reset_mock()
+
+ after_data = before_data.copy()
+ after_data[attribute] = value
+ before_user = helpers.MockUser(**before_data)
+ after_user = helpers.MockUser(**after_data)
+
+ await self.cog.on_user_update(before_user, after_user)
+
+ if should_patch:
+ self.cog.patch_user.assert_called_once()
+
+ # Don't care if *all* keys are present; only the changed one is required
+ call_args = self.cog.patch_user.call_args
+ self.assertEqual(call_args[0][0], after_user.id)
+ self.assertIn("updated_information", call_args[1])
+
+ updated_information = call_args[1]["updated_information"]
+ self.assertIn(api_field, updated_information)
+ self.assertEqual(updated_information[api_field], api_value)
+ else:
+ self.cog.patch_user.assert_not_called()
+
+ async def on_member_join_helper(self, side_effect: Exception) -> dict:
+ """
+ Helper to set `side_effect` for on_member_join and assert a PUT request was sent.
+
+ The request data for the mock member is returned. All exceptions will be re-raised.
+ """
+ member = helpers.MockMember(
+ discriminator="1234",
+ roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)],
+ )
+
+ data = {
+ "avatar_hash": member.avatar,
+ "discriminator": int(member.discriminator),
+ "id": member.id,
+ "in_guild": True,
+ "name": member.name,
+ "roles": sorted(role.id for role in member.roles)
+ }
+
+ self.bot.api_client.put.reset_mock(side_effect=True)
+ self.bot.api_client.put.side_effect = side_effect
+
+ try:
+ await self.cog.on_member_join(member)
+ except Exception:
+ raise
+ finally:
+ self.bot.api_client.put.assert_called_once_with(
+ f"bot/users/{member.id}",
+ json=data
+ )
+
+ return data
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_join(self):
+ """Should PUT user's data or POST it if the user doesn't exist."""
+ for side_effect in (None, self.response_error(404)):
+ with self.subTest(side_effect=side_effect):
+ self.bot.api_client.post.reset_mock()
+ data = await self.on_member_join_helper(side_effect)
+
+ if side_effect:
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=data)
+ else:
+ self.bot.api_client.post.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_join_non_404(self):
+ """ResponseCodeError should be re-raised if status code isn't a 404."""
+ with self.assertRaises(ResponseCodeError):
+ await self.on_member_join_helper(self.response_error(500))
+
+ self.bot.api_client.post.assert_not_called()
+
+
+class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
+ """Tests for the commands in the Sync cog."""
+
+ @helpers.async_test
+ async def test_sync_roles_command(self):
+ """sync() should be called on the RoleSyncer."""
+ ctx = helpers.MockContext()
+ await self.cog.sync_roles_command.callback(self.cog, ctx)
+
+ self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+
+ @helpers.async_test
+ async def test_sync_users_command(self):
+ """sync() should be called on the UserSyncer."""
+ ctx = helpers.MockContext()
+ await self.cog.sync_users_command.callback(self.cog, ctx)
+
+ self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+
+ def test_commands_require_admin(self):
+ """The sync commands should only run if the author has the administrator permission."""
+ cmds = (
+ self.cog.sync_group,
+ self.cog.sync_roles_command,
+ self.cog.sync_users_command,
+ )
+
+ for cmd in cmds:
+ with self.subTest(cmd=cmd):
+ self.assertHasPermissionsCheck(cmd, {"administrator": True})
diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py
index 27ae27639..14fb2577a 100644
--- a/tests/bot/cogs/sync/test_roles.py
+++ b/tests/bot/cogs/sync/test_roles.py
@@ -1,126 +1,165 @@
import unittest
+from unittest import mock
-from bot.cogs.sync.syncers import Role, get_roles_for_sync
-
-
-class GetRolesForSyncTests(unittest.TestCase):
- """Tests constructing the roles to synchronize with the site."""
-
- def test_get_roles_for_sync_empty_return_for_equal_roles(self):
- """No roles should be synced when no diff is found."""
- api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (set(), set(), set())
- )
-
- def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self):
- """Roles to be synced are returned when non-ID attributes differ."""
- api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)}
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (set(), guild_roles, set())
- )
-
- def test_get_roles_only_returns_roles_that_require_update(self):
- """Roles that require an update should be returned as the second tuple element."""
- api_roles = {
- Role(id=41, name='old name', colour=33, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
- guild_roles = {
- Role(id=41, name='new name', colour=35, permissions=0x8, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- set(),
- {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)},
- set(),
- )
- )
-
- def test_get_roles_returns_new_roles_in_first_tuple_element(self):
- """Newly created roles are returned as the first tuple element."""
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=2)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=53, name='other role', colour=55, permissions=0, position=2)},
- set(),
- set(),
- )
- )
-
- def test_get_roles_returns_roles_to_update_and_new_roles(self):
- """Newly created and updated roles should be returned together."""
- api_roles = {
- Role(id=41, name='old name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='new name', colour=40, permissions=0x16, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=53, name='other role', colour=55, permissions=0, position=3)},
- {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)},
- set(),
- )
- )
-
- def test_get_roles_returns_roles_to_delete(self):
- """Roles to be deleted should be returned as the third tuple element."""
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- set(),
- set(),
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
- )
-
- def test_get_roles_returns_roles_to_delete_update_and_new_roles(self):
- """When roles were added, updated, and removed, all of them are returned properly."""
- api_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- Role(id=71, name='to update', colour=99, permissions=0x9, position=3),
- }
- guild_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=81, name='to create', colour=99, permissions=0x9, position=4),
- Role(id=71, name='updated', colour=101, permissions=0x5, position=3),
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)},
- {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)},
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
- )
+import discord
+
+from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role
+from tests import helpers
+
+
+def fake_role(**kwargs):
+ """Fixture to return a dictionary representing a role with default values set."""
+ kwargs.setdefault("id", 9)
+ kwargs.setdefault("name", "fake role")
+ kwargs.setdefault("colour", 7)
+ kwargs.setdefault("permissions", 0)
+ kwargs.setdefault("position", 55)
+
+ return kwargs
+
+
+class RoleSyncerDiffTests(unittest.TestCase):
+ """Tests for determining differences between roles in the DB and roles in the Guild cache."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = RoleSyncer(self.bot)
+
+ @staticmethod
+ def get_guild(*roles):
+ """Fixture to return a guild object with the given roles."""
+ guild = helpers.MockGuild()
+ guild.roles = []
+
+ for role in roles:
+ mock_role = helpers.MockRole(**role)
+ mock_role.colour = discord.Colour(role["colour"])
+ mock_role.permissions = discord.Permissions(role["permissions"])
+ guild.roles.append(mock_role)
+
+ return guild
+
+ @helpers.async_test
+ async def test_empty_diff_for_identical_roles(self):
+ """No differences should be found if the roles in the guild and DB are identical."""
+ self.bot.api_client.get.return_value = [fake_role()]
+ guild = self.get_guild(fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_updated_roles(self):
+ """Only updated roles should be added to the 'updated' set of the diff."""
+ updated_role = fake_role(id=41, name="new")
+
+ self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()]
+ guild = self.get_guild(updated_role, fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_Role(**updated_role)}, set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_roles(self):
+ """Only new roles should be added to the 'created' set of the diff."""
+ new_role = fake_role(id=41, name="new")
+
+ self.bot.api_client.get.return_value = [fake_role()]
+ guild = self.get_guild(fake_role(), new_role)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_Role(**new_role)}, set(), set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_deleted_roles(self):
+ """Only deleted roles should be added to the 'deleted' set of the diff."""
+ deleted_role = fake_role(id=61, name="deleted")
+
+ self.bot.api_client.get.return_value = [fake_role(), deleted_role]
+ guild = self.get_guild(fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), {_Role(**deleted_role)})
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_updated_and_deleted_roles(self):
+ """When roles are added, updated, and removed, all of them are returned properly."""
+ new = fake_role(id=41, name="new")
+ updated = fake_role(id=71, name="updated")
+ deleted = fake_role(id=61, name="deleted")
+
+ self.bot.api_client.get.return_value = [
+ fake_role(),
+ fake_role(id=71, name="updated name"),
+ deleted,
+ ]
+ guild = self.get_guild(fake_role(), new, updated)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)})
+
+ self.assertEqual(actual_diff, expected_diff)
+
+
+class RoleSyncerSyncTests(unittest.TestCase):
+ """Tests for the API requests that sync roles."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = RoleSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_created_roles(self):
+ """Only POST requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(role_tuples, set(), set())
+ await self.syncer._sync(diff)
+
+ calls = [mock.call("bot/roles", json=role) for role in roles]
+ self.bot.api_client.post.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.post.call_count, len(roles))
+
+ self.bot.api_client.put.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_updated_roles(self):
+ """Only PUT requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(set(), role_tuples, set())
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles]
+ self.bot.api_client.put.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.put.call_count, len(roles))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_deleted_roles(self):
+ """Only DELETE requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(set(), set(), role_tuples)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/roles/{role['id']}") for role in roles]
+ self.bot.api_client.delete.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.delete.call_count, len(roles))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.put.assert_not_called()
diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py
index ccaf67490..421bf6bb6 100644
--- a/tests/bot/cogs/sync/test_users.py
+++ b/tests/bot/cogs/sync/test_users.py
@@ -1,84 +1,169 @@
import unittest
+from unittest import mock
-from bot.cogs.sync.syncers import User, get_users_for_sync
+from bot.cogs.sync.syncers import UserSyncer, _Diff, _User
+from tests import helpers
def fake_user(**kwargs):
- kwargs.setdefault('id', 43)
- kwargs.setdefault('name', 'bob the test man')
- kwargs.setdefault('discriminator', 1337)
- kwargs.setdefault('avatar_hash', None)
- kwargs.setdefault('roles', (666,))
- kwargs.setdefault('in_guild', True)
- return User(**kwargs)
-
-
-class GetUsersForSyncTests(unittest.TestCase):
- """Tests constructing the users to synchronize with the site."""
-
- def test_get_users_for_sync_returns_nothing_for_empty_params(self):
- """When no users are given, none are returned."""
- self.assertEqual(
- get_users_for_sync({}, {}),
- (set(), set())
- )
-
- def test_get_users_for_sync_returns_nothing_for_equal_users(self):
- """When no users are updated, none are returned."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user()}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), set())
- )
-
- def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self):
- """When a non-ID-field differs, the user to update is returned."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(name='new fancy name')}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), {fake_user(name='new fancy name')})
- )
-
- def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self):
- """When new users join the guild, they are returned as the first tuple element."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(), 63: fake_user(id=63)}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- ({fake_user(id=63)}, set())
- )
-
- def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self):
+ """Fixture to return a dictionary representing a user with default values set."""
+ kwargs.setdefault("id", 43)
+ kwargs.setdefault("name", "bob the test man")
+ kwargs.setdefault("discriminator", 1337)
+ kwargs.setdefault("avatar_hash", None)
+ kwargs.setdefault("roles", (666,))
+ kwargs.setdefault("in_guild", True)
+
+ return kwargs
+
+
+class UserSyncerDiffTests(unittest.TestCase):
+ """Tests for determining differences between users in the DB and users in the Guild cache."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = UserSyncer(self.bot)
+
+ @staticmethod
+ def get_guild(*members):
+ """Fixture to return a guild object with the given members."""
+ guild = helpers.MockGuild()
+ guild.members = []
+
+ for member in members:
+ member = member.copy()
+ member["avatar"] = member.pop("avatar_hash")
+ del member["in_guild"]
+
+ mock_member = helpers.MockMember(**member)
+ mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]]
+
+ guild.members.append(mock_member)
+
+ return guild
+
+ @helpers.async_test
+ async def test_empty_diff_for_no_users(self):
+ """When no users are given, an empty diff should be returned."""
+ guild = self.get_guild()
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_empty_diff_for_identical_users(self):
+ """No differences should be found if the users in the guild and DB are identical."""
+ self.bot.api_client.get.return_value = [fake_user()]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_updated_users(self):
+ """Only updated users should be added to the 'updated' set of the diff."""
+ updated_user = fake_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()]
+ guild = self.get_guild(updated_user, fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_User(**updated_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_users(self):
+ """Only new users should be added to the 'created' set of the diff."""
+ new_user = fake_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = [fake_user()]
+ guild = self.get_guild(fake_user(), new_user)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_User(**new_user)}, set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- api_users = {43: fake_user(), 63: fake_user(id=63)}
- guild_users = {43: fake_user()}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), {fake_user(id=63, in_guild=False)})
- )
-
- def test_get_users_for_sync_updates_and_creates_users_as_needed(self):
- """When one user left and another one was updated, both are returned."""
- api_users = {43: fake_user()}
- guild_users = {63: fake_user(id=63)}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- ({fake_user(id=63)}, {fake_user(in_guild=False)})
- )
-
- def test_get_users_for_sync_does_not_duplicate_update_users(self):
- """When the API knows a user the guild doesn't, nothing is performed."""
- api_users = {43: fake_user(in_guild=False)}
- guild_users = {}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), set())
- )
+ leaving_user = fake_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_User(**leaving_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_updated_and_leaving_users(self):
+ """When users are added, updated, and removed, all of them are returned properly."""
+ new_user = fake_user(id=99, name="new")
+ updated_user = fake_user(id=55, name="updated")
+ leaving_user = fake_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)]
+ guild = self.get_guild(fake_user(), new_user, updated_user)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_empty_diff_for_db_users_not_in_guild(self):
+ """When the DB knows a user the guild doesn't, no difference is found."""
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+
+class UserSyncerSyncTests(unittest.TestCase):
+ """Tests for the API requests that sync users."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = UserSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_created_users(self):
+ """Only POST requests should be made with the correct payload."""
+ users = [fake_user(id=111), fake_user(id=222)]
+
+ user_tuples = {_User(**user) for user in users}
+ diff = _Diff(user_tuples, set(), None)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call("bot/users", json=user) for user in users]
+ self.bot.api_client.post.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.post.call_count, len(users))
+
+ self.bot.api_client.put.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_updated_users(self):
+ """Only PUT requests should be made with the correct payload."""
+ users = [fake_user(id=111), fake_user(id=222)]
+
+ user_tuples = {_User(**user) for user in users}
+ diff = _Diff(set(), user_tuples, None)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users]
+ self.bot.api_client.put.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.put.call_count, len(users))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index 7370b8471..7e6bfc748 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):
asyncio.run(self.cog.fetch_webhook())
- self.bot.wait_until_ready.assert_called_once()
+ self.bot.wait_until_guild_available.assert_called_once()
self.bot.fetch_webhook.assert_called_once_with(1)
self.assertEqual(self.cog.webhook, "dummy webhook")
@@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):
with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
asyncio.run(self.cog.fetch_webhook())
- self.bot.wait_until_ready.assert_called_once()
+ self.bot.wait_until_guild_available.assert_called_once()
self.bot.fetch_webhook.assert_called_once_with(1)
self.assertEqual(len(log_watcher.records), 1)
diff --git a/tests/helpers.py b/tests/helpers.py
index 506fe9894..7ae7ed621 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -9,6 +9,7 @@ from typing import Iterable, Optional
import discord
from discord.ext.commands import Context
+from bot.api import APIClient
from bot.bot import Bot
@@ -182,9 +183,21 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
spec_set = role_instance
def __init__(self, **kwargs) -> None:
- default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1}
+ default_kwargs = {
+ 'id': next(self.discord_id),
+ 'name': 'role',
+ 'position': 1,
+ 'colour': discord.Colour(0xdeadbf),
+ 'permissions': discord.Permissions(),
+ }
super().__init__(**collections.ChainMap(kwargs, default_kwargs))
+ if isinstance(self.colour, int):
+ self.colour = discord.Colour(self.colour)
+
+ if isinstance(self.permissions, int):
+ self.permissions = discord.Permissions(self.permissions)
+
if 'mention' not in kwargs:
self.mention = f'&{self.name}'
@@ -241,6 +254,18 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
self.mention = f"@{self.name}"
+class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock APIClient objects.
+
+ Instances of this class will follow the specifications of `bot.api.APIClient` instances.
+ For more information, see the `MockGuild` docstring.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=APIClient, **kwargs)
+
+
# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
bot_instance.http_session = None
@@ -259,11 +284,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
-
- # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
- # and should therefore be awaited. (The documentation calls it a coroutine as well, which
- # is technically incorrect, since it's a regular def.)
- # self.wait_for = unittest.mock.AsyncMock()
+ self.api_client = MockAPIClient()
# Since calling `create_task` on our MockBot does not actually schedule the coroutine object
# as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
@@ -429,6 +450,8 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
user_iterator.__aiter__.return_value = _users
self.users.return_value = user_iterator
+ self.__str__.return_value = str(self.emoji)
+
webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock())