aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/backend/sync/_syncers.py68
-rw-r--r--tests/bot/exts/backend/sync/test_users.py95
2 files changed, 135 insertions, 28 deletions
diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py
index 3d4a09df3..ae7d5d893 100644
--- a/bot/exts/backend/sync/_syncers.py
+++ b/bot/exts/backend/sync/_syncers.py
@@ -2,6 +2,7 @@ import abc
import logging
import typing as t
from collections import namedtuple
+from urllib.parse import parse_qsl, urlparse
from discord import Guild
from discord.ext.commands import Context
@@ -134,7 +135,8 @@ class UserSyncer(Syncer):
async def _get_diff(self, guild: Guild) -> _Diff:
"""Return the difference of users between the cache of `guild` and the database."""
log.trace("Getting the diff for users.")
- users = await self.bot.api_client.get('bot/users')
+
+ users = await self._get_users()
# Pack DB roles and guild roles into one common, hashable format.
# They're hashable so that they're easily comparable with sets later.
@@ -161,9 +163,18 @@ class UserSyncer(Syncer):
for db_user in db_users.values():
guild_user = guild_users.get(db_user.id)
+
if guild_user is not None:
if db_user != guild_user:
- users_to_update.add(guild_user)
+ fields_to_none: dict = {}
+
+ for field in _User._fields:
+ # Set un-changed values to None except ID to speed up API PATCH method.
+ if getattr(db_user, field) == getattr(guild_user, field) and field != "id":
+ fields_to_none[field] = None
+
+ new_api_user = guild_user._replace(**fields_to_none)
+ users_to_update.add(new_api_user)
elif db_user.in_guild:
# The user is known in the DB but not the guild, and the
@@ -171,7 +182,13 @@ class UserSyncer(Syncer):
# This means that the user has left since the last sync.
# Update the `in_guild` attribute of the user on the site
# to signify that the user left.
- new_api_user = db_user._replace(in_guild=False)
+
+ # Set un-changed fields to None except ID as it is required by the API.
+ fields_to_none: dict = {field: None for field in db_user._fields if field not in ["id", "in_guild"]}
+ new_api_user = db_user._replace(
+ in_guild=False,
+ **fields_to_none
+ )
users_to_update.add(new_api_user)
new_user_ids = set(guild_users.keys()) - set(db_users.keys())
@@ -183,12 +200,47 @@ class UserSyncer(Syncer):
return _Diff(users_to_create, users_to_update, None)
+ async def _get_users(self, endpoint: str = "bot/users", query_params: list = None) -> t.List[dict]:
+ """GET all users recursively."""
+ users = []
+ response: dict = await self.bot.api_client.get(endpoint, params=query_params)
+ users.extend(response["results"])
+
+ # The `response` is paginated, hence check if next page exists.
+ if (next_page_url := response["next"]) is not None:
+ next_endpoint, query_params = self.get_endpoint(next_page_url)
+ users.extend(await self._get_users(next_endpoint, query_params))
+ return users
+
+ @staticmethod
+ def get_endpoint(url: str) -> t.Tuple[str, t.List[tuple]]:
+ """Extract the API endpoint and query params from a URL."""
+ url = urlparse(url)
+
+ # Do not include starting `/` for endpoint.
+ endpoint = url.path[1:]
+
+ # Query params.
+ params = parse_qsl(url.query)
+
+ return endpoint, params
+
+ @staticmethod
+ def patch_dict(user: _User) -> t.Dict[str, t.Union[int, str, tuple, bool]]:
+ """Convert namedtuple to dict by omitting None values."""
+ user_dict = {}
+ for field in user._fields:
+ if (value := getattr(user, field)) is not None:
+ user_dict[field] = value
+ return user_dict
+
async def _sync(self, diff: _Diff) -> None:
"""Synchronise the database with the user cache of `guild`."""
log.trace("Syncing created users...")
- for user in diff.created:
- await self.bot.api_client.post('bot/users', json=user._asdict())
-
+ if diff.created:
+ created = [user._asdict() for user in diff.created]
+ await self.bot.api_client.post("bot/users", json=created)
log.trace("Syncing updated users...")
- for user in diff.updated:
- await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict())
+ if diff.updated:
+ updated = [self.patch_dict(user) for user in diff.updated]
+ await self.bot.api_client.patch("bot/users/bulk_patch", json=updated)
diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py
index c0a1da35c..c3a486743 100644
--- a/tests/bot/exts/backend/sync/test_users.py
+++ b/tests/bot/exts/backend/sync/test_users.py
@@ -1,5 +1,4 @@
import unittest
-from unittest import mock
from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User
from tests import helpers
@@ -16,6 +15,16 @@ def fake_user(**kwargs):
return kwargs
+def fake_none_user(**kwargs):
+ kwargs.setdefault("id", None)
+ kwargs.setdefault("name", None)
+ kwargs.setdefault("discriminator", None)
+ kwargs.setdefault("roles", None)
+ kwargs.setdefault("in_guild", None)
+
+ return kwargs
+
+
class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Tests for determining differences between users in the DB and users in the Guild cache."""
@@ -42,6 +51,12 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
async def test_empty_diff_for_no_users(self):
"""When no users are given, an empty diff should be returned."""
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next": None,
+ "previous": None,
+ "results": []
+ }
guild = self.get_guild()
actual_diff = await self.syncer._get_diff(guild)
@@ -51,7 +66,12 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
async def test_empty_diff_for_identical_users(self):
"""No differences should be found if the users in the guild and DB are identical."""
- self.bot.api_client.get.return_value = [fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next": None,
+ "previous": None,
+ "results": [fake_user()]
+ }
guild = self.get_guild(fake_user())
actual_diff = await self.syncer._get_diff(guild)
@@ -62,12 +82,18 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
async def test_diff_for_updated_users(self):
"""Only updated users should be added to the 'updated' set of the diff."""
updated_user = fake_user(id=99, name="new")
-
- self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()]
+ updated_user_none = fake_none_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next": None,
+ "previous": None,
+ "results": [fake_user(id=99, name="old"), fake_user()]
+ }
guild = self.get_guild(updated_user, fake_user())
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**updated_user)}, None)
+ expected_diff = (set(), {_User(**updated_user_none)}, None)
self.assertEqual(actual_diff, expected_diff)
@@ -75,7 +101,12 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Only new users should be added to the 'created' set of the diff."""
new_user = fake_user(id=99, name="new")
- self.bot.api_client.get.return_value = [fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next": None,
+ "previous": None,
+ "results": [fake_user()]
+ }
guild = self.get_guild(fake_user(), new_user)
actual_diff = await self.syncer._get_diff(guild)
@@ -85,33 +116,58 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- leaving_user = fake_user(id=63, in_guild=False)
-
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)]
+ leaving_user_none = fake_none_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next": None,
+ "previous": None,
+ "results": [fake_user(), fake_user(id=63)]
+ }
guild = self.get_guild(fake_user())
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**leaving_user)}, None)
+ expected_diff = (set(), {_User(**leaving_user_none)}, None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_new_updated_and_leaving_users(self):
"""When users are added, updated, and removed, all of them are returned properly."""
new_user = fake_user(id=99, name="new")
+
updated_user = fake_user(id=55, name="updated")
- leaving_user = fake_user(id=63, in_guild=False)
+ updated_user_none = fake_none_user(id=55, name="updated")
+
+ leaving_user_none = fake_none_user(id=63, in_guild=False)
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next": None,
+ "previous": None,
+ "results": [fake_user(), fake_user(id=55), fake_user(id=63)]
+ }
guild = self.get_guild(fake_user(), new_user, updated_user)
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None)
+ expected_diff = (
+ {_User(**new_user)},
+ {
+ _User(**updated_user_none),
+ _User(**leaving_user_none)
+ },
+ None
+ )
self.assertEqual(actual_diff, expected_diff)
async def test_empty_diff_for_db_users_not_in_guild(self):
"""When the DB knows a user the guild doesn't, no difference is found."""
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next": None,
+ "previous": None,
+ "results": [fake_user(), fake_user(id=63, in_guild=False)]
+ }
guild = self.get_guild(fake_user())
actual_diff = await self.syncer._get_diff(guild)
@@ -135,9 +191,9 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
diff = _Diff(user_tuples, set(), None)
await self.syncer._sync(diff)
- calls = [mock.call("bot/users", json=user) for user in users]
- self.bot.api_client.post.assert_has_calls(calls, any_order=True)
- self.assertEqual(self.bot.api_client.post.call_count, len(users))
+ # Convert namedtuples to dicts as done in self.syncer._sync method.
+ created = [user._asdict() for user in diff.created]
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=created)
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
@@ -150,9 +206,8 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
diff = _Diff(set(), user_tuples, None)
await self.syncer._sync(diff)
- calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users]
- self.bot.api_client.put.assert_has_calls(calls, any_order=True)
- self.assertEqual(self.bot.api_client.put.call_count, len(users))
+ updated = [self.syncer.patch_dict(user) for user in diff.updated]
+ self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=updated)
self.bot.api_client.post.assert_not_called()
self.bot.api_client.delete.assert_not_called()