diff options
| -rw-r--r-- | bot/exts/backend/sync/_syncers.py | 70 | 
1 files changed, 61 insertions, 9 deletions
| diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index f7ba811bc..512efaa3d 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -4,6 +4,7 @@ import logging  import typing as t  from collections import namedtuple  from functools import partial +from urllib.parse import parse_qsl, urlparse  import discord  from discord import Guild, HTTPException, Member, Message, Reaction, User @@ -287,7 +288,8 @@ class UserSyncer(Syncer):      async def _get_diff(self, guild: Guild) -> _Diff:          """Return the difference of users between the cache of `guild` and the database."""          log.trace("Getting the diff for users.") -        users = await self.bot.api_client.get('bot/users') + +        users = await self._get_users()          # Pack DB roles and guild roles into one common, hashable format.          # They're hashable so that they're easily comparable with sets later. @@ -314,9 +316,18 @@ class UserSyncer(Syncer):          for db_user in db_users.values():              guild_user = guild_users.get(db_user.id) +              if guild_user is not None:                  if db_user != guild_user: -                    users_to_update.add(guild_user) +                    fields_to_none: dict = {} + +                    for field in _User._fields: +                        # Set un-changed values to None except ID to speed up API PATCH method. +                        if getattr(db_user, field) == getattr(guild_user, field) and field != "id": +                            fields_to_none[field] = None + +                    new_api_user = guild_user._replace(**fields_to_none) +                    users_to_update.add(new_api_user)              elif db_user.in_guild:                  # The user is known in the DB but not the guild, and the @@ -324,7 +335,13 @@ class UserSyncer(Syncer):                  # This means that the user has left since the last sync.                  # Update the `in_guild` attribute of the user on the site                  # to signify that the user left. -                new_api_user = db_user._replace(in_guild=False) + +                # Set un-changed fields to None except ID as it is required by the API. +                fields_to_none: dict = {field: None for field in db_user._fields if field not in ["id", "in_guild"]} +                new_api_user = db_user._replace( +                    in_guild=False, +                    **fields_to_none +                )                  users_to_update.add(new_api_user)          new_user_ids = set(guild_users.keys()) - set(db_users.keys()) @@ -336,12 +353,47 @@ class UserSyncer(Syncer):          return _Diff(users_to_create, users_to_update, None) +    async def _get_users(self, endpoint: str = "bot/users", query_params: dict = None) -> t.List[dict]: +        """GET all users recursively.""" +        users: list = [] +        response: dict = await self.bot.api_client.get(endpoint, params=query_params) +        users.extend(response["results"]) + +        # The `response` is paginated, hence check if next page exists. +        if (next_page_url := response["next"]) is not None: +            next_endpoint, query_params = self.get_endpoint(next_page_url) +            users.extend(await self._get_users(next_endpoint, query_params)) + +        return users + +    @staticmethod +    def get_endpoint(url: str) -> tuple: +        """Extract the API endpoint and query params from a URL.""" +        url = urlparse(url) + +        # Do not include starting `/` for endpoint. +        endpoint = url.path[1:] + +        # Query params. +        params = parse_qsl(url.query) + +        return endpoint, params + +    @staticmethod +    def patch_dict(user: _User) -> dict: +        """Convert namedtuple to dict by omitting None values.""" +        user_dict: dict = {} +        for field in user._fields: +            if (value := getattr(user, field)) is not None: +                user_dict[field] = value +        return user_dict +      async def _sync(self, diff: _Diff) -> None:          """Synchronise the database with the user cache of `guild`."""          log.trace("Syncing created users...") -        for user in diff.created: -            await self.bot.api_client.post('bot/users', json=user._asdict()) - -        log.trace("Syncing updated users...") -        for user in diff.updated: -            await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) +        if diff.created: +            created: list = [user._asdict() for user in diff.created] +            await self.bot.api_client.post("bot/users", json=created) +        if diff.updated: +            updated: list = [self.patch_dict(user) for user in diff.updated] +            await self.bot.api_client.patch("bot/users/bulk_patch", json=updated) | 
