diff options
author | 2020-10-07 23:02:39 +0530 | |
---|---|---|
committer | 2020-10-07 23:02:39 +0530 | |
commit | 7e7a801366e2bf8f1190fae91f93729b33f32895 (patch) | |
tree | a8f55b9e96693ca4a7173cf5d00804226363b114 | |
parent | Merge remote-tracking branch 'upstream/master' into smart_syncing_users (diff) |
improve code efficiency and use updated API changes to pagination
-rw-r--r-- | bot/exts/backend/sync/_syncers.py | 146 |
1 files changed, 48 insertions, 98 deletions
diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index ae7d5d893..70887a217 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,7 +2,6 @@ import abc import logging import typing as t from collections import namedtuple -from urllib.parse import parse_qsl, urlparse from discord import Guild from discord.ext.commands import Context @@ -15,7 +14,6 @@ log = logging.getLogger(__name__) # These objects are declared as namedtuples because tuples are hashable, # something that we make use of when diffing site roles against guild roles. _Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) _Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) @@ -42,11 +40,7 @@ class Syncer(abc.ABC): raise NotImplementedError # pragma: no cover async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: - """ - Synchronise the database with the cache of `guild`. - - If `ctx` is given, send a message with the results. - """ + """If `ctx` is given, send a message with the results.""" log.info(f"Starting {self.name} syncer.") if ctx: @@ -136,111 +130,67 @@ class UserSyncer(Syncer): """Return the difference of users between the cache of `guild` and the database.""" log.trace("Getting the diff for users.") - users = await self._get_users() + users_to_create = [] + users_to_update = [] + seen_guild_users = set() - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_users = { - user_dict['id']: _User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in users - } - guild_users = { - member.id: _User( - id=member.id, - name=member.name, - discriminator=int(member.discriminator), - roles=tuple(sorted(role.id for role in member.roles)), - in_guild=True - ) - for member in guild.members - } + async for db_user in self._get_users(): + updated_fields = {} - users_to_create = set() - users_to_update = set() - - for db_user in db_users.values(): - guild_user = guild_users.get(db_user.id) - - if guild_user is not None: - if db_user != guild_user: - fields_to_none: dict = {} - - for field in _User._fields: - # Set un-changed values to None except ID to speed up API PATCH method. - if getattr(db_user, field) == getattr(guild_user, field) and field != "id": - fields_to_none[field] = None - - new_api_user = guild_user._replace(**fields_to_none) - users_to_update.add(new_api_user) - - elif db_user.in_guild: - # The user is known in the DB but not the guild, and the - # DB currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - - # Set un-changed fields to None except ID as it is required by the API. - fields_to_none: dict = {field: None for field in db_user._fields if field not in ["id", "in_guild"]} - new_api_user = db_user._replace( - in_guild=False, - **fields_to_none - ) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(db_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) + def maybe_update(db_field: str, guild_value: t.Union[str, int]) -> None: + if db_user[db_field] != guild_value: + updated_fields[db_field] = guild_value - return _Diff(users_to_create, users_to_update, None) + if guild_user := guild.get_member(db_user["id"]): + seen_guild_users.add(guild_user.id) + + maybe_update("name", guild_user.name) + maybe_update("discriminator", int(guild_user.discriminator)) + maybe_update("in_guild", True) - async def _get_users(self, endpoint: str = "bot/users", query_params: list = None) -> t.List[dict]: - """GET all users recursively.""" - users = [] - response: dict = await self.bot.api_client.get(endpoint, params=query_params) - users.extend(response["results"]) + guild_roles = [role.id for role in guild_user.roles] + if set(db_user["roles"]) != set(guild_roles): + updated_fields["roles"] = guild_roles - # The `response` is paginated, hence check if next page exists. - if (next_page_url := response["next"]) is not None: - next_endpoint, query_params = self.get_endpoint(next_page_url) - users.extend(await self._get_users(next_endpoint, query_params)) - return users + elif db_user["in_guild"]: + updated_fields["in_guild"] = False - @staticmethod - def get_endpoint(url: str) -> t.Tuple[str, t.List[tuple]]: - """Extract the API endpoint and query params from a URL.""" - url = urlparse(url) + if updated_fields and updated_fields not in users_to_update: + updated_fields["id"] = db_user["id"] + users_to_update.append(updated_fields) - # Do not include starting `/` for endpoint. - endpoint = url.path[1:] + for member in guild.members: + if member.id not in seen_guild_users: + new_user = { + "id": member.id, + "name": member.name, + "discriminator": int(member.discriminator), + "roles": [role.id for role in member.roles], + "in_guild": True + } + if new_user not in users_to_create: + users_to_create.append(new_user) - # Query params. - params = parse_qsl(url.query) + return _Diff(users_to_create, users_to_update, None) - return endpoint, params + async def _get_users(self) -> t.AsyncIterable: + """GET users from database.""" + query_params = { + "page": 1 + } + while query_params["page"]: + res = await self.bot.api_client.get("bot/users", params=query_params) + for user in res["results"]: + yield user - @staticmethod - def patch_dict(user: _User) -> t.Dict[str, t.Union[int, str, tuple, bool]]: - """Convert namedtuple to dict by omitting None values.""" - user_dict = {} - for field in user._fields: - if (value := getattr(user, field)) is not None: - user_dict[field] = value - return user_dict + query_params["page"] = res["next_page_no"] async def _sync(self, diff: _Diff) -> None: """Synchronise the database with the user cache of `guild`.""" log.trace("Syncing created users...") if diff.created: - created = [user._asdict() for user in diff.created] - await self.bot.api_client.post("bot/users", json=created) + await self.bot.api_client.post("bot/users", json=diff.created) + log.trace("Syncing updated users...") if diff.updated: - updated = [self.patch_dict(user) for user in diff.updated] - await self.bot.api_client.patch("bot/users/bulk_patch", json=updated) + await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated) |