diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/sync/cog.py | 39 | ||||
| -rw-r--r-- | bot/cogs/sync/syncers.py | 108 | 
2 files changed, 54 insertions, 93 deletions
| diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index 1fd39b544..66ffbabf9 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -1,7 +1,7 @@  import logging -from typing import Any, Dict, Optional +from typing import Any, Dict -from discord import Guild, Member, Role, User +from discord import Member, Role, User  from discord.ext import commands  from discord.ext.commands import Cog, Context @@ -18,8 +18,8 @@ class Sync(Cog):      def __init__(self, bot: Bot) -> None:          self.bot = bot -        self.role_syncer = syncers.RoleSyncer(self.bot.api_client) -        self.user_syncer = syncers.UserSyncer(self.bot.api_client) +        self.role_syncer = syncers.RoleSyncer(self.bot) +        self.user_syncer = syncers.UserSyncer(self.bot)          self.bot.loop.create_task(self.sync_guild()) @@ -32,32 +32,7 @@ class Sync(Cog):              return          for syncer in (self.role_syncer, self.user_syncer): -            await self.sync(syncer, guild) - -    @staticmethod -    async def sync(syncer: syncers.Syncer, guild: Guild, ctx: Optional[Context] = None) -> None: -        """Run `syncer` using the cache of the given `guild`.""" -        log.info(f"Starting {syncer.name} syncer.") -        if ctx: -            message = await ctx.send(f"📊 Synchronising {syncer.name}s.") - -        diff = await syncer.sync(guild, ctx) -        if not diff: -            return  # Sync was aborted. - -        totals = zip(("created", "updated", "deleted"), diff) -        results = ", ".join(f"{name} `{len(total)}`" for name, total in totals if total is not None) - -        if results: -            log.info(f"{syncer.name} syncer finished: {results}.") -            if ctx: -                await message.edit( -                    content=f":ok_hand: Synchronisation of {syncer.name}s complete: {results}" -                ) -        else: -            log.warning(f"{syncer.name} syncer aborted!") -            if ctx: -                await message.edit(content=f":x: Synchronisation of {syncer.name}s aborted!") +            await syncer.sync(guild)      async def patch_user(self, user_id: int, updated_information: Dict[str, Any]) -> None:          """Send a PATCH request to partially update a user in the database.""" @@ -186,10 +161,10 @@ class Sync(Cog):      @commands.has_permissions(administrator=True)      async def sync_roles_command(self, ctx: Context) -> None:          """Manually synchronise the guild's roles with the roles on the site.""" -        await self.sync(self.role_syncer, ctx.guild, ctx) +        await self.role_syncer.sync(ctx.guild, ctx)      @sync_group.command(name='users')      @commands.has_permissions(administrator=True)      async def sync_users_command(self, ctx: Context) -> None:          """Manually synchronise the guild's users with the users on the site.""" -        await self.sync(self.user_syncer, ctx.guild, ctx) +        await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 7608c6870..7cc518348 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -3,7 +3,7 @@ import logging  import typing as t  from collections import namedtuple -from discord import Guild, HTTPException +from discord import Guild, HTTPException, Message  from discord.ext.commands import Context  from bot import constants @@ -22,9 +22,9 @@ _T = t.TypeVar("_T")  class Diff(t.NamedTuple, t.Generic[_T]):      """The differences between the Discord cache and the contents of the database.""" -    created: t.Optional[t.Set[_T]] = None -    updated: t.Optional[t.Set[_T]] = None -    deleted: t.Optional[t.Set[_T]] = None +    created: t.Set[_T] = {} +    updated: t.Set[_T] = {} +    deleted: t.Set[_T] = {}  class Syncer(abc.ABC, t.Generic[_T]): @@ -42,18 +42,22 @@ class Syncer(abc.ABC, t.Generic[_T]):          """The name of the syncer; used in output messages and logging."""          raise NotImplementedError -    async def _confirm(self, ctx: t.Optional[Context] = None) -> bool: +    async def _confirm(self, message: t.Optional[Message] = None) -> bool:          """          Send a prompt to confirm or abort a sync using reactions and return True if confirmed. -        If no context is given, the prompt is sent to the dev-core channel and mentions the core -        developers role. +        If a message is given, it is edited to display the prompt and reactions. Otherwise, a new +        message is sent to the dev-core channel and mentions the core developers role.          """          allowed_emoji = (constants.Emojis.check_mark, constants.Emojis.cross_mark) +        msg_content = ( +            f'Possible cache issue while syncing {self.name}s. ' +            f'Found no {self.name}s or more than {self.MAX_DIFF} {self.name}s were changed. ' +            f'React to confirm or abort the sync.' +        )          # Send to core developers if it's an automatic sync. -        if not ctx: -            mention = f'<@&{constants.Roles.core_developer}>' +        if not message:              channel = self.bot.get_channel(constants.Channels.devcore)              if not channel: @@ -65,24 +69,20 @@ class Syncer(abc.ABC, t.Generic[_T]):                          f"aborting {self.name} sync."                      )                      return False -        else: -            mention = ctx.author.mention -            channel = ctx.channel -        message = await channel.send( -            f'{mention} Possible cache issue while syncing {self.name}s. ' -            f'Found no {self.name}s or more than {self.MAX_DIFF} {self.name}s were changed. ' -            f'React to confirm or abort the sync.' -        ) +            message = await channel.send(f"<@&{constants.Roles.core_developer}> {msg_content}") +        else: +            message = await message.edit(content=f"{message.author.mention} {msg_content}")          # Add the initial reactions.          for emoji in allowed_emoji:              await message.add_reaction(emoji)          def check(_reaction, user):  # noqa: TYP +            # Skip author check for auto syncs              return (                  _reaction.message.id == message.id -                and True if not ctx else user == ctx.author  # Skip author check for auto syncs +                and True if message.author.bot else user == message.author                  and str(_reaction.emoji) in allowed_emoji              ) @@ -98,10 +98,11 @@ class Syncer(abc.ABC, t.Generic[_T]):              pass          finally:              if str(reaction) == constants.Emojis.check_mark: -                await channel.send(f':ok_hand: {self.name} sync will proceed.') +                await message.edit(content=f':ok_hand: {self.name} sync will proceed.')                  return True              else: -                await channel.send(f':x: {self.name} sync aborted!') +                log.warning(f"{self.name} syncer aborted!") +                await message.edit(content=f':x: {self.name} sync aborted!')                  return False      @abc.abstractmethod @@ -110,23 +111,36 @@ class Syncer(abc.ABC, t.Generic[_T]):          raise NotImplementedError      @abc.abstractmethod -    async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> t.Optional[Diff[_T]]: +    async def _sync(self, diff: Diff[_T]) -> None: +        """Perform the API calls for synchronisation.""" +        raise NotImplementedError + +    async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None:          """ -        Synchronise the database with the cache of `guild` and return the synced difference. +        Synchronise the database with the cache of `guild`.          If the differences between the cache and the database are greater than `MAX_DIFF`, then          a confirmation prompt will be sent to the dev-core channel. The confirmation can be          optionally redirect to `ctx` instead. - -        If the sync is not confirmed, None is returned.          """ +        log.info(f"Starting {self.name} syncer.") +        if ctx: +            message = await ctx.send(f"📊 Synchronising {self.name}s.") +          diff = await self._get_diff(guild) -        confirmed = await self._confirm(ctx) +        total = sum(map(len, diff)) -        if not confirmed: -            return None -        else: -            return diff +        if total > self.MAX_DIFF and not await self._confirm(ctx): +            return  # Sync aborted. + +        await self._sync(diff) + +        results = ", ".join(f"{name} `{len(total)}`" for name, total in diff._asdict().items()) +        log.info(f"{self.name} syncer finished: {results}.") +        if ctx: +            await message.edit( +                content=f":ok_hand: Synchronisation of {self.name}s complete: {results}" +            )  class RoleSyncer(Syncer[Role]): @@ -165,20 +179,8 @@ class RoleSyncer(Syncer[Role]):          return Diff(roles_to_create, roles_to_update, roles_to_delete) -    async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> t.Optional[Diff[Role]]: -        """ -        Synchronise the database with the role cache of `guild` and return the synced difference. - -        If the differences between the cache and the database are greater than `MAX_DIFF`, then -        a confirmation prompt will be sent to the dev-core channel. The confirmation can be -        optionally redirect to `ctx` instead. - -        If the sync is not confirmed, None is returned. -        """ -        diff = await super().sync(guild, ctx) -        if diff is None: -            return None - +    async def _sync(self, diff: Diff[Role]) -> None: +        """Synchronise the database with the role cache of `guild`."""          for role in diff.created:              await self.bot.api_client.post('bot/roles', json={**role._asdict()}) @@ -188,8 +190,6 @@ class RoleSyncer(Syncer[Role]):          for role in diff.deleted:              await self.bot.api_client.delete(f'bot/roles/{role.id}') -        return diff -  class UserSyncer(Syncer[User]):      """Synchronise the database with users in the cache.""" @@ -248,24 +248,10 @@ class UserSyncer(Syncer[User]):          return Diff(users_to_create, users_to_update) -    async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> t.Optional[Diff[_T]]: -        """ -        Synchronise the database with the user cache of `guild` and return the synced difference. - -        If the differences between the cache and the database are greater than `MAX_DIFF`, then -        a confirmation prompt will be sent to the dev-core channel. The confirmation can be -        optionally redirect to `ctx` instead. - -        If the sync is not confirmed, None is returned. -        """ -        diff = await super().sync(guild, ctx) -        if diff is None: -            return None - +    async def _sync(self, diff: Diff[User]) -> None: +        """Synchronise the database with the user cache of `guild`."""          for user in diff.created:              await self.bot.api_client.post('bot/users', json={**user._asdict()})          for user in diff.updated:              await self.bot.api_client.put(f'bot/users/{user.id}', json={**user._asdict()}) - -        return diff | 
