aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/backend/sync/_syncers.py70
1 files changed, 61 insertions, 9 deletions
diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py
index f7ba811bc..512efaa3d 100644
--- a/bot/exts/backend/sync/_syncers.py
+++ b/bot/exts/backend/sync/_syncers.py
@@ -4,6 +4,7 @@ import logging
import typing as t
from collections import namedtuple
from functools import partial
+from urllib.parse import parse_qsl, urlparse
import discord
from discord import Guild, HTTPException, Member, Message, Reaction, User
@@ -287,7 +288,8 @@ class UserSyncer(Syncer):
async def _get_diff(self, guild: Guild) -> _Diff:
"""Return the difference of users between the cache of `guild` and the database."""
log.trace("Getting the diff for users.")
- users = await self.bot.api_client.get('bot/users')
+
+ users = await self._get_users()
# Pack DB roles and guild roles into one common, hashable format.
# They're hashable so that they're easily comparable with sets later.
@@ -314,9 +316,18 @@ class UserSyncer(Syncer):
for db_user in db_users.values():
guild_user = guild_users.get(db_user.id)
+
if guild_user is not None:
if db_user != guild_user:
- users_to_update.add(guild_user)
+ fields_to_none: dict = {}
+
+ for field in _User._fields:
+ # Set un-changed values to None except ID to speed up API PATCH method.
+ if getattr(db_user, field) == getattr(guild_user, field) and field != "id":
+ fields_to_none[field] = None
+
+ new_api_user = guild_user._replace(**fields_to_none)
+ users_to_update.add(new_api_user)
elif db_user.in_guild:
# The user is known in the DB but not the guild, and the
@@ -324,7 +335,13 @@ class UserSyncer(Syncer):
# This means that the user has left since the last sync.
# Update the `in_guild` attribute of the user on the site
# to signify that the user left.
- new_api_user = db_user._replace(in_guild=False)
+
+ # Set un-changed fields to None except ID as it is required by the API.
+ fields_to_none: dict = {field: None for field in db_user._fields if field not in ["id", "in_guild"]}
+ new_api_user = db_user._replace(
+ in_guild=False,
+ **fields_to_none
+ )
users_to_update.add(new_api_user)
new_user_ids = set(guild_users.keys()) - set(db_users.keys())
@@ -336,12 +353,47 @@ class UserSyncer(Syncer):
return _Diff(users_to_create, users_to_update, None)
+ async def _get_users(self, endpoint: str = "bot/users", query_params: dict = None) -> t.List[dict]:
+ """GET all users recursively."""
+ users: list = []
+ response: dict = await self.bot.api_client.get(endpoint, params=query_params)
+ users.extend(response["results"])
+
+ # The `response` is paginated, hence check if next page exists.
+ if (next_page_url := response["next"]) is not None:
+ next_endpoint, query_params = self.get_endpoint(next_page_url)
+ users.extend(await self._get_users(next_endpoint, query_params))
+
+ return users
+
+ @staticmethod
+ def get_endpoint(url: str) -> tuple:
+ """Extract the API endpoint and query params from a URL."""
+ url = urlparse(url)
+
+ # Do not include starting `/` for endpoint.
+ endpoint = url.path[1:]
+
+ # Query params.
+ params = parse_qsl(url.query)
+
+ return endpoint, params
+
+ @staticmethod
+ def patch_dict(user: _User) -> dict:
+ """Convert namedtuple to dict by omitting None values."""
+ user_dict: dict = {}
+ for field in user._fields:
+ if (value := getattr(user, field)) is not None:
+ user_dict[field] = value
+ return user_dict
+
async def _sync(self, diff: _Diff) -> None:
"""Synchronise the database with the user cache of `guild`."""
log.trace("Syncing created users...")
- for user in diff.created:
- await self.bot.api_client.post('bot/users', json=user._asdict())
-
- log.trace("Syncing updated users...")
- for user in diff.updated:
- await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict())
+ if diff.created:
+ created: list = [user._asdict() for user in diff.created]
+ await self.bot.api_client.post("bot/users", json=created)
+ if diff.updated:
+ updated: list = [self.patch_dict(user) for user in diff.updated]
+ await self.bot.api_client.patch("bot/users/bulk_patch", json=updated)