diff options
| author | 2020-09-21 19:21:13 +0530 | |
|---|---|---|
| committer | 2020-09-21 19:21:13 +0530 | |
| commit | 681027b61663bcdff5b174aa3e06f34b54f05349 (patch) | |
| tree | a091ea8db162315459fd00f0f4a7f33dfeda1762 | |
| parent | Merge pull request #1158 from python-discord/config-update (diff) | |
refactor code to GET users from site endpoint `bot/users` with pagination
Added method to recursively GET users if paginated and another method to parse URL and return endpoint and query parameters.
| -rw-r--r-- | bot/cogs/sync/syncers.py | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index f7ba811bc..156c32a15 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -4,6 +4,7 @@ import logging import typing as t from collections import namedtuple from functools import partial +from urllib.parse import parse_qsl, urlparse import discord from discord import Guild, HTTPException, Member, Message, Reaction, User @@ -287,7 +288,8 @@ class UserSyncer(Syncer): async def _get_diff(self, guild: Guild) -> _Diff: """Return the difference of users between the cache of `guild` and the database.""" log.trace("Getting the diff for users.") - users = await self.bot.api_client.get('bot/users') + + users = await self._get_users() # Pack DB roles and guild roles into one common, hashable format. # They're hashable so that they're easily comparable with sets later. @@ -336,6 +338,32 @@ class UserSyncer(Syncer): return _Diff(users_to_create, users_to_update, None) + async def _get_users(self, endpoint: str = "bot/users", query_params: dict = None) -> t.List[dict]: + """GET all users recursively.""" + users: list = [] + response: dict = await self.bot.api_client.get(endpoint, params=query_params) + users.extend(response["results"]) + + # The `response` is paginated, hence check if next page exists. + if (next_page_url := response["next"]) is not None: + next_endpoint, query_params = self.get_endpoint(next_page_url) + users.extend(await self._get_users(next_endpoint, query_params)) + + return users + + @staticmethod + def get_endpoint(url: str) -> tuple: + """Extract the API endpoint and query params from a URL.""" + url = urlparse(url) + + # Do not include starting `/` for endpoint. + endpoint = url.path[1:] + + # Query params. + params = parse_qsl(url.query) + + return endpoint, params + async def _sync(self, diff: _Diff) -> None: """Synchronise the database with the user cache of `guild`.""" log.trace("Syncing created users...") |