diff options
| -rw-r--r-- | bot/exts/backend/sync/_syncers.py | 10 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_users.py | 29 | 
2 files changed, 26 insertions, 13 deletions
| diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 2eb9f9971..c9f2d2da8 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -5,12 +5,15 @@ from collections import namedtuple  from discord import Guild  from discord.ext.commands import Context +from more_itertools import chunked  import bot  from bot.api import ResponseCodeError  log = logging.getLogger(__name__) +CHUNK_SIZE = 1000 +  # These objects are declared as namedtuples because tuples are hashable,  # something that we make use of when diffing site roles against guild roles.  _Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) @@ -207,10 +210,13 @@ class UserSyncer(Syncer):      @staticmethod      async def _sync(diff: _Diff) -> None:          """Synchronise the database with the user cache of `guild`.""" +        # Using asyncio.gather would still consume too many resources on the site.          log.trace("Syncing created users...")          if diff.created: -            await bot.instance.api_client.post("bot/users", json=diff.created) +            for chunk in chunked(diff.created, CHUNK_SIZE): +                await bot.instance.api_client.post("bot/users", json=chunk)          log.trace("Syncing updated users...")          if diff.updated: -            await bot.instance.api_client.patch("bot/users/bulk_patch", json=diff.updated) +            for chunk in chunked(diff.updated, CHUNK_SIZE): +                await bot.instance.api_client.patch("bot/users/bulk_patch", json=chunk) diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index 61673e1bb..27932be95 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -188,30 +188,37 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync users."""      def setUp(self): -        patcher = mock.patch("bot.instance", new=helpers.MockBot()) -        self.bot = patcher.start() -        self.addCleanup(patcher.stop) +        bot_patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = bot_patcher.start() +        self.addCleanup(bot_patcher.stop) + +        chunk_patcher = mock.patch("bot.exts.backend.sync._syncers.CHUNK_SIZE", 2) +        self.chunk_size = chunk_patcher.start() +        self.addCleanup(chunk_patcher.stop) + +        self.chunk_count = 2 +        self.users = [fake_user(id=i) for i in range(self.chunk_size * self.chunk_count)]      async def test_sync_created_users(self):          """Only POST requests should be made with the correct payload.""" -        users = [fake_user(id=111), fake_user(id=222)] - -        diff = _Diff(users, [], None) +        diff = _Diff(self.users, [], None)          await UserSyncer._sync(diff) -        self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created) +        self.bot.api_client.post.assert_any_call("bot/users", json=diff.created[:self.chunk_size]) +        self.bot.api_client.post.assert_any_call("bot/users", json=diff.created[self.chunk_size:]) +        self.assertEqual(self.bot.api_client.post.call_count, self.chunk_count)          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called()      async def test_sync_updated_users(self):          """Only PUT requests should be made with the correct payload.""" -        users = [fake_user(id=111), fake_user(id=222)] - -        diff = _Diff([], users, None) +        diff = _Diff([], self.users, None)          await UserSyncer._sync(diff) -        self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated) +        self.bot.api_client.patch.assert_any_call("bot/users/bulk_patch", json=diff.updated[:self.chunk_size]) +        self.bot.api_client.patch.assert_any_call("bot/users/bulk_patch", json=diff.updated[self.chunk_size:]) +        self.assertEqual(self.bot.api_client.patch.call_count, self.chunk_count)          self.bot.api_client.post.assert_not_called()          self.bot.api_client.delete.assert_not_called() | 
