diff options
| -rw-r--r-- | bot/exts/backend/sync/_syncers.py | 146 | 
1 files changed, 48 insertions, 98 deletions
| diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index ae7d5d893..70887a217 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,7 +2,6 @@ import abc  import logging  import typing as t  from collections import namedtuple -from urllib.parse import parse_qsl, urlparse  from discord import Guild  from discord.ext.commands import Context @@ -15,7 +14,6 @@ log = logging.getLogger(__name__)  # These objects are declared as namedtuples because tuples are hashable,  # something that we make use of when diffing site roles against guild roles.  _Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild'))  _Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) @@ -42,11 +40,7 @@ class Syncer(abc.ABC):          raise NotImplementedError  # pragma: no cover      async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: -        """ -        Synchronise the database with the cache of `guild`. - -        If `ctx` is given, send a message with the results. -        """ +        """If `ctx` is given, send a message with the results."""          log.info(f"Starting {self.name} syncer.")          if ctx: @@ -136,111 +130,67 @@ class UserSyncer(Syncer):          """Return the difference of users between the cache of `guild` and the database."""          log.trace("Getting the diff for users.") -        users = await self._get_users() +        users_to_create = [] +        users_to_update = [] +        seen_guild_users = set() -        # Pack DB roles and guild roles into one common, hashable format. -        # They're hashable so that they're easily comparable with sets later. -        db_users = { -            user_dict['id']: _User( -                roles=tuple(sorted(user_dict.pop('roles'))), -                **user_dict -            ) -            for user_dict in users -        } -        guild_users = { -            member.id: _User( -                id=member.id, -                name=member.name, -                discriminator=int(member.discriminator), -                roles=tuple(sorted(role.id for role in member.roles)), -                in_guild=True -            ) -            for member in guild.members -        } +        async for db_user in self._get_users(): +            updated_fields = {} -        users_to_create = set() -        users_to_update = set() - -        for db_user in db_users.values(): -            guild_user = guild_users.get(db_user.id) - -            if guild_user is not None: -                if db_user != guild_user: -                    fields_to_none: dict = {} - -                    for field in _User._fields: -                        # Set un-changed values to None except ID to speed up API PATCH method. -                        if getattr(db_user, field) == getattr(guild_user, field) and field != "id": -                            fields_to_none[field] = None - -                    new_api_user = guild_user._replace(**fields_to_none) -                    users_to_update.add(new_api_user) - -            elif db_user.in_guild: -                # The user is known in the DB but not the guild, and the -                # DB currently specifies that the user is a member of the guild. -                # This means that the user has left since the last sync. -                # Update the `in_guild` attribute of the user on the site -                # to signify that the user left. - -                # Set un-changed fields to None except ID as it is required by the API. -                fields_to_none: dict = {field: None for field in db_user._fields if field not in ["id", "in_guild"]} -                new_api_user = db_user._replace( -                    in_guild=False, -                    **fields_to_none -                ) -                users_to_update.add(new_api_user) - -        new_user_ids = set(guild_users.keys()) - set(db_users.keys()) -        for user_id in new_user_ids: -            # The user is known on the guild but not on the API. This means -            # that the user has joined since the last sync. Create it. -            new_user = guild_users[user_id] -            users_to_create.add(new_user) +            def maybe_update(db_field: str, guild_value: t.Union[str, int]) -> None: +                if db_user[db_field] != guild_value: +                    updated_fields[db_field] = guild_value -        return _Diff(users_to_create, users_to_update, None) +            if guild_user := guild.get_member(db_user["id"]): +                seen_guild_users.add(guild_user.id) + +                maybe_update("name", guild_user.name) +                maybe_update("discriminator", int(guild_user.discriminator)) +                maybe_update("in_guild", True) -    async def _get_users(self, endpoint: str = "bot/users", query_params: list = None) -> t.List[dict]: -        """GET all users recursively.""" -        users = [] -        response: dict = await self.bot.api_client.get(endpoint, params=query_params) -        users.extend(response["results"]) +                guild_roles = [role.id for role in guild_user.roles] +                if set(db_user["roles"]) != set(guild_roles): +                    updated_fields["roles"] = guild_roles -        # The `response` is paginated, hence check if next page exists. -        if (next_page_url := response["next"]) is not None: -            next_endpoint, query_params = self.get_endpoint(next_page_url) -            users.extend(await self._get_users(next_endpoint, query_params)) -        return users +            elif db_user["in_guild"]: +                updated_fields["in_guild"] = False -    @staticmethod -    def get_endpoint(url: str) -> t.Tuple[str, t.List[tuple]]: -        """Extract the API endpoint and query params from a URL.""" -        url = urlparse(url) +            if updated_fields and updated_fields not in users_to_update: +                updated_fields["id"] = db_user["id"] +                users_to_update.append(updated_fields) -        # Do not include starting `/` for endpoint. -        endpoint = url.path[1:] +        for member in guild.members: +            if member.id not in seen_guild_users: +                new_user = { +                    "id": member.id, +                    "name": member.name, +                    "discriminator": int(member.discriminator), +                    "roles": [role.id for role in member.roles], +                    "in_guild": True +                } +                if new_user not in users_to_create: +                    users_to_create.append(new_user) -        # Query params. -        params = parse_qsl(url.query) +        return _Diff(users_to_create, users_to_update, None) -        return endpoint, params +    async def _get_users(self) -> t.AsyncIterable: +        """GET users from database.""" +        query_params = { +            "page": 1 +        } +        while query_params["page"]: +            res = await self.bot.api_client.get("bot/users", params=query_params) +            for user in res["results"]: +                yield user -    @staticmethod -    def patch_dict(user: _User) -> t.Dict[str, t.Union[int, str, tuple, bool]]: -        """Convert namedtuple to dict by omitting None values.""" -        user_dict = {} -        for field in user._fields: -            if (value := getattr(user, field)) is not None: -                user_dict[field] = value -        return user_dict +            query_params["page"] = res["next_page_no"]      async def _sync(self, diff: _Diff) -> None:          """Synchronise the database with the user cache of `guild`."""          log.trace("Syncing created users...")          if diff.created: -            created = [user._asdict() for user in diff.created] -            await self.bot.api_client.post("bot/users", json=created) +            await self.bot.api_client.post("bot/users", json=diff.created) +          log.trace("Syncing updated users...")          if diff.updated: -            updated = [self.patch_dict(user) for user in diff.updated] -            await self.bot.api_client.patch("bot/users/bulk_patch", json=updated) +            await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated) | 
