diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/exts/backend/sync/_syncers.py | 69 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_users.py | 95 | 
2 files changed, 136 insertions, 28 deletions
| diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 3d4a09df3..759af96d7 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,6 +2,8 @@ import abc  import logging  import typing as t  from collections import namedtuple +from functools import partial +from urllib.parse import parse_qsl, urlparse  from discord import Guild  from discord.ext.commands import Context @@ -134,7 +136,8 @@ class UserSyncer(Syncer):      async def _get_diff(self, guild: Guild) -> _Diff:          """Return the difference of users between the cache of `guild` and the database."""          log.trace("Getting the diff for users.") -        users = await self.bot.api_client.get('bot/users') + +        users = await self._get_users()          # Pack DB roles and guild roles into one common, hashable format.          # They're hashable so that they're easily comparable with sets later. @@ -161,9 +164,18 @@ class UserSyncer(Syncer):          for db_user in db_users.values():              guild_user = guild_users.get(db_user.id) +              if guild_user is not None:                  if db_user != guild_user: -                    users_to_update.add(guild_user) +                    fields_to_none: dict = {} + +                    for field in _User._fields: +                        # Set un-changed values to None except ID to speed up API PATCH method. +                        if getattr(db_user, field) == getattr(guild_user, field) and field != "id": +                            fields_to_none[field] = None + +                    new_api_user = guild_user._replace(**fields_to_none) +                    users_to_update.add(new_api_user)              elif db_user.in_guild:                  # The user is known in the DB but not the guild, and the @@ -171,7 +183,13 @@ class UserSyncer(Syncer):                  # This means that the user has left since the last sync.                  # Update the `in_guild` attribute of the user on the site                  # to signify that the user left. -                new_api_user = db_user._replace(in_guild=False) + +                # Set un-changed fields to None except ID as it is required by the API. +                fields_to_none: dict = {field: None for field in db_user._fields if field not in ["id", "in_guild"]} +                new_api_user = db_user._replace( +                    in_guild=False, +                    **fields_to_none +                )                  users_to_update.add(new_api_user)          new_user_ids = set(guild_users.keys()) - set(db_users.keys()) @@ -183,12 +201,47 @@ class UserSyncer(Syncer):          return _Diff(users_to_create, users_to_update, None) +    async def _get_users(self, endpoint: str = "bot/users", query_params: list = None) -> t.List[dict]: +        """GET all users recursively.""" +        users = [] +        response: dict = await self.bot.api_client.get(endpoint, params=query_params) +        users.extend(response["results"]) + +        # The `response` is paginated, hence check if next page exists. +        if (next_page_url := response["next"]) is not None: +            next_endpoint, query_params = self.get_endpoint(next_page_url) +            users.extend(await self._get_users(next_endpoint, query_params)) +        return users + +    @staticmethod +    def get_endpoint(url: str) -> t.Tuple[str, t.List[tuple]]: +        """Extract the API endpoint and query params from a URL.""" +        url = urlparse(url) + +        # Do not include starting `/` for endpoint. +        endpoint = url.path[1:] + +        # Query params. +        params = parse_qsl(url.query) + +        return endpoint, params + +    @staticmethod +    def patch_dict(user: _User) -> t.Dict[str, t.Union[int, str, tuple, bool]]: +        """Convert namedtuple to dict by omitting None values.""" +        user_dict = {} +        for field in user._fields: +            if (value := getattr(user, field)) is not None: +                user_dict[field] = value +        return user_dict +      async def _sync(self, diff: _Diff) -> None:          """Synchronise the database with the user cache of `guild`."""          log.trace("Syncing created users...") -        for user in diff.created: -            await self.bot.api_client.post('bot/users', json=user._asdict()) - +        if diff.created: +            created = [user._asdict() for user in diff.created] +            await self.bot.api_client.post("bot/users", json=created)          log.trace("Syncing updated users...") -        for user in diff.updated: -            await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) +        if diff.updated: +            updated = [self.patch_dict(user) for user in diff.updated] +            await self.bot.api_client.patch("bot/users/bulk_patch", json=updated) diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index c0a1da35c..c3a486743 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,5 +1,4 @@  import unittest -from unittest import mock  from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User  from tests import helpers @@ -16,6 +15,16 @@ def fake_user(**kwargs):      return kwargs +def fake_none_user(**kwargs): +    kwargs.setdefault("id", None) +    kwargs.setdefault("name", None) +    kwargs.setdefault("discriminator", None) +    kwargs.setdefault("roles", None) +    kwargs.setdefault("in_guild", None) + +    return kwargs + +  class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between users in the DB and users in the Guild cache.""" @@ -42,6 +51,12 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      async def test_empty_diff_for_no_users(self):          """When no users are given, an empty diff should be returned.""" +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next": None, +            "previous": None, +            "results": [] +        }          guild = self.get_guild()          actual_diff = await self.syncer._get_diff(guild) @@ -51,7 +66,12 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      async def test_empty_diff_for_identical_users(self):          """No differences should be found if the users in the guild and DB are identical.""" -        self.bot.api_client.get.return_value = [fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next": None, +            "previous": None, +            "results": [fake_user()] +        }          guild = self.get_guild(fake_user())          actual_diff = await self.syncer._get_diff(guild) @@ -62,12 +82,18 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      async def test_diff_for_updated_users(self):          """Only updated users should be added to the 'updated' set of the diff."""          updated_user = fake_user(id=99, name="new") - -        self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] +        updated_user_none = fake_none_user(id=99, name="new") + +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next": None, +            "previous": None, +            "results": [fake_user(id=99, name="old"), fake_user()] +        }          guild = self.get_guild(updated_user, fake_user())          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), {_User(**updated_user)}, None) +        expected_diff = (set(), {_User(**updated_user_none)}, None)          self.assertEqual(actual_diff, expected_diff) @@ -75,7 +101,12 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          """Only new users should be added to the 'created' set of the diff."""          new_user = fake_user(id=99, name="new") -        self.bot.api_client.get.return_value = [fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next": None, +            "previous": None, +            "results": [fake_user()] +        }          guild = self.get_guild(fake_user(), new_user)          actual_diff = await self.syncer._get_diff(guild) @@ -85,33 +116,58 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" -        leaving_user = fake_user(id=63, in_guild=False) - -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] +        leaving_user_none = fake_none_user(id=63, in_guild=False) + +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next": None, +            "previous": None, +            "results": [fake_user(), fake_user(id=63)] +        }          guild = self.get_guild(fake_user())          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), {_User(**leaving_user)}, None) +        expected_diff = (set(), {_User(**leaving_user_none)}, None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_for_new_updated_and_leaving_users(self):          """When users are added, updated, and removed, all of them are returned properly."""          new_user = fake_user(id=99, name="new") +          updated_user = fake_user(id=55, name="updated") -        leaving_user = fake_user(id=63, in_guild=False) +        updated_user_none = fake_none_user(id=55, name="updated") + +        leaving_user_none = fake_none_user(id=63, in_guild=False) -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next": None, +            "previous": None, +            "results": [fake_user(), fake_user(id=55), fake_user(id=63)] +        }          guild = self.get_guild(fake_user(), new_user, updated_user)          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) +        expected_diff = ( +            {_User(**new_user)}, +            { +                _User(**updated_user_none), +                _User(**leaving_user_none) +            }, +            None +        )          self.assertEqual(actual_diff, expected_diff)      async def test_empty_diff_for_db_users_not_in_guild(self):          """When the DB knows a user the guild doesn't, no difference is found.""" -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next": None, +            "previous": None, +            "results": [fake_user(), fake_user(id=63, in_guild=False)] +        }          guild = self.get_guild(fake_user())          actual_diff = await self.syncer._get_diff(guild) @@ -135,9 +191,9 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          diff = _Diff(user_tuples, set(), None)          await self.syncer._sync(diff) -        calls = [mock.call("bot/users", json=user) for user in users] -        self.bot.api_client.post.assert_has_calls(calls, any_order=True) -        self.assertEqual(self.bot.api_client.post.call_count, len(users)) +        # Convert namedtuples to dicts as done in self.syncer._sync method. +        created = [user._asdict() for user in diff.created] +        self.bot.api_client.post.assert_called_once_with("bot/users", json=created)          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called() @@ -150,9 +206,8 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          diff = _Diff(set(), user_tuples, None)          await self.syncer._sync(diff) -        calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] -        self.bot.api_client.put.assert_has_calls(calls, any_order=True) -        self.assertEqual(self.bot.api_client.put.call_count, len(users)) +        updated = [self.syncer.patch_dict(user) for user in diff.updated] +        self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=updated)          self.bot.api_client.post.assert_not_called()          self.bot.api_client.delete.assert_not_called() | 
