diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/sync/syncers.py | 30 | 
1 files changed, 29 insertions, 1 deletions
| diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index f7ba811bc..156c32a15 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -4,6 +4,7 @@ import logging  import typing as t  from collections import namedtuple  from functools import partial +from urllib.parse import parse_qsl, urlparse  import discord  from discord import Guild, HTTPException, Member, Message, Reaction, User @@ -287,7 +288,8 @@ class UserSyncer(Syncer):      async def _get_diff(self, guild: Guild) -> _Diff:          """Return the difference of users between the cache of `guild` and the database."""          log.trace("Getting the diff for users.") -        users = await self.bot.api_client.get('bot/users') + +        users = await self._get_users()          # Pack DB roles and guild roles into one common, hashable format.          # They're hashable so that they're easily comparable with sets later. @@ -336,6 +338,32 @@ class UserSyncer(Syncer):          return _Diff(users_to_create, users_to_update, None) +    async def _get_users(self, endpoint: str = "bot/users", query_params: dict = None) -> t.List[dict]: +        """GET all users recursively.""" +        users: list = [] +        response: dict = await self.bot.api_client.get(endpoint, params=query_params) +        users.extend(response["results"]) + +        # The `response` is paginated, hence check if next page exists. +        if (next_page_url := response["next"]) is not None: +            next_endpoint, query_params = self.get_endpoint(next_page_url) +            users.extend(await self._get_users(next_endpoint, query_params)) + +        return users + +    @staticmethod +    def get_endpoint(url: str) -> tuple: +        """Extract the API endpoint and query params from a URL.""" +        url = urlparse(url) + +        # Do not include starting `/` for endpoint. +        endpoint = url.path[1:] + +        # Query params. +        params = parse_qsl(url.query) + +        return endpoint, params +      async def _sync(self, diff: _Diff) -> None:          """Synchronise the database with the user cache of `guild`."""          log.trace("Syncing created users...") | 
