diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/api.py | 45 | ||||
| -rw-r--r-- | bot/bot.py | 54 | ||||
| -rw-r--r-- | bot/cogs/antispam.py | 2 | ||||
| -rw-r--r-- | bot/cogs/defcon.py | 2 | ||||
| -rw-r--r-- | bot/cogs/doc.py | 2 | ||||
| -rw-r--r-- | bot/cogs/duck_pond.py | 2 | ||||
| -rw-r--r-- | bot/cogs/logging.py | 2 | ||||
| -rw-r--r-- | bot/cogs/moderation/scheduler.py | 2 | ||||
| -rw-r--r-- | bot/cogs/moderation/superstarify.py | 3 | ||||
| -rw-r--r-- | bot/cogs/off_topic_names.py | 2 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 4 | ||||
| -rw-r--r-- | bot/cogs/reminders.py | 54 | ||||
| -rw-r--r-- | bot/cogs/sync/cog.py | 100 | ||||
| -rw-r--r-- | bot/cogs/sync/syncers.py | 551 | ||||
| -rw-r--r-- | bot/cogs/verification.py | 2 | ||||
| -rw-r--r-- | bot/cogs/watchchannels/watchchannel.py | 2 | ||||
| -rw-r--r-- | bot/constants.py | 13 | ||||
| -rw-r--r-- | config-default.yml | 11 | ||||
| -rw-r--r-- | tests/base.py | 34 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_base.py | 412 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_cog.py | 395 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_roles.py | 287 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_users.py | 241 | ||||
| -rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 4 | ||||
| -rw-r--r-- | tests/helpers.py | 29 | 
25 files changed, 1704 insertions, 551 deletions
| diff --git a/bot/api.py b/bot/api.py index fb126b384..e59916114 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,6 +32,11 @@ class ResponseCodeError(ValueError):  class APIClient:      """Django Site API wrapper.""" +    # These are class attributes so they can be seen when being mocked for tests. +    # See commit 22a55534ef13990815a6f69d361e2a12693075d5 for details. +    session: Optional[aiohttp.ClientSession] = None +    loop: asyncio.AbstractEventLoop = None +      def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs):          auth_headers = {              'Authorization': f"Token {Keys.site_api}" @@ -42,7 +47,7 @@ class APIClient:          else:              kwargs['headers'] = auth_headers -        self.session: Optional[aiohttp.ClientSession] = None +        self.session = None          self.loop = loop          self._ready = asyncio.Event(loop=loop) @@ -85,43 +90,35 @@ class APIClient:                  response_text = await response.text()                  raise ResponseCodeError(response=response, response_text=response_text) -    async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: -        """Site API GET.""" +    async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: +        """Send an HTTP request to the site API and return the JSON response."""          await self._ready.wait() -        async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: +        async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json() -    async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: -        """Site API PATCH.""" -        await self._ready.wait() +    async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: +        """Site API GET.""" +        return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs) -        async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp: -            await self.maybe_raise_for_status(resp, raise_for_status) -            return await resp.json() +    async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: +        """Site API PATCH.""" +        return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs) -    async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: +    async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:          """Site API POST.""" -        await self._ready.wait() - -        async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp: -            await self.maybe_raise_for_status(resp, raise_for_status) -            return await resp.json() +        return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs) -    async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: +    async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:          """Site API PUT.""" -        await self._ready.wait() - -        async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp: -            await self.maybe_raise_for_status(resp, raise_for_status) -            return await resp.json() +        return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs) -    async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]: +    async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]:          """Site API DELETE."""          await self._ready.wait() -        async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp: +        async with self.session.delete(self._url_for(endpoint), **kwargs) as resp:              if resp.status == 204:                  return None diff --git a/bot/bot.py b/bot/bot.py index cecee7b68..19b9035c4 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,11 +1,14 @@ +import asyncio  import logging  import socket  from typing import Optional  import aiohttp +import discord  from discord.ext import commands  from bot import api +from bot import constants  log = logging.getLogger('bot') @@ -17,15 +20,17 @@ class Bot(commands.Bot):          # Use asyncio for DNS resolution instead of threads so threads aren't spammed.          # Use AF_INET as its socket family to prevent HTTPS related problems both locally          # and in production. -        self.connector = aiohttp.TCPConnector( +        self._connector = aiohttp.TCPConnector(              resolver=aiohttp.AsyncResolver(),              family=socket.AF_INET,          ) -        super().__init__(*args, connector=self.connector, **kwargs) +        super().__init__(*args, connector=self._connector, **kwargs) + +        self._guild_available = asyncio.Event()          self.http_session: Optional[aiohttp.ClientSession] = None -        self.api_client = api.APIClient(loop=self.loop, connector=self.connector) +        self.api_client = api.APIClient(loop=self.loop, connector=self._connector)      def add_cog(self, cog: commands.Cog) -> None:          """Adds a "cog" to the bot and logs the operation.""" @@ -46,6 +51,47 @@ class Bot(commands.Bot):      async def start(self, *args, **kwargs) -> None:          """Open an aiohttp session before logging in and connecting to Discord.""" -        self.http_session = aiohttp.ClientSession(connector=self.connector) +        self.http_session = aiohttp.ClientSession(connector=self._connector)          await super().start(*args, **kwargs) + +    async def on_guild_available(self, guild: discord.Guild) -> None: +        """ +        Set the internal guild available event when constants.Guild.id becomes available. + +        If the cache appears to still be empty (no members, no channels, or no roles), the event +        will not be set. +        """ +        if guild.id != constants.Guild.id: +            return + +        if not guild.roles or not guild.members or not guild.channels: +            msg = "Guild available event was dispatched but the cache appears to still be empty!" +            log.warning(msg) + +            try: +                webhook = await self.fetch_webhook(constants.Webhooks.dev_log) +            except discord.HTTPException as e: +                log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}") +            else: +                await webhook.send(f"<@&{constants.Roles.admin}> {msg}") + +            return + +        self._guild_available.set() + +    async def on_guild_unavailable(self, guild: discord.Guild) -> None: +        """Clear the internal guild available event when constants.Guild.id becomes unavailable.""" +        if guild.id != constants.Guild.id: +            return + +        self._guild_available.clear() + +    async def wait_until_guild_available(self) -> None: +        """ +        Wait until the constants.Guild.id guild is available (and the cache is ready). + +        The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE +        gateway event before giving up and thus not populating the cache for unavailable guilds. +        """ +        await self._guild_available.wait() diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index f67ef6f05..baa6b9459 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -123,7 +123,7 @@ class AntiSpam(Cog):      async def alert_on_validation_error(self) -> None:          """Unloads the cog and alerts admins if configuration validation failed.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          if self.validation_errors:              body = "**The following errors were encountered:**\n"              body += "\n".join(f"- {error}" for error in self.validation_errors.values()) diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index a0d8fedd5..20961e0a2 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -59,7 +59,7 @@ class Defcon(Cog):      async def sync_settings(self) -> None:          """On cog load, try to synchronize DEFCON settings to the API.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          self.channel = await self.bot.fetch_channel(Channels.defcon)          try: diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 6e7c00b6a..204cffb37 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -157,7 +157,7 @@ class Doc(commands.Cog):      async def init_refresh_inventory(self) -> None:          """Refresh documentation inventory on cog initialization.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          await self.refresh_inventory()      async def update_single( diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 345d2856c..1f84a0609 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -22,7 +22,7 @@ class DuckPond(Cog):      async def fetch_webhook(self) -> None:          """Fetches the webhook object, so we can post to it.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          try:              self.webhook = await self.bot.fetch_webhook(self.webhook_id) diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index d1b7dcab3..dbd76672f 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -20,7 +20,7 @@ class Logging(Cog):      async def startup_greeting(self) -> None:          """Announce our presence to the configured devlog channel.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          log.info("Bot connected!")          embed = Embed(description="Connected!") diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index c0de0e4da..3c5185468 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -38,7 +38,7 @@ class InfractionScheduler(Scheduler):      async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None:          """Schedule expiration for previous infractions.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 050c847ac..c41874a95 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -109,7 +109,8 @@ class Superstarify(InfractionScheduler, Cog):          ctx: Context,          member: Member,          duration: Expiry, -        reason: str = None +        *, +        reason: str = None,      ) -> None:          """          Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index bf777ea5a..81511f99d 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -88,7 +88,7 @@ class OffTopicNames(Cog):      async def init_offtopic_updater(self) -> None:          """Start off-topic channel updating event loop if it hasn't already started.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          if self.updater_task is None:              coro = update_names(self.bot)              self.updater_task = self.bot.loop.create_task(coro) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index aa487f18e..4f6584aba 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -48,7 +48,7 @@ class Reddit(Cog):      async def init_reddit_ready(self) -> None:          """Sets the reddit webhook when the cog is loaded.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          if not self.webhook:              self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) @@ -208,7 +208,7 @@ class Reddit(Cog):          await asyncio.sleep(seconds_until) -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          if not self.webhook:              await self.bot.fetch_webhook(Webhooks.reddit) diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index f3e516158..041791056 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -7,11 +7,12 @@ from datetime import datetime, timedelta  from operator import itemgetter  import discord +from dateutil.parser import isoparse  from dateutil.relativedelta import relativedelta  from discord.ext.commands import Cog, Context, group  from bot.bot import Bot -from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES +from bot.constants import Guild, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES  from bot.converters import Duration  from bot.pagination import LinePaginator  from bot.utils.checks import without_role_check @@ -20,7 +21,7 @@ from bot.utils.time import humanize_delta, wait_until  log = logging.getLogger(__name__) -WHITELISTED_CHANNELS = (Channels.bot,) +WHITELISTED_CHANNELS = Guild.reminder_whitelist  MAXIMUM_REMINDERS = 5 @@ -35,7 +36,7 @@ class Reminders(Scheduler, Cog):      async def reschedule_reminders(self) -> None:          """Get all current reminders from the API and reschedule them.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          response = await self.bot.api_client.get(              'bot/reminders',              params={'active': 'true'} @@ -49,13 +50,12 @@ class Reminders(Scheduler, Cog):              if not is_valid:                  continue -            remind_at = datetime.fromisoformat(reminder['expiration'][:-1]) +            remind_at = isoparse(reminder['expiration']).replace(tzinfo=None)              # If the reminder is already overdue ...              if remind_at < now:                  late = relativedelta(now, remind_at)                  await self.send_reminder(reminder, late) -              else:                  self.schedule_task(loop, reminder["id"], reminder) @@ -79,18 +79,31 @@ class Reminders(Scheduler, Cog):          return is_valid, user, channel      @staticmethod -    async def _send_confirmation(ctx: Context, on_success: str) -> None: +    async def _send_confirmation( +        ctx: Context, +        on_success: str, +        reminder_id: str, +        delivery_dt: t.Optional[datetime], +    ) -> None:          """Send an embed confirming the reminder change was made successfully."""          embed = discord.Embed()          embed.colour = discord.Colour.green()          embed.title = random.choice(POSITIVE_REPLIES)          embed.description = on_success + +        footer_str = f"ID: {reminder_id}" +        if delivery_dt: +            # Reminder deletion will have a `None` `delivery_dt` +            footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" + +        embed.set_footer(text=footer_str) +          await ctx.send(embed=embed)      async def _scheduled_task(self, reminder: dict) -> None:          """A coroutine which sends the reminder once the time is reached, and cancels the running task."""          reminder_id = reminder["id"] -        reminder_datetime = datetime.fromisoformat(reminder['expiration'][:-1]) +        reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None)          # Send the reminder message once the desired duration has passed          await wait_until(reminder_datetime) @@ -203,11 +216,14 @@ class Reminders(Scheduler, Cog):          )          now = datetime.utcnow() - timedelta(seconds=1) +        humanized_delta = humanize_delta(relativedelta(expiration, now))          # Confirm to the user that it worked.          await self._send_confirmation(              ctx, -            on_success=f"Your reminder will arrive in {humanize_delta(relativedelta(expiration, now))}!" +            on_success=f"Your reminder will arrive in {humanized_delta}!", +            reminder_id=reminder["id"], +            delivery_dt=expiration,          )          loop = asyncio.get_event_loop() @@ -237,7 +253,7 @@ class Reminders(Scheduler, Cog):          for content, remind_at, id_ in reminders:              # Parse and humanize the time, make it pretty :D -            remind_datetime = datetime.fromisoformat(remind_at[:-1]) +            remind_datetime = isoparse(remind_at).replace(tzinfo=None)              time = humanize_delta(relativedelta(remind_datetime, now))              text = textwrap.dedent(f""" @@ -286,7 +302,10 @@ class Reminders(Scheduler, Cog):          # Send a confirmation message to the channel          await self._send_confirmation( -            ctx, on_success="That reminder has been edited successfully!" +            ctx, +            on_success="That reminder has been edited successfully!", +            reminder_id=id_, +            delivery_dt=expiration,          )          await self._reschedule_reminder(reminder) @@ -300,18 +319,27 @@ class Reminders(Scheduler, Cog):              json={'content': content}          ) +        # Parse the reminder expiration back into a datetime for the confirmation message +        expiration = isoparse(reminder['expiration']).replace(tzinfo=None) +          # Send a confirmation message to the channel          await self._send_confirmation( -            ctx, on_success="That reminder has been edited successfully!" +            ctx, +            on_success="That reminder has been edited successfully!", +            reminder_id=id_, +            delivery_dt=expiration,          )          await self._reschedule_reminder(reminder) -    @remind_group.command("delete", aliases=("remove",)) +    @remind_group.command("delete", aliases=("remove", "cancel"))      async def delete_reminder(self, ctx: Context, id_: int) -> None:          """Delete one of your active reminders."""          await self._delete_reminder(id_)          await self._send_confirmation( -            ctx, on_success="That reminder has been deleted successfully!" +            ctx, +            on_success="That reminder has been deleted successfully!", +            reminder_id=id_, +            delivery_dt=None,          ) diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index 4e6ed156b..5708be3f4 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -1,7 +1,7 @@  import logging -from typing import Callable, Dict, Iterable, Union +from typing import Any, Dict -from discord import Guild, Member, Role, User +from discord import Member, Role, User  from discord.ext import commands  from discord.ext.commands import Cog, Context @@ -16,45 +16,28 @@ log = logging.getLogger(__name__)  class Sync(Cog):      """Captures relevant events and sends them to the site.""" -    # The server to synchronize events on. -    # Note that setting this wrongly will result in things getting deleted -    # that possibly shouldn't be. -    SYNC_SERVER_ID = constants.Guild.id - -    # An iterable of callables that are called when the bot is ready. -    ON_READY_SYNCERS: Iterable[Callable[[Bot, Guild], None]] = ( -        syncers.sync_roles, -        syncers.sync_users -    ) -      def __init__(self, bot: Bot) -> None:          self.bot = bot +        self.role_syncer = syncers.RoleSyncer(self.bot) +        self.user_syncer = syncers.UserSyncer(self.bot)          self.bot.loop.create_task(self.sync_guild())      async def sync_guild(self) -> None:          """Syncs the roles/users of the guild with the database.""" -        await self.bot.wait_until_ready() -        guild = self.bot.get_guild(self.SYNC_SERVER_ID) -        if guild is not None: -            for syncer in self.ON_READY_SYNCERS: -                syncer_name = syncer.__name__[5:]  # drop off `sync_` -                log.info("Starting `%s` syncer.", syncer_name) -                total_created, total_updated, total_deleted = await syncer(self.bot, guild) -                if total_deleted is None: -                    log.info( -                        f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`." -                    ) -                else: -                    log.info( -                        f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`, " -                        f"deleted `{total_deleted}`." -                    ) - -    async def patch_user(self, user_id: int, updated_information: Dict[str, Union[str, int]]) -> None: +        await self.bot.wait_until_guild_available() + +        guild = self.bot.get_guild(constants.Guild.id) +        if guild is None: +            return + +        for syncer in (self.role_syncer, self.user_syncer): +            await syncer.sync(guild) + +    async def patch_user(self, user_id: int, updated_information: Dict[str, Any]) -> None:          """Send a PATCH request to partially update a user in the database."""          try: -            await self.bot.api_client.patch("bot/users/" + str(user_id), json=updated_information) +            await self.bot.api_client.patch(f"bot/users/{user_id}", json=updated_information)          except ResponseCodeError as e:              if e.response.status != 404:                  raise @@ -82,12 +65,14 @@ class Sync(Cog):      @Cog.listener()      async def on_guild_role_update(self, before: Role, after: Role) -> None:          """Syncs role with the database if any of the stored attributes were updated.""" -        if ( -                before.name != after.name -                or before.colour != after.colour -                or before.permissions != after.permissions -                or before.position != after.position -        ): +        was_updated = ( +            before.name != after.name +            or before.colour != after.colour +            or before.permissions != after.permissions +            or before.position != after.position +        ) + +        if was_updated:              await self.bot.api_client.put(                  f'bot/roles/{after.id}',                  json={ @@ -137,18 +122,8 @@ class Sync(Cog):      @Cog.listener()      async def on_member_remove(self, member: Member) -> None: -        """Updates the user information when a member leaves the guild.""" -        await self.bot.api_client.put( -            f'bot/users/{member.id}', -            json={ -                'avatar_hash': member.avatar, -                'discriminator': int(member.discriminator), -                'id': member.id, -                'in_guild': False, -                'name': member.name, -                'roles': sorted(role.id for role in member.roles) -            } -        ) +        """Set the in_guild field to False when a member leaves the guild.""" +        await self.patch_user(member.id, updated_information={"in_guild": False})      @Cog.listener()      async def on_member_update(self, before: Member, after: Member) -> None: @@ -160,7 +135,8 @@ class Sync(Cog):      @Cog.listener()      async def on_user_update(self, before: User, after: User) -> None:          """Update the user information in the database if a relevant change is detected.""" -        if any(getattr(before, attr) != getattr(after, attr) for attr in ("name", "discriminator", "avatar")): +        attrs = ("name", "discriminator", "avatar") +        if any(getattr(before, attr) != getattr(after, attr) for attr in attrs):              updated_information = {                  "name": after.name,                  "discriminator": int(after.discriminator), @@ -176,25 +152,11 @@ class Sync(Cog):      @sync_group.command(name='roles')      @commands.has_permissions(administrator=True)      async def sync_roles_command(self, ctx: Context) -> None: -        """Manually synchronize the guild's roles with the roles on the site.""" -        initial_response = await ctx.send("📊 Synchronizing roles.") -        total_created, total_updated, total_deleted = await syncers.sync_roles(self.bot, ctx.guild) -        await initial_response.edit( -            content=( -                f"👌 Role synchronization complete, created **{total_created}** " -                f", updated **{total_created}** roles, and deleted **{total_deleted}** roles." -            ) -        ) +        """Manually synchronise the guild's roles with the roles on the site.""" +        await self.role_syncer.sync(ctx.guild, ctx)      @sync_group.command(name='users')      @commands.has_permissions(administrator=True)      async def sync_users_command(self, ctx: Context) -> None: -        """Manually synchronize the guild's users with the users on the site.""" -        initial_response = await ctx.send("📊 Synchronizing users.") -        total_created, total_updated, total_deleted = await syncers.sync_users(self.bot, ctx.guild) -        await initial_response.edit( -            content=( -                f"👌 User synchronization complete, created **{total_created}** " -                f"and updated **{total_created}** users." -            ) -        ) +        """Manually synchronise the guild's users with the users on the site.""" +        await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 14cf51383..6715ad6fb 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -1,235 +1,342 @@ +import abc +import logging +import typing as t  from collections import namedtuple -from typing import Dict, Set, Tuple +from functools import partial -from discord import Guild +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context +from bot import constants +from bot.api import ResponseCodeError  from bot.bot import Bot +log = logging.getLogger(__name__) +  # These objects are declared as namedtuples because tuples are hashable,  # something that we make use of when diffing site roles against guild roles. -Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild')) - - -def get_roles_for_sync( -        guild_roles: Set[Role], api_roles: Set[Role] -) -> Tuple[Set[Role], Set[Role], Set[Role]]: -    """ -    Determine which roles should be created or updated on the site. - -    Arguments: -        guild_roles (Set[Role]): -            Roles that were found on the guild at startup. - -        api_roles (Set[Role]): -            Roles that were retrieved from the API at startup. - -    Returns: -        Tuple[Set[Role], Set[Role]. Set[Role]]: -            A tuple with three elements. The first element represents -            roles to be created on the site, meaning that they were -            present on the cached guild but not on the API. The second -            element represents roles to be updated, meaning they were -            present on both the cached guild and the API but non-ID -            fields have changed inbetween. The third represents roles -            to be deleted on the site, meaning the roles are present on -            the API but not in the cached guild. -    """ -    guild_role_ids = {role.id for role in guild_roles} -    api_role_ids = {role.id for role in api_roles} -    new_role_ids = guild_role_ids - api_role_ids -    deleted_role_ids = api_role_ids - guild_role_ids - -    # New roles are those which are on the cached guild but not on the -    # API guild, going by the role ID. We need to send them in for creation. -    roles_to_create = {role for role in guild_roles if role.id in new_role_ids} -    roles_to_update = guild_roles - api_roles - roles_to_create -    roles_to_delete = {role for role in api_roles if role.id in deleted_role_ids} -    return roles_to_create, roles_to_update, roles_to_delete - - -async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]: -    """ -    Synchronize roles found on the given `guild` with the ones on the API. - -    Arguments: -        bot (bot.bot.Bot): -            The bot instance that we're running with. - -        guild (discord.Guild): -            The guild instance from the bot's cache -            to synchronize roles with. - -    Returns: -        Tuple[int, int, int]: -            A tuple with three integers representing how many roles were created -            (element `0`) , how many roles were updated (element `1`), and how many -            roles were deleted (element `2`) on the API. -    """ -    roles = await bot.api_client.get('bot/roles') - -    # Pack API roles and guild roles into one common format, -    # which is also hashable. We need hashability to be able -    # to compare these easily later using sets. -    api_roles = {Role(**role_dict) for role_dict in roles} -    guild_roles = { -        Role( -            id=role.id, name=role.name, -            colour=role.colour.value, permissions=role.permissions.value, -            position=role.position, -        ) -        for role in guild.roles -    } -    roles_to_create, roles_to_update, roles_to_delete = get_roles_for_sync(guild_roles, api_roles) - -    for role in roles_to_create: -        await bot.api_client.post( -            'bot/roles', -            json={ -                'id': role.id, -                'name': role.name, -                'colour': role.colour, -                'permissions': role.permissions, -                'position': role.position, -            } -        ) +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) -    for role in roles_to_update: -        await bot.api_client.put( -            f'bot/roles/{role.id}', -            json={ -                'id': role.id, -                'name': role.name, -                'colour': role.colour, -                'permissions': role.permissions, -                'position': role.position, -            } -        ) -    for role in roles_to_delete: -        await bot.api_client.delete(f'bot/roles/{role.id}') - -    return len(roles_to_create), len(roles_to_update), len(roles_to_delete) - - -def get_users_for_sync( -        guild_users: Dict[int, User], api_users: Dict[int, User] -) -> Tuple[Set[User], Set[User]]: -    """ -    Determine which users should be created or updated on the website. - -    Arguments: -        guild_users (Dict[int, User]): -            A mapping of user IDs to user data, populated from the -            guild cached on the running bot instance. - -        api_users (Dict[int, User]): -            A mapping of user IDs to user data, populated from the API's -            current inventory of all users. - -    Returns: -        Tuple[Set[User], Set[User]]: -            Two user sets as a tuple. The first element represents users -            to be created on the website, these are users that are present -            in the cached guild data but not in the API at all, going by -            their ID. The second element represents users to update. It is -            populated by users which are present on both the API and the -            guild, but where the attribute of a user on the API is not -            equal to the attribute of the user on the guild. -    """ -    users_to_create = set() -    users_to_update = set() - -    for api_user in api_users.values(): -        guild_user = guild_users.get(api_user.id) -        if guild_user is not None: -            if api_user != guild_user: -                users_to_update.add(guild_user) - -        elif api_user.in_guild: -            # The user is known on the API but not the guild, and the -            # API currently specifies that the user is a member of the guild. -            # This means that the user has left since the last sync. -            # Update the `in_guild` attribute of the user on the site -            # to signify that the user left. -            new_api_user = api_user._replace(in_guild=False) -            users_to_update.add(new_api_user) - -    new_user_ids = set(guild_users.keys()) - set(api_users.keys()) -    for user_id in new_user_ids: -        # The user is known on the guild but not on the API. This means -        # that the user has joined since the last sync. Create it. -        new_user = guild_users[user_id] -        users_to_create.add(new_user) - -    return users_to_create, users_to_update - - -async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]: -    """ -    Synchronize users found in the given `guild` with the ones in the API. - -    Arguments: -        bot (bot.bot.Bot): -            The bot instance that we're running with. - -        guild (discord.Guild): -            The guild instance from the bot's cache -            to synchronize roles with. - -    Returns: -        Tuple[int, int, None]: -            A tuple with two integers, representing how many users were created -            (element `0`) and how many users were updated (element `1`), and `None` -            to indicate that a user sync never deletes entries from the API. -    """ -    current_users = await bot.api_client.get('bot/users') - -    # Pack API users and guild users into one common format, -    # which is also hashable. We need hashability to be able -    # to compare these easily later using sets. -    api_users = { -        user_dict['id']: User( -            roles=tuple(sorted(user_dict.pop('roles'))), -            **user_dict -        ) -        for user_dict in current_users -    } -    guild_users = { -        member.id: User( -            id=member.id, name=member.name, -            discriminator=int(member.discriminator), avatar_hash=member.avatar, -            roles=tuple(sorted(role.id for role in member.roles)), in_guild=True -        ) -        for member in guild.members -    } - -    users_to_create, users_to_update = get_users_for_sync(guild_users, api_users) - -    for user in users_to_create: -        await bot.api_client.post( -            'bot/users', -            json={ -                'avatar_hash': user.avatar_hash, -                'discriminator': user.discriminator, -                'id': user.id, -                'in_guild': user.in_guild, -                'name': user.name, -                'roles': list(user.roles) -            } +class Syncer(abc.ABC): +    """Base class for synchronising the database with objects in the Discord cache.""" + +    _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developer}> " +    _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + +    def __init__(self, bot: Bot) -> None: +        self.bot = bot + +    @property +    @abc.abstractmethod +    def name(self) -> str: +        """The name of the syncer; used in output messages and logging.""" +        raise NotImplementedError  # pragma: no cover + +    async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: +        """ +        Send a prompt to confirm or abort a sync using reactions and return the sent message. + +        If a message is given, it is edited to display the prompt and reactions. Otherwise, a new +        message is sent to the dev-core channel and mentions the core developers role. If the +        channel cannot be retrieved, return None. +        """ +        log.trace(f"Sending {self.name} sync confirmation prompt.") + +        msg_content = ( +            f'Possible cache issue while syncing {self.name}s. ' +            f'More than {constants.Sync.max_diff} {self.name}s were changed. ' +            f'React to confirm or abort the sync.'          ) -    for user in users_to_update: -        await bot.api_client.put( -            f'bot/users/{user.id}', -            json={ -                'avatar_hash': user.avatar_hash, -                'discriminator': user.discriminator, -                'id': user.id, -                'in_guild': user.in_guild, -                'name': user.name, -                'roles': list(user.roles) -            } +        # Send to core developers if it's an automatic sync. +        if not message: +            log.trace("Message not provided for confirmation; creating a new one in dev-core.") +            channel = self.bot.get_channel(constants.Channels.devcore) + +            if not channel: +                log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") +                try: +                    channel = await self.bot.fetch_channel(constants.Channels.devcore) +                except HTTPException: +                    log.exception( +                        f"Failed to fetch channel for sending sync confirmation prompt; " +                        f"aborting {self.name} sync." +                    ) +                    return None + +            message = await channel.send(f"{self._CORE_DEV_MENTION}{msg_content}") +        else: +            await message.edit(content=msg_content) + +        # Add the initial reactions. +        log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") +        for emoji in self._REACTION_EMOJIS: +            await message.add_reaction(emoji) + +        return message + +    def _reaction_check( +        self, +        author: Member, +        message: Message, +        reaction: Reaction, +        user: t.Union[Member, User] +    ) -> bool: +        """ +        Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + +        If the `author` of the prompt is a bot, then a reaction by any core developer will be +        considered valid. Otherwise, the author of the reaction (`user`) will have to be the +        `author` of the prompt. +        """ +        # For automatic syncs, check for the core dev role instead of an exact author +        has_role = any(constants.Roles.core_developer == role.id for role in user.roles) +        return ( +            reaction.message.id == message.id +            and not user.bot +            and (has_role if author.bot else user == author) +            and str(reaction.emoji) in self._REACTION_EMOJIS          ) -    return len(users_to_create), len(users_to_update), None +    async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: +        """ +        Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + +        Uses the `_reaction_check` function to determine if a reaction is valid. + +        If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. +        To acknowledge the reaction (or lack thereof), `message` will be edited. +        """ +        # Preserve the core-dev role mention in the message edits so users aren't confused about +        # where notifications came from. +        mention = self._CORE_DEV_MENTION if author.bot else "" + +        reaction = None +        try: +            log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") +            reaction, _ = await self.bot.wait_for( +                'reaction_add', +                check=partial(self._reaction_check, author, message), +                timeout=constants.Sync.confirm_timeout +            ) +        except TimeoutError: +            # reaction will remain none thus sync will be aborted in the finally block below. +            log.debug(f"The {self.name} syncer confirmation prompt timed out.") +        finally: +            if str(reaction) == constants.Emojis.check_mark: +                log.trace(f"The {self.name} syncer was confirmed.") +                await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') +                return True +            else: +                log.warning(f"The {self.name} syncer was aborted or timed out!") +                await message.edit( +                    content=f':warning: {mention}{self.name} sync aborted or timed out!' +                ) +                return False + +    @abc.abstractmethod +    async def _get_diff(self, guild: Guild) -> _Diff: +        """Return the difference between the cache of `guild` and the database.""" +        raise NotImplementedError  # pragma: no cover + +    @abc.abstractmethod +    async def _sync(self, diff: _Diff) -> None: +        """Perform the API calls for synchronisation.""" +        raise NotImplementedError  # pragma: no cover + +    async def _get_confirmation_result( +        self, +        diff_size: int, +        author: Member, +        message: t.Optional[Message] = None +    ) -> t.Tuple[bool, t.Optional[Message]]: +        """ +        Prompt for confirmation and return a tuple of the result and the prompt message. + +        `diff_size` is the size of the diff of the sync. If it is greater than +        `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the +        sync and the `message` is an extant message to edit to display the prompt. + +        If confirmed or no confirmation was needed, the result is True. The returned message will +        either be the given `message` or a new one which was created when sending the prompt. +        """ +        log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") +        if diff_size > constants.Sync.max_diff: +            message = await self._send_prompt(message) +            if not message: +                return False, None  # Couldn't get channel. + +            confirmed = await self._wait_for_confirmation(author, message) +            if not confirmed: +                return False, message  # Sync aborted. + +        return True, message + +    async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: +        """ +        Synchronise the database with the cache of `guild`. + +        If the differences between the cache and the database are greater than +        `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core +        channel. The confirmation can be optionally redirect to `ctx` instead. +        """ +        log.info(f"Starting {self.name} syncer.") + +        message = None +        author = self.bot.user +        if ctx: +            message = await ctx.send(f"📊 Synchronising {self.name}s.") +            author = ctx.author + +        diff = await self._get_diff(guild) +        diff_dict = diff._asdict()  # Ugly method for transforming the NamedTuple into a dict +        totals = {k: len(v) for k, v in diff_dict.items() if v is not None} +        diff_size = sum(totals.values()) + +        confirmed, message = await self._get_confirmation_result(diff_size, author, message) +        if not confirmed: +            return + +        # Preserve the core-dev role mention in the message edits so users aren't confused about +        # where notifications came from. +        mention = self._CORE_DEV_MENTION if author.bot else "" + +        try: +            await self._sync(diff) +        except ResponseCodeError as e: +            log.exception(f"{self.name} syncer failed!") + +            # Don't show response text because it's probably some really long HTML. +            results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" +            content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" +        else: +            results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) +            log.info(f"{self.name} syncer finished: {results}.") +            content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + +        if message: +            await message.edit(content=content) + + +class RoleSyncer(Syncer): +    """Synchronise the database with roles in the cache.""" + +    name = "role" + +    async def _get_diff(self, guild: Guild) -> _Diff: +        """Return the difference of roles between the cache of `guild` and the database.""" +        log.trace("Getting the diff for roles.") +        roles = await self.bot.api_client.get('bot/roles') + +        # Pack DB roles and guild roles into one common, hashable format. +        # They're hashable so that they're easily comparable with sets later. +        db_roles = {_Role(**role_dict) for role_dict in roles} +        guild_roles = { +            _Role( +                id=role.id, +                name=role.name, +                colour=role.colour.value, +                permissions=role.permissions.value, +                position=role.position, +            ) +            for role in guild.roles +        } + +        guild_role_ids = {role.id for role in guild_roles} +        api_role_ids = {role.id for role in db_roles} +        new_role_ids = guild_role_ids - api_role_ids +        deleted_role_ids = api_role_ids - guild_role_ids + +        # New roles are those which are on the cached guild but not on the +        # DB guild, going by the role ID. We need to send them in for creation. +        roles_to_create = {role for role in guild_roles if role.id in new_role_ids} +        roles_to_update = guild_roles - db_roles - roles_to_create +        roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + +        return _Diff(roles_to_create, roles_to_update, roles_to_delete) + +    async def _sync(self, diff: _Diff) -> None: +        """Synchronise the database with the role cache of `guild`.""" +        log.trace("Syncing created roles...") +        for role in diff.created: +            await self.bot.api_client.post('bot/roles', json=role._asdict()) + +        log.trace("Syncing updated roles...") +        for role in diff.updated: +            await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + +        log.trace("Syncing deleted roles...") +        for role in diff.deleted: +            await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): +    """Synchronise the database with users in the cache.""" + +    name = "user" + +    async def _get_diff(self, guild: Guild) -> _Diff: +        """Return the difference of users between the cache of `guild` and the database.""" +        log.trace("Getting the diff for users.") +        users = await self.bot.api_client.get('bot/users') + +        # Pack DB roles and guild roles into one common, hashable format. +        # They're hashable so that they're easily comparable with sets later. +        db_users = { +            user_dict['id']: _User( +                roles=tuple(sorted(user_dict.pop('roles'))), +                **user_dict +            ) +            for user_dict in users +        } +        guild_users = { +            member.id: _User( +                id=member.id, +                name=member.name, +                discriminator=int(member.discriminator), +                avatar_hash=member.avatar, +                roles=tuple(sorted(role.id for role in member.roles)), +                in_guild=True +            ) +            for member in guild.members +        } + +        users_to_create = set() +        users_to_update = set() + +        for db_user in db_users.values(): +            guild_user = guild_users.get(db_user.id) +            if guild_user is not None: +                if db_user != guild_user: +                    users_to_update.add(guild_user) + +            elif db_user.in_guild: +                # The user is known in the DB but not the guild, and the +                # DB currently specifies that the user is a member of the guild. +                # This means that the user has left since the last sync. +                # Update the `in_guild` attribute of the user on the site +                # to signify that the user left. +                new_api_user = db_user._replace(in_guild=False) +                users_to_update.add(new_api_user) + +        new_user_ids = set(guild_users.keys()) - set(db_users.keys()) +        for user_id in new_user_ids: +            # The user is known on the guild but not on the API. This means +            # that the user has joined since the last sync. Create it. +            new_user = guild_users[user_id] +            users_to_create.add(new_user) + +        return _Diff(users_to_create, users_to_update, None) + +    async def _sync(self, diff: _Diff) -> None: +        """Synchronise the database with the user cache of `guild`.""" +        log.trace("Syncing created users...") +        for user in diff.created: +            await self.bot.api_client.post('bot/users', json=user._asdict()) + +        log.trace("Syncing updated users...") +        for user in diff.updated: +            await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 582237374..e3c396863 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -219,7 +219,7 @@ class Verification(Cog):      @periodic_ping.before_loop      async def before_ping(self) -> None:          """Only start the loop when the bot is ready.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()      def cog_unload(self) -> None:          """Cancel the periodic ping task when the cog is unloaded.""" diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index eb787b083..3667a80e8 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -91,7 +91,7 @@ class WatchChannel(metaclass=CogABCMeta):      async def start_watchchannel(self) -> None:          """Starts the watch channel by getting the channel, webhook, and user cache ready.""" -        await self.bot.wait_until_ready() +        await self.bot.wait_until_guild_available()          try:              self.channel = await self.bot.fetch_channel(self.destination) diff --git a/bot/constants.py b/bot/constants.py index a4c65a1f8..9bc331dc4 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -263,6 +263,7 @@ class Emojis(metaclass=YAMLGetter):      new: str      pencil: str      cross_mark: str +    check_mark: str      ducky_yellow: int      ducky_blurple: int @@ -365,6 +366,8 @@ class Channels(metaclass=YAMLGetter):      bot: int      checkpoint_test: int      defcon: int +    devcontrib: int +    devcore: int      devlog: int      devtest: int      esoteric: int @@ -404,6 +407,7 @@ class Webhooks(metaclass=YAMLGetter):      big_brother: int      reddit: int      duck_pond: int +    dev_log: int  class Roles(metaclass=YAMLGetter): @@ -432,7 +436,7 @@ class Guild(metaclass=YAMLGetter):      id: int      ignored: List[int]      staff_channels: List[int] - +    reminder_whitelist: List[int]  class Keys(metaclass=YAMLGetter):      section = "keys" @@ -537,6 +541,13 @@ class RedirectOutput(metaclass=YAMLGetter):      delete_delay: int +class Sync(metaclass=YAMLGetter): +    section = 'sync' + +    confirm_timeout: int +    max_diff: int + +  class Event(Enum):      """      Event names. This does not include every event (for example, raw diff --git a/config-default.yml b/config-default.yml index ba6ea2742..f70fe3c34 100644 --- a/config-default.yml +++ b/config-default.yml @@ -35,6 +35,7 @@ style:          pencil:     "\u270F"          new:        "\U0001F195"          cross_mark: "\u274C" +        check_mark: "\u2705"          ducky_yellow:   &DUCKY_YELLOW   574951975574175744          ducky_blurple:  &DUCKY_BLURPLE  574951975310065675 @@ -119,9 +120,11 @@ guild:          announcements:                    354619224620138496          attachment_log:    &ATTCH_LOG     649243850006855680          big_brother_logs:  &BBLOGS        468507907357409333 -        bot:                              267659945086812160 +        bot:               &BOT_CMD       267659945086812160          checkpoint_test:                  422077681434099723          defcon:            &DEFCON        464469101889454091 +        devcontrib:        &DEV_CONTRIB   635950537262759947 +        devcore:                          411200599653351425          devlog:            &DEVLOG        622895325144940554          devtest:           &DEVTEST       414574275865870337          esoteric:                         470884583684964352 @@ -156,6 +159,7 @@ guild:      staff_channels: [*ADMINS, *ADMIN_SPAM, *MOD_SPAM, *MODS, *HELPERS, *ORGANISATION, *DEFCON]      ignored: [*ADMINS, *MESSAGE_LOG, *MODLOG, *ADMINS_VOICE, *STAFF_VOICE, *ATTCH_LOG] +    reminder_whitelist: [*BOT_CMD, *DEV_CONTRIB]      roles:          admin:             &ADMIN_ROLE      267628507062992896 @@ -178,6 +182,7 @@ guild:          big_brother:                        569133704568373283          reddit:                             635408384794951680          duck_pond:                          637821475327311927 +        dev_log:                            680501655111729222  filter: @@ -430,6 +435,10 @@ redirect_output:      delete_invocation: true      delete_delay: 15 +sync: +    confirm_timeout: 300 +    max_diff: 10 +  duck_pond:      threshold: 5      custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE, *DUCKY_HUNT, *DUCKY_WIZARD, *DUCKY_PARTY, *DUCKY_ANGEL, *DUCKY_MAUL, *DUCKY_SANTA] diff --git a/tests/base.py b/tests/base.py index 029a249ed..88693f382 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,12 @@  import logging  import unittest  from contextlib import contextmanager +from typing import Dict + +import discord +from discord.ext import commands + +from tests import helpers  class _CaptureLogHandler(logging.Handler): @@ -65,3 +71,31 @@ class LoggingTestCase(unittest.TestCase):              standard_message = self._truncateMessage(base_message, record_message)              msg = self._formatMessage(msg, standard_message)              self.fail(msg) + + +class CommandTestCase(unittest.TestCase): +    """TestCase with additional assertions that are useful for testing Discord commands.""" + +    @helpers.async_test +    async def assertHasPermissionsCheck( +        self, +        cmd: commands.Command, +        permissions: Dict[str, bool], +    ) -> None: +        """ +        Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`. + +        Every permission in `permissions` is expected to be reported as missing. In other words, do +        not include permissions which should not raise an exception along with those which should. +        """ +        # Invert permission values because it's more intuitive to pass to this assertion the same +        # permissions as those given to the check decorator. +        permissions = {k: not v for k, v in permissions.items()} + +        ctx = helpers.MockContext() +        ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) + +        with self.assertRaises(commands.MissingPermissions) as cm: +            await cmd.can_run(ctx) + +        self.assertCountEqual(permissions.keys(), cm.exception.missing_perms) diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py new file mode 100644 index 000000000..e6a6f9688 --- /dev/null +++ b/tests/bot/cogs/sync/test_base.py @@ -0,0 +1,412 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): +    """Syncer subclass with mocks for abstract methods for testing purposes.""" + +    name = "test" +    _get_diff = helpers.AsyncMock() +    _sync = helpers.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): +    """Tests for the syncer base class.""" + +    def setUp(self): +        self.bot = helpers.MockBot() + +    def test_instantiation_fails_without_abstract_methods(self): +        """The class must have abstract methods implemented.""" +        with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): +            Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.TestCase): +    """Tests for sending the sync confirmation prompt.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = TestSyncer(self.bot) + +    def mock_get_channel(self): +        """Fixture to return a mock channel and message for when `get_channel` is used.""" +        self.bot.reset_mock() + +        mock_channel = helpers.MockTextChannel() +        mock_message = helpers.MockMessage() + +        mock_channel.send.return_value = mock_message +        self.bot.get_channel.return_value = mock_channel + +        return mock_channel, mock_message + +    def mock_fetch_channel(self): +        """Fixture to return a mock channel and message for when `fetch_channel` is used.""" +        self.bot.reset_mock() + +        mock_channel = helpers.MockTextChannel() +        mock_message = helpers.MockMessage() + +        self.bot.get_channel.return_value = None +        mock_channel.send.return_value = mock_message +        self.bot.fetch_channel.return_value = mock_channel + +        return mock_channel, mock_message + +    @helpers.async_test +    async def test_send_prompt_edits_and_returns_message(self): +        """The given message should be edited to display the prompt and then should be returned.""" +        msg = helpers.MockMessage() +        ret_val = await self.syncer._send_prompt(msg) + +        msg.edit.assert_called_once() +        self.assertIn("content", msg.edit.call_args[1]) +        self.assertEqual(ret_val, msg) + +    @helpers.async_test +    async def test_send_prompt_gets_dev_core_channel(self): +        """The dev-core channel should be retrieved if an extant message isn't given.""" +        subtests = ( +            (self.bot.get_channel, self.mock_get_channel), +            (self.bot.fetch_channel, self.mock_fetch_channel), +        ) + +        for method, mock_ in subtests: +            with self.subTest(method=method, msg=mock_.__name__): +                mock_() +                await self.syncer._send_prompt() + +                method.assert_called_once_with(constants.Channels.devcore) + +    @helpers.async_test +    async def test_send_prompt_returns_None_if_channel_fetch_fails(self): +        """None should be returned if there's an HTTPException when fetching the channel.""" +        self.bot.get_channel.return_value = None +        self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + +        ret_val = await self.syncer._send_prompt() + +        self.assertIsNone(ret_val) + +    @helpers.async_test +    async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): +        """A new message mentioning core devs should be sent and returned if message isn't given.""" +        for mock_ in (self.mock_get_channel, self.mock_fetch_channel): +            with self.subTest(msg=mock_.__name__): +                mock_channel, mock_message = mock_() +                ret_val = await self.syncer._send_prompt() + +                mock_channel.send.assert_called_once() +                self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) +                self.assertEqual(ret_val, mock_message) + +    @helpers.async_test +    async def test_send_prompt_adds_reactions(self): +        """The message should have reactions for confirmation added.""" +        extant_message = helpers.MockMessage() +        subtests = ( +            (extant_message, lambda: (None, extant_message)), +            (None, self.mock_get_channel), +            (None, self.mock_fetch_channel), +        ) + +        for message_arg, mock_ in subtests: +            subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__ + +            with self.subTest(msg=subtest_msg): +                _, mock_message = mock_() +                await self.syncer._send_prompt(message_arg) + +                calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] +                mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.TestCase): +    """Tests for waiting for a sync confirmation reaction on the prompt.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = TestSyncer(self.bot) +        self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developer) + +    @staticmethod +    def get_message_reaction(emoji): +        """Fixture to return a mock message an reaction from the given `emoji`.""" +        message = helpers.MockMessage() +        reaction = helpers.MockReaction(emoji=emoji, message=message) + +        return message, reaction + +    def test_reaction_check_for_valid_emoji_and_authors(self): +        """Should return True if authors are identical or are a bot and a core dev, respectively.""" +        user_subtests = ( +            ( +                helpers.MockMember(id=77), +                helpers.MockMember(id=77), +                "identical users", +            ), +            ( +                helpers.MockMember(id=77, bot=True), +                helpers.MockMember(id=43, roles=[self.core_dev_role]), +                "bot author and core-dev reactor", +            ), +        ) + +        for emoji in self.syncer._REACTION_EMOJIS: +            for author, user, msg in user_subtests: +                with self.subTest(author=author, user=user, emoji=emoji, msg=msg): +                    message, reaction = self.get_message_reaction(emoji) +                    ret_val = self.syncer._reaction_check(author, message, reaction, user) + +                    self.assertTrue(ret_val) + +    def test_reaction_check_for_invalid_reactions(self): +        """Should return False for invalid reaction events.""" +        valid_emoji = self.syncer._REACTION_EMOJIS[0] +        subtests = ( +            ( +                helpers.MockMember(id=77), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=43, roles=[self.core_dev_role]), +                "users are not identical", +            ), +            ( +                helpers.MockMember(id=77, bot=True), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=43), +                "reactor lacks the core-dev role", +            ), +            ( +                helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), +                "reactor is a bot", +            ), +            ( +                helpers.MockMember(id=77), +                helpers.MockMessage(id=95), +                helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), +                helpers.MockMember(id=77), +                "messages are not identical", +            ), +            ( +                helpers.MockMember(id=77), +                *self.get_message_reaction("InVaLiD"), +                helpers.MockMember(id=77), +                "emoji is invalid", +            ), +        ) + +        for *args, msg in subtests: +            kwargs = dict(zip(("author", "message", "reaction", "user"), args)) +            with self.subTest(**kwargs, msg=msg): +                ret_val = self.syncer._reaction_check(*args) +                self.assertFalse(ret_val) + +    @helpers.async_test +    async def test_wait_for_confirmation(self): +        """The message should always be edited and only return True if the emoji is a check mark.""" +        subtests = ( +            (constants.Emojis.check_mark, True, None), +            ("InVaLiD", False, None), +            (None, False, TimeoutError), +        ) + +        for emoji, ret_val, side_effect in subtests: +            for bot in (True, False): +                with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): +                    # Set up mocks +                    message = helpers.MockMessage() +                    member = helpers.MockMember(bot=bot) + +                    self.bot.wait_for.reset_mock() +                    self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) +                    self.bot.wait_for.side_effect = side_effect + +                    # Call the function +                    actual_return = await self.syncer._wait_for_confirmation(member, message) + +                    # Perform assertions +                    self.bot.wait_for.assert_called_once() +                    self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + +                    message.edit.assert_called_once() +                    kwargs = message.edit.call_args[1] +                    self.assertIn("content", kwargs) + +                    # Core devs should only be mentioned if the author is a bot. +                    if bot: +                        self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) +                    else: +                        self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + +                    self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.TestCase): +    """Tests for main function orchestrating the sync.""" + +    def setUp(self): +        self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) +        self.syncer = TestSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_respects_confirmation_result(self): +        """The sync should abort if confirmation fails and continue if confirmed.""" +        mock_message = helpers.MockMessage() +        subtests = ( +            (True, mock_message), +            (False, None), +        ) + +        for confirmed, message in subtests: +            with self.subTest(confirmed=confirmed): +                self.syncer._sync.reset_mock() +                self.syncer._get_diff.reset_mock() + +                diff = _Diff({1, 2, 3}, {4, 5}, None) +                self.syncer._get_diff.return_value = diff +                self.syncer._get_confirmation_result = helpers.AsyncMock( +                    return_value=(confirmed, message) +                ) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                self.syncer._get_diff.assert_called_once_with(guild) +                self.syncer._get_confirmation_result.assert_called_once() + +                if confirmed: +                    self.syncer._sync.assert_called_once_with(diff) +                else: +                    self.syncer._sync.assert_not_called() + +    @helpers.async_test +    async def test_sync_diff_size(self): +        """The diff size should be correctly calculated.""" +        subtests = ( +            (6, _Diff({1, 2}, {3, 4}, {5, 6})), +            (5, _Diff({1, 2, 3}, None, {4, 5})), +            (0, _Diff(None, None, None)), +            (0, _Diff(set(), set(), set())), +        ) + +        for size, diff in subtests: +            with self.subTest(size=size, diff=diff): +                self.syncer._get_diff.reset_mock() +                self.syncer._get_diff.return_value = diff +                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                self.syncer._get_diff.assert_called_once_with(guild) +                self.syncer._get_confirmation_result.assert_called_once() +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + +    @helpers.async_test +    async def test_sync_message_edited(self): +        """The message should be edited if one was sent, even if the sync has an API error.""" +        subtests = ( +            (None, None, False), +            (helpers.MockMessage(), None, True), +            (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), +        ) + +        for message, side_effect, should_edit in subtests: +            with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): +                self.syncer._sync.side_effect = side_effect +                self.syncer._get_confirmation_result = helpers.AsyncMock( +                    return_value=(True, message) +                ) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                if should_edit: +                    message.edit.assert_called_once() +                    self.assertIn("content", message.edit.call_args[1]) + +    @helpers.async_test +    async def test_sync_confirmation_context_redirect(self): +        """If ctx is given, a new message should be sent and author should be ctx's author.""" +        mock_member = helpers.MockMember() +        subtests = ( +            (None, self.bot.user, None), +            (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), +        ) + +        for ctx, author, message in subtests: +            with self.subTest(ctx=ctx, author=author, message=message): +                if ctx is not None: +                    ctx.send.return_value = message + +                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild, ctx) + +                if ctx is not None: +                    ctx.send.assert_called_once() + +                self.syncer._get_confirmation_result.assert_called_once() +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + +    @mock.patch.object(constants.Sync, "max_diff", new=3) +    @helpers.async_test +    async def test_confirmation_result_small_diff(self): +        """Should always return True and the given message if the diff size is too small.""" +        author = helpers.MockMember() +        expected_message = helpers.MockMessage() + +        for size in (3, 2): +            with self.subTest(size=size): +                self.syncer._send_prompt = helpers.AsyncMock() +                self.syncer._wait_for_confirmation = helpers.AsyncMock() + +                coro = self.syncer._get_confirmation_result(size, author, expected_message) +                result, actual_message = await coro + +                self.assertTrue(result) +                self.assertEqual(actual_message, expected_message) +                self.syncer._send_prompt.assert_not_called() +                self.syncer._wait_for_confirmation.assert_not_called() + +    @mock.patch.object(constants.Sync, "max_diff", new=3) +    @helpers.async_test +    async def test_confirmation_result_large_diff(self): +        """Should return True if confirmed and False if _send_prompt fails or aborted.""" +        author = helpers.MockMember() +        mock_message = helpers.MockMessage() + +        subtests = ( +            (True, mock_message, True, "confirmed"), +            (False, None, False, "_send_prompt failed"), +            (False, mock_message, False, "aborted"), +        ) + +        for expected_result, expected_message, confirmed, msg in subtests: +            with self.subTest(msg=msg): +                self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) +                self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) + +                coro = self.syncer._get_confirmation_result(4, author) +                actual_result, actual_message = await coro + +                self.syncer._send_prompt.assert_called_once_with(None)  # message defaults to None +                self.assertIs(actual_result, expected_result) +                self.assertEqual(actual_message, expected_message) + +                if expected_message: +                    self.syncer._wait_for_confirmation.assert_called_once_with( +                        author, expected_message +                    ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py new file mode 100644 index 000000000..98c9afc0d --- /dev/null +++ b/tests/bot/cogs/sync/test_cog.py @@ -0,0 +1,395 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs import sync +from bot.cogs.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): +    """ +    A MagicMock subclass to mock Syncer objects. + +    Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` +    instances. For more information, see the `MockGuild` docstring. +    """ + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=Syncer, **kwargs) + + +class SyncExtensionTests(unittest.TestCase): +    """Tests for the sync extension.""" + +    @staticmethod +    def test_extension_setup(): +        """The Sync cog should be added.""" +        bot = helpers.MockBot() +        sync.setup(bot) +        bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.TestCase): +    """Base class for Sync cog tests. Sets up patches for syncers.""" + +    def setUp(self): +        self.bot = helpers.MockBot() + +        # These patch the type. When the type is called, a MockSyncer instanced is returned. +        # MockSyncer is needed so that our custom AsyncMock is used. +        # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed. +        self.role_syncer_patcher = mock.patch( +            "bot.cogs.sync.syncers.RoleSyncer", +            new=mock.MagicMock(return_value=MockSyncer()) +        ) +        self.user_syncer_patcher = mock.patch( +            "bot.cogs.sync.syncers.UserSyncer", +            new=mock.MagicMock(return_value=MockSyncer()) +        ) +        self.RoleSyncer = self.role_syncer_patcher.start() +        self.UserSyncer = self.user_syncer_patcher.start() + +        self.cog = sync.Sync(self.bot) + +    def tearDown(self): +        self.role_syncer_patcher.stop() +        self.user_syncer_patcher.stop() + +    @staticmethod +    def response_error(status: int) -> ResponseCodeError: +        """Fixture to return a ResponseCodeError with the given status code.""" +        response = mock.MagicMock() +        response.status = status + +        return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): +    """Tests for the Sync cog.""" + +    @mock.patch.object(sync.Sync, "sync_guild") +    def test_sync_cog_init(self, sync_guild): +        """Should instantiate syncers and run a sync for the guild.""" +        # Reset because a Sync cog was already instantiated in setUp. +        self.RoleSyncer.reset_mock() +        self.UserSyncer.reset_mock() +        self.bot.loop.create_task.reset_mock() + +        mock_sync_guild_coro = mock.MagicMock() +        sync_guild.return_value = mock_sync_guild_coro + +        sync.Sync(self.bot) + +        self.RoleSyncer.assert_called_once_with(self.bot) +        self.UserSyncer.assert_called_once_with(self.bot) +        sync_guild.assert_called_once_with() +        self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + +    @helpers.async_test +    async def test_sync_cog_sync_guild(self): +        """Roles and users should be synced only if a guild is successfully retrieved.""" +        for guild in (helpers.MockGuild(), None): +            with self.subTest(guild=guild): +                self.bot.reset_mock() +                self.cog.role_syncer.reset_mock() +                self.cog.user_syncer.reset_mock() + +                self.bot.get_guild = mock.MagicMock(return_value=guild) + +                await self.cog.sync_guild() + +                self.bot.wait_until_guild_available.assert_called_once() +                self.bot.get_guild.assert_called_once_with(constants.Guild.id) + +                if guild is None: +                    self.cog.role_syncer.sync.assert_not_called() +                    self.cog.user_syncer.sync.assert_not_called() +                else: +                    self.cog.role_syncer.sync.assert_called_once_with(guild) +                    self.cog.user_syncer.sync.assert_called_once_with(guild) + +    async def patch_user_helper(self, side_effect: BaseException) -> None: +        """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" +        self.bot.api_client.patch.reset_mock(side_effect=True) +        self.bot.api_client.patch.side_effect = side_effect + +        user_id, updated_information = 5, {"key": 123} +        await self.cog.patch_user(user_id, updated_information) + +        self.bot.api_client.patch.assert_called_once_with( +            f"bot/users/{user_id}", +            json=updated_information, +        ) + +    @helpers.async_test +    async def test_sync_cog_patch_user(self): +        """A PATCH request should be sent and 404 errors ignored.""" +        for side_effect in (None, self.response_error(404)): +            with self.subTest(side_effect=side_effect): +                await self.patch_user_helper(side_effect) + +    @helpers.async_test +    async def test_sync_cog_patch_user_non_404(self): +        """A PATCH request should be sent and the error raised if it's not a 404.""" +        with self.assertRaises(ResponseCodeError): +            await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): +    """Tests for the listeners of the Sync cog.""" + +    def setUp(self): +        super().setUp() +        self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_create(self): +        """A POST request should be sent with the new role's data.""" +        self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + +        role_data = { +            "colour": 49, +            "id": 777, +            "name": "rolename", +            "permissions": 8, +            "position": 23, +        } +        role = helpers.MockRole(**role_data) +        await self.cog.on_guild_role_create(role) + +        self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_delete(self): +        """A DELETE request should be sent.""" +        self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + +        role = helpers.MockRole(id=99) +        await self.cog.on_guild_role_delete(role) + +        self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_update(self): +        """A PUT request should be sent if the colour, name, permissions, or position changes.""" +        self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + +        role_data = { +            "colour": 49, +            "id": 777, +            "name": "rolename", +            "permissions": 8, +            "position": 23, +        } +        subtests = ( +            (True, ("colour", "name", "permissions", "position")), +            (False, ("hoist", "mentionable")), +        ) + +        for should_put, attributes in subtests: +            for attribute in attributes: +                with self.subTest(should_put=should_put, changed_attribute=attribute): +                    self.bot.api_client.put.reset_mock() + +                    after_role_data = role_data.copy() +                    after_role_data[attribute] = 876 + +                    before_role = helpers.MockRole(**role_data) +                    after_role = helpers.MockRole(**after_role_data) + +                    await self.cog.on_guild_role_update(before_role, after_role) + +                    if should_put: +                        self.bot.api_client.put.assert_called_once_with( +                            f"bot/roles/{after_role.id}", +                            json=after_role_data +                        ) +                    else: +                        self.bot.api_client.put.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_member_remove(self): +        """Member should patched to set in_guild as False.""" +        self.assertTrue(self.cog.on_member_remove.__cog_listener__) + +        member = helpers.MockMember() +        await self.cog.on_member_remove(member) + +        self.cog.patch_user.assert_called_once_with( +            member.id, +            updated_information={"in_guild": False} +        ) + +    @helpers.async_test +    async def test_sync_cog_on_member_update_roles(self): +        """Members should be patched if their roles have changed.""" +        self.assertTrue(self.cog.on_member_update.__cog_listener__) + +        # Roles are intentionally unsorted. +        before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] +        before_member = helpers.MockMember(roles=before_roles) +        after_member = helpers.MockMember(roles=before_roles[1:]) + +        await self.cog.on_member_update(before_member, after_member) + +        data = {"roles": sorted(role.id for role in after_member.roles)} +        self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + +    @helpers.async_test +    async def test_sync_cog_on_member_update_other(self): +        """Members should not be patched if other attributes have changed.""" +        self.assertTrue(self.cog.on_member_update.__cog_listener__) + +        subtests = ( +            ("activities", discord.Game("Pong"), discord.Game("Frogger")), +            ("nick", "old nick", "new nick"), +            ("status", discord.Status.online, discord.Status.offline), +        ) + +        for attribute, old_value, new_value in subtests: +            with self.subTest(attribute=attribute): +                self.cog.patch_user.reset_mock() + +                before_member = helpers.MockMember(**{attribute: old_value}) +                after_member = helpers.MockMember(**{attribute: new_value}) + +                await self.cog.on_member_update(before_member, after_member) + +                self.cog.patch_user.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_user_update(self): +        """A user should be patched only if the name, discriminator, or avatar changes.""" +        self.assertTrue(self.cog.on_user_update.__cog_listener__) + +        before_data = { +            "name": "old name", +            "discriminator": "1234", +            "avatar": "old avatar", +            "bot": False, +        } + +        subtests = ( +            (True, "name", "name", "new name", "new name"), +            (True, "discriminator", "discriminator", "8765", 8765), +            (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"), +            (False, "bot", "bot", True, True), +        ) + +        for should_patch, attribute, api_field, value, api_value in subtests: +            with self.subTest(attribute=attribute): +                self.cog.patch_user.reset_mock() + +                after_data = before_data.copy() +                after_data[attribute] = value +                before_user = helpers.MockUser(**before_data) +                after_user = helpers.MockUser(**after_data) + +                await self.cog.on_user_update(before_user, after_user) + +                if should_patch: +                    self.cog.patch_user.assert_called_once() + +                    # Don't care if *all* keys are present; only the changed one is required +                    call_args = self.cog.patch_user.call_args +                    self.assertEqual(call_args[0][0], after_user.id) +                    self.assertIn("updated_information", call_args[1]) + +                    updated_information = call_args[1]["updated_information"] +                    self.assertIn(api_field, updated_information) +                    self.assertEqual(updated_information[api_field], api_value) +                else: +                    self.cog.patch_user.assert_not_called() + +    async def on_member_join_helper(self, side_effect: Exception) -> dict: +        """ +        Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + +        The request data for the mock member is returned. All exceptions will be re-raised. +        """ +        member = helpers.MockMember( +            discriminator="1234", +            roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], +        ) + +        data = { +            "avatar_hash": member.avatar, +            "discriminator": int(member.discriminator), +            "id": member.id, +            "in_guild": True, +            "name": member.name, +            "roles": sorted(role.id for role in member.roles) +        } + +        self.bot.api_client.put.reset_mock(side_effect=True) +        self.bot.api_client.put.side_effect = side_effect + +        try: +            await self.cog.on_member_join(member) +        except Exception: +            raise +        finally: +            self.bot.api_client.put.assert_called_once_with( +                f"bot/users/{member.id}", +                json=data +            ) + +        return data + +    @helpers.async_test +    async def test_sync_cog_on_member_join(self): +        """Should PUT user's data or POST it if the user doesn't exist.""" +        for side_effect in (None, self.response_error(404)): +            with self.subTest(side_effect=side_effect): +                self.bot.api_client.post.reset_mock() +                data = await self.on_member_join_helper(side_effect) + +                if side_effect: +                    self.bot.api_client.post.assert_called_once_with("bot/users", json=data) +                else: +                    self.bot.api_client.post.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_member_join_non_404(self): +        """ResponseCodeError should be re-raised if status code isn't a 404.""" +        with self.assertRaises(ResponseCodeError): +            await self.on_member_join_helper(self.response_error(500)) + +        self.bot.api_client.post.assert_not_called() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): +    """Tests for the commands in the Sync cog.""" + +    @helpers.async_test +    async def test_sync_roles_command(self): +        """sync() should be called on the RoleSyncer.""" +        ctx = helpers.MockContext() +        await self.cog.sync_roles_command.callback(self.cog, ctx) + +        self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + +    @helpers.async_test +    async def test_sync_users_command(self): +        """sync() should be called on the UserSyncer.""" +        ctx = helpers.MockContext() +        await self.cog.sync_users_command.callback(self.cog, ctx) + +        self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + +    def test_commands_require_admin(self): +        """The sync commands should only run if the author has the administrator permission.""" +        cmds = ( +            self.cog.sync_group, +            self.cog.sync_roles_command, +            self.cog.sync_users_command, +        ) + +        for cmd in cmds: +            with self.subTest(cmd=cmd): +                self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 27ae27639..14fb2577a 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -1,126 +1,165 @@  import unittest +from unittest import mock -from bot.cogs.sync.syncers import Role, get_roles_for_sync - - -class GetRolesForSyncTests(unittest.TestCase): -    """Tests constructing the roles to synchronize with the site.""" - -    def test_get_roles_for_sync_empty_return_for_equal_roles(self): -        """No roles should be synced when no diff is found.""" -        api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} -        guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            (set(), set(), set()) -        ) - -    def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): -        """Roles to be synced are returned when non-ID attributes differ.""" -        api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} -        guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            (set(), guild_roles, set()) -        ) - -    def test_get_roles_only_returns_roles_that_require_update(self): -        """Roles that require an update should be returned as the second tuple element.""" -        api_roles = { -            Role(id=41, name='old name', colour=33, permissions=0x8, position=1), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } -        guild_roles = { -            Role(id=41, name='new name', colour=35, permissions=0x8, position=2), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                set(), -                {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, -                set(), -            ) -        ) - -    def test_get_roles_returns_new_roles_in_first_tuple_element(self): -        """Newly created roles are returned as the first tuple element.""" -        api_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -        } -        guild_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -            Role(id=53, name='other role', colour=55, permissions=0, position=2) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, -                set(), -                set(), -            ) -        ) - -    def test_get_roles_returns_roles_to_update_and_new_roles(self): -        """Newly created and updated roles should be returned together.""" -        api_roles = { -            Role(id=41, name='old name', colour=35, permissions=0x8, position=1), -        } -        guild_roles = { -            Role(id=41, name='new name', colour=40, permissions=0x16, position=2), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, -                {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, -                set(), -            ) -        ) - -    def test_get_roles_returns_roles_to_delete(self): -        """Roles to be deleted should be returned as the third tuple element.""" -        api_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), -        } -        guild_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                set(), -                set(), -                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, -            ) -        ) - -    def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): -        """When roles were added, updated, and removed, all of them are returned properly.""" -        api_roles = { -            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), -            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), -            Role(id=71, name='to update', colour=99, permissions=0x9, position=3), -        } -        guild_roles = { -            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), -            Role(id=81, name='to create', colour=99, permissions=0x9, position=4), -            Role(id=71, name='updated', colour=101, permissions=0x5, position=3), -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, -                {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, -                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, -            ) -        ) +import discord + +from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): +    """Fixture to return a dictionary representing a role with default values set.""" +    kwargs.setdefault("id", 9) +    kwargs.setdefault("name", "fake role") +    kwargs.setdefault("colour", 7) +    kwargs.setdefault("permissions", 0) +    kwargs.setdefault("position", 55) + +    return kwargs + + +class RoleSyncerDiffTests(unittest.TestCase): +    """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = RoleSyncer(self.bot) + +    @staticmethod +    def get_guild(*roles): +        """Fixture to return a guild object with the given roles.""" +        guild = helpers.MockGuild() +        guild.roles = [] + +        for role in roles: +            mock_role = helpers.MockRole(**role) +            mock_role.colour = discord.Colour(role["colour"]) +            mock_role.permissions = discord.Permissions(role["permissions"]) +            guild.roles.append(mock_role) + +        return guild + +    @helpers.async_test +    async def test_empty_diff_for_identical_roles(self): +        """No differences should be found if the roles in the guild and DB are identical.""" +        self.bot.api_client.get.return_value = [fake_role()] +        guild = self.get_guild(fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_updated_roles(self): +        """Only updated roles should be added to the 'updated' set of the diff.""" +        updated_role = fake_role(id=41, name="new") + +        self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] +        guild = self.get_guild(updated_role, fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_Role(**updated_role)}, set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_roles(self): +        """Only new roles should be added to the 'created' set of the diff.""" +        new_role = fake_role(id=41, name="new") + +        self.bot.api_client.get.return_value = [fake_role()] +        guild = self.get_guild(fake_role(), new_role) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_Role(**new_role)}, set(), set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_deleted_roles(self): +        """Only deleted roles should be added to the 'deleted' set of the diff.""" +        deleted_role = fake_role(id=61, name="deleted") + +        self.bot.api_client.get.return_value = [fake_role(), deleted_role] +        guild = self.get_guild(fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), {_Role(**deleted_role)}) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_updated_and_deleted_roles(self): +        """When roles are added, updated, and removed, all of them are returned properly.""" +        new = fake_role(id=41, name="new") +        updated = fake_role(id=71, name="updated") +        deleted = fake_role(id=61, name="deleted") + +        self.bot.api_client.get.return_value = [ +            fake_role(), +            fake_role(id=71, name="updated name"), +            deleted, +        ] +        guild = self.get_guild(fake_role(), new, updated) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + +        self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.TestCase): +    """Tests for the API requests that sync roles.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = RoleSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_created_roles(self): +        """Only POST requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(role_tuples, set(), set()) +        await self.syncer._sync(diff) + +        calls = [mock.call("bot/roles", json=role) for role in roles] +        self.bot.api_client.post.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + +        self.bot.api_client.put.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_updated_roles(self): +        """Only PUT requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(set(), role_tuples, set()) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] +        self.bot.api_client.put.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_deleted_roles(self): +        """Only DELETE requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(set(), set(), role_tuples) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] +        self.bot.api_client.delete.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index ccaf67490..421bf6bb6 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,84 +1,169 @@  import unittest +from unittest import mock -from bot.cogs.sync.syncers import User, get_users_for_sync +from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers  def fake_user(**kwargs): -    kwargs.setdefault('id', 43) -    kwargs.setdefault('name', 'bob the test man') -    kwargs.setdefault('discriminator', 1337) -    kwargs.setdefault('avatar_hash', None) -    kwargs.setdefault('roles', (666,)) -    kwargs.setdefault('in_guild', True) -    return User(**kwargs) - - -class GetUsersForSyncTests(unittest.TestCase): -    """Tests constructing the users to synchronize with the site.""" - -    def test_get_users_for_sync_returns_nothing_for_empty_params(self): -        """When no users are given, none are returned.""" -        self.assertEqual( -            get_users_for_sync({}, {}), -            (set(), set()) -        ) - -    def test_get_users_for_sync_returns_nothing_for_equal_users(self): -        """When no users are updated, none are returned.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user()} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), set()) -        ) - -    def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): -        """When a non-ID-field differs, the user to update is returned.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user(name='new fancy name')} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), {fake_user(name='new fancy name')}) -        ) - -    def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): -        """When new users join the guild, they are returned as the first tuple element.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user(), 63: fake_user(id=63)} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            ({fake_user(id=63)}, set()) -        ) - -    def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): +    """Fixture to return a dictionary representing a user with default values set.""" +    kwargs.setdefault("id", 43) +    kwargs.setdefault("name", "bob the test man") +    kwargs.setdefault("discriminator", 1337) +    kwargs.setdefault("avatar_hash", None) +    kwargs.setdefault("roles", (666,)) +    kwargs.setdefault("in_guild", True) + +    return kwargs + + +class UserSyncerDiffTests(unittest.TestCase): +    """Tests for determining differences between users in the DB and users in the Guild cache.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = UserSyncer(self.bot) + +    @staticmethod +    def get_guild(*members): +        """Fixture to return a guild object with the given members.""" +        guild = helpers.MockGuild() +        guild.members = [] + +        for member in members: +            member = member.copy() +            member["avatar"] = member.pop("avatar_hash") +            del member["in_guild"] + +            mock_member = helpers.MockMember(**member) +            mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + +            guild.members.append(mock_member) + +        return guild + +    @helpers.async_test +    async def test_empty_diff_for_no_users(self): +        """When no users are given, an empty diff should be returned.""" +        guild = self.get_guild() + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_empty_diff_for_identical_users(self): +        """No differences should be found if the users in the guild and DB are identical.""" +        self.bot.api_client.get.return_value = [fake_user()] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_updated_users(self): +        """Only updated users should be added to the 'updated' set of the diff.""" +        updated_user = fake_user(id=99, name="new") + +        self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] +        guild = self.get_guild(updated_user, fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_User(**updated_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_users(self): +        """Only new users should be added to the 'created' set of the diff.""" +        new_user = fake_user(id=99, name="new") + +        self.bot.api_client.get.return_value = [fake_user()] +        guild = self.get_guild(fake_user(), new_user) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_User(**new_user)}, set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" -        api_users = {43: fake_user(), 63: fake_user(id=63)} -        guild_users = {43: fake_user()} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), {fake_user(id=63, in_guild=False)}) -        ) - -    def test_get_users_for_sync_updates_and_creates_users_as_needed(self): -        """When one user left and another one was updated, both are returned.""" -        api_users = {43: fake_user()} -        guild_users = {63: fake_user(id=63)} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            ({fake_user(id=63)}, {fake_user(in_guild=False)}) -        ) - -    def test_get_users_for_sync_does_not_duplicate_update_users(self): -        """When the API knows a user the guild doesn't, nothing is performed.""" -        api_users = {43: fake_user(in_guild=False)} -        guild_users = {} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), set()) -        ) +        leaving_user = fake_user(id=63, in_guild=False) + +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_User(**leaving_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_updated_and_leaving_users(self): +        """When users are added, updated, and removed, all of them are returned properly.""" +        new_user = fake_user(id=99, name="new") +        updated_user = fake_user(id=55, name="updated") +        leaving_user = fake_user(id=63, in_guild=False) + +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] +        guild = self.get_guild(fake_user(), new_user, updated_user) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_empty_diff_for_db_users_not_in_guild(self): +        """When the DB knows a user the guild doesn't, no difference is found.""" +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.TestCase): +    """Tests for the API requests that sync users.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = UserSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_created_users(self): +        """Only POST requests should be made with the correct payload.""" +        users = [fake_user(id=111), fake_user(id=222)] + +        user_tuples = {_User(**user) for user in users} +        diff = _Diff(user_tuples, set(), None) +        await self.syncer._sync(diff) + +        calls = [mock.call("bot/users", json=user) for user in users] +        self.bot.api_client.post.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.post.call_count, len(users)) + +        self.bot.api_client.put.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_updated_users(self): +        """Only PUT requests should be made with the correct payload.""" +        users = [fake_user(id=111), fake_user(id=222)] + +        user_tuples = {_User(**user) for user in users} +        diff = _Diff(set(), user_tuples, None) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] +        self.bot.api_client.put.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.put.call_count, len(users)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index d07b2bce1..5b0a3b8c3 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestCase):          asyncio.run(self.cog.fetch_webhook()) -        self.bot.wait_until_ready.assert_called_once() +        self.bot.wait_until_guild_available.assert_called_once()          self.bot.fetch_webhook.assert_called_once_with(1)          self.assertEqual(self.cog.webhook, "dummy webhook") @@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestCase):          with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:              asyncio.run(self.cog.fetch_webhook()) -        self.bot.wait_until_ready.assert_called_once() +        self.bot.wait_until_guild_available.assert_called_once()          self.bot.fetch_webhook.assert_called_once_with(1)          self.assertEqual(len(log_watcher.records), 1) diff --git a/tests/helpers.py b/tests/helpers.py index 5df796c23..9d9dd5da6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -12,6 +12,7 @@ from typing import Any, Iterable, Optional  import discord  from discord.ext.commands import Context +from bot.api import APIClient  from bot.bot import Bot @@ -269,9 +270,21 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      information, see the `MockGuild` docstring.      """      def __init__(self, **kwargs) -> None: -        default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} +        default_kwargs = { +            'id': next(self.discord_id), +            'name': 'role', +            'position': 1, +            'colour': discord.Colour(0xdeadbf), +            'permissions': discord.Permissions(), +        }          super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) +        if isinstance(self.colour, int): +            self.colour = discord.Colour(self.colour) + +        if isinstance(self.permissions, int): +            self.permissions = discord.Permissions(self.permissions) +          if 'mention' not in kwargs:              self.mention = f'&{self.name}' @@ -324,6 +337,18 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):              self.mention = f"@{self.name}" +class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock APIClient objects. + +    Instances of this class will follow the specifications of `bot.api.APIClient` instances. +    For more information, see the `MockGuild` docstring. +    """ + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=APIClient, **kwargs) + +  # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`  bot_instance = Bot(command_prefix=unittest.mock.MagicMock())  bot_instance.http_session = None @@ -340,6 +365,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      def __init__(self, **kwargs) -> None:          super().__init__(spec_set=bot_instance, **kwargs) +        self.api_client = MockAPIClient()          # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and          # and should therefore be awaited. (The documentation calls it a coroutine as well, which @@ -503,6 +529,7 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):          self.emoji = kwargs.get('emoji', MockEmoji())          self.message = kwargs.get('message', MockMessage())          self.users = AsyncIteratorMock(kwargs.get('users', [])) +        self.__str__.return_value = str(self.emoji)  webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) | 
