diff options
-rw-r--r-- | bot/cogs/sync/cog.py | 32 | ||||
-rw-r--r-- | tests/bot/cogs/sync/test_cog.py | 84 |
2 files changed, 93 insertions, 23 deletions
diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index 7cc3726b2..5ace957e7 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -34,18 +34,22 @@ class Sync(Cog): for syncer in (self.role_syncer, self.user_syncer): await syncer.sync(guild) - async def patch_user(self, user_id: int, updated_information: Dict[str, Any]) -> None: + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: """Send a PATCH request to partially update a user in the database.""" try: - await self.bot.api_client.patch(f"bot/users/{user_id}", json=updated_information) + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) except ResponseCodeError as e: if e.response.status != 404: raise - log.warning("Unable to update user, got 404. Assuming race condition from join event.") + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") @Cog.listener() async def on_guild_role_create(self, role: Role) -> None: """Adds newly create role to the database table over the API.""" + if role.guild.id != constants.Guild.id: + return + await self.bot.api_client.post( 'bot/roles', json={ @@ -60,11 +64,17 @@ class Sync(Cog): @Cog.listener() async def on_guild_role_delete(self, role: Role) -> None: """Deletes role from the database when it's deleted from the guild.""" + if role.guild.id != constants.Guild.id: + return + await self.bot.api_client.delete(f'bot/roles/{role.id}') @Cog.listener() async def on_guild_role_update(self, before: Role, after: Role) -> None: """Syncs role with the database if any of the stored attributes were updated.""" + if after.guild.id != constants.Guild.id: + return + was_updated = ( before.name != after.name or before.colour != after.colour @@ -93,6 +103,9 @@ class Sync(Cog): previously left), it will update the user's information. If the user is not yet known by the database, the user is added. """ + if member.guild.id != constants.Guild.id: + return + packed = { 'discriminator': int(member.discriminator), 'id': member.id, @@ -122,14 +135,20 @@ class Sync(Cog): @Cog.listener() async def on_member_remove(self, member: Member) -> None: """Set the in_guild field to False when a member leaves the guild.""" - await self.patch_user(member.id, updated_information={"in_guild": False}) + if member.guild.id != constants.Guild.id: + return + + await self.patch_user(member.id, json={"in_guild": False}) @Cog.listener() async def on_member_update(self, before: Member, after: Member) -> None: """Update the roles of the member in the database if a change is detected.""" + if after.guild.id != constants.Guild.id: + return + if before.roles != after.roles: updated_information = {"roles": sorted(role.id for role in after.roles)} - await self.patch_user(after.id, updated_information=updated_information) + await self.patch_user(after.id, json=updated_information) @Cog.listener() async def on_user_update(self, before: User, after: User) -> None: @@ -140,7 +159,8 @@ class Sync(Cog): "name": after.name, "discriminator": int(after.discriminator), } - await self.patch_user(after.id, updated_information=updated_information) + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) @commands.group(name='sync') @commands.has_permissions(administrator=True) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 14fd909c4..120bc991d 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -131,6 +131,15 @@ class SyncCogListenerTests(SyncCogTestCase): super().setUp() self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) + self.guild_id_patcher = mock.patch("bot.cogs.sync.cog.constants.Guild.id", 5) + self.guild_id = self.guild_id_patcher.start() + + self.guild = helpers.MockGuild(id=self.guild_id) + self.other_guild = helpers.MockGuild(id=0) + + def tearDown(self): + self.guild_id_patcher.stop() + async def test_sync_cog_on_guild_role_create(self): """A POST request should be sent with the new role's data.""" self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) @@ -142,20 +151,32 @@ class SyncCogListenerTests(SyncCogTestCase): "permissions": 8, "position": 23, } - role = helpers.MockRole(**role_data) + role = helpers.MockRole(**role_data, guild=self.guild) await self.cog.on_guild_role_create(role) self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + async def test_sync_cog_on_guild_role_create_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_create(role) + self.bot.api_client.post.assert_not_awaited() + async def test_sync_cog_on_guild_role_delete(self): """A DELETE request should be sent.""" self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) - role = helpers.MockRole(id=99) + role = helpers.MockRole(id=99, guild=self.guild) await self.cog.on_guild_role_delete(role) self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_delete(role) + self.bot.api_client.delete.assert_not_awaited() + async def test_sync_cog_on_guild_role_update(self): """A PUT request should be sent if the colour, name, permissions, or position changes.""" self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) @@ -180,8 +201,8 @@ class SyncCogListenerTests(SyncCogTestCase): after_role_data = role_data.copy() after_role_data[attribute] = 876 - before_role = helpers.MockRole(**role_data) - after_role = helpers.MockRole(**after_role_data) + before_role = helpers.MockRole(**role_data, guild=self.guild) + after_role = helpers.MockRole(**after_role_data, guild=self.guild) await self.cog.on_guild_role_update(before_role, after_role) @@ -193,31 +214,43 @@ class SyncCogListenerTests(SyncCogTestCase): else: self.bot.api_client.put.assert_not_called() + async def test_sync_cog_on_guild_role_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_update(role, role) + self.bot.api_client.put.assert_not_awaited() + async def test_sync_cog_on_member_remove(self): - """Member should patched to set in_guild as False.""" + """Member should be patched to set in_guild as False.""" self.assertTrue(self.cog.on_member_remove.__cog_listener__) - member = helpers.MockMember() + member = helpers.MockMember(guild=self.guild) await self.cog.on_member_remove(member) self.cog.patch_user.assert_called_once_with( member.id, - updated_information={"in_guild": False} + json={"in_guild": False} ) + async def test_sync_cog_on_member_remove_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_remove(member) + self.cog.patch_user.assert_not_awaited() + async def test_sync_cog_on_member_update_roles(self): """Members should be patched if their roles have changed.""" self.assertTrue(self.cog.on_member_update.__cog_listener__) # Roles are intentionally unsorted. before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] - before_member = helpers.MockMember(roles=before_roles) - after_member = helpers.MockMember(roles=before_roles[1:]) + before_member = helpers.MockMember(roles=before_roles, guild=self.guild) + after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) await self.cog.on_member_update(before_member, after_member) data = {"roles": sorted(role.id for role in after_member.roles)} - self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + self.cog.patch_user.assert_called_once_with(after_member.id, json=data) async def test_sync_cog_on_member_update_other(self): """Members should not be patched if other attributes have changed.""" @@ -233,13 +266,19 @@ class SyncCogListenerTests(SyncCogTestCase): with self.subTest(attribute=attribute): self.cog.patch_user.reset_mock() - before_member = helpers.MockMember(**{attribute: old_value}) - after_member = helpers.MockMember(**{attribute: new_value}) + before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) + after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) await self.cog.on_member_update(before_member, after_member) self.cog.patch_user.assert_not_called() + async def test_sync_cog_on_member_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_update(member, member) + self.cog.patch_user.assert_not_awaited() + async def test_sync_cog_on_user_update(self): """A user should be patched only if the name, discriminator, or avatar changes.""" self.assertTrue(self.cog.on_user_update.__cog_listener__) @@ -272,12 +311,15 @@ class SyncCogListenerTests(SyncCogTestCase): # Don't care if *all* keys are present; only the changed one is required call_args = self.cog.patch_user.call_args - self.assertEqual(call_args[0][0], after_user.id) - self.assertIn("updated_information", call_args[1]) + self.assertEqual(call_args.args[0], after_user.id) + self.assertIn("json", call_args.kwargs) + + self.assertIn("ignore_404", call_args.kwargs) + self.assertTrue(call_args.kwargs["ignore_404"]) - updated_information = call_args[1]["updated_information"] - self.assertIn(api_field, updated_information) - self.assertEqual(updated_information[api_field], api_value) + json = call_args.kwargs["json"] + self.assertIn(api_field, json) + self.assertEqual(json[api_field], api_value) else: self.cog.patch_user.assert_not_called() @@ -290,6 +332,7 @@ class SyncCogListenerTests(SyncCogTestCase): member = helpers.MockMember( discriminator="1234", roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + guild=self.guild, ) data = { @@ -334,6 +377,13 @@ class SyncCogListenerTests(SyncCogTestCase): self.bot.api_client.post.assert_not_called() + async def test_sync_cog_on_member_join_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_join(member) + self.bot.api_client.post.assert_not_awaited() + self.bot.api_client.put.assert_not_awaited() + class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): """Tests for the commands in the Sync cog.""" |