diff options
| -rw-r--r-- | bot/exts/backend/sync/_syncers.py | 70 |
1 files changed, 61 insertions, 9 deletions
diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index f7ba811bc..512efaa3d 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -4,6 +4,7 @@ import logging import typing as t from collections import namedtuple from functools import partial +from urllib.parse import parse_qsl, urlparse import discord from discord import Guild, HTTPException, Member, Message, Reaction, User @@ -287,7 +288,8 @@ class UserSyncer(Syncer): async def _get_diff(self, guild: Guild) -> _Diff: """Return the difference of users between the cache of `guild` and the database.""" log.trace("Getting the diff for users.") - users = await self.bot.api_client.get('bot/users') + + users = await self._get_users() # Pack DB roles and guild roles into one common, hashable format. # They're hashable so that they're easily comparable with sets later. @@ -314,9 +316,18 @@ class UserSyncer(Syncer): for db_user in db_users.values(): guild_user = guild_users.get(db_user.id) + if guild_user is not None: if db_user != guild_user: - users_to_update.add(guild_user) + fields_to_none: dict = {} + + for field in _User._fields: + # Set un-changed values to None except ID to speed up API PATCH method. + if getattr(db_user, field) == getattr(guild_user, field) and field != "id": + fields_to_none[field] = None + + new_api_user = guild_user._replace(**fields_to_none) + users_to_update.add(new_api_user) elif db_user.in_guild: # The user is known in the DB but not the guild, and the @@ -324,7 +335,13 @@ class UserSyncer(Syncer): # This means that the user has left since the last sync. # Update the `in_guild` attribute of the user on the site # to signify that the user left. - new_api_user = db_user._replace(in_guild=False) + + # Set un-changed fields to None except ID as it is required by the API. + fields_to_none: dict = {field: None for field in db_user._fields if field not in ["id", "in_guild"]} + new_api_user = db_user._replace( + in_guild=False, + **fields_to_none + ) users_to_update.add(new_api_user) new_user_ids = set(guild_users.keys()) - set(db_users.keys()) @@ -336,12 +353,47 @@ class UserSyncer(Syncer): return _Diff(users_to_create, users_to_update, None) + async def _get_users(self, endpoint: str = "bot/users", query_params: dict = None) -> t.List[dict]: + """GET all users recursively.""" + users: list = [] + response: dict = await self.bot.api_client.get(endpoint, params=query_params) + users.extend(response["results"]) + + # The `response` is paginated, hence check if next page exists. + if (next_page_url := response["next"]) is not None: + next_endpoint, query_params = self.get_endpoint(next_page_url) + users.extend(await self._get_users(next_endpoint, query_params)) + + return users + + @staticmethod + def get_endpoint(url: str) -> tuple: + """Extract the API endpoint and query params from a URL.""" + url = urlparse(url) + + # Do not include starting `/` for endpoint. + endpoint = url.path[1:] + + # Query params. + params = parse_qsl(url.query) + + return endpoint, params + + @staticmethod + def patch_dict(user: _User) -> dict: + """Convert namedtuple to dict by omitting None values.""" + user_dict: dict = {} + for field in user._fields: + if (value := getattr(user, field)) is not None: + user_dict[field] = value + return user_dict + async def _sync(self, diff: _Diff) -> None: """Synchronise the database with the user cache of `guild`.""" log.trace("Syncing created users...") - for user in diff.created: - await self.bot.api_client.post('bot/users', json=user._asdict()) - - log.trace("Syncing updated users...") - for user in diff.updated: - await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) + if diff.created: + created: list = [user._asdict() for user in diff.created] + await self.bot.api_client.post("bot/users", json=created) + if diff.updated: + updated: list = [self.patch_dict(user) for user in diff.updated] + await self.bot.api_client.patch("bot/users/bulk_patch", json=updated) |