aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/backend/sync/_syncers.py146
-rw-r--r--tests/bot/exts/backend/sync/test_users.py117
2 files changed, 107 insertions, 156 deletions
diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py
index ae7d5d893..70887a217 100644
--- a/bot/exts/backend/sync/_syncers.py
+++ b/bot/exts/backend/sync/_syncers.py
@@ -2,7 +2,6 @@ import abc
import logging
import typing as t
from collections import namedtuple
-from urllib.parse import parse_qsl, urlparse
from discord import Guild
from discord.ext.commands import Context
@@ -15,7 +14,6 @@ log = logging.getLogger(__name__)
# These objects are declared as namedtuples because tuples are hashable,
# something that we make use of when diffing site roles against guild roles.
_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))
-_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild'))
_Diff = namedtuple('Diff', ('created', 'updated', 'deleted'))
@@ -42,11 +40,7 @@ class Syncer(abc.ABC):
raise NotImplementedError # pragma: no cover
async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None:
- """
- Synchronise the database with the cache of `guild`.
-
- If `ctx` is given, send a message with the results.
- """
+ """If `ctx` is given, send a message with the results."""
log.info(f"Starting {self.name} syncer.")
if ctx:
@@ -136,111 +130,67 @@ class UserSyncer(Syncer):
"""Return the difference of users between the cache of `guild` and the database."""
log.trace("Getting the diff for users.")
- users = await self._get_users()
+ users_to_create = []
+ users_to_update = []
+ seen_guild_users = set()
- # Pack DB roles and guild roles into one common, hashable format.
- # They're hashable so that they're easily comparable with sets later.
- db_users = {
- user_dict['id']: _User(
- roles=tuple(sorted(user_dict.pop('roles'))),
- **user_dict
- )
- for user_dict in users
- }
- guild_users = {
- member.id: _User(
- id=member.id,
- name=member.name,
- discriminator=int(member.discriminator),
- roles=tuple(sorted(role.id for role in member.roles)),
- in_guild=True
- )
- for member in guild.members
- }
+ async for db_user in self._get_users():
+ updated_fields = {}
- users_to_create = set()
- users_to_update = set()
-
- for db_user in db_users.values():
- guild_user = guild_users.get(db_user.id)
-
- if guild_user is not None:
- if db_user != guild_user:
- fields_to_none: dict = {}
-
- for field in _User._fields:
- # Set un-changed values to None except ID to speed up API PATCH method.
- if getattr(db_user, field) == getattr(guild_user, field) and field != "id":
- fields_to_none[field] = None
-
- new_api_user = guild_user._replace(**fields_to_none)
- users_to_update.add(new_api_user)
-
- elif db_user.in_guild:
- # The user is known in the DB but not the guild, and the
- # DB currently specifies that the user is a member of the guild.
- # This means that the user has left since the last sync.
- # Update the `in_guild` attribute of the user on the site
- # to signify that the user left.
-
- # Set un-changed fields to None except ID as it is required by the API.
- fields_to_none: dict = {field: None for field in db_user._fields if field not in ["id", "in_guild"]}
- new_api_user = db_user._replace(
- in_guild=False,
- **fields_to_none
- )
- users_to_update.add(new_api_user)
-
- new_user_ids = set(guild_users.keys()) - set(db_users.keys())
- for user_id in new_user_ids:
- # The user is known on the guild but not on the API. This means
- # that the user has joined since the last sync. Create it.
- new_user = guild_users[user_id]
- users_to_create.add(new_user)
+ def maybe_update(db_field: str, guild_value: t.Union[str, int]) -> None:
+ if db_user[db_field] != guild_value:
+ updated_fields[db_field] = guild_value
- return _Diff(users_to_create, users_to_update, None)
+ if guild_user := guild.get_member(db_user["id"]):
+ seen_guild_users.add(guild_user.id)
+
+ maybe_update("name", guild_user.name)
+ maybe_update("discriminator", int(guild_user.discriminator))
+ maybe_update("in_guild", True)
- async def _get_users(self, endpoint: str = "bot/users", query_params: list = None) -> t.List[dict]:
- """GET all users recursively."""
- users = []
- response: dict = await self.bot.api_client.get(endpoint, params=query_params)
- users.extend(response["results"])
+ guild_roles = [role.id for role in guild_user.roles]
+ if set(db_user["roles"]) != set(guild_roles):
+ updated_fields["roles"] = guild_roles
- # The `response` is paginated, hence check if next page exists.
- if (next_page_url := response["next"]) is not None:
- next_endpoint, query_params = self.get_endpoint(next_page_url)
- users.extend(await self._get_users(next_endpoint, query_params))
- return users
+ elif db_user["in_guild"]:
+ updated_fields["in_guild"] = False
- @staticmethod
- def get_endpoint(url: str) -> t.Tuple[str, t.List[tuple]]:
- """Extract the API endpoint and query params from a URL."""
- url = urlparse(url)
+ if updated_fields and updated_fields not in users_to_update:
+ updated_fields["id"] = db_user["id"]
+ users_to_update.append(updated_fields)
- # Do not include starting `/` for endpoint.
- endpoint = url.path[1:]
+ for member in guild.members:
+ if member.id not in seen_guild_users:
+ new_user = {
+ "id": member.id,
+ "name": member.name,
+ "discriminator": int(member.discriminator),
+ "roles": [role.id for role in member.roles],
+ "in_guild": True
+ }
+ if new_user not in users_to_create:
+ users_to_create.append(new_user)
- # Query params.
- params = parse_qsl(url.query)
+ return _Diff(users_to_create, users_to_update, None)
- return endpoint, params
+ async def _get_users(self) -> t.AsyncIterable:
+ """GET users from database."""
+ query_params = {
+ "page": 1
+ }
+ while query_params["page"]:
+ res = await self.bot.api_client.get("bot/users", params=query_params)
+ for user in res["results"]:
+ yield user
- @staticmethod
- def patch_dict(user: _User) -> t.Dict[str, t.Union[int, str, tuple, bool]]:
- """Convert namedtuple to dict by omitting None values."""
- user_dict = {}
- for field in user._fields:
- if (value := getattr(user, field)) is not None:
- user_dict[field] = value
- return user_dict
+ query_params["page"] = res["next_page_no"]
async def _sync(self, diff: _Diff) -> None:
"""Synchronise the database with the user cache of `guild`."""
log.trace("Syncing created users...")
if diff.created:
- created = [user._asdict() for user in diff.created]
- await self.bot.api_client.post("bot/users", json=created)
+ await self.bot.api_client.post("bot/users", json=diff.created)
+
log.trace("Syncing updated users...")
if diff.updated:
- updated = [self.patch_dict(user) for user in diff.updated]
- await self.bot.api_client.patch("bot/users/bulk_patch", json=updated)
+ await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated)
diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py
index c3a486743..9f380a15d 100644
--- a/tests/bot/exts/backend/sync/test_users.py
+++ b/tests/bot/exts/backend/sync/test_users.py
@@ -1,6 +1,6 @@
import unittest
-from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User
+from bot.exts.backend.sync._syncers import UserSyncer, _Diff
from tests import helpers
@@ -9,22 +9,12 @@ def fake_user(**kwargs):
kwargs.setdefault("id", 43)
kwargs.setdefault("name", "bob the test man")
kwargs.setdefault("discriminator", 1337)
- kwargs.setdefault("roles", (666,))
+ kwargs.setdefault("roles", [666])
kwargs.setdefault("in_guild", True)
return kwargs
-def fake_none_user(**kwargs):
- kwargs.setdefault("id", None)
- kwargs.setdefault("name", None)
- kwargs.setdefault("discriminator", None)
- kwargs.setdefault("roles", None)
- kwargs.setdefault("in_guild", None)
-
- return kwargs
-
-
class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Tests for determining differences between users in the DB and users in the Guild cache."""
@@ -49,18 +39,26 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
return guild
+ @staticmethod
+ def get_mock_member(member: dict):
+ member = member.copy()
+ del member["in_guild"]
+ mock_member = helpers.MockMember(**member)
+ mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]]
+ return mock_member
+
async def test_empty_diff_for_no_users(self):
"""When no users are given, an empty diff should be returned."""
self.bot.api_client.get.return_value = {
"count": 3,
- "next": None,
- "previous": None,
+ "next_page_no": None,
+ "previous_page_no": None,
"results": []
}
guild = self.get_guild()
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -68,66 +66,75 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""No differences should be found if the users in the guild and DB are identical."""
self.bot.api_client.get.return_value = {
"count": 3,
- "next": None,
- "previous": None,
+ "next_page_no": None,
+ "previous_page_no": None,
"results": [fake_user()]
}
guild = self.get_guild(fake_user())
+ guild.get_member.return_value = self.get_mock_member(fake_user())
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_updated_users(self):
"""Only updated users should be added to the 'updated' set of the diff."""
updated_user = fake_user(id=99, name="new")
- updated_user_none = fake_none_user(id=99, name="new")
self.bot.api_client.get.return_value = {
"count": 3,
- "next": None,
- "previous": None,
+ "next_page_no": None,
+ "previous_page_no": None,
"results": [fake_user(id=99, name="old"), fake_user()]
}
guild = self.get_guild(updated_user, fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(updated_user),
+ self.get_mock_member(fake_user())
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**updated_user_none)}, None)
+ expected_diff = ([], [{"id": 99, "name": "new"}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_new_users(self):
- """Only new users should be added to the 'created' set of the diff."""
+ """Only new users should be added to the 'created' list of the diff."""
new_user = fake_user(id=99, name="new")
self.bot.api_client.get.return_value = {
"count": 3,
- "next": None,
- "previous": None,
+ "next_page_no": None,
+ "previous_page_no": None,
"results": [fake_user()]
}
guild = self.get_guild(fake_user(), new_user)
-
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ self.get_mock_member(new_user)
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = ({_User(**new_user)}, set(), None)
+ expected_diff = ([new_user], [], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- leaving_user_none = fake_none_user(id=63, in_guild=False)
-
self.bot.api_client.get.return_value = {
"count": 3,
- "next": None,
- "previous": None,
+ "next_page_no": None,
+ "previous_page_no": None,
"results": [fake_user(), fake_user(id=63)]
}
guild = self.get_guild(fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**leaving_user_none)}, None)
+ expected_diff = ([], [{"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
@@ -136,42 +143,41 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
new_user = fake_user(id=99, name="new")
updated_user = fake_user(id=55, name="updated")
- updated_user_none = fake_none_user(id=55, name="updated")
-
- leaving_user_none = fake_none_user(id=63, in_guild=False)
self.bot.api_client.get.return_value = {
"count": 3,
- "next": None,
- "previous": None,
+ "next_page_no": None,
+ "previous_page_no": None,
"results": [fake_user(), fake_user(id=55), fake_user(id=63)]
}
guild = self.get_guild(fake_user(), new_user, updated_user)
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ self.get_mock_member(updated_user),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (
- {_User(**new_user)},
- {
- _User(**updated_user_none),
- _User(**leaving_user_none)
- },
- None
- )
+ expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_empty_diff_for_db_users_not_in_guild(self):
- """When the DB knows a user the guild doesn't, no difference is found."""
+ """When the DB knows a user, but the guild doesn't, no difference is found."""
self.bot.api_client.get.return_value = {
"count": 3,
- "next": None,
- "previous": None,
+ "next_page_no": None,
+ "previous_page_no": None,
"results": [fake_user(), fake_user(id=63, in_guild=False)]
}
guild = self.get_guild(fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -187,13 +193,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Only POST requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
- user_tuples = {_User(**user) for user in users}
- diff = _Diff(user_tuples, set(), None)
+ diff = _Diff(users, [], None)
await self.syncer._sync(diff)
- # Convert namedtuples to dicts as done in self.syncer._sync method.
- created = [user._asdict() for user in diff.created]
- self.bot.api_client.post.assert_called_once_with("bot/users", json=created)
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created)
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
@@ -202,12 +205,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Only PUT requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
- user_tuples = {_User(**user) for user in users}
- diff = _Diff(set(), user_tuples, None)
+ diff = _Diff([], users, None)
await self.syncer._sync(diff)
- updated = [self.syncer.patch_dict(user) for user in diff.updated]
- self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=updated)
+ self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated)
self.bot.api_client.post.assert_not_called()
self.bot.api_client.delete.assert_not_called()