diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/exts/backend/sync/_cog.py | 4 | ||||
| -rw-r--r-- | bot/exts/backend/sync/_syncers.py | 42 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_base.py | 12 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_cog.py | 20 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_roles.py | 14 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_users.py | 15 | 
6 files changed, 62 insertions, 45 deletions
diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index 6e85e2b7d..b71ed3e69 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -18,8 +18,8 @@ class Sync(Cog):      def __init__(self, bot: Bot) -> None:          self.bot = bot -        self.role_syncer = _syncers.RoleSyncer(self.bot) -        self.user_syncer = _syncers.UserSyncer(self.bot) +        self.role_syncer = _syncers.RoleSyncer() +        self.user_syncer = _syncers.UserSyncer()          self.bot.loop.create_task(self.sync_guild()) diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 38468c2b1..bdd76806b 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -6,8 +6,8 @@ from collections import namedtuple  from discord import Guild  from discord.ext.commands import Context +import bot  from bot.api import ResponseCodeError -from bot.bot import Bot  log = logging.getLogger(__name__) @@ -20,22 +20,21 @@ _Diff = namedtuple('Diff', ('created', 'updated', 'deleted'))  class Syncer(abc.ABC):      """Base class for synchronising the database with objects in the Discord cache.""" -    def __init__(self, bot: Bot) -> None: -        self.bot = bot -      @property      @abc.abstractmethod      def name(self) -> str:          """The name of the syncer; used in output messages and logging."""          raise NotImplementedError  # pragma: no cover +    @staticmethod      @abc.abstractmethod -    async def _get_diff(self, guild: Guild) -> _Diff: +    async def _get_diff(guild: Guild) -> _Diff:          """Return the difference between the cache of `guild` and the database."""          raise NotImplementedError  # pragma: no cover +    @staticmethod      @abc.abstractmethod -    async def _sync(self, diff: _Diff) -> None: +    async def _sync(diff: _Diff) -> None:          """Perform the API calls for synchronisation."""          raise NotImplementedError  # pragma: no cover @@ -78,10 +77,11 @@ class RoleSyncer(Syncer):      name = "role" -    async def _get_diff(self, guild: Guild) -> _Diff: +    @staticmethod +    async def _get_diff(guild: Guild) -> _Diff:          """Return the difference of roles between the cache of `guild` and the database."""          log.trace("Getting the diff for roles.") -        roles = await self.bot.api_client.get('bot/roles') +        roles = await bot.instance.api_client.get('bot/roles')          # Pack DB roles and guild roles into one common, hashable format.          # They're hashable so that they're easily comparable with sets later. @@ -110,19 +110,20 @@ class RoleSyncer(Syncer):          return _Diff(roles_to_create, roles_to_update, roles_to_delete) -    async def _sync(self, diff: _Diff) -> None: +    @staticmethod +    async def _sync(diff: _Diff) -> None:          """Synchronise the database with the role cache of `guild`."""          log.trace("Syncing created roles...")          for role in diff.created: -            await self.bot.api_client.post('bot/roles', json=role._asdict()) +            await bot.instance.api_client.post('bot/roles', json=role._asdict())          log.trace("Syncing updated roles...")          for role in diff.updated: -            await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) +            await bot.instance.api_client.put(f'bot/roles/{role.id}', json=role._asdict())          log.trace("Syncing deleted roles...")          for role in diff.deleted: -            await self.bot.api_client.delete(f'bot/roles/{role.id}') +            await bot.instance.api_client.delete(f'bot/roles/{role.id}')  class UserSyncer(Syncer): @@ -130,7 +131,8 @@ class UserSyncer(Syncer):      name = "user" -    async def _get_diff(self, guild: Guild) -> _Diff: +    @staticmethod +    async def _get_diff(guild: Guild) -> _Diff:          """Return the difference of users between the cache of `guild` and the database."""          log.trace("Getting the diff for users.") @@ -138,7 +140,7 @@ class UserSyncer(Syncer):          users_to_update = []          seen_guild_users = set() -        async for db_user in self._get_users(): +        async for db_user in UserSyncer._get_users():              # Store user fields which are to be updated.              updated_fields = {} @@ -185,24 +187,26 @@ class UserSyncer(Syncer):          return _Diff(users_to_create, users_to_update, None) -    async def _get_users(self) -> t.AsyncIterable: +    @staticmethod +    async def _get_users() -> t.AsyncIterable:          """GET users from database."""          query_params = {              "page": 1          }          while query_params["page"]: -            res = await self.bot.api_client.get("bot/users", params=query_params) +            res = await bot.instance.api_client.get("bot/users", params=query_params)              for user in res["results"]:                  yield user              query_params["page"] = res["next_page_no"] -    async def _sync(self, diff: _Diff) -> None: +    @staticmethod +    async def _sync(diff: _Diff) -> None:          """Synchronise the database with the user cache of `guild`."""          log.trace("Syncing created users...")          if diff.created: -            await self.bot.api_client.post("bot/users", json=diff.created) +            await bot.instance.api_client.post("bot/users", json=diff.created)          log.trace("Syncing updated users...")          if diff.updated: -            await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated) +            await bot.instance.api_client.patch("bot/users/bulk_patch", json=diff.updated) diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 4953550f9..157d42452 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -18,21 +18,21 @@ class TestSyncer(Syncer):  class SyncerBaseTests(unittest.TestCase):      """Tests for the syncer base class.""" -    def setUp(self): -        self.bot = helpers.MockBot() -      def test_instantiation_fails_without_abstract_methods(self):          """The class must have abstract methods implemented."""          with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): -            Syncer(self.bot) +            Syncer()  class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for main function orchestrating the sync."""      def setUp(self): -        self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) -        self.syncer = TestSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot(user=helpers.MockMember(bot=True))) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop) + +        self.syncer = TestSyncer()          self.guild = helpers.MockGuild()          # Make sure `_get_diff` returns a MagicMock, not an AsyncMock diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 063a82754..1e1883558 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -29,24 +29,24 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = helpers.MockBot() -        self.role_syncer_patcher = mock.patch( +        role_syncer_patcher = mock.patch(              "bot.exts.backend.sync._syncers.RoleSyncer",              autospec=Syncer,              spec_set=True          ) -        self.user_syncer_patcher = mock.patch( +        user_syncer_patcher = mock.patch(              "bot.exts.backend.sync._syncers.UserSyncer",              autospec=Syncer,              spec_set=True          ) -        self.RoleSyncer = self.role_syncer_patcher.start() -        self.UserSyncer = self.user_syncer_patcher.start() -        self.cog = Sync(self.bot) +        self.RoleSyncer = role_syncer_patcher.start() +        self.UserSyncer = user_syncer_patcher.start() -    def tearDown(self): -        self.role_syncer_patcher.stop() -        self.user_syncer_patcher.stop() +        self.addCleanup(role_syncer_patcher.stop) +        self.addCleanup(user_syncer_patcher.stop) + +        self.cog = Sync(self.bot)      @staticmethod      def response_error(status: int) -> ResponseCodeError: @@ -73,8 +73,8 @@ class SyncCogTests(SyncCogTestCase):          Sync(self.bot) -        self.RoleSyncer.assert_called_once_with(self.bot) -        self.UserSyncer.assert_called_once_with(self.bot) +        self.RoleSyncer.assert_called_once_with() +        self.UserSyncer.assert_called_once_with()          sync_guild.assert_called_once_with()          self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 7b9f40cad..fb63a4ae0 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -22,8 +22,11 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between roles in the DB and roles in the Guild cache."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = RoleSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop) + +        self.syncer = RoleSyncer()      @staticmethod      def get_guild(*roles): @@ -108,8 +111,11 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync roles."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = RoleSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop) + +        self.syncer = RoleSyncer()      async def test_sync_created_roles(self):          """Only POST requests should be made with the correct payload.""" diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index 9f380a15d..9f28d0162 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,4 +1,5 @@  import unittest +from unittest import mock  from bot.exts.backend.sync._syncers import UserSyncer, _Diff  from tests import helpers @@ -19,8 +20,11 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between users in the DB and users in the Guild cache."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = UserSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop) + +        self.syncer = UserSyncer()      @staticmethod      def get_guild(*members): @@ -186,8 +190,11 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync users."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = UserSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop) + +        self.syncer = UserSyncer()      async def test_sync_created_users(self):          """Only POST requests should be made with the correct payload."""  |