From 81711d77750be55a62a927b1c90f0eaf773e0567 Mon Sep 17 00:00:00 2001 From: ks123 Date: Tue, 3 Mar 2020 16:21:53 +0200 Subject: Created file for moderation utils tests + added setUp to this. --- tests/bot/cogs/moderation/test_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 tests/bot/cogs/moderation/test_utils.py (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py new file mode 100644 index 000000000..ed1d1ed59 --- /dev/null +++ b/tests/bot/cogs/moderation/test_utils.py @@ -0,0 +1,12 @@ +import unittest + + +from tests.helpers import MockBot, MockContext + + +class ModerationUtilsTests(unittest.TestCase): + """Tests Moderation utils.""" + + def setUp(self) -> None: + self.bot = MockBot() + self.ctx = MockContext(bot=self.bot) -- cgit v1.2.3 From 154969022cf62bd4a2bab2f7492ded08bb26ffba Mon Sep 17 00:00:00 2001 From: ks123 Date: Tue, 3 Mar 2020 16:32:13 +0200 Subject: (Moderation Utils Tests): Added imports, modified tests class instance and created new params for tests class --- tests/bot/cogs/moderation/test_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index ed1d1ed59..7d47715d4 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,12 +1,14 @@ import unittest +from unittest.mock import AsyncMock +from tests.helpers import MockBot, MockContext, MockMember -from tests.helpers import MockBot, MockContext - -class ModerationUtilsTests(unittest.TestCase): +class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """Tests Moderation utils.""" - def setUp(self) -> None: + def setUp(self): self.bot = MockBot() - self.ctx = MockContext(bot=self.bot) + self.member = MockMember(id=1234) + self.ctx = MockContext(bot=self.bot, author=self.member) + self.bot.api_client.get = AsyncMock() -- cgit v1.2.3 From fa6a0ae59958ce143f6a7acfbd41e477e940fa84 Mon Sep 17 00:00:00 2001 From: ks123 Date: Tue, 3 Mar 2020 17:29:20 +0200 Subject: (Moderation Utils Tests): Created tests for `has_active_infraction` function --- tests/bot/cogs/moderation/test_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 7d47715d4..d25fbfcb5 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import AsyncMock +from bot.cogs.moderation.utils import has_active_infraction from tests.helpers import MockBot, MockContext, MockMember @@ -12,3 +13,26 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.member = MockMember(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) self.bot.api_client.get = AsyncMock() + + async def test_user_has_active_infraction_true(self): + """Test does `has_active_infraction` return that user have active infraction.""" + self.bot.api_client.get.return_value = [{ + "id": 1, + "inserted_at": "2018-11-22T07:24:06.132307Z", + "expires_at": "5018-11-20T15:52:00Z", + "active": True, + "user": 1234, + "actor": 1234, + "type": "ban", + "reason": "Test", + "hidden": False + }] + self.assertTrue(await has_active_infraction(self.ctx, self.member, "ban"), "User should have active infraction") + + async def test_user_has_active_infraction_false(self): + """Test does `has_active_infraction` return that user don't have active infractions.""" + self.bot.api_client.get.return_value = [] + self.assertFalse( + await has_active_infraction(self.ctx, self.member, "ban"), + "User shouldn't have active infraction" + ) -- cgit v1.2.3 From 98f7a3777152b32bfda24f9d5add938479827c85 Mon Sep 17 00:00:00 2001 From: ks123 Date: Wed, 4 Mar 2020 18:15:54 +0200 Subject: (Moderation Utils Tests): Created tests for `notify_infraction` function. --- tests/bot/cogs/moderation/test_utils.py | 93 ++++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index d25fbfcb5..89f853262 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,8 +1,25 @@ import unittest from unittest.mock import AsyncMock -from bot.cogs.moderation.utils import has_active_infraction -from tests.helpers import MockBot, MockContext, MockMember +from discord import Embed + +from bot.cogs.moderation.utils import has_active_infraction, notify_infraction +from bot.constants import Colours, Icons +from tests.helpers import MockBot, MockContext, MockMember, MockUser + +RULES_URL = "https://pythondiscord.com/pages/rules" +APPEAL_EMAIL = "appeals@pythondiscord.com" + +INFRACTION_TITLE = f"Please review our rules over at {RULES_URL}" +INFRACTION_APPEAL_FOOTER = f"To appeal this infraction, send an e-mail to {APPEAL_EMAIL}" +INFRACTION_AUTHOR_NAME = "Infraction information" +INFRACTION_COLOR = Colours.soft_red + +INFRACTION_DESCRIPTION_TEMPLATE = ( + "\n**Type:** {type}\n" + "**Expires:** {expires}\n" + "**Reason:** {reason}\n" +) class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): @@ -11,6 +28,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.member = MockMember(id=1234) + self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) self.bot.api_client.get = AsyncMock() @@ -36,3 +54,74 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): await has_active_infraction(self.ctx, self.member, "ban"), "User shouldn't have active infraction" ) + + async def test_notify_infraction(self): + """Test does `notify_infraction` create correct embed.""" + test_cases = [ + { + "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), + "expected_output": { + "description": INFRACTION_DESCRIPTION_TEMPLATE.format(**{ + "type": "Ban", + "expires": "2020-02-26 09:20 (23 hours and 59 minutes)", + "reason": "No reason provided." + }), + "icon_url": Icons.token_removed, + "footer": INFRACTION_APPEAL_FOOTER + } + }, + { + "args": (self.user, "warning", None, "Test reason."), + "expected_output": { + "description": INFRACTION_DESCRIPTION_TEMPLATE.format(**{ + "type": "Warning", + "expires": "N/A", + "reason": "Test reason." + }), + "icon_url": Icons.token_removed, + "footer": Embed.Empty + } + }, + { + "args": (self.user, "note", None, None, Icons.defcon_denied), + "expected_output": { + "description": INFRACTION_DESCRIPTION_TEMPLATE.format(**{ + "type": "Note", + "expires": "N/A", + "reason": "No reason provided." + }), + "icon_url": Icons.defcon_denied, + "footer": Embed.Empty + } + }, + { + "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), + "expected_output": { + "description": INFRACTION_DESCRIPTION_TEMPLATE.format(**{ + "type": "Mute", + "expires": "2020-02-26 09:20 (23 hours and 59 minutes)", + "reason": "Test" + }), + "icon_url": Icons.defcon_denied, + "footer": INFRACTION_APPEAL_FOOTER + } + } + ] + + for case in test_cases: + args = case["args"] + expected = case["expected_output"] + + with self.subTest(args=case["args"], expected=case["expected_output"]): + await notify_infraction(*args) + + embed: Embed = self.user.send.call_args[1]["embed"] + + self.assertEqual(embed.title, INFRACTION_TITLE) + self.assertEqual(embed.colour.value, INFRACTION_COLOR) + self.assertEqual(embed.url, RULES_URL) + self.assertEqual(embed.author.name, INFRACTION_AUTHOR_NAME) + self.assertEqual(embed.author.url, RULES_URL) + self.assertEqual(embed.author.icon_url, expected["icon_url"]) + self.assertEqual(embed.footer.text, expected["footer"]) + self.assertEqual(embed.description, expected["description"]) -- cgit v1.2.3 From 4a746fc60b6c51e20e1fab92726665092405f93d Mon Sep 17 00:00:00 2001 From: ks123 Date: Wed, 4 Mar 2020 18:38:30 +0200 Subject: (Moderation Utils Tests): Created tests for `notify_pardon` function. --- tests/bot/cogs/moderation/test_utils.py | 41 +++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 89f853262..05e71e695 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock from discord import Embed -from bot.cogs.moderation.utils import has_active_infraction, notify_infraction +from bot.cogs.moderation.utils import has_active_infraction, notify_infraction, notify_pardon from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser @@ -21,6 +21,8 @@ INFRACTION_DESCRIPTION_TEMPLATE = ( "**Reason:** {reason}\n" ) +PARDON_COLOR = Colours.soft_green + class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """Tests Moderation utils.""" @@ -112,7 +114,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): args = case["args"] expected = case["expected_output"] - with self.subTest(args=case["args"], expected=case["expected_output"]): + with self.subTest(args=args, expected=expected): await notify_infraction(*args) embed: Embed = self.user.send.call_args[1]["embed"] @@ -125,3 +127,38 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.author.icon_url, expected["icon_url"]) self.assertEqual(embed.footer.text, expected["footer"]) self.assertEqual(embed.description, expected["description"]) + + async def test_notify_pardon(self): + """Test does `notify_pardon` create correct embed.""" + test_cases = [ + { + "args": (self.user, "Test title", "Example content"), + "expected_output": { + "description": "Example content", + "title": "Test title", + "icon_url": Icons.user_verified + } + }, + { + "args": (self.user, "Test title 1", "Example content 1", Icons.user_update), + "expected_output": { + "description": "Example content 1", + "title": "Test title 1", + "icon_url": Icons.user_update + } + } + ] + + for case in test_cases: + args = case["args"] + expected = case["expected_output"] + + with self.subTest(args=args, expected=expected): + await notify_pardon(*args) + + embed: Embed = self.user.send.call_args[1]["embed"] + + self.assertEqual(embed.description, expected["description"]) + self.assertEqual(embed.colour.value, PARDON_COLOR) + self.assertEqual(embed.author.name, expected["title"]) + self.assertEqual(embed.author.icon_url, expected["icon_url"]) -- cgit v1.2.3 From 615ffaa97cb14d83c7c57e0efd675aae0b58abd1 Mon Sep 17 00:00:00 2001 From: ks123 Date: Wed, 4 Mar 2020 19:10:36 +0200 Subject: (Moderation Utils Tests): Created tests for `post_user` function. --- tests/bot/cogs/moderation/test_utils.py | 60 ++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 05e71e695..c8c1f9e1a 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -3,7 +3,8 @@ from unittest.mock import AsyncMock from discord import Embed -from bot.cogs.moderation.utils import has_active_infraction, notify_infraction, notify_pardon +from bot.api import ResponseCodeError +from bot.cogs.moderation.utils import has_active_infraction, notify_infraction, notify_pardon, post_user from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser @@ -162,3 +163,60 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.colour.value, PARDON_COLOR) self.assertEqual(embed.author.name, expected["title"]) self.assertEqual(embed.author.icon_url, expected["icon_url"]) + + async def test_post_user(self): + """Test does `post_user` work correctly.""" + test_cases = [ + { + "args": (self.ctx, self.user), + "post_result": [ + { + "id": 1234, + "avatar": "test", + "name": "Test", + "discriminator": 1234, + "roles": [ + 1234, + 5678 + ], + "in_guild": True + } + ], + "raise_error": False + }, + { + "args": (self.ctx, self.user), + "post_result": [ + { + "id": 1234, + "avatar": "test", + "name": "Test", + "discriminator": 1234, + "roles": [ + 1234, + 5678 + ], + "in_guild": True + } + ], + "raise_error": True + } + ] + + for case in test_cases: + args = case["args"] + expected = case["post_result"] + error = case["raise_error"] + + with self.subTest(args=args, result=expected, error=error): + self.ctx.bot.api_client.post.return_value = expected + + if error: + self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(response_code=400), expected) + + result = await post_user(*args) + + if error: + self.assertIsNone(result) + else: + self.assertEqual(result, expected) -- cgit v1.2.3 From af71a7775d190a11ed92c0d88b52801cdf3804d8 Mon Sep 17 00:00:00 2001 From: ks123 Date: Wed, 4 Mar 2020 19:26:40 +0200 Subject: (Moderation Utils Tests): Created tests for `send_private_embed` function + Fixed errors. --- tests/bot/cogs/moderation/test_utils.py | 47 ++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index c8c1f9e1a..c1cc11724 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,10 +1,13 @@ import unittest +from typing import Union from unittest.mock import AsyncMock -from discord import Embed +from discord import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError -from bot.cogs.moderation.utils import has_active_infraction, notify_infraction, notify_pardon, post_user +from bot.cogs.moderation.utils import ( + has_active_infraction, notify_infraction, notify_pardon, post_user, send_private_embed +) from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser @@ -212,7 +215,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.ctx.bot.api_client.post.return_value = expected if error: - self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(response_code=400), expected) + self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(), expected) result = await post_user(*args) @@ -220,3 +223,41 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(result) else: self.assertEqual(result, expected) + + async def test_send_private_embed(self): + """Test does `send_private_embed` return correct value.""" + test_cases = [ + { + "args": (self.user, Embed(title="Test", description="Test val")), + "expected_output": True, + "raised_exception": None + }, + { + "args": (self.user, Embed(title="Test", description="Test val")), + "expected_output": False, + "raised_exception": HTTPException + }, + { + "args": (self.user, Embed(title="Test", description="Test val")), + "expected_output": False, + "raised_exception": Forbidden + }, + { + "args": (self.user, Embed(title="Test", description="Test val")), + "expected_output": False, + "raised_exception": NotFound + } + ] + + for case in test_cases: + args = case["args"] + expected = case["expected_output"] + raised: Union[Forbidden, HTTPException, NotFound, None] = case["raised_exception"] + + with self.subTest(args=args, expected=expected, raised=raised): + if raised: + self.user.send.side_effect = raised(AsyncMock(), AsyncMock()) + + result = await send_private_embed(*args) + + self.assertEqual(result, expected) -- cgit v1.2.3 From 1c7675ba55342e29fa3e3b82cf36a6e321f76bf8 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 08:31:14 +0200 Subject: (Moderation Utils Tests): Created tests for `post_infraction` function, created __init__.py for moderation tests --- tests/bot/cogs/moderation/__init__.py | 0 tests/bot/cogs/moderation/test_utils.py | 69 ++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 tests/bot/cogs/moderation/__init__.py (limited to 'tests') diff --git a/tests/bot/cogs/moderation/__init__.py b/tests/bot/cogs/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index c1cc11724..984a8aa41 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,4 +1,5 @@ import unittest +from datetime import datetime from typing import Union from unittest.mock import AsyncMock @@ -6,7 +7,7 @@ from discord import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError from bot.cogs.moderation.utils import ( - has_active_infraction, notify_infraction, notify_pardon, post_user, send_private_embed + has_active_infraction, notify_infraction, notify_pardon, post_infraction, post_user, send_private_embed ) from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser @@ -261,3 +262,69 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): result = await send_private_embed(*args) self.assertEqual(result, expected) + + async def test_post_infraction(self): + """Test does `post_infraction` return correct value.""" + test_cases = [ + { + "args": (self.ctx, self.member, "ban", "Test Ban"), + "expected_output": [ + { + "id": 1, + "inserted_at": "2018-11-22T07:24:06.132307Z", + "expires_at": "5018-11-20T15:52:00Z", + "active": True, + "user": 1234, + "actor": 1234, + "type": "ban", + "reason": "Test Ban", + "hidden": False + } + ], + "raised_error": None + }, + { + "args": (self.ctx, self.member, "note", "Test Ban"), + "expected_output": None, + "raised_error": ResponseCodeError(AsyncMock(), AsyncMock()) + }, + { + "args": (self.ctx, self.member, "mute", "Test Ban"), + "expected_output": None, + "raised_error": ResponseCodeError(AsyncMock(), {'user': 1234}) + }, + { + "args": (self.ctx, self.member, "ban", "Test Ban", datetime.now()), + "expected_output": [ + { + "id": 1, + "inserted_at": "2018-11-22T07:24:06.132307Z", + "expires_at": "5018-11-20T15:52:00Z", + "active": True, + "user": 1234, + "actor": 1234, + "type": "ban", + "reason": "Test Ban", + "hidden": False + } + ], + "raised_error": None + }, + ] + + for case in test_cases: + args = case["args"] + expected = case["expected_output"] + raised = case["raised_error"] + + with self.subTest(args=args, expected=expected, raised=raised): + if raised: + self.ctx.bot.api_client.post.side_effect = raised + + self.ctx.bot.api_client.post.return_value = expected + + result = await post_infraction(*args) + + self.assertEqual(result, expected) + + self.ctx.bot.api_client.post.reset_mock(side_effect=True) -- cgit v1.2.3 From 3b3b9f72807fe4c2dfaedb98aa714150b01d46ba Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 08:34:01 +0200 Subject: (Moderation Utils Tests): `send_private_embed` moved exception creating from cases testing to test cases listing, added side_effect resetting. --- tests/bot/cogs/moderation/test_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 984a8aa41..2a07cdc6b 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -236,17 +236,17 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.user, Embed(title="Test", description="Test val")), "expected_output": False, - "raised_exception": HTTPException + "raised_exception": HTTPException(AsyncMock(), AsyncMock()) }, { "args": (self.user, Embed(title="Test", description="Test val")), "expected_output": False, - "raised_exception": Forbidden + "raised_exception": Forbidden(AsyncMock(), AsyncMock()) }, { "args": (self.user, Embed(title="Test", description="Test val")), "expected_output": False, - "raised_exception": NotFound + "raised_exception": NotFound(AsyncMock(), AsyncMock()) } ] @@ -257,12 +257,14 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): with self.subTest(args=args, expected=expected, raised=raised): if raised: - self.user.send.side_effect = raised(AsyncMock(), AsyncMock()) + self.user.send.side_effect = raised result = await send_private_embed(*args) self.assertEqual(result, expected) + self.user.send.reset_mock(side_effect=True) + async def test_post_infraction(self): """Test does `post_infraction` return correct value.""" test_cases = [ -- cgit v1.2.3 From 30e090be63c96b5844087c979a37a321ccd170df Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 08:43:18 +0200 Subject: (Moderation Utils Tests): Moved `has_active_infraction` tests to one test. --- tests/bot/cogs/moderation/test_utils.py | 57 ++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 2a07cdc6b..18794136c 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -39,28 +39,41 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.ctx = MockContext(bot=self.bot, author=self.member) self.bot.api_client.get = AsyncMock() - async def test_user_has_active_infraction_true(self): - """Test does `has_active_infraction` return that user have active infraction.""" - self.bot.api_client.get.return_value = [{ - "id": 1, - "inserted_at": "2018-11-22T07:24:06.132307Z", - "expires_at": "5018-11-20T15:52:00Z", - "active": True, - "user": 1234, - "actor": 1234, - "type": "ban", - "reason": "Test", - "hidden": False - }] - self.assertTrue(await has_active_infraction(self.ctx, self.member, "ban"), "User should have active infraction") - - async def test_user_has_active_infraction_false(self): - """Test does `has_active_infraction` return that user don't have active infractions.""" - self.bot.api_client.get.return_value = [] - self.assertFalse( - await has_active_infraction(self.ctx, self.member, "ban"), - "User shouldn't have active infraction" - ) + async def test_user_has_active_infraction(self): + """Test does `has_active_infraction` return correct value.""" + test_cases = [ + { + "args": (self.ctx, self.member, "ban"), + "get_return_value": [], + "expected_output": False + }, + { + "args": (self.ctx, self.member, "ban"), + "get_return_value": [{ + "id": 1, + "inserted_at": "2018-11-22T07:24:06.132307Z", + "expires_at": "5018-11-20T15:52:00Z", + "active": True, + "user": 1234, + "actor": 1234, + "type": "ban", + "reason": "Test", + "hidden": False + }], + "expected_output": True + } + ] + + for case in test_cases: + args = case["args"] + return_value = case["get_return_value"] + expected = case["expected_output"] + + with self.subTest(args=args, return_value=return_value, expected=expected): + self.bot.api_client.get.return_value = return_value + + result = await has_active_infraction(*args) + self.assertEqual(result, expected) async def test_notify_infraction(self): """Test does `notify_infraction` create correct embed.""" -- cgit v1.2.3 From 9211baaf987277b115bd1e2092f69f29389ac887 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 15:26:35 +0200 Subject: (Moderation Utils Tests): Removed unnecessary `AsyncMock()` from `__init__` (`self.bot.api_client.get`) --- tests/bot/cogs/moderation/test_utils.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 18794136c..60d7efa5e 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -37,7 +37,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.member = MockMember(id=1234) self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) - self.bot.api_client.get = AsyncMock() async def test_user_has_active_infraction(self): """Test does `has_active_infraction` return correct value.""" -- cgit v1.2.3 From 7d988453fe6536df06f47aeef9b5ff36f5d64c39 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 15:32:25 +0200 Subject: (Moderation Utils Tests): Use `bot.cogs.moderation.utils`'s `RULES_URL` instead creating new one --- tests/bot/cogs/moderation/test_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 60d7efa5e..ea5aadc59 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -7,12 +7,11 @@ from discord import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError from bot.cogs.moderation.utils import ( - has_active_infraction, notify_infraction, notify_pardon, post_infraction, post_user, send_private_embed + RULES_URL, has_active_infraction, notify_infraction, notify_pardon, post_infraction, post_user, send_private_embed ) from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser -RULES_URL = "https://pythondiscord.com/pages/rules" APPEAL_EMAIL = "appeals@pythondiscord.com" INFRACTION_TITLE = f"Please review our rules over at {RULES_URL}" -- cgit v1.2.3 From b0ae911f3ecb4c5229c6944f5fed77eced2fc79b Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 15:52:43 +0200 Subject: (Moderation Utils Tests): Added following new assertions to `has_active_infraction` tests: `ctx.send` and `bot.api_client.get` calling. --- tests/bot/cogs/moderation/test_utils.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index ea5aadc59..40159f6d9 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -43,7 +43,13 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.ctx, self.member, "ban"), "get_return_value": [], - "expected_output": False + "expected_output": False, + "get_call": { + "active": "true", + "type": "ban", + "user__id": str(self.member.id) + }, + "send_params": None }, { "args": (self.ctx, self.member, "ban"), @@ -58,7 +64,16 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "reason": "Test", "hidden": False }], - "expected_output": True + "expected_output": True, + "get_call": { + "active": "true", + "type": "ban", + "user__id": str(self.member.id) + }, + "send_params": ( + f":x: According to my records, this user already has a ban infraction. " + f"See infraction **#1**." + ) } ] @@ -66,12 +81,21 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): args = case["args"] return_value = case["get_return_value"] expected = case["expected_output"] + get = case["get_call"] + send_vals = case["send_params"] - with self.subTest(args=args, return_value=return_value, expected=expected): + with self.subTest(args=args, return_value=return_value, expected=expected, get=get, send_vals=send_vals): self.bot.api_client.get.return_value = return_value result = await has_active_infraction(*args) self.assertEqual(result, expected) + self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params=get) + + if result: + self.ctx.send.assert_awaited_once_with(send_vals) + + self.bot.api_client.get.reset_mock() + self.ctx.send.reset_mock() async def test_notify_infraction(self): """Test does `notify_infraction` create correct embed.""" -- cgit v1.2.3 From 05c3a21f34b5cc87ac5e439b0256fffc10682f54 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 16:39:36 +0200 Subject: (Moderation Utils Tests): Added new assertions to `post_infraction`, added `ctx.send` raising errors, added check for return values and `send_private_embed` call. --- tests/bot/cogs/moderation/test_utils.py | 44 ++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 40159f6d9..609ec2642 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,7 +1,7 @@ import unittest from datetime import datetime from typing import Union -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch from discord import Embed, Forbidden, HTTPException, NotFound @@ -97,8 +97,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.reset_mock() self.ctx.send.reset_mock() - async def test_notify_infraction(self): - """Test does `notify_infraction` create correct embed.""" + @patch("bot.cogs.moderation.utils.send_private_embed") + async def test_notify_infraction(self, send_private_embed_mock): + """Test does `notify_infraction` create correct result.""" test_cases = [ { "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), @@ -109,8 +110,10 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "reason": "No reason provided." }), "icon_url": Icons.token_removed, - "footer": INFRACTION_APPEAL_FOOTER - } + "footer": INFRACTION_APPEAL_FOOTER, + }, + "send_result": True, + "send_raise": None }, { "args": (self.user, "warning", None, "Test reason."), @@ -122,7 +125,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): }), "icon_url": Icons.token_removed, "footer": Embed.Empty - } + }, + "send_result": False, + "send_raise": Forbidden(AsyncMock(), AsyncMock()) }, { "args": (self.user, "note", None, None, Icons.defcon_denied), @@ -134,7 +139,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): }), "icon_url": Icons.defcon_denied, "footer": Embed.Empty - } + }, + "send_result": False, + "send_raise": NotFound(AsyncMock(), AsyncMock()) }, { "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), @@ -146,18 +153,28 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): }), "icon_url": Icons.defcon_denied, "footer": INFRACTION_APPEAL_FOOTER - } + }, + "send_result": False, + "send_raise": HTTPException(AsyncMock(), AsyncMock()) } ] for case in test_cases: args = case["args"] expected = case["expected_output"] + send, send_raise = case["send_result"], case["send_raise"] - with self.subTest(args=args, expected=expected): - await notify_infraction(*args) + with self.subTest(args=args, expected=expected, send=send, send_raise=send_raise): + if send_raise: + self.ctx.send.side_effect = send_raise - embed: Embed = self.user.send.call_args[1]["embed"] + send_private_embed_mock.return_value = send + + result = await notify_infraction(*args) + + self.assertEqual(send, result) + + embed = send_private_embed_mock.call_args[0][1] self.assertEqual(embed.title, INFRACTION_TITLE) self.assertEqual(embed.colour.value, INFRACTION_COLOR) @@ -168,6 +185,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.footer.text, expected["footer"]) self.assertEqual(embed.description, expected["description"]) + send_private_embed_mock.assert_awaited_once_with(args[0], embed) + + self.ctx.send.reset_mock(side_effect=True) + send_private_embed_mock.reset_mock() + async def test_notify_pardon(self): """Test does `notify_pardon` create correct embed.""" test_cases = [ -- cgit v1.2.3 From 87e5bdb3ff8f591e05e5eb410bbc5139afcc8d23 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 16:46:15 +0200 Subject: (Moderation Utils Tests): Added new assertions to `notify_pardon`, added `ctx.send` raising errors, added check for return values and `send_private_embed` call. --- tests/bot/cogs/moderation/test_utils.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 609ec2642..f38f4557b 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -190,8 +190,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.ctx.send.reset_mock(side_effect=True) send_private_embed_mock.reset_mock() - async def test_notify_pardon(self): - """Test does `notify_pardon` create correct embed.""" + @patch("bot.cogs.moderation.utils.send_private_embed") + async def test_notify_pardon(self, send_private_embed_mock): + """Test does `notify_pardon` create correct result.""" test_cases = [ { "args": (self.user, "Test title", "Example content"), @@ -199,7 +200,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "description": "Example content", "title": "Test title", "icon_url": Icons.user_verified - } + }, + "send_result": True, + "send_raise": None }, { "args": (self.user, "Test title 1", "Example content 1", Icons.user_update), @@ -207,24 +210,39 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "description": "Example content 1", "title": "Test title 1", "icon_url": Icons.user_update - } + }, + "send_result": False, + "send_raise": NotFound(AsyncMock(), AsyncMock()) } ] for case in test_cases: args = case["args"] expected = case["expected_output"] + send, send_raise = case["send_result"], case["send_raise"] with self.subTest(args=args, expected=expected): - await notify_pardon(*args) + if send_raise: + self.ctx.send.side_effect = send_raise - embed: Embed = self.user.send.call_args[1]["embed"] + send_private_embed_mock.return_value = send + + result = await notify_pardon(*args) + + self.assertEqual(send, result) + + embed = send_private_embed_mock.call_args[0][1] self.assertEqual(embed.description, expected["description"]) self.assertEqual(embed.colour.value, PARDON_COLOR) self.assertEqual(embed.author.name, expected["title"]) self.assertEqual(embed.author.icon_url, expected["icon_url"]) + send_private_embed_mock.assert_awaited_once_with(args[0], embed) + + self.ctx.send.reset_mock(side_effect=True) + send_private_embed_mock.reset_mock() + async def test_post_user(self): """Test does `post_user` work correctly.""" test_cases = [ -- cgit v1.2.3 From ded64749940525ea9b1f613560e4e30ec74c0c01 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 16:54:00 +0200 Subject: (Moderation Utils Tests): Added API POST call assertion to `test_post_user`. --- tests/bot/cogs/moderation/test_utils.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index f38f4557b..847ba8465 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -258,10 +258,18 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): 1234, 5678 ], - "in_guild": True + "in_guild": False } ], - "raise_error": False + "raise_error": False, + "payload": { + "avatar_hash": getattr(self.user, "avatar", 0), + "discriminator": int(getattr(self.user, "discriminator", 0)), + "id": self.user.id, + "in_guild": False, + "name": getattr(self.user, "name", "Name unknown"), + "roles": [] + } }, { "args": (self.ctx, self.user), @@ -275,10 +283,18 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): 1234, 5678 ], - "in_guild": True + "in_guild": False } ], - "raise_error": True + "raise_error": True, + "payload": { + "avatar_hash": getattr(self.user, "avatar", 0), + "discriminator": int(getattr(self.user, "discriminator", 0)), + "id": self.user.id, + "in_guild": False, + "name": getattr(self.user, "name", "Name unknown"), + "roles": [] + } } ] @@ -286,8 +302,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): args = case["args"] expected = case["post_result"] error = case["raise_error"] + payload = case["payload"] - with self.subTest(args=args, result=expected, error=error): + with self.subTest(args=args, result=expected, error=error, payload=payload): self.ctx.bot.api_client.post.return_value = expected if error: @@ -300,6 +317,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): else: self.assertEqual(result, expected) + self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) + async def test_send_private_embed(self): """Test does `send_private_embed` return correct value.""" test_cases = [ -- cgit v1.2.3 From d8a00abd3860df68dab1805f213f6467085d78fd Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 16:57:02 +0200 Subject: (Moderation Utils Tests): Added `user.send` call assertion to `test_send_private_embed`. --- tests/bot/cogs/moderation/test_utils.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 847ba8465..300f0b80d 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -356,6 +356,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): result = await send_private_embed(*args) self.assertEqual(result, expected) + if expected: + args[0].send.assert_awaited_once_with(embed=args[1]) self.user.send.reset_mock(side_effect=True) -- cgit v1.2.3 From c1b97d0d6132175910ca8e66d35e908444ef512f Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 19:54:44 +0200 Subject: (Moderation Utils Tests): Added additional assertions to `post_infraction` test. --- tests/bot/cogs/moderation/test_utils.py | 62 ++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 300f0b80d..c5b8f380f 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -361,8 +361,10 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.user.send.reset_mock(side_effect=True) - async def test_post_infraction(self): + @patch("bot.cogs.moderation.utils.post_user") + async def test_post_infraction(self, post_user_mock): """Test does `post_infraction` return correct value.""" + now = datetime.now() test_cases = [ { "args": (self.ctx, self.member, "ban", "Test Ban"), @@ -379,20 +381,44 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "hidden": False } ], - "raised_error": None + "raised_error": None, + "payload": { + "actor": self.ctx.message.author.id, + "hidden": False, + "reason": "Test Ban", + "type": "ban", + "user": self.member.id, + "active": True + } }, { "args": (self.ctx, self.member, "note", "Test Ban"), "expected_output": None, - "raised_error": ResponseCodeError(AsyncMock(), AsyncMock()) + "raised_error": ResponseCodeError(AsyncMock(), AsyncMock()), + "payload": { + "actor": self.ctx.message.author.id, + "hidden": False, + "reason": "Test Ban", + "type": "note", + "user": self.member.id, + "active": True + } }, { "args": (self.ctx, self.member, "mute", "Test Ban"), "expected_output": None, - "raised_error": ResponseCodeError(AsyncMock(), {'user': 1234}) + "raised_error": ResponseCodeError(AsyncMock(status=400), {'user': 1234}), + "payload": { + "actor": self.ctx.message.author.id, + "hidden": False, + "reason": "Test Ban", + "type": "mute", + "user": self.member.id, + "active": True + } }, { - "args": (self.ctx, self.member, "ban", "Test Ban", datetime.now()), + "args": (self.ctx, self.member, "ban", "Test Ban", now, True, False), "expected_output": [ { "id": 1, @@ -406,7 +432,16 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "hidden": False } ], - "raised_error": None + "raised_error": None, + "payload": { + "actor": self.ctx.message.author.id, + "hidden": True, + "reason": "Test Ban", + "type": "ban", + "user": self.member.id, + "active": False, + "expires_at": now.isoformat() + } }, ] @@ -414,15 +449,26 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): args = case["args"] expected = case["expected_output"] raised = case["raised_error"] + payload = case["payload"] + + with self.subTest(args=args, expected=expected, raised=raised, payload=payload): + self.ctx.bot.api_client.post.reset_mock(side_effect=True) + post_user_mock.reset_mock() - with self.subTest(args=args, expected=expected, raised=raised): if raised: self.ctx.bot.api_client.post.side_effect = raised + post_user_mock.return_value = "foo" + self.ctx.bot.api_client.post.return_value = expected result = await post_infraction(*args) self.assertEqual(result, expected) - self.ctx.bot.api_client.post.reset_mock(side_effect=True) + if not raised: + self.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + + if hasattr(raised, "status") and hasattr(raised, "response_json"): + if raised.status == 400 and "user" in raised.response_json: + post_user_mock.assert_awaited_once_with(args[0], args[1]) -- cgit v1.2.3 From 7dfac36ab5d513fada631e6d473915e05eafe778 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 20:01:34 +0200 Subject: (Moderation Utils Tests): Fixed errors, added checks before assertions for errors --- tests/bot/cogs/moderation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index c5b8f380f..7f94f20e8 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -317,7 +317,10 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): else: self.assertEqual(result, expected) - self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) + if not error: + self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) + + self.bot.api_client.post.reset_mock(side_effect=True) async def test_send_private_embed(self): """Test does `send_private_embed` return correct value.""" -- cgit v1.2.3 From 1e0170481624d4a5ec52058cd4a57dd461439fd4 Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 08:02:21 +0200 Subject: (Moderation Utils Tests): Fixed docstrings, added more information to these. --- tests/bot/cogs/moderation/test_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 7f94f20e8..e2345ea37 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -38,7 +38,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.ctx = MockContext(bot=self.bot, author=self.member) async def test_user_has_active_infraction(self): - """Test does `has_active_infraction` return correct value.""" + """ + Test does `has_active_infraction` return call at least once `ctx.send` API get, check does return correct bool. + """ test_cases = [ { "args": (self.ctx, self.member, "ban"), @@ -99,7 +101,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_infraction(self, send_private_embed_mock): - """Test does `notify_infraction` create correct result.""" + """Test does `notify_infraction` create correct embed and return correct boolean.""" test_cases = [ { "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), @@ -192,7 +194,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_pardon(self, send_private_embed_mock): - """Test does `notify_pardon` create correct result.""" + """Test does `notify_pardon` create correct embed and return correct bool.""" test_cases = [ { "args": (self.user, "Test title", "Example content"), @@ -244,7 +246,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.reset_mock() async def test_post_user(self): - """Test does `post_user` work correctly.""" + """Test does `post_user` handle errors and results correctly.""" test_cases = [ { "args": (self.ctx, self.user), @@ -323,7 +325,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.post.reset_mock(side_effect=True) async def test_send_private_embed(self): - """Test does `send_private_embed` return correct value.""" + """Test does `send_private_embed` return correct bool.""" test_cases = [ { "args": (self.user, Embed(title="Test", description="Test val")), @@ -366,7 +368,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.post_user") async def test_post_infraction(self, post_user_mock): - """Test does `post_infraction` return correct value.""" + """Test does `post_infraction` call functions correctly and return `None` or `Dict`.""" now = datetime.now() test_cases = [ { -- cgit v1.2.3 From 87ecf72a328b05c922d1f7c0d6e8a1c86ab405c8 Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 08:07:04 +0200 Subject: (Moderation Utils Tests): Removed large `utils` parts import, use import `utils` instead and added `utils` before variables and function that was imported directly before. --- tests/bot/cogs/moderation/test_utils.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index e2345ea37..6722c2d16 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -6,15 +6,13 @@ from unittest.mock import AsyncMock, patch from discord import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError -from bot.cogs.moderation.utils import ( - RULES_URL, has_active_infraction, notify_infraction, notify_pardon, post_infraction, post_user, send_private_embed -) +from bot.cogs.moderation import utils from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser APPEAL_EMAIL = "appeals@pythondiscord.com" -INFRACTION_TITLE = f"Please review our rules over at {RULES_URL}" +INFRACTION_TITLE = f"Please review our rules over at {utils.RULES_URL}" INFRACTION_APPEAL_FOOTER = f"To appeal this infraction, send an e-mail to {APPEAL_EMAIL}" INFRACTION_AUTHOR_NAME = "Infraction information" INFRACTION_COLOR = Colours.soft_red @@ -89,7 +87,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): with self.subTest(args=args, return_value=return_value, expected=expected, get=get, send_vals=send_vals): self.bot.api_client.get.return_value = return_value - result = await has_active_infraction(*args) + result = await utils.has_active_infraction(*args) self.assertEqual(result, expected) self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params=get) @@ -172,7 +170,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.return_value = send - result = await notify_infraction(*args) + result = await utils.notify_infraction(*args) self.assertEqual(send, result) @@ -180,9 +178,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.title, INFRACTION_TITLE) self.assertEqual(embed.colour.value, INFRACTION_COLOR) - self.assertEqual(embed.url, RULES_URL) + self.assertEqual(embed.url, utils.RULES_URL) self.assertEqual(embed.author.name, INFRACTION_AUTHOR_NAME) - self.assertEqual(embed.author.url, RULES_URL) + self.assertEqual(embed.author.url, utils.RULES_URL) self.assertEqual(embed.author.icon_url, expected["icon_url"]) self.assertEqual(embed.footer.text, expected["footer"]) self.assertEqual(embed.description, expected["description"]) @@ -229,7 +227,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.return_value = send - result = await notify_pardon(*args) + result = await utils.notify_pardon(*args) self.assertEqual(send, result) @@ -312,7 +310,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if error: self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(), expected) - result = await post_user(*args) + result = await utils.post_user(*args) if error: self.assertIsNone(result) @@ -358,7 +356,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if raised: self.user.send.side_effect = raised - result = await send_private_embed(*args) + result = await utils.send_private_embed(*args) self.assertEqual(result, expected) if expected: @@ -467,7 +465,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.ctx.bot.api_client.post.return_value = expected - result = await post_infraction(*args) + result = await utils.post_infraction(*args) self.assertEqual(result, expected) -- cgit v1.2.3 From 2faa982722f2e9ed9a0710e0030a6078ecab421a Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 08:10:29 +0200 Subject: (Moderation Utils Tests): Hard-coded API get request params for `has_active_infraction` test. --- tests/bot/cogs/moderation/test_utils.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 6722c2d16..56bf6d67e 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -44,11 +44,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.ctx, self.member, "ban"), "get_return_value": [], "expected_output": False, - "get_call": { - "active": "true", - "type": "ban", - "user__id": str(self.member.id) - }, "send_params": None }, { @@ -65,11 +60,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "hidden": False }], "expected_output": True, - "get_call": { - "active": "true", - "type": "ban", - "user__id": str(self.member.id) - }, "send_params": ( f":x: According to my records, this user already has a ban infraction. " f"See infraction **#1**." @@ -81,15 +71,18 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): args = case["args"] return_value = case["get_return_value"] expected = case["expected_output"] - get = case["get_call"] send_vals = case["send_params"] - with self.subTest(args=args, return_value=return_value, expected=expected, get=get, send_vals=send_vals): + with self.subTest(args=args, return_value=return_value, expected=expected, send_vals=send_vals): self.bot.api_client.get.return_value = return_value result = await utils.has_active_infraction(*args) self.assertEqual(result, expected) - self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params=get) + self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params={ + "active": "true", + "type": "ban", + "user__id": str(self.member.id) + }) if result: self.ctx.send.assert_awaited_once_with(send_vals) -- cgit v1.2.3 From 94d3b1303ca55039e19a65043da3abe1ef09280b Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 08:29:01 +0200 Subject: (Moderation Utils Tests): Cleaned up `has_active_infraction` test cases, hard-coded args, moved mocks resetting to beginning of subtest, added `ctx.send` check only is infraction nr and type in sent string. --- tests/bot/cogs/moderation/test_utils.py | 41 +++++++++------------------------ 1 file changed, 11 insertions(+), 30 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 56bf6d67e..5868da61f 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -41,43 +41,26 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """ test_cases = [ { - "args": (self.ctx, self.member, "ban"), "get_return_value": [], "expected_output": False, - "send_params": None + "infraction_nr": None }, { - "args": (self.ctx, self.member, "ban"), - "get_return_value": [{ - "id": 1, - "inserted_at": "2018-11-22T07:24:06.132307Z", - "expires_at": "5018-11-20T15:52:00Z", - "active": True, - "user": 1234, - "actor": 1234, - "type": "ban", - "reason": "Test", - "hidden": False - }], + "get_return_value": [{"id": 1}], "expected_output": True, - "send_params": ( - f":x: According to my records, this user already has a ban infraction. " - f"See infraction **#1**." - ) + "infraction_nr": "**#1**" } ] for case in test_cases: - args = case["args"] - return_value = case["get_return_value"] - expected = case["expected_output"] - send_vals = case["send_params"] + with self.subTest(return_value=case["get_return_value"], expected=case["expected_output"]): + self.bot.api_client.get.reset_mock() + self.ctx.send.reset_mock() - with self.subTest(args=args, return_value=return_value, expected=expected, send_vals=send_vals): - self.bot.api_client.get.return_value = return_value + self.bot.api_client.get.return_value = case["get_return_value"] - result = await utils.has_active_infraction(*args) - self.assertEqual(result, expected) + result = await utils.has_active_infraction(self.ctx, self.member, "ban") + self.assertEqual(result, case["expected_output"]) self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params={ "active": "true", "type": "ban", @@ -85,10 +68,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): }) if result: - self.ctx.send.assert_awaited_once_with(send_vals) - - self.bot.api_client.get.reset_mock() - self.ctx.send.reset_mock() + self.assertTrue(case["infraction_nr"] in self.ctx.send.call_args[0][0]) + self.assertTrue("ban" in self.ctx.send.call_args[0][0]) @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_infraction(self, send_private_embed_mock): -- cgit v1.2.3 From f4bb6849f8f345ff99f6295e707aa0712af070a7 Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 08:36:42 +0200 Subject: (Moderation Utils Tests): Removed `Dict` unpacking in `notify_infraction` test. --- tests/bot/cogs/moderation/test_utils.py | 40 ++++++++++++++++----------------- 1 file changed, 20 insertions(+), 20 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 5868da61f..d6e300c89 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -78,11 +78,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), "expected_output": { - "description": INFRACTION_DESCRIPTION_TEMPLATE.format(**{ - "type": "Ban", - "expires": "2020-02-26 09:20 (23 hours and 59 minutes)", - "reason": "No reason provided." - }), + "description": INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Ban", + expires="2020-02-26 09:20 (23 hours and 59 minutes)", + reason="No reason provided." + ), "icon_url": Icons.token_removed, "footer": INFRACTION_APPEAL_FOOTER, }, @@ -92,11 +92,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.user, "warning", None, "Test reason."), "expected_output": { - "description": INFRACTION_DESCRIPTION_TEMPLATE.format(**{ - "type": "Warning", - "expires": "N/A", - "reason": "Test reason." - }), + "description": INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Warning", + expires="N/A", + reason="Test reason." + ), "icon_url": Icons.token_removed, "footer": Embed.Empty }, @@ -106,11 +106,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.user, "note", None, None, Icons.defcon_denied), "expected_output": { - "description": INFRACTION_DESCRIPTION_TEMPLATE.format(**{ - "type": "Note", - "expires": "N/A", - "reason": "No reason provided." - }), + "description": INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Note", + expires="N/A", + reason="No reason provided." + ), "icon_url": Icons.defcon_denied, "footer": Embed.Empty }, @@ -120,11 +120,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), "expected_output": { - "description": INFRACTION_DESCRIPTION_TEMPLATE.format(**{ - "type": "Mute", - "expires": "2020-02-26 09:20 (23 hours and 59 minutes)", - "reason": "Test" - }), + "description": INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Mute", + expires="2020-02-26 09:20 (23 hours and 59 minutes)", + reason="Test" + ), "icon_url": Icons.defcon_denied, "footer": INFRACTION_APPEAL_FOOTER }, -- cgit v1.2.3 From 2870472eae2f62982283d160378ca6953231da4e Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 08:39:40 +0200 Subject: (Moderation Utils Tests): Removed unnecessary `ctx.send` `side_effect` and removed these in test cases too in `notify_infraction` test. --- tests/bot/cogs/moderation/test_utils.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index d6e300c89..5ab279391 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -86,8 +86,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "icon_url": Icons.token_removed, "footer": INFRACTION_APPEAL_FOOTER, }, - "send_result": True, - "send_raise": None + "send_result": True }, { "args": (self.user, "warning", None, "Test reason."), @@ -100,8 +99,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "icon_url": Icons.token_removed, "footer": Embed.Empty }, - "send_result": False, - "send_raise": Forbidden(AsyncMock(), AsyncMock()) + "send_result": False }, { "args": (self.user, "note", None, None, Icons.defcon_denied), @@ -114,8 +112,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "icon_url": Icons.defcon_denied, "footer": Embed.Empty }, - "send_result": False, - "send_raise": NotFound(AsyncMock(), AsyncMock()) + "send_result": False }, { "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), @@ -128,19 +125,16 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "icon_url": Icons.defcon_denied, "footer": INFRACTION_APPEAL_FOOTER }, - "send_result": False, - "send_raise": HTTPException(AsyncMock(), AsyncMock()) + "send_result": False } ] for case in test_cases: args = case["args"] expected = case["expected_output"] - send, send_raise = case["send_result"], case["send_raise"] + send = case["send_result"] - with self.subTest(args=args, expected=expected, send=send, send_raise=send_raise): - if send_raise: - self.ctx.send.side_effect = send_raise + with self.subTest(args=args, expected=expected, send=send): send_private_embed_mock.return_value = send -- cgit v1.2.3 From f4fff7139ddffa08b12973d69c8f4bd7c47c0224 Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 08:41:23 +0200 Subject: (Moderation Utils Tests): Removed unnecessary `ctx.send` mock resetting, moved `send_private_embed` mock reset to beginning of subtest. --- tests/bot/cogs/moderation/test_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 5ab279391..5637ff508 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -135,9 +135,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send = case["send_result"] with self.subTest(args=args, expected=expected, send=send): + send_private_embed_mock.reset_mock() send_private_embed_mock.return_value = send - result = await utils.notify_infraction(*args) self.assertEqual(send, result) @@ -155,9 +155,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.assert_awaited_once_with(args[0], embed) - self.ctx.send.reset_mock(side_effect=True) - send_private_embed_mock.reset_mock() - @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_pardon(self, send_private_embed_mock): """Test does `notify_pardon` create correct embed and return correct bool.""" -- cgit v1.2.3 From 8c638bfa67c5e471089fad199bf2c5d64c0be163 Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 08:49:32 +0200 Subject: (Moderation Utils Tests): Moved `notify_infraction` embed check from dict to `Embed`. --- tests/bot/cogs/moderation/test_utils.py | 71 ++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 28 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 5637ff508..9844c02f9 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -77,54 +77,76 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): test_cases = [ { "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), - "expected_output": { - "description": INFRACTION_DESCRIPTION_TEMPLATE.format( + "expected_output": Embed( + title=INFRACTION_TITLE, + description=INFRACTION_DESCRIPTION_TEMPLATE.format( type="Ban", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="No reason provided." ), - "icon_url": Icons.token_removed, - "footer": INFRACTION_APPEAL_FOOTER, - }, + colour=INFRACTION_COLOR, + url=utils.RULES_URL + ).set_author( + name=INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.token_removed + ).set_footer(text=INFRACTION_APPEAL_FOOTER), "send_result": True }, { "args": (self.user, "warning", None, "Test reason."), - "expected_output": { - "description": INFRACTION_DESCRIPTION_TEMPLATE.format( + "expected_output": Embed( + title=INFRACTION_TITLE, + description=INFRACTION_DESCRIPTION_TEMPLATE.format( type="Warning", expires="N/A", reason="Test reason." ), - "icon_url": Icons.token_removed, - "footer": Embed.Empty - }, + colour=INFRACTION_COLOR, + url=utils.RULES_URL + ).set_author( + name=INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.token_removed + ), "send_result": False }, { "args": (self.user, "note", None, None, Icons.defcon_denied), - "expected_output": { - "description": INFRACTION_DESCRIPTION_TEMPLATE.format( + "expected_output": Embed( + title=INFRACTION_TITLE, + description=INFRACTION_DESCRIPTION_TEMPLATE.format( type="Note", expires="N/A", reason="No reason provided." ), - "icon_url": Icons.defcon_denied, - "footer": Embed.Empty - }, + colour=INFRACTION_COLOR, + url=utils.RULES_URL + ).set_author( + name=INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.defcon_denied + ), "send_result": False }, { "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), - "expected_output": { - "description": INFRACTION_DESCRIPTION_TEMPLATE.format( + "expected_output": Embed( + title=INFRACTION_TITLE, + description=INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="Test" ), - "icon_url": Icons.defcon_denied, - "footer": INFRACTION_APPEAL_FOOTER - }, + colour=INFRACTION_COLOR, + url=utils.RULES_URL + ).set_author( + name=INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.defcon_denied + ).set_footer( + text=INFRACTION_APPEAL_FOOTER + ), "send_result": False } ] @@ -144,14 +166,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): embed = send_private_embed_mock.call_args[0][1] - self.assertEqual(embed.title, INFRACTION_TITLE) - self.assertEqual(embed.colour.value, INFRACTION_COLOR) - self.assertEqual(embed.url, utils.RULES_URL) - self.assertEqual(embed.author.name, INFRACTION_AUTHOR_NAME) - self.assertEqual(embed.author.url, utils.RULES_URL) - self.assertEqual(embed.author.icon_url, expected["icon_url"]) - self.assertEqual(embed.footer.text, expected["footer"]) - self.assertEqual(embed.description, expected["description"]) + self.assertEqual(embed.to_dict(), expected.to_dict()) send_private_embed_mock.assert_awaited_once_with(args[0], embed) -- cgit v1.2.3 From 4260d3cf60f01f0de55a95290dd038d8a5c079ca Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 09:01:18 +0200 Subject: (Moderation Utils Tests): Removed unnecessary `ctx.send` `side_effect` from `notify_pardon`, applied changes to test cases. --- tests/bot/cogs/moderation/test_utils.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 9844c02f9..3616b3cf0 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -181,8 +181,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "title": "Test title", "icon_url": Icons.user_verified }, - "send_result": True, - "send_raise": None + "send_result": True }, { "args": (self.user, "Test title 1", "Example content 1", Icons.user_update), @@ -191,24 +190,21 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "title": "Test title 1", "icon_url": Icons.user_update }, - "send_result": False, - "send_raise": NotFound(AsyncMock(), AsyncMock()) + "send_result": False } ] for case in test_cases: args = case["args"] expected = case["expected_output"] - send, send_raise = case["send_result"], case["send_raise"] + send = case["send_result"] with self.subTest(args=args, expected=expected): - if send_raise: - self.ctx.send.side_effect = send_raise + send_private_embed_mock.reset_mock() send_private_embed_mock.return_value = send result = await utils.notify_pardon(*args) - self.assertEqual(send, result) embed = send_private_embed_mock.call_args[0][1] @@ -220,9 +216,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.assert_awaited_once_with(args[0], embed) - self.ctx.send.reset_mock(side_effect=True) - send_private_embed_mock.reset_mock() - async def test_post_user(self): """Test does `post_user` handle errors and results correctly.""" test_cases = [ -- cgit v1.2.3 From b01c2cd813b2df1f8a12e7b493e5085f4a8b9a6e Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 09:05:21 +0200 Subject: (Moderation Utils Tests): Moved `expected_output` from `Dict` to `discord.Embed` in `notify_pardon` test. --- tests/bot/cogs/moderation/test_utils.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 3616b3cf0..f8fbee4e2 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -176,20 +176,18 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): test_cases = [ { "args": (self.user, "Test title", "Example content"), - "expected_output": { - "description": "Example content", - "title": "Test title", - "icon_url": Icons.user_verified - }, + "expected_output": Embed( + description="Example content", + colour=PARDON_COLOR + ).set_author(name="Test title", icon_url=Icons.user_verified), "send_result": True }, { "args": (self.user, "Test title 1", "Example content 1", Icons.user_update), - "expected_output": { - "description": "Example content 1", - "title": "Test title 1", - "icon_url": Icons.user_update - }, + "expected_output": Embed( + description="Example content 1", + colour=PARDON_COLOR + ).set_author(name="Test title 1", icon_url=Icons.user_update), "send_result": False } ] @@ -208,11 +206,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(send, result) embed = send_private_embed_mock.call_args[0][1] - - self.assertEqual(embed.description, expected["description"]) - self.assertEqual(embed.colour.value, PARDON_COLOR) - self.assertEqual(embed.author.name, expected["title"]) - self.assertEqual(embed.author.icon_url, expected["icon_url"]) + self.assertEqual(embed.to_dict(), expected.to_dict()) send_private_embed_mock.assert_awaited_once_with(args[0], embed) -- cgit v1.2.3 From fc8b796d3c9d88cff959e8d5035bf62a257a7c9c Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 10:33:11 +0200 Subject: (Moderation Utils Tests): Added new check to `post_user` test (`ctx.send` content test), improved test cases. --- tests/bot/cogs/moderation/test_utils.py | 51 +++++++++++---------------------- 1 file changed, 16 insertions(+), 35 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index f8fbee4e2..5e9c627bb 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -212,54 +212,31 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): async def test_post_user(self): """Test does `post_user` handle errors and results correctly.""" + user = MockUser(avatar="abc", discriminator=5678, id=1234, name="Test user") test_cases = [ { - "args": (self.ctx, self.user), - "post_result": [ - { - "id": 1234, - "avatar": "test", - "name": "Test", - "discriminator": 1234, - "roles": [ - 1234, - 5678 - ], - "in_guild": False - } - ], + "args": (self.ctx, user), + "post_result": "bar", "raise_error": False, "payload": { - "avatar_hash": getattr(self.user, "avatar", 0), - "discriminator": int(getattr(self.user, "discriminator", 0)), + "avatar_hash": "abc", + "discriminator": 5678, "id": self.user.id, "in_guild": False, - "name": getattr(self.user, "name", "Name unknown"), + "name": "Test user", "roles": [] } }, { - "args": (self.ctx, self.user), - "post_result": [ - { - "id": 1234, - "avatar": "test", - "name": "Test", - "discriminator": 1234, - "roles": [ - 1234, - 5678 - ], - "in_guild": False - } - ], + "args": (self.ctx, self.member), + "post_result": "foo", "raise_error": True, "payload": { - "avatar_hash": getattr(self.user, "avatar", 0), - "discriminator": int(getattr(self.user, "discriminator", 0)), - "id": self.user.id, + "avatar_hash": 0, + "discriminator": 0, + "id": self.member.id, "in_guild": False, - "name": getattr(self.user, "name", "Name unknown"), + "name": "Name unknown", "roles": [] } } @@ -276,6 +253,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if error: self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(), expected) + err = self.ctx.bot.api_client.post.side_effect + err.status = 400 result = await utils.post_user(*args) @@ -286,6 +265,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if not error: self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) + else: + self.assertTrue(str(err.status) in self.ctx.send.call_args[0][0]) self.bot.api_client.post.reset_mock(side_effect=True) -- cgit v1.2.3 From 50582f1eeae46d25653eb545455e720df1d4b162 Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 10:37:43 +0200 Subject: (Moderation Utils Tests): Hard-coded args for `send_private_embed` test. --- tests/bot/cogs/moderation/test_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 5e9c627bb..b6bf1a96e 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,6 +1,5 @@ import unittest from datetime import datetime -from typing import Union from unittest.mock import AsyncMock, patch from discord import Embed, Forbidden, HTTPException, NotFound @@ -272,43 +271,40 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): async def test_send_private_embed(self): """Test does `send_private_embed` return correct bool.""" + embed = Embed(title="Test", description="Test val") + test_cases = [ { - "args": (self.user, Embed(title="Test", description="Test val")), "expected_output": True, "raised_exception": None }, { - "args": (self.user, Embed(title="Test", description="Test val")), "expected_output": False, "raised_exception": HTTPException(AsyncMock(), AsyncMock()) }, { - "args": (self.user, Embed(title="Test", description="Test val")), "expected_output": False, "raised_exception": Forbidden(AsyncMock(), AsyncMock()) }, { - "args": (self.user, Embed(title="Test", description="Test val")), "expected_output": False, "raised_exception": NotFound(AsyncMock(), AsyncMock()) } ] for case in test_cases: - args = case["args"] expected = case["expected_output"] - raised: Union[Forbidden, HTTPException, NotFound, None] = case["raised_exception"] + raised = case["raised_exception"] - with self.subTest(args=args, expected=expected, raised=raised): + with self.subTest(expected=expected, raised=raised): if raised: self.user.send.side_effect = raised - result = await utils.send_private_embed(*args) + result = await utils.send_private_embed(self.user, embed) self.assertEqual(result, expected) if expected: - args[0].send.assert_awaited_once_with(embed=args[1]) + self.user.send.assert_awaited_once_with(embed=embed) self.user.send.reset_mock(side_effect=True) -- cgit v1.2.3 From 4b211f5278dc5e14871468d05d0414d2b2f7de3c Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 10:38:56 +0200 Subject: (Moderation Utils Tests): Removed unnecessary `if` check from `send_private_embed` test --- tests/bot/cogs/moderation/test_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index b6bf1a96e..7291e42c6 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -297,8 +297,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): raised = case["raised_exception"] with self.subTest(expected=expected, raised=raised): - if raised: - self.user.send.side_effect = raised + self.user.send.side_effect = raised result = await utils.send_private_embed(self.user, embed) -- cgit v1.2.3 From 181971424f4e6c494f8ecb8f75919e27b784dcf5 Mon Sep 17 00:00:00 2001 From: ks123 Date: Sun, 8 Mar 2020 10:40:58 +0200 Subject: (Moderation Utils Tests): Moved mock resetting to beginning of subtest in `post_user` and `send_private_embed` test. --- tests/bot/cogs/moderation/test_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 7291e42c6..d43269b19 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -248,6 +248,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): payload = case["payload"] with self.subTest(args=args, result=expected, error=error, payload=payload): + self.bot.api_client.post.reset_mock(side_effect=True) self.ctx.bot.api_client.post.return_value = expected if error: @@ -267,8 +268,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): else: self.assertTrue(str(err.status) in self.ctx.send.call_args[0][0]) - self.bot.api_client.post.reset_mock(side_effect=True) - async def test_send_private_embed(self): """Test does `send_private_embed` return correct bool.""" embed = Embed(title="Test", description="Test val") @@ -297,6 +296,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): raised = case["raised_exception"] with self.subTest(expected=expected, raised=raised): + self.user.send.reset_mock(side_effect=True) self.user.send.side_effect = raised result = await utils.send_private_embed(self.user, embed) @@ -305,8 +305,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if expected: self.user.send.assert_awaited_once_with(embed=embed) - self.user.send.reset_mock(side_effect=True) - @patch("bot.cogs.moderation.utils.post_user") async def test_post_infraction(self, post_user_mock): """Test does `post_infraction` call functions correctly and return `None` or `Dict`.""" -- cgit v1.2.3 From 708e2165ff44b19d31bd6f2d8fdd7d3b408a9ef3 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 19:51:17 +0200 Subject: (Moderation Utils Tests): Create extra new tests set for `post_infraction` testing, removed old. --- tests/bot/cogs/moderation/test_utils.py | 163 ++++++++++++-------------------- 1 file changed, 60 insertions(+), 103 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index d43269b19..f34a56d50 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,6 +1,6 @@ import unittest from datetime import datetime -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from discord import Embed, Forbidden, HTTPException, NotFound @@ -305,114 +305,71 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if expected: self.user.send.assert_awaited_once_with(embed=embed) - @patch("bot.cogs.moderation.utils.post_user") - async def test_post_infraction(self, post_user_mock): - """Test does `post_infraction` call functions correctly and return `None` or `Dict`.""" - now = datetime.now() - test_cases = [ - { - "args": (self.ctx, self.member, "ban", "Test Ban"), - "expected_output": [ - { - "id": 1, - "inserted_at": "2018-11-22T07:24:06.132307Z", - "expires_at": "5018-11-20T15:52:00Z", - "active": True, - "user": 1234, - "actor": 1234, - "type": "ban", - "reason": "Test Ban", - "hidden": False - } - ], - "raised_error": None, - "payload": { - "actor": self.ctx.message.author.id, - "hidden": False, - "reason": "Test Ban", - "type": "ban", - "user": self.member.id, - "active": True - } - }, - { - "args": (self.ctx, self.member, "note", "Test Ban"), - "expected_output": None, - "raised_error": ResponseCodeError(AsyncMock(), AsyncMock()), - "payload": { - "actor": self.ctx.message.author.id, - "hidden": False, - "reason": "Test Ban", - "type": "note", - "user": self.member.id, - "active": True - } - }, - { - "args": (self.ctx, self.member, "mute", "Test Ban"), - "expected_output": None, - "raised_error": ResponseCodeError(AsyncMock(status=400), {'user': 1234}), - "payload": { - "actor": self.ctx.message.author.id, - "hidden": False, - "reason": "Test Ban", - "type": "mute", - "user": self.member.id, - "active": True - } - }, - { - "args": (self.ctx, self.member, "ban", "Test Ban", now, True, False), - "expected_output": [ - { - "id": 1, - "inserted_at": "2018-11-22T07:24:06.132307Z", - "expires_at": "5018-11-20T15:52:00Z", - "active": True, - "user": 1234, - "actor": 1234, - "type": "ban", - "reason": "Test Ban", - "hidden": False - } - ], - "raised_error": None, - "payload": { - "actor": self.ctx.message.author.id, - "hidden": True, - "reason": "Test Ban", - "type": "ban", - "user": self.member.id, - "active": False, - "expires_at": now.isoformat() - } - }, - ] - for case in test_cases: - args = case["args"] - expected = case["expected_output"] - raised = case["raised_error"] - payload = case["payload"] +class TestPostInfraction(unittest.IsolatedAsyncioTestCase): + """Tests for `post_infraction` function.""" - with self.subTest(args=args, expected=expected, raised=raised, payload=payload): - self.ctx.bot.api_client.post.reset_mock(side_effect=True) - post_user_mock.reset_mock() + def setUp(self): + self.bot = MockBot() + self.member = MockMember(id=1234) + self.user = MockUser(id=1234) + self.ctx = MockContext(bot=self.bot, author=self.member) - if raised: - self.ctx.bot.api_client.post.side_effect = raised + async def test_normal_post_infraction(self): + """Test does `post_infraction` return correct value when no errors raise.""" + now = datetime.now() + payload = { + "actor": self.ctx.message.author.id, + "hidden": True, + "reason": "Test reason", + "type": "ban", + "user": self.member.id, + "active": False, + "expires_at": now.isoformat() + } - post_user_mock.return_value = "foo" + self.ctx.bot.api_client.post.return_value = "foo" + actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) - self.ctx.bot.api_client.post.return_value = expected + self.assertEqual(actual, "foo") + self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) - result = await utils.post_infraction(*args) + async def test_unknown_error_post_infraction(self): + """Test does `post_infraction` send info about fail to chat (`ctx.send`).""" + self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(), AsyncMock()) + self.ctx.bot.api_client.post.side_effect.status = 500 - self.assertEqual(result, expected) + actual = await utils.post_infraction(self.ctx, self.user, "ban", "Test reason") + self.assertIsNone(actual) - if not raised: - self.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + self.assertTrue("500" in self.ctx.send.call_args[0][0]) - if hasattr(raised, "status") and hasattr(raised, "response_json"): - if raised.status == 400 and "user" in raised.response_json: - post_user_mock.assert_awaited_once_with(args[0], args[1]) + @patch("bot.cogs.moderation.utils.post_user") + async def test_user_not_found_none_post_infraction(self, post_user_mock): + """Test does `post_infraction` return `None` correctly due can't create new user.""" + self.bot.api_client.post.side_effect = ResponseCodeError(MagicMock(status=400), {"user": "foo"}) + post_user_mock.return_value = None + + actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") + self.assertIsNone(actual) + post_user_mock.assert_awaited_once_with(self.ctx, self.user) + + @patch("bot.cogs.moderation.utils.post_user") + async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): + """Test does `post_infraction` fail first time and return correct result 2nd time when new user posted.""" + payload = { + "actor": self.ctx.message.author.id, + "hidden": False, + "reason": "Test reason", + "type": "mute", + "user": self.user.id, + "active": True + } + + self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] + post_user_mock.return_value = "bar" + + actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") + self.assertEqual(actual, "foo") + self.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + post_user_mock.assert_awaited_once_with(self.ctx, self.user) -- cgit v1.2.3 From f793c0772c7b1c5c4edb457fb372e9a57a8200ff Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 19:54:37 +0200 Subject: (Moderation Utils Tests): Added params to variable in `has_active_infraction` test. --- tests/bot/cogs/moderation/test_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index f34a56d50..3432ff595 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -56,15 +56,17 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.reset_mock() self.ctx.send.reset_mock() + params = { + "active": "true", + "type": "ban", + "user__id": str(self.member.id) + } + self.bot.api_client.get.return_value = case["get_return_value"] result = await utils.has_active_infraction(self.ctx, self.member, "ban") self.assertEqual(result, case["expected_output"]) - self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params={ - "active": "true", - "type": "ban", - "user__id": str(self.member.id) - }) + self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params=params) if result: self.assertTrue(case["infraction_nr"] in self.ctx.send.call_args[0][0]) -- cgit v1.2.3 From cd1193ec09c5259ff2f2c5906faf20ed788326c9 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 20:00:24 +0200 Subject: (Moderation Utils Tests): Moved embed generating to test cases loop from test cases listing, added icon to test cases in `notify_pardon` test --- tests/bot/cogs/moderation/test_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 3432ff595..7f5e441b7 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -177,27 +177,27 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): test_cases = [ { "args": (self.user, "Test title", "Example content"), - "expected_output": Embed( - description="Example content", - colour=PARDON_COLOR - ).set_author(name="Test title", icon_url=Icons.user_verified), + "icon": Icons.user_verified, "send_result": True }, { - "args": (self.user, "Test title 1", "Example content 1", Icons.user_update), - "expected_output": Embed( - description="Example content 1", - colour=PARDON_COLOR - ).set_author(name="Test title 1", icon_url=Icons.user_update), + "args": (self.user, "Test title", "Example content", Icons.user_update), + "icon": Icons.user_update, "send_result": False } ] for case in test_cases: args = case["args"] - expected = case["expected_output"] send = case["send_result"] + expected = Embed( + description="Example content", + colour=PARDON_COLOR).set_author( + name="Test title", + icon_url=case["icon"] + ) + with self.subTest(args=args, expected=expected): send_private_embed_mock.reset_mock() -- cgit v1.2.3 From 6376b6a47033c51b80d32e0cf00e6d13ca9d05c3 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 20:03:00 +0200 Subject: (Moderation Utils Tests): Removed unnecessary symbols from `has_active_infraction` test `infraction_nr` variable and changes this to more unique number. --- tests/bot/cogs/moderation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 7f5e441b7..2f66904d8 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -45,9 +45,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "infraction_nr": None }, { - "get_return_value": [{"id": 1}], + "get_return_value": [{"id": 123987}], "expected_output": True, - "infraction_nr": "**#1**" + "infraction_nr": "123987" } ] -- cgit v1.2.3 From f93be96dd484e4b484a3a7e8c1c3b79062b6c386 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 20:05:03 +0200 Subject: (Moderation Utils Tests): Fixed formatting in `notify_infraction` test. --- tests/bot/cogs/moderation/test_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 2f66904d8..61fb618d4 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -145,9 +145,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): name=INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, icon_url=Icons.defcon_denied - ).set_footer( - text=INFRACTION_APPEAL_FOOTER - ), + ).set_footer(text=INFRACTION_APPEAL_FOOTER), "send_result": False } ] -- cgit v1.2.3 From b70d2fc557bb2bdbc32f905132a0e80272f174a2 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 20:07:27 +0200 Subject: (Moderation Utils Tests): Hard-coded `self.ctx` argument to `post_user` test, renamed current `args` to `user`, applied this in code. --- tests/bot/cogs/moderation/test_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 61fb618d4..3f721d182 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -214,7 +214,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): user = MockUser(avatar="abc", discriminator=5678, id=1234, name="Test user") test_cases = [ { - "args": (self.ctx, user), + "user": user, "post_result": "bar", "raise_error": False, "payload": { @@ -227,7 +227,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): } }, { - "args": (self.ctx, self.member), + "user": self.member, "post_result": "foo", "raise_error": True, "payload": { @@ -242,12 +242,12 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): ] for case in test_cases: - args = case["args"] + test_user = case["user"] expected = case["post_result"] error = case["raise_error"] payload = case["payload"] - with self.subTest(args=args, result=expected, error=error, payload=payload): + with self.subTest(user=test_user, result=expected, error=error, payload=payload): self.bot.api_client.post.reset_mock(side_effect=True) self.ctx.bot.api_client.post.return_value = expected @@ -256,7 +256,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): err = self.ctx.bot.api_client.post.side_effect err.status = 400 - result = await utils.post_user(*args) + result = await utils.post_user(self.ctx, test_user) if error: self.assertIsNone(result) -- cgit v1.2.3 From 4724b66c4a337b1735f7b08cb60c4cbcb68a6e3c Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 20:13:53 +0200 Subject: (Moderation Utils Tests): Move errors from booleans to actual errors in `post_user` test. --- tests/bot/cogs/moderation/test_utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 3f721d182..9afa5ab0b 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -216,7 +216,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "user": user, "post_result": "bar", - "raise_error": False, + "raise_error": None, "payload": { "avatar_hash": "abc", "discriminator": 5678, @@ -229,7 +229,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "user": self.member, "post_result": "foo", - "raise_error": True, + "raise_error": ResponseCodeError(MagicMock(status=400), "foo"), "payload": { "avatar_hash": 0, "discriminator": 0, @@ -251,10 +251,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.post.reset_mock(side_effect=True) self.ctx.bot.api_client.post.return_value = expected - if error: - self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(), expected) - err = self.ctx.bot.api_client.post.side_effect - err.status = 400 + self.ctx.bot.api_client.post.side_effect = error result = await utils.post_user(self.ctx, test_user) @@ -266,7 +263,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if not error: self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) else: - self.assertTrue(str(err.status) in self.ctx.send.call_args[0][0]) + self.assertTrue(str(error.status) in self.ctx.send.call_args[0][0]) async def test_send_private_embed(self): """Test does `send_private_embed` return correct bool.""" -- cgit v1.2.3 From 35ffc216e62a46aa6ba6fcb7d0717354d176e175 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 20:15:28 +0200 Subject: (Moderation Utils Tests): Added call check for `ctx.send` in `post_user` test. --- tests/bot/cogs/moderation/test_utils.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 9afa5ab0b..e0af13a46 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -263,6 +263,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if not error: self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) else: + self.ctx.send.assert_awaited_once() self.assertTrue(str(error.status) in self.ctx.send.call_args[0][0]) async def test_send_private_embed(self): -- cgit v1.2.3 From fe504bd30360df0c18fbfccfb27d2652ac19e9b8 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 12 Mar 2020 20:19:43 +0200 Subject: (Moderation Utils Tests): Added mock reset due fail. --- tests/bot/cogs/moderation/test_utils.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index e0af13a46..2cba37e3a 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -355,6 +355,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.post_user") async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): """Test does `post_infraction` fail first time and return correct result 2nd time when new user posted.""" + self.bot.api_client.post.reset_mock() + payload = { "actor": self.ctx.message.author.id, "hidden": False, -- cgit v1.2.3 From 5ac2aa48109f16e96195dda60e3a70b65b9562fa Mon Sep 17 00:00:00 2001 From: Karlis S <45097959+ks129@users.noreply.github.com> Date: Thu, 12 Mar 2020 21:10:23 +0200 Subject: (Moderation Utils Tests): Removed `once` from `post_infraction` test due tests failing. --- tests/bot/cogs/moderation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 2cba37e3a..e23585c99 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -371,5 +371,5 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") self.assertEqual(actual, "foo") - self.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + self.bot.api_client.post.assert_awaited_with("bot/infractions", json=payload) post_user_mock.assert_awaited_once_with(self.ctx, self.user) -- cgit v1.2.3 From dac8e758201e938c5e694efb63b96485fb771274 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 12 Mar 2020 19:22:27 -0700 Subject: Revise docstrings for moderation util tests --- tests/bot/cogs/moderation/test_utils.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index e23585c99..ca951250f 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -36,7 +36,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): async def test_user_has_active_infraction(self): """ - Test does `has_active_infraction` return call at least once `ctx.send` API get, check does return correct bool. + Should request the API for active infractions and return `True` if the user has one or `False` otherwise. + + A message should be sent to the context indicating a user already has an infraction, if that's the case. """ test_cases = [ { @@ -74,7 +76,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_infraction(self, send_private_embed_mock): - """Test does `notify_infraction` create correct embed and return correct boolean.""" + """ + Should send an embed of a certain format as a DM and return `True` if DM successful. + + Appealable infractions should have the appeal message in the embed's footer. + """ test_cases = [ { "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), @@ -171,7 +177,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_pardon(self, send_private_embed_mock): - """Test does `notify_pardon` create correct embed and return correct bool.""" + """Should send an embed of a certain format as a DM and return `True` if DM successful.""" test_cases = [ { "args": (self.user, "Test title", "Example content"), @@ -210,7 +216,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.assert_awaited_once_with(args[0], embed) async def test_post_user(self): - """Test does `post_user` handle errors and results correctly.""" + """Should POST a new user and return the response if successful or otherwise send an error message.""" user = MockUser(avatar="abc", discriminator=5678, id=1234, name="Test user") test_cases = [ { @@ -267,7 +273,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(str(error.status) in self.ctx.send.call_args[0][0]) async def test_send_private_embed(self): - """Test does `send_private_embed` return correct bool.""" + """Should DM the user and return `True` on success or `False` on failure.""" embed = Embed(title="Test", description="Test val") test_cases = [ @@ -305,7 +311,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): class TestPostInfraction(unittest.IsolatedAsyncioTestCase): - """Tests for `post_infraction` function.""" + """Tests for the `post_infraction` function.""" def setUp(self): self.bot = MockBot() @@ -314,7 +320,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.ctx = MockContext(bot=self.bot, author=self.member) async def test_normal_post_infraction(self): - """Test does `post_infraction` return correct value when no errors raise.""" + """Should return response from POST request if there are no errors.""" now = datetime.now() payload = { "actor": self.ctx.message.author.id, @@ -333,7 +339,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) async def test_unknown_error_post_infraction(self): - """Test does `post_infraction` send info about fail to chat (`ctx.send`).""" + """Should send an error message to chat when a non-400 error occurs.""" self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(), AsyncMock()) self.ctx.bot.api_client.post.side_effect.status = 500 @@ -344,7 +350,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.post_user") async def test_user_not_found_none_post_infraction(self, post_user_mock): - """Test does `post_infraction` return `None` correctly due can't create new user.""" + """Should abort and return `None` when a new user fails to be posted.""" self.bot.api_client.post.side_effect = ResponseCodeError(MagicMock(status=400), {"user": "foo"}) post_user_mock.return_value = None @@ -354,7 +360,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.post_user") async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): - """Test does `post_infraction` fail first time and return correct result 2nd time when new user posted.""" + """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" self.bot.api_client.post.reset_mock() payload = { -- cgit v1.2.3 From 3043dd1d565943f180a5ae16e46e6daa531466c7 Mon Sep 17 00:00:00 2001 From: Karlis S Date: Fri, 13 Mar 2020 07:25:43 +0000 Subject: (Moderation Utils Tests): Added 2 call check to `post_infraction` test. --- tests/bot/cogs/moderation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index ca951250f..659884d93 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,6 +1,6 @@ import unittest from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch from discord import Embed, Forbidden, HTTPException, NotFound @@ -377,5 +377,5 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") self.assertEqual(actual, "foo") - self.bot.api_client.post.assert_awaited_with("bot/infractions", json=payload) + self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) post_user_mock.assert_awaited_once_with(self.ctx, self.user) -- cgit v1.2.3 From 7806e11a017698ff43494a1fa1a908e11bb63e33 Mon Sep 17 00:00:00 2001 From: Karlis S Date: Fri, 13 Mar 2020 07:27:50 +0000 Subject: (Moderation Utils Tests): Removed unnecessary mock resetting. --- tests/bot/cogs/moderation/test_utils.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 659884d93..03f086ba9 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -361,8 +361,6 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.post_user") async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" - self.bot.api_client.post.reset_mock() - payload = { "actor": self.ctx.message.author.id, "hidden": False, -- cgit v1.2.3 From 1052ad4213348ede7ec6e495d32e21b3818153e0 Mon Sep 17 00:00:00 2001 From: Karlis S Date: Fri, 13 Mar 2020 07:31:00 +0000 Subject: (Moderation Utils Tests): Moved `return_value` to `patch` decorator. --- tests/bot/cogs/moderation/test_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 03f086ba9..6702372d6 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -348,17 +348,16 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.assertTrue("500" in self.ctx.send.call_args[0][0]) - @patch("bot.cogs.moderation.utils.post_user") + @patch("bot.cogs.moderation.utils.post_user", return_value=None) async def test_user_not_found_none_post_infraction(self, post_user_mock): """Should abort and return `None` when a new user fails to be posted.""" self.bot.api_client.post.side_effect = ResponseCodeError(MagicMock(status=400), {"user": "foo"}) - post_user_mock.return_value = None actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") self.assertIsNone(actual) post_user_mock.assert_awaited_once_with(self.ctx, self.user) - @patch("bot.cogs.moderation.utils.post_user") + @patch("bot.cogs.moderation.utils.post_user", return_value="bar") async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" payload = { @@ -371,7 +370,6 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): } self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] - post_user_mock.return_value = "bar" actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") self.assertEqual(actual, "foo") -- cgit v1.2.3 From 1d486096e20dde3bcf6bc95ced0557840625e84d Mon Sep 17 00:00:00 2001 From: Karlis S Date: Fri, 13 Mar 2020 07:32:21 +0000 Subject: (Moderation Utils Tests): Fixed formatting in `notify_pardon` test. --- tests/bot/cogs/moderation/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 6702372d6..2e4c31836 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -197,7 +197,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): expected = Embed( description="Example content", - colour=PARDON_COLOR).set_author( + colour=PARDON_COLOR + ).set_author( name="Test title", icon_url=case["icon"] ) -- cgit v1.2.3 From 897b378a09c1a058f10a92aa23c21e2f737e6819 Mon Sep 17 00:00:00 2001 From: Karlis S Date: Fri, 13 Mar 2020 17:42:39 +0000 Subject: (Moderation Utils Tests): Removed Infraction Color constant. --- tests/bot/cogs/moderation/test_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 2e4c31836..f30e85b12 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -14,7 +14,6 @@ APPEAL_EMAIL = "appeals@pythondiscord.com" INFRACTION_TITLE = f"Please review our rules over at {utils.RULES_URL}" INFRACTION_APPEAL_FOOTER = f"To appeal this infraction, send an e-mail to {APPEAL_EMAIL}" INFRACTION_AUTHOR_NAME = "Infraction information" -INFRACTION_COLOR = Colours.soft_red INFRACTION_DESCRIPTION_TEMPLATE = ( "\n**Type:** {type}\n" @@ -91,7 +90,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="No reason provided." ), - colour=INFRACTION_COLOR, + colour=Colours.soft_red, url=utils.RULES_URL ).set_author( name=INFRACTION_AUTHOR_NAME, @@ -109,7 +108,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): expires="N/A", reason="Test reason." ), - colour=INFRACTION_COLOR, + colour=Colours.soft_red, url=utils.RULES_URL ).set_author( name=INFRACTION_AUTHOR_NAME, @@ -127,7 +126,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): expires="N/A", reason="No reason provided." ), - colour=INFRACTION_COLOR, + colour=Colours.soft_red, url=utils.RULES_URL ).set_author( name=INFRACTION_AUTHOR_NAME, @@ -145,7 +144,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="Test" ), - colour=INFRACTION_COLOR, + colour=Colours.soft_red, url=utils.RULES_URL ).set_author( name=INFRACTION_AUTHOR_NAME, -- cgit v1.2.3 From ea2e8bbe320996d4292f958240e231d622d0481d Mon Sep 17 00:00:00 2001 From: Karlis S Date: Fri, 13 Mar 2020 17:45:15 +0000 Subject: (Moderation Utils Tests): Removed Pardon Color constant. --- tests/bot/cogs/moderation/test_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index f30e85b12..52bdb5fbc 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -21,8 +21,6 @@ INFRACTION_DESCRIPTION_TEMPLATE = ( "**Reason:** {reason}\n" ) -PARDON_COLOR = Colours.soft_green - class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """Tests Moderation utils.""" @@ -196,7 +194,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): expected = Embed( description="Example content", - colour=PARDON_COLOR + colour=Colours.soft_green ).set_author( name="Test title", icon_url=case["icon"] -- cgit v1.2.3 From 99e1239f4734d0ed34688fa77d5094f8984b9209 Mon Sep 17 00:00:00 2001 From: Karlis S Date: Fri, 13 Mar 2020 18:03:40 +0000 Subject: (Mod Utils + Tests): Moved constants from tests to utils, applied change --- bot/cogs/moderation/utils.py | 28 ++++++++++++++++------- tests/bot/cogs/moderation/test_utils.py | 40 ++++++++++++--------------------- 2 files changed, 34 insertions(+), 34 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index 5052b9048..8121a0af8 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -28,6 +28,18 @@ UserObject = t.Union[discord.Member, discord.User] UserSnowflake = t.Union[UserObject, discord.Object] Infraction = t.Dict[str, t.Union[str, int, bool]] +APPEAL_EMAIL = "appeals@pythondiscord.com" + +INFRACTION_TITLE = f"Please review our rules over at {RULES_URL}" +INFRACTION_APPEAL_FOOTER = f"To appeal this infraction, send an e-mail to {APPEAL_EMAIL}" +INFRACTION_AUTHOR_NAME = "Infraction information" + +INFRACTION_DESCRIPTION_TEMPLATE = ( + "\n**Type:** {type}\n" + "**Expires:** {expires}\n" + "**Reason:** {reason}\n" +) + async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: """ @@ -132,21 +144,21 @@ async def notify_infraction( log.trace(f"Sending {user} a DM about their {infr_type} infraction.") embed = discord.Embed( - description=textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """), + description=INFRACTION_DESCRIPTION_TEMPLATE.format( + type=infr_type.capitalize(), + expires=expires_at or "N/A", + reason=reason or "No reason provided." + ), colour=Colours.soft_red ) - embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) - embed.title = f"Please review our rules over at {RULES_URL}" + embed.set_author(name=INFRACTION_AUTHOR_NAME, icon_url=icon_url, url=RULES_URL) + embed.title = INFRACTION_TITLE embed.url = RULES_URL if infr_type in APPEALABLE_INFRACTIONS: embed.set_footer( - text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" + text=INFRACTION_APPEAL_FOOTER ) return await send_private_embed(user, embed) diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 52bdb5fbc..4f81a2477 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -9,18 +9,6 @@ from bot.cogs.moderation import utils from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser -APPEAL_EMAIL = "appeals@pythondiscord.com" - -INFRACTION_TITLE = f"Please review our rules over at {utils.RULES_URL}" -INFRACTION_APPEAL_FOOTER = f"To appeal this infraction, send an e-mail to {APPEAL_EMAIL}" -INFRACTION_AUTHOR_NAME = "Infraction information" - -INFRACTION_DESCRIPTION_TEMPLATE = ( - "\n**Type:** {type}\n" - "**Expires:** {expires}\n" - "**Reason:** {reason}\n" -) - class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """Tests Moderation utils.""" @@ -82,8 +70,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), "expected_output": Embed( - title=INFRACTION_TITLE, - description=INFRACTION_DESCRIPTION_TEMPLATE.format( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Ban", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="No reason provided." @@ -91,17 +79,17 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): colour=Colours.soft_red, url=utils.RULES_URL ).set_author( - name=INFRACTION_AUTHOR_NAME, + name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, icon_url=Icons.token_removed - ).set_footer(text=INFRACTION_APPEAL_FOOTER), + ).set_footer(text=utils.INFRACTION_APPEAL_FOOTER), "send_result": True }, { "args": (self.user, "warning", None, "Test reason."), "expected_output": Embed( - title=INFRACTION_TITLE, - description=INFRACTION_DESCRIPTION_TEMPLATE.format( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Warning", expires="N/A", reason="Test reason." @@ -109,7 +97,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): colour=Colours.soft_red, url=utils.RULES_URL ).set_author( - name=INFRACTION_AUTHOR_NAME, + name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, icon_url=Icons.token_removed ), @@ -118,8 +106,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.user, "note", None, None, Icons.defcon_denied), "expected_output": Embed( - title=INFRACTION_TITLE, - description=INFRACTION_DESCRIPTION_TEMPLATE.format( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Note", expires="N/A", reason="No reason provided." @@ -127,7 +115,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): colour=Colours.soft_red, url=utils.RULES_URL ).set_author( - name=INFRACTION_AUTHOR_NAME, + name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, icon_url=Icons.defcon_denied ), @@ -136,8 +124,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), "expected_output": Embed( - title=INFRACTION_TITLE, - description=INFRACTION_DESCRIPTION_TEMPLATE.format( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="Test" @@ -145,10 +133,10 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): colour=Colours.soft_red, url=utils.RULES_URL ).set_author( - name=INFRACTION_AUTHOR_NAME, + name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, icon_url=Icons.defcon_denied - ).set_footer(text=INFRACTION_APPEAL_FOOTER), + ).set_footer(text=utils.INFRACTION_APPEAL_FOOTER), "send_result": False } ] -- cgit v1.2.3 From 5b11b248b945cd2a732c6d8d430d117fc062cc8d Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Thu, 7 May 2020 16:46:32 +0200 Subject: Remove tests from moved function. --- tests/bot/cogs/test_snekbox.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 1dec0ccaf..d32d80ead 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -1,5 +1,4 @@ import asyncio -import logging import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch @@ -53,20 +52,6 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): raise_for_status=True ) - async def test_upload_output_gracefully_fallback_if_exception_during_request(self): - """Output upload gracefully fallback if the upload fail.""" - resp = MagicMock() - resp.json = AsyncMock(side_effect=Exception) - self.bot.http_session.post().__aenter__.return_value = resp - - log = logging.getLogger("bot.cogs.snekbox") - with self.assertLogs(logger=log, level='ERROR'): - await self.cog.upload_output('My awesome output!') - - async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): - """Output upload gracefully fallback if there is no key entry in the response body.""" - self.assertEqual((await self.cog.upload_output('My awesome output!')), None) - def test_prepare_input(self): cases = ( ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), -- cgit v1.2.3 From 14c670dfa87e142e24c027e2976fa02b07c4d7ac Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Thu, 7 May 2020 17:11:56 +0200 Subject: Adjust behaviour for new func usage. --- tests/bot/cogs/test_snekbox.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index d32d80ead..f4c13fc43 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -35,21 +35,12 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) self.assertEqual(result, "too long to upload") - async def test_upload_output(self): + @patch("bot.cogs.snekbox.send_to_paste_service") + async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - key = "MarkDiamond" - resp = MagicMock() - resp.json = AsyncMock(return_value={"key": key}) - self.bot.http_session.post().__aenter__.return_value = resp - - self.assertEqual( - await self.cog.upload_output("My awesome output"), - constants.URLs.paste_service.format(key=key) - ) - self.bot.http_session.post.assert_called_with( - constants.URLs.paste_service.format(key="documents"), - data="My awesome output", - raise_for_status=True + await self.cog.upload_output("Test output.") + mock_paste_util.assert_called_once_with( + self.bot.http_session, "Test output.", extension="txt" ) def test_prepare_input(self): -- cgit v1.2.3 From 5d96e96a2e8982ec57c1a19d1a085ceccd35a6d7 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Fri, 8 May 2020 01:38:14 +0200 Subject: Add tests for `send_to_paste_service`. --- tests/bot/utils/test_init.py | 74 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/bot/utils/test_init.py (limited to 'tests') diff --git a/tests/bot/utils/test_init.py b/tests/bot/utils/test_init.py new file mode 100644 index 000000000..f3a8f5939 --- /dev/null +++ b/tests/bot/utils/test_init.py @@ -0,0 +1,74 @@ +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +from aiohttp import ClientConnectorError + +from bot.utils import FAILED_REQUEST_ATTEMPTS, send_to_paste_service + + +class PasteTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.http_session = MagicMock() + + @patch("bot.utils.URLs.paste_service", "https://paste_service.com/{key}") + async def test_url_and_sent_contents(self): + """Correct url was used and post was called with expected data.""" + response = MagicMock( + json=AsyncMock(return_value={"key": ""}) + ) + self.http_session.post().__aenter__.return_value = response + self.http_session.post.reset_mock() + await send_to_paste_service(self.http_session, "Content") + self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") + + @patch("bot.utils.URLs.paste_service", "https://paste_service.com/{key}") + async def test_paste_returns_correct_url_on_success(self): + """Url with specified extension is returned on successful requests.""" + key = "paste_key" + test_cases = ( + (f"https://paste_service.com/{key}.txt", "txt"), + (f"https://paste_service.com/{key}.py", "py"), + (f"https://paste_service.com/{key}", ""), + ) + response = MagicMock( + json=AsyncMock(return_value={"key": key}) + ) + self.http_session.post().__aenter__.return_value = response + + for expected_output, extension in test_cases: + with self.subTest(msg=f"Send contents with extension {repr(extension)}"): + self.assertEqual( + await send_to_paste_service(self.http_session, "", extension=extension), + expected_output + ) + + async def test_request_repeated_on_json_errors(self): + """Json with error message and invalid json are handled as errors and requests repeated.""" + test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) + self.http_session.post().__aenter__.return_value = response = MagicMock() + self.http_session.post.reset_mock() + + for error_json in test_cases: + with self.subTest(error_json=error_json): + response.json = AsyncMock(return_value=error_json) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + self.http_session.post.reset_mock() + + async def test_request_repeated_on_connection_errors(self): + """Requests are repeated in the case of connection errors.""" + self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + async def test_general_error_handled_and_request_repeated(self): + """All `Exception`s are handled, logged and request repeated.""" + self.http_session.post = MagicMock(side_effect=Exception) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertLogs("bot.utils", logging.ERROR) + self.assertIsNone(result) -- cgit v1.2.3 From 72d2f662ff84c8bfca448870e8d7e60777301a68 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 14 May 2020 19:39:45 +0300 Subject: Mod Utils Tests: Replace `has_active_infraction` with `get_active_infraction` --- tests/bot/cogs/moderation/test_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 4f81a2477..248adbcb8 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -19,21 +19,21 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) - async def test_user_has_active_infraction(self): + async def test_user_get_active_infraction(self): """ - Should request the API for active infractions and return `True` if the user has one or `False` otherwise. + Should request the API for active infractions and return infraction if the user has one or `None` otherwise. A message should be sent to the context indicating a user already has an infraction, if that's the case. """ test_cases = [ { "get_return_value": [], - "expected_output": False, + "expected_output": None, "infraction_nr": None }, { "get_return_value": [{"id": 123987}], - "expected_output": True, + "expected_output": {"id": 123987}, "infraction_nr": "123987" } ] @@ -51,7 +51,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = case["get_return_value"] - result = await utils.has_active_infraction(self.ctx, self.member, "ban") + result = await utils.get_active_infraction(self.ctx, self.member, "ban") self.assertEqual(result, case["expected_output"]) self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params=params) -- cgit v1.2.3 From 6c58ecb647b046c6a9a1e2b6d9b4d0e0f326e9bd Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Fri, 12 Jun 2020 16:54:55 +0300 Subject: Remove deprecated avatar hash in `test_post_user` --- tests/bot/cogs/moderation/test_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 248adbcb8..596f077b5 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -203,14 +203,13 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): async def test_post_user(self): """Should POST a new user and return the response if successful or otherwise send an error message.""" - user = MockUser(avatar="abc", discriminator=5678, id=1234, name="Test user") + user = MockUser(discriminator=5678, id=1234, name="Test user") test_cases = [ { "user": user, "post_result": "bar", "raise_error": None, "payload": { - "avatar_hash": "abc", "discriminator": 5678, "id": self.user.id, "in_guild": False, @@ -223,7 +222,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "post_result": "foo", "raise_error": ResponseCodeError(MagicMock(status=400), "foo"), "payload": { - "avatar_hash": 0, "discriminator": 0, "id": self.member.id, "in_guild": False, -- cgit v1.2.3 From 4cc6f759f53ebe31d5025ff902189ab211409d4f Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Fri, 12 Jun 2020 19:30:30 +0300 Subject: Implement description shortening to infraction notify tests --- tests/bot/cogs/moderation/test_utils.py | 35 +++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 596f077b5..363d8938a 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,3 +1,4 @@ +import textwrap import unittest from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch @@ -71,11 +72,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Ban", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="No reason provided." - ), + ), width=2048, placeholder="..."), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( @@ -89,11 +90,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "warning", None, "Test reason."), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Warning", expires="N/A", reason="Test reason." - ), + ), width=2048, placeholder="..."), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( @@ -107,11 +108,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "note", None, None, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Note", expires="N/A", reason="No reason provided." - ), + ), width=2048, placeholder="..."), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( @@ -125,11 +126,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="Test" - ), + ), width=2048, placeholder="..."), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( @@ -138,6 +139,24 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): icon_url=Icons.defcon_denied ).set_footer(text=utils.INFRACTION_APPEAL_FOOTER), "send_result": False + }, + { + "args": (self.user, "mute", None, "foo bar" * 4000, Icons.defcon_denied), + "expected_output": Embed( + title=utils.INFRACTION_TITLE, + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Mute", + expires="N/A", + reason="foo bar" * 4000 + ), width=2048, placeholder="..."), + colour=Colours.soft_red, + url=utils.RULES_URL + ).set_author( + name=utils.INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.defcon_denied + ).set_footer(text=utils.INFRACTION_APPEAL_FOOTER), + "send_result": True } ] -- cgit v1.2.3 From 0d0f4318dc2c08d87d473ecb2d66a5622d36cf9d Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Fri, 12 Jun 2020 19:53:00 +0300 Subject: Increase coverage of moderation utils tests --- tests/bot/cogs/moderation/test_utils.py | 41 +++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 363d8938a..77f926a48 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -30,12 +30,20 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): { "get_return_value": [], "expected_output": None, - "infraction_nr": None + "infraction_nr": None, + "send_msg": True }, { "get_return_value": [{"id": 123987}], "expected_output": {"id": 123987}, - "infraction_nr": "123987" + "infraction_nr": "123987", + "send_msg": False + }, + { + "get_return_value": [{"id": 123987}], + "expected_output": {"id": 123987}, + "infraction_nr": "123987", + "send_msg": True } ] @@ -52,13 +60,16 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = case["get_return_value"] - result = await utils.get_active_infraction(self.ctx, self.member, "ban") + result = await utils.get_active_infraction(self.ctx, self.member, "ban", send_msg=case["send_msg"]) self.assertEqual(result, case["expected_output"]) self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params=params) - if result: + if case["send_msg"] and case["get_return_value"]: + self.ctx.send.assert_awaited_once() self.assertTrue(case["infraction_nr"] in self.ctx.send.call_args[0][0]) self.assertTrue("ban" in self.ctx.send.call_args[0][0]) + else: + self.ctx.send.assert_not_awaited() @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_infraction(self, send_private_embed_mock): @@ -220,9 +231,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.assert_awaited_once_with(args[0], embed) - async def test_post_user(self): + @patch("bot.cogs.moderation.utils.log") + async def test_post_user(self, log_mock): """Should POST a new user and return the response if successful or otherwise send an error message.""" user = MockUser(discriminator=5678, id=1234, name="Test user") + some_mock = MagicMock(discriminator=3333) test_cases = [ { "user": user, @@ -247,6 +260,18 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "name": "Name unknown", "roles": [] } + }, + { + "user": some_mock, + "post_result": "bar", + "raise_error": None, + "payload": { + "discriminator": some_mock.discriminator, + "id": some_mock.id, + "in_guild": False, + "name": some_mock.name, + "roles": [] + } } ] @@ -257,6 +282,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): payload = case["payload"] with self.subTest(user=test_user, result=expected, error=error, payload=payload): + log_mock.reset_mock() self.bot.api_client.post.reset_mock(side_effect=True) self.ctx.bot.api_client.post.return_value = expected @@ -275,6 +301,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.ctx.send.assert_awaited_once() self.assertTrue(str(error.status) in self.ctx.send.call_args[0][0]) + if isinstance(test_user, MagicMock): + log_mock.debug.assert_called_once() + else: + log_mock.debug.assert_not_called() + async def test_send_private_embed(self): """Should DM the user and return `True` on success or `False` on failure.""" embed = Embed(title="Test", description="Test val") -- cgit v1.2.3 From 5f0490aad5a8d22a5f05dc6debdb3485a0ed9671 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 14:45:51 +0300 Subject: Mod Utils Tests: Move INFRACTION_DESCRIPTION_TEMPLATE to tests file --- bot/cogs/moderation/utils.py | 6 ------ tests/bot/cogs/moderation/test_utils.py | 16 +++++++++++----- 2 files changed, 11 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index 5df282f80..104baf528 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -34,12 +34,6 @@ INFRACTION_TITLE = f"Please review our rules over at {RULES_URL}" INFRACTION_APPEAL_FOOTER = f"To appeal this infraction, send an e-mail to {APPEAL_EMAIL}" INFRACTION_AUTHOR_NAME = "Infraction information" -INFRACTION_DESCRIPTION_TEMPLATE = ( - "\n**Type:** {type}\n" - "**Expires:** {expires}\n" - "**Reason:** {reason}\n" -) - async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: """ diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 77f926a48..dde5b438d 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -10,6 +10,12 @@ from bot.cogs.moderation import utils from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser +INFRACTION_DESCRIPTION_TEMPLATE = ( + "\n**Type:** {type}\n" + "**Expires:** {expires}\n" + "**Reason:** {reason}\n" +) + class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """Tests Moderation utils.""" @@ -83,7 +89,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( type="Ban", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="No reason provided." @@ -101,7 +107,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "warning", None, "Test reason."), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( type="Warning", expires="N/A", reason="Test reason." @@ -119,7 +125,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "note", None, None, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( type="Note", expires="N/A", reason="No reason provided." @@ -137,7 +143,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="Test" @@ -155,7 +161,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "mute", None, "foo bar" * 4000, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="N/A", reason="foo bar" * 4000 -- cgit v1.2.3 From 9a80e9cf2fea30f9760f5fd0a2d2f21ad5c828b4 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:01:05 +0300 Subject: Mod Utils Tests: Move some test cases to `namedtuple` --- tests/bot/cogs/moderation/test_utils.py | 95 ++++++++++----------------------- 1 file changed, 29 insertions(+), 66 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index dde5b438d..e54c0d240 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,5 +1,6 @@ import textwrap import unittest +from collections import namedtuple from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch @@ -32,29 +33,15 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): A message should be sent to the context indicating a user already has an infraction, if that's the case. """ + test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"]) test_cases = [ - { - "get_return_value": [], - "expected_output": None, - "infraction_nr": None, - "send_msg": True - }, - { - "get_return_value": [{"id": 123987}], - "expected_output": {"id": 123987}, - "infraction_nr": "123987", - "send_msg": False - }, - { - "get_return_value": [{"id": 123987}], - "expected_output": {"id": 123987}, - "infraction_nr": "123987", - "send_msg": True - } + test_case([], None, None, True), + test_case([{"id": 123987}], {"id": 123987}, "123987", False), + test_case([{"id": 123987}], {"id": 123987}, "123987", True) ] for case in test_cases: - with self.subTest(return_value=case["get_return_value"], expected=case["expected_output"]): + with self.subTest(return_value=case.get_return_value, expected=case.expected_output): self.bot.api_client.get.reset_mock() self.ctx.send.reset_mock() @@ -64,15 +51,15 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "user__id": str(self.member.id) } - self.bot.api_client.get.return_value = case["get_return_value"] + self.bot.api_client.get.return_value = case.get_return_value - result = await utils.get_active_infraction(self.ctx, self.member, "ban", send_msg=case["send_msg"]) - self.assertEqual(result, case["expected_output"]) + result = await utils.get_active_infraction(self.ctx, self.member, "ban", send_msg=case.send_msg) + self.assertEqual(result, case.expected_output) self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params=params) - if case["send_msg"] and case["get_return_value"]: + if case.send_msg and case.get_return_value: self.ctx.send.assert_awaited_once() - self.assertTrue(case["infraction_nr"] in self.ctx.send.call_args[0][0]) + self.assertTrue(case.infraction_nr in self.ctx.send.call_args[0][0]) self.assertTrue("ban" in self.ctx.send.call_args[0][0]) else: self.ctx.send.assert_not_awaited() @@ -199,43 +186,33 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_pardon(self, send_private_embed_mock): """Should send an embed of a certain format as a DM and return `True` if DM successful.""" + test_case = namedtuple("test_case", ["args", "icon", "send_result"]) test_cases = [ - { - "args": (self.user, "Test title", "Example content"), - "icon": Icons.user_verified, - "send_result": True - }, - { - "args": (self.user, "Test title", "Example content", Icons.user_update), - "icon": Icons.user_update, - "send_result": False - } + test_case((self.user, "Test title", "Example content"), Icons.user_verified, True), + test_case((self.user, "Test title", "Example content", Icons.user_update), Icons.user_update, False) ] for case in test_cases: - args = case["args"] - send = case["send_result"] - expected = Embed( description="Example content", colour=Colours.soft_green ).set_author( name="Test title", - icon_url=case["icon"] + icon_url=case.icon ) - with self.subTest(args=args, expected=expected): + with self.subTest(args=case.args, expected=expected): send_private_embed_mock.reset_mock() - send_private_embed_mock.return_value = send + send_private_embed_mock.return_value = case.send_result - result = await utils.notify_pardon(*args) - self.assertEqual(send, result) + result = await utils.notify_pardon(*case.args) + self.assertEqual(case.send_result, result) embed = send_private_embed_mock.call_args[0][1] self.assertEqual(embed.to_dict(), expected.to_dict()) - send_private_embed_mock.assert_awaited_once_with(args[0], embed) + send_private_embed_mock.assert_awaited_once_with(case.args[0], embed) @patch("bot.cogs.moderation.utils.log") async def test_post_user(self, log_mock): @@ -316,37 +293,23 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """Should DM the user and return `True` on success or `False` on failure.""" embed = Embed(title="Test", description="Test val") + test_case = namedtuple("test_case", ["expected_output", "raised_exception"]) test_cases = [ - { - "expected_output": True, - "raised_exception": None - }, - { - "expected_output": False, - "raised_exception": HTTPException(AsyncMock(), AsyncMock()) - }, - { - "expected_output": False, - "raised_exception": Forbidden(AsyncMock(), AsyncMock()) - }, - { - "expected_output": False, - "raised_exception": NotFound(AsyncMock(), AsyncMock()) - } + test_case(True, None), + test_case(False, HTTPException(AsyncMock(), AsyncMock())), + test_case(False, Forbidden(AsyncMock(), AsyncMock())), + test_case(False, NotFound(AsyncMock(), AsyncMock())) ] for case in test_cases: - expected = case["expected_output"] - raised = case["raised_exception"] - - with self.subTest(expected=expected, raised=raised): + with self.subTest(expected=case.expected_output, raised=case.raised_exception): self.user.send.reset_mock(side_effect=True) - self.user.send.side_effect = raised + self.user.send.side_effect = case.raised_exception result = await utils.send_private_embed(self.user, embed) - self.assertEqual(result, expected) - if expected: + self.assertEqual(result, case.expected_output) + if case.expected_output: self.user.send.assert_awaited_once_with(embed=embed) -- cgit v1.2.3 From 024633a470d86d84189c714d194e750507f47d47 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:03:38 +0300 Subject: Mod Utils Tests: Change `True` assert to `In` assert for message check --- tests/bot/cogs/moderation/test_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index e54c0d240..aaa0861e5 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -59,8 +59,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if case.send_msg and case.get_return_value: self.ctx.send.assert_awaited_once() - self.assertTrue(case.infraction_nr in self.ctx.send.call_args[0][0]) - self.assertTrue("ban" in self.ctx.send.call_args[0][0]) + sent_message = self.ctx.send.call_args[0][0] + self.assertIn(case.infraction_nr, sent_message) + self.assertIn("ban", sent_message) else: self.ctx.send.assert_not_awaited() -- cgit v1.2.3 From c205f6303a6533cee6cb02cf85dba30b43e0630f Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:04:31 +0300 Subject: Mod Utils Tests: Remove unnecessary `user` from test name --- tests/bot/cogs/moderation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index aaa0861e5..a104b969a 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -27,7 +27,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) - async def test_user_get_active_infraction(self): + async def test_get_active_infraction(self): """ Should request the API for active infractions and return infraction if the user has one or `None` otherwise. -- cgit v1.2.3 From 2123cdb2f7f438491093ef0195cedd432466f9b8 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:13:38 +0300 Subject: Remove case variable definitions in `test_notify_infraction` --- tests/bot/cogs/moderation/test_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index a104b969a..7e8e6d9f0 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -166,23 +166,19 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): ] for case in test_cases: - args = case["args"] - expected = case["expected_output"] - send = case["send_result"] - - with self.subTest(args=args, expected=expected, send=send): + with self.subTest(args=case["args"], expected=case["expected_output"], send=case["send_result"]): send_private_embed_mock.reset_mock() - send_private_embed_mock.return_value = send - result = await utils.notify_infraction(*args) + send_private_embed_mock.return_value = case["send_result"] + result = await utils.notify_infraction(*case["args"]) - self.assertEqual(send, result) + self.assertEqual(case["send_result"], result) embed = send_private_embed_mock.call_args[0][1] - self.assertEqual(embed.to_dict(), expected.to_dict()) + self.assertEqual(embed.to_dict(), case["expected"].to_dict()) - send_private_embed_mock.assert_awaited_once_with(args[0], embed) + send_private_embed_mock.assert_awaited_once_with(case["args"][0], embed) @patch("bot.cogs.moderation.utils.send_private_embed") async def test_notify_pardon(self, send_private_embed_mock): -- cgit v1.2.3 From efde49c677650599b097955a1606dae0d122c97d Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:17:53 +0300 Subject: Sync keys, variable names and kwargs in `test_post_user` --- tests/bot/cogs/moderation/test_utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 7e8e6d9f0..b434737ea 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -256,32 +256,32 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): ] for case in test_cases: - test_user = case["user"] - expected = case["post_result"] - error = case["raise_error"] + user = case["user"] + post_result = case["post_result"] + raise_error = case["raise_error"] payload = case["payload"] - with self.subTest(user=test_user, result=expected, error=error, payload=payload): + with self.subTest(user=user, post_result=post_result, raise_error=raise_error, payload=payload): log_mock.reset_mock() self.bot.api_client.post.reset_mock(side_effect=True) - self.ctx.bot.api_client.post.return_value = expected + self.ctx.bot.api_client.post.return_value = post_result - self.ctx.bot.api_client.post.side_effect = error + self.ctx.bot.api_client.post.side_effect = raise_error - result = await utils.post_user(self.ctx, test_user) + result = await utils.post_user(self.ctx, user) - if error: + if raise_error: self.assertIsNone(result) else: - self.assertEqual(result, expected) + self.assertEqual(result, post_result) - if not error: + if not raise_error: self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) else: self.ctx.send.assert_awaited_once() - self.assertTrue(str(error.status) in self.ctx.send.call_args[0][0]) + self.assertTrue(str(raise_error.status) in self.ctx.send.call_args[0][0]) - if isinstance(test_user, MagicMock): + if isinstance(user, MagicMock): log_mock.debug.assert_called_once() else: log_mock.debug.assert_not_called() -- cgit v1.2.3 From 3d4c50c498647a6537eef747e84690f8852d388c Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:18:49 +0300 Subject: Replace `True` test with `In` test on `test_post_user` --- tests/bot/cogs/moderation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index b434737ea..5be703bc6 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -279,7 +279,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) else: self.ctx.send.assert_awaited_once() - self.assertTrue(str(raise_error.status) in self.ctx.send.call_args[0][0]) + self.assertIn(str(raise_error.status), self.ctx.send.call_args[0][0]) if isinstance(user, MagicMock): log_mock.debug.assert_called_once() -- cgit v1.2.3 From 4430e590ece503c262419324e2bc47dbaa5823d2 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:21:00 +0300 Subject: Merge 2 if-else branches is `test_post_user` --- tests/bot/cogs/moderation/test_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 5be703bc6..f4c634936 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -272,14 +272,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): if raise_error: self.assertIsNone(result) + self.ctx.send.assert_awaited_once() + self.assertIn(str(raise_error.status), self.ctx.send.call_args[0][0]) else: self.assertEqual(result, post_result) - - if not raise_error: self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) - else: - self.ctx.send.assert_awaited_once() - self.assertIn(str(raise_error.status), self.ctx.send.call_args[0][0]) if isinstance(user, MagicMock): log_mock.debug.assert_called_once() -- cgit v1.2.3 From 136ebd22a73318620e8a3fa6136d28f5390ddeaf Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:22:18 +0300 Subject: Remove unnecessary `log.debug` assert in `test_post_user` --- tests/bot/cogs/moderation/test_utils.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index f4c634936..e6eac6831 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -211,8 +211,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.assert_awaited_once_with(case.args[0], embed) - @patch("bot.cogs.moderation.utils.log") - async def test_post_user(self, log_mock): + async def test_post_user(self): """Should POST a new user and return the response if successful or otherwise send an error message.""" user = MockUser(discriminator=5678, id=1234, name="Test user") some_mock = MagicMock(discriminator=3333) @@ -262,7 +261,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): payload = case["payload"] with self.subTest(user=user, post_result=post_result, raise_error=raise_error, payload=payload): - log_mock.reset_mock() self.bot.api_client.post.reset_mock(side_effect=True) self.ctx.bot.api_client.post.return_value = post_result @@ -278,11 +276,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(result, post_result) self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) - if isinstance(user, MagicMock): - log_mock.debug.assert_called_once() - else: - log_mock.debug.assert_not_called() - async def test_send_private_embed(self): """Should DM the user and return `True` on success or `False` on failure.""" embed = Embed(title="Test", description="Test val") -- cgit v1.2.3 From b2a70712ac4aaa067edfbb7a8940cf9b78f44e53 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:28:22 +0300 Subject: Add other parameters to `test_post_user` `not_user` mock --- tests/bot/cogs/moderation/test_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index e6eac6831..f89f41d25 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -214,7 +214,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): async def test_post_user(self): """Should POST a new user and return the response if successful or otherwise send an error message.""" user = MockUser(discriminator=5678, id=1234, name="Test user") - some_mock = MagicMock(discriminator=3333) + not_user = MagicMock(discriminator=3333, id=5678, name="Wrong user") test_cases = [ { "user": user, @@ -241,14 +241,14 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): } }, { - "user": some_mock, + "user": not_user, "post_result": "bar", "raise_error": None, "payload": { - "discriminator": some_mock.discriminator, - "id": some_mock.id, + "discriminator": not_user.discriminator, + "id": not_user.id, "in_guild": False, - "name": some_mock.name, + "name": not_user.name, "roles": [] } } -- cgit v1.2.3 From 9b1538878221c966b62dc5c9d0be2af1fd475325 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:29:52 +0300 Subject: Fix test case key name in `test_notify_infraction` --- tests/bot/cogs/moderation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index f89f41d25..c4d0d6f16 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -176,7 +176,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): embed = send_private_embed_mock.call_args[0][1] - self.assertEqual(embed.to_dict(), case["expected"].to_dict()) + self.assertEqual(embed.to_dict(), case["expected_output"].to_dict()) send_private_embed_mock.assert_awaited_once_with(case["args"][0], embed) -- cgit v1.2.3 From 7b89d2cfad91cc9a56565ebc7700f4858814f149 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:36:34 +0300 Subject: Move infraction description template back to main file, apply it there --- bot/cogs/moderation/utils.py | 18 +++++++++++++----- tests/bot/cogs/moderation/test_utils.py | 16 +++++----------- 2 files changed, 18 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index 104baf528..cbef3420a 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -34,6 +34,12 @@ INFRACTION_TITLE = f"Please review our rules over at {RULES_URL}" INFRACTION_APPEAL_FOOTER = f"To appeal this infraction, send an e-mail to {APPEAL_EMAIL}" INFRACTION_AUTHOR_NAME = "Infraction information" +INFRACTION_DESCRIPTION_TEMPLATE = ( + "\n**Type:** {type}\n" + "**Expires:** {expires}\n" + "**Reason:** {reason}\n" +) + async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: """ @@ -148,11 +154,13 @@ async def notify_infraction( """DM a user about their new infraction and return True if the DM is successful.""" log.trace(f"Sending {user} a DM about their {infr_type} infraction.") - text = textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """) + text = textwrap.dedent( + INFRACTION_DESCRIPTION_TEMPLATE.format( + type=infr_type.capitalize(), + expires=expires_at or "N/A", + reason=reason or "No reason provided." + ) + ) embed = discord.Embed( description=textwrap.shorten(text, width=2048, placeholder="..."), diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index c4d0d6f16..c35c0edf5 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -11,12 +11,6 @@ from bot.cogs.moderation import utils from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser -INFRACTION_DESCRIPTION_TEMPLATE = ( - "\n**Type:** {type}\n" - "**Expires:** {expires}\n" - "**Reason:** {reason}\n" -) - class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """Tests Moderation utils.""" @@ -77,7 +71,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Ban", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="No reason provided." @@ -95,7 +89,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "warning", None, "Test reason."), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Warning", expires="N/A", reason="Test reason." @@ -113,7 +107,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "note", None, None, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Note", expires="N/A", reason="No reason provided." @@ -131,7 +125,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="Test" @@ -149,7 +143,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "mute", None, "foo bar" * 4000, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(INFRACTION_DESCRIPTION_TEMPLATE.format( + description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="N/A", reason="foo bar" * 4000 -- cgit v1.2.3 From 0c9fc3a1bbaf590d7ccf8737ffffcfb4b1b5b1b8 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 24 Jun 2020 15:40:46 +0300 Subject: Reorder tests order to match with original file --- tests/bot/cogs/moderation/test_utils.py | 130 ++++++++++++++++---------------- 1 file changed, 65 insertions(+), 65 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index c35c0edf5..0f6f9c469 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -21,6 +21,71 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) + async def test_post_user(self): + """Should POST a new user and return the response if successful or otherwise send an error message.""" + user = MockUser(discriminator=5678, id=1234, name="Test user") + not_user = MagicMock(discriminator=3333, id=5678, name="Wrong user") + test_cases = [ + { + "user": user, + "post_result": "bar", + "raise_error": None, + "payload": { + "discriminator": 5678, + "id": self.user.id, + "in_guild": False, + "name": "Test user", + "roles": [] + } + }, + { + "user": self.member, + "post_result": "foo", + "raise_error": ResponseCodeError(MagicMock(status=400), "foo"), + "payload": { + "discriminator": 0, + "id": self.member.id, + "in_guild": False, + "name": "Name unknown", + "roles": [] + } + }, + { + "user": not_user, + "post_result": "bar", + "raise_error": None, + "payload": { + "discriminator": not_user.discriminator, + "id": not_user.id, + "in_guild": False, + "name": not_user.name, + "roles": [] + } + } + ] + + for case in test_cases: + user = case["user"] + post_result = case["post_result"] + raise_error = case["raise_error"] + payload = case["payload"] + + with self.subTest(user=user, post_result=post_result, raise_error=raise_error, payload=payload): + self.bot.api_client.post.reset_mock(side_effect=True) + self.ctx.bot.api_client.post.return_value = post_result + + self.ctx.bot.api_client.post.side_effect = raise_error + + result = await utils.post_user(self.ctx, user) + + if raise_error: + self.assertIsNone(result) + self.ctx.send.assert_awaited_once() + self.assertIn(str(raise_error.status), self.ctx.send.call_args[0][0]) + else: + self.assertEqual(result, post_result) + self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) + async def test_get_active_infraction(self): """ Should request the API for active infractions and return infraction if the user has one or `None` otherwise. @@ -205,71 +270,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.assert_awaited_once_with(case.args[0], embed) - async def test_post_user(self): - """Should POST a new user and return the response if successful or otherwise send an error message.""" - user = MockUser(discriminator=5678, id=1234, name="Test user") - not_user = MagicMock(discriminator=3333, id=5678, name="Wrong user") - test_cases = [ - { - "user": user, - "post_result": "bar", - "raise_error": None, - "payload": { - "discriminator": 5678, - "id": self.user.id, - "in_guild": False, - "name": "Test user", - "roles": [] - } - }, - { - "user": self.member, - "post_result": "foo", - "raise_error": ResponseCodeError(MagicMock(status=400), "foo"), - "payload": { - "discriminator": 0, - "id": self.member.id, - "in_guild": False, - "name": "Name unknown", - "roles": [] - } - }, - { - "user": not_user, - "post_result": "bar", - "raise_error": None, - "payload": { - "discriminator": not_user.discriminator, - "id": not_user.id, - "in_guild": False, - "name": not_user.name, - "roles": [] - } - } - ] - - for case in test_cases: - user = case["user"] - post_result = case["post_result"] - raise_error = case["raise_error"] - payload = case["payload"] - - with self.subTest(user=user, post_result=post_result, raise_error=raise_error, payload=payload): - self.bot.api_client.post.reset_mock(side_effect=True) - self.ctx.bot.api_client.post.return_value = post_result - - self.ctx.bot.api_client.post.side_effect = raise_error - - result = await utils.post_user(self.ctx, user) - - if raise_error: - self.assertIsNone(result) - self.ctx.send.assert_awaited_once() - self.assertIn(str(raise_error.status), self.ctx.send.call_args[0][0]) - else: - self.assertEqual(result, post_result) - self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) - async def test_send_private_embed(self): """Should DM the user and return `True` on success or `False` on failure.""" embed = Embed(title="Test", description="Test val") -- cgit v1.2.3 From 1a812b7c3ef7048d8058c8c5a7d5e3afd0f86317 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 25 Jun 2020 11:48:02 +0300 Subject: Remove unnecessary if statement from send_private_embed test --- tests/bot/cogs/moderation/test_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 0f6f9c469..029719669 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -290,8 +290,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): result = await utils.send_private_embed(self.user, embed) self.assertEqual(result, case.expected_output) - if case.expected_output: - self.user.send.assert_awaited_once_with(embed=embed) + self.user.send.assert_awaited_once_with(embed=embed) class TestPostInfraction(unittest.IsolatedAsyncioTestCase): -- cgit v1.2.3 From 604c6a7a09d7826870fb384b98e0a6d1463721b4 Mon Sep 17 00:00:00 2001 From: Karlis S Date: Mon, 6 Jul 2020 14:24:55 +0000 Subject: Restore newlines for `notify_infraction` embed description Truncate reason instead full content to avoid removing newlines --- bot/cogs/moderation/utils.py | 6 +++--- tests/bot/cogs/moderation/test_utils.py | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index 8b36210be..95820404a 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -35,7 +35,7 @@ INFRACTION_APPEAL_FOOTER = f"To appeal this infraction, send an e-mail to {APPEA INFRACTION_AUTHOR_NAME = "Infraction information" INFRACTION_DESCRIPTION_TEMPLATE = ( - "\n**Type:** {type}\n" + "**Type:** {type}\n" "**Expires:** {expires}\n" "**Reason:** {reason}\n" ) @@ -157,11 +157,11 @@ async def notify_infraction( text = INFRACTION_DESCRIPTION_TEMPLATE.format( type=infr_type.capitalize(), expires=expires_at or "N/A", - reason=reason or "No reason provided." + reason=textwrap.shorten(reason, 1000, placeholder="...") if reason else "No reason provided." ) embed = discord.Embed( - description=textwrap.shorten(text, width=2048, placeholder="..."), + description=text, colour=Colours.soft_red ) diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 029719669..c9a4e4040 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -136,11 +136,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Ban", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="No reason provided." - ), width=2048, placeholder="..."), + ), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( @@ -154,11 +154,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "warning", None, "Test reason."), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Warning", expires="N/A", reason="Test reason." - ), width=2048, placeholder="..."), + ), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( @@ -172,11 +172,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "note", None, None, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Note", expires="N/A", reason="No reason provided." - ), width=2048, placeholder="..."), + ), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( @@ -190,11 +190,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="2020-02-26 09:20 (23 hours and 59 minutes)", reason="Test" - ), width=2048, placeholder="..."), + ), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( @@ -208,11 +208,11 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "args": (self.user, "mute", None, "foo bar" * 4000, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, - description=textwrap.shorten(utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="N/A", - reason="foo bar" * 4000 - ), width=2048, placeholder="..."), + reason=textwrap.shorten("foo bar" * 4000, 1000, placeholder="...") + ), colour=Colours.soft_red, url=utils.RULES_URL ).set_author( -- cgit v1.2.3 From c115dcfb72e4d4a86b66bb84a72984705a2afcd4 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 15 Jul 2020 02:45:31 +0200 Subject: Change tests to work with the new file layout. 326beebe9b097731a39ecc9868e5e1f2bd762aae --- tests/bot/utils/test_init.py | 74 ---------------------------------------- tests/bot/utils/test_services.py | 74 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 74 deletions(-) delete mode 100644 tests/bot/utils/test_init.py create mode 100644 tests/bot/utils/test_services.py (limited to 'tests') diff --git a/tests/bot/utils/test_init.py b/tests/bot/utils/test_init.py deleted file mode 100644 index f3a8f5939..000000000 --- a/tests/bot/utils/test_init.py +++ /dev/null @@ -1,74 +0,0 @@ -import logging -import unittest -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -from aiohttp import ClientConnectorError - -from bot.utils import FAILED_REQUEST_ATTEMPTS, send_to_paste_service - - -class PasteTests(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: - self.http_session = MagicMock() - - @patch("bot.utils.URLs.paste_service", "https://paste_service.com/{key}") - async def test_url_and_sent_contents(self): - """Correct url was used and post was called with expected data.""" - response = MagicMock( - json=AsyncMock(return_value={"key": ""}) - ) - self.http_session.post().__aenter__.return_value = response - self.http_session.post.reset_mock() - await send_to_paste_service(self.http_session, "Content") - self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") - - @patch("bot.utils.URLs.paste_service", "https://paste_service.com/{key}") - async def test_paste_returns_correct_url_on_success(self): - """Url with specified extension is returned on successful requests.""" - key = "paste_key" - test_cases = ( - (f"https://paste_service.com/{key}.txt", "txt"), - (f"https://paste_service.com/{key}.py", "py"), - (f"https://paste_service.com/{key}", ""), - ) - response = MagicMock( - json=AsyncMock(return_value={"key": key}) - ) - self.http_session.post().__aenter__.return_value = response - - for expected_output, extension in test_cases: - with self.subTest(msg=f"Send contents with extension {repr(extension)}"): - self.assertEqual( - await send_to_paste_service(self.http_session, "", extension=extension), - expected_output - ) - - async def test_request_repeated_on_json_errors(self): - """Json with error message and invalid json are handled as errors and requests repeated.""" - test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) - self.http_session.post().__aenter__.return_value = response = MagicMock() - self.http_session.post.reset_mock() - - for error_json in test_cases: - with self.subTest(error_json=error_json): - response.json = AsyncMock(return_value=error_json) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) - - self.http_session.post.reset_mock() - - async def test_request_repeated_on_connection_errors(self): - """Requests are repeated in the case of connection errors.""" - self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) - - async def test_general_error_handled_and_request_repeated(self): - """All `Exception`s are handled, logged and request repeated.""" - self.http_session.post = MagicMock(side_effect=Exception) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertLogs("bot.utils", logging.ERROR) - self.assertIsNone(result) diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py new file mode 100644 index 000000000..5e0855704 --- /dev/null +++ b/tests/bot/utils/test_services.py @@ -0,0 +1,74 @@ +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +from aiohttp import ClientConnectorError + +from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service + + +class PasteTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.http_session = MagicMock() + + @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") + async def test_url_and_sent_contents(self): + """Correct url was used and post was called with expected data.""" + response = MagicMock( + json=AsyncMock(return_value={"key": ""}) + ) + self.http_session.post().__aenter__.return_value = response + self.http_session.post.reset_mock() + await send_to_paste_service(self.http_session, "Content") + self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") + + @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") + async def test_paste_returns_correct_url_on_success(self): + """Url with specified extension is returned on successful requests.""" + key = "paste_key" + test_cases = ( + (f"https://paste_service.com/{key}.txt", "txt"), + (f"https://paste_service.com/{key}.py", "py"), + (f"https://paste_service.com/{key}", ""), + ) + response = MagicMock( + json=AsyncMock(return_value={"key": key}) + ) + self.http_session.post().__aenter__.return_value = response + + for expected_output, extension in test_cases: + with self.subTest(msg=f"Send contents with extension {repr(extension)}"): + self.assertEqual( + await send_to_paste_service(self.http_session, "", extension=extension), + expected_output + ) + + async def test_request_repeated_on_json_errors(self): + """Json with error message and invalid json are handled as errors and requests repeated.""" + test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) + self.http_session.post().__aenter__.return_value = response = MagicMock() + self.http_session.post.reset_mock() + + for error_json in test_cases: + with self.subTest(error_json=error_json): + response.json = AsyncMock(return_value=error_json) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + self.http_session.post.reset_mock() + + async def test_request_repeated_on_connection_errors(self): + """Requests are repeated in the case of connection errors.""" + self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + async def test_general_error_handled_and_request_repeated(self): + """All `Exception`s are handled, logged and request repeated.""" + self.http_session.post = MagicMock(side_effect=Exception) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertLogs("bot.utils", logging.ERROR) + self.assertIsNone(result) -- cgit v1.2.3 From 0fca2445e2979d6e4bebf6a974c974a5ddd14fbe Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 14 Jun 2020 17:29:43 -0700 Subject: Move extensions into sub-directories --- bot/cogs/alias.py | 2 +- bot/cogs/antimalware.py | 98 ---- bot/cogs/antispam.py | 288 ----------- bot/cogs/backend/__init__.py | 0 bot/cogs/backend/config_verifier.py | 40 ++ bot/cogs/backend/error_handler.py | 287 +++++++++++ bot/cogs/backend/logging.py | 42 ++ bot/cogs/backend/sync/__init__.py | 7 + bot/cogs/backend/sync/cog.py | 180 +++++++ bot/cogs/backend/sync/syncers.py | 347 +++++++++++++ bot/cogs/bot.py | 385 --------------- bot/cogs/clean.py | 272 ---------- bot/cogs/config_verifier.py | 40 -- bot/cogs/defcon.py | 258 ---------- bot/cogs/doc.py | 511 ------------------- bot/cogs/error_handler.py | 287 ----------- bot/cogs/eval.py | 202 -------- bot/cogs/extensions.py | 236 --------- bot/cogs/filter_lists.py | 273 ---------- bot/cogs/filtering.py | 575 ---------------------- bot/cogs/filters/__init__.py | 0 bot/cogs/filters/antimalware.py | 98 ++++ bot/cogs/filters/antispam.py | 288 +++++++++++ bot/cogs/filters/filter_lists.py | 273 ++++++++++ bot/cogs/filters/filtering.py | 575 ++++++++++++++++++++++ bot/cogs/filters/security.py | 31 ++ bot/cogs/filters/token_remover.py | 182 +++++++ bot/cogs/filters/webhook_remover.py | 84 ++++ bot/cogs/help.py | 375 -------------- bot/cogs/info/__init__.py | 0 bot/cogs/info/doc.py | 511 +++++++++++++++++++ bot/cogs/info/help.py | 375 ++++++++++++++ bot/cogs/info/information.py | 422 ++++++++++++++++ bot/cogs/info/python_news.py | 232 +++++++++ bot/cogs/info/reddit.py | 304 ++++++++++++ bot/cogs/info/site.py | 146 ++++++ bot/cogs/info/source.py | 141 ++++++ bot/cogs/info/stats.py | 129 +++++ bot/cogs/info/tags.py | 277 +++++++++++ bot/cogs/info/wolfram.py | 280 +++++++++++ bot/cogs/information.py | 422 ---------------- bot/cogs/jams.py | 150 ------ bot/cogs/logging.py | 42 -- bot/cogs/moderation/__init__.py | 6 +- bot/cogs/moderation/defcon.py | 258 ++++++++++ bot/cogs/moderation/infraction/__init__.py | 0 bot/cogs/moderation/infraction/infractions.py | 370 ++++++++++++++ bot/cogs/moderation/infraction/management.py | 305 ++++++++++++ bot/cogs/moderation/infraction/scheduler.py | 463 +++++++++++++++++ bot/cogs/moderation/infraction/superstarify.py | 239 +++++++++ bot/cogs/moderation/infraction/utils.py | 201 ++++++++ bot/cogs/moderation/infractions.py | 370 -------------- bot/cogs/moderation/management.py | 305 ------------ bot/cogs/moderation/scheduler.py | 463 ----------------- bot/cogs/moderation/superstarify.py | 239 --------- bot/cogs/moderation/utils.py | 201 -------- bot/cogs/moderation/verification.py | 191 +++++++ bot/cogs/moderation/watchchannels/__init__.py | 9 + bot/cogs/moderation/watchchannels/bigbrother.py | 165 +++++++ bot/cogs/moderation/watchchannels/talentpool.py | 264 ++++++++++ bot/cogs/moderation/watchchannels/watchchannel.py | 348 +++++++++++++ bot/cogs/python_news.py | 232 --------- bot/cogs/reddit.py | 304 ------------ bot/cogs/reminders.py | 427 ---------------- bot/cogs/security.py | 31 -- bot/cogs/site.py | 146 ------ bot/cogs/snekbox.py | 349 ------------- bot/cogs/source.py | 141 ------ bot/cogs/stats.py | 129 ----- bot/cogs/sync/__init__.py | 7 - bot/cogs/sync/cog.py | 180 ------- bot/cogs/sync/syncers.py | 347 ------------- bot/cogs/tags.py | 277 ----------- bot/cogs/token_remover.py | 182 ------- bot/cogs/utils.py | 265 ---------- bot/cogs/utils/__init__.py | 0 bot/cogs/utils/bot.py | 385 +++++++++++++++ bot/cogs/utils/clean.py | 272 ++++++++++ bot/cogs/utils/eval.py | 202 ++++++++ bot/cogs/utils/extensions.py | 236 +++++++++ bot/cogs/utils/jams.py | 150 ++++++ bot/cogs/utils/reminders.py | 427 ++++++++++++++++ bot/cogs/utils/snekbox.py | 349 +++++++++++++ bot/cogs/utils/utils.py | 265 ++++++++++ bot/cogs/verification.py | 191 ------- bot/cogs/watchchannels/__init__.py | 9 - bot/cogs/watchchannels/bigbrother.py | 165 ------- bot/cogs/watchchannels/talentpool.py | 264 ---------- bot/cogs/watchchannels/watchchannel.py | 348 ------------- bot/cogs/webhook_remover.py | 84 ---- bot/cogs/wolfram.py | 280 ----------- tests/bot/cogs/moderation/test_infractions.py | 2 +- tests/bot/cogs/sync/test_base.py | 2 +- tests/bot/cogs/sync/test_cog.py | 4 +- tests/bot/cogs/sync/test_roles.py | 2 +- tests/bot/cogs/sync/test_users.py | 2 +- tests/bot/cogs/test_antimalware.py | 2 +- tests/bot/cogs/test_antispam.py | 2 +- tests/bot/cogs/test_information.py | 2 +- tests/bot/cogs/test_security.py | 2 +- tests/bot/cogs/test_snekbox.py | 4 +- tests/bot/cogs/test_token_remover.py | 4 +- 102 files changed, 10368 insertions(+), 10368 deletions(-) delete mode 100644 bot/cogs/antimalware.py delete mode 100644 bot/cogs/antispam.py create mode 100644 bot/cogs/backend/__init__.py create mode 100644 bot/cogs/backend/config_verifier.py create mode 100644 bot/cogs/backend/error_handler.py create mode 100644 bot/cogs/backend/logging.py create mode 100644 bot/cogs/backend/sync/__init__.py create mode 100644 bot/cogs/backend/sync/cog.py create mode 100644 bot/cogs/backend/sync/syncers.py delete mode 100644 bot/cogs/bot.py delete mode 100644 bot/cogs/clean.py delete mode 100644 bot/cogs/config_verifier.py delete mode 100644 bot/cogs/defcon.py delete mode 100644 bot/cogs/doc.py delete mode 100644 bot/cogs/error_handler.py delete mode 100644 bot/cogs/eval.py delete mode 100644 bot/cogs/extensions.py delete mode 100644 bot/cogs/filter_lists.py delete mode 100644 bot/cogs/filtering.py create mode 100644 bot/cogs/filters/__init__.py create mode 100644 bot/cogs/filters/antimalware.py create mode 100644 bot/cogs/filters/antispam.py create mode 100644 bot/cogs/filters/filter_lists.py create mode 100644 bot/cogs/filters/filtering.py create mode 100644 bot/cogs/filters/security.py create mode 100644 bot/cogs/filters/token_remover.py create mode 100644 bot/cogs/filters/webhook_remover.py delete mode 100644 bot/cogs/help.py create mode 100644 bot/cogs/info/__init__.py create mode 100644 bot/cogs/info/doc.py create mode 100644 bot/cogs/info/help.py create mode 100644 bot/cogs/info/information.py create mode 100644 bot/cogs/info/python_news.py create mode 100644 bot/cogs/info/reddit.py create mode 100644 bot/cogs/info/site.py create mode 100644 bot/cogs/info/source.py create mode 100644 bot/cogs/info/stats.py create mode 100644 bot/cogs/info/tags.py create mode 100644 bot/cogs/info/wolfram.py delete mode 100644 bot/cogs/information.py delete mode 100644 bot/cogs/jams.py delete mode 100644 bot/cogs/logging.py create mode 100644 bot/cogs/moderation/defcon.py create mode 100644 bot/cogs/moderation/infraction/__init__.py create mode 100644 bot/cogs/moderation/infraction/infractions.py create mode 100644 bot/cogs/moderation/infraction/management.py create mode 100644 bot/cogs/moderation/infraction/scheduler.py create mode 100644 bot/cogs/moderation/infraction/superstarify.py create mode 100644 bot/cogs/moderation/infraction/utils.py delete mode 100644 bot/cogs/moderation/infractions.py delete mode 100644 bot/cogs/moderation/management.py delete mode 100644 bot/cogs/moderation/scheduler.py delete mode 100644 bot/cogs/moderation/superstarify.py delete mode 100644 bot/cogs/moderation/utils.py create mode 100644 bot/cogs/moderation/verification.py create mode 100644 bot/cogs/moderation/watchchannels/__init__.py create mode 100644 bot/cogs/moderation/watchchannels/bigbrother.py create mode 100644 bot/cogs/moderation/watchchannels/talentpool.py create mode 100644 bot/cogs/moderation/watchchannels/watchchannel.py delete mode 100644 bot/cogs/python_news.py delete mode 100644 bot/cogs/reddit.py delete mode 100644 bot/cogs/reminders.py delete mode 100644 bot/cogs/security.py delete mode 100644 bot/cogs/site.py delete mode 100644 bot/cogs/snekbox.py delete mode 100644 bot/cogs/source.py delete mode 100644 bot/cogs/stats.py delete mode 100644 bot/cogs/sync/__init__.py delete mode 100644 bot/cogs/sync/cog.py delete mode 100644 bot/cogs/sync/syncers.py delete mode 100644 bot/cogs/tags.py delete mode 100644 bot/cogs/token_remover.py delete mode 100644 bot/cogs/utils.py create mode 100644 bot/cogs/utils/__init__.py create mode 100644 bot/cogs/utils/bot.py create mode 100644 bot/cogs/utils/clean.py create mode 100644 bot/cogs/utils/eval.py create mode 100644 bot/cogs/utils/extensions.py create mode 100644 bot/cogs/utils/jams.py create mode 100644 bot/cogs/utils/reminders.py create mode 100644 bot/cogs/utils/snekbox.py create mode 100644 bot/cogs/utils/utils.py delete mode 100644 bot/cogs/verification.py delete mode 100644 bot/cogs/watchchannels/__init__.py delete mode 100644 bot/cogs/watchchannels/bigbrother.py delete mode 100644 bot/cogs/watchchannels/talentpool.py delete mode 100644 bot/cogs/watchchannels/watchchannel.py delete mode 100644 bot/cogs/webhook_remover.py delete mode 100644 bot/cogs/wolfram.py (limited to 'tests') diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 55c7efe65..3c5a35c24 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -8,7 +8,7 @@ from discord.ext.commands import ( ) from bot.bot import Bot -from bot.cogs.extensions import Extension +from bot.cogs.utils.extensions import Extension from bot.converters import FetchedMember, TagNameConverter from bot.pagination import LinePaginator diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py deleted file mode 100644 index c76bd2c60..000000000 --- a/bot/cogs/antimalware.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging -import typing as t -from os.path import splitext - -from discord import Embed, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, STAFF_ROLES, URLs - -log = logging.getLogger(__name__) - -PY_EMBED_DESCRIPTION = ( - "It looks like you tried to attach a Python file - " - f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" -) - -TXT_EMBED_DESCRIPTION = ( - "**Uh-oh!** It looks like your message got zapped by our spam filter. " - "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" - "• If you attempted to send a message longer than 2000 characters, try shortening your message " - "to fit within the character limit or use a pasting service (see below) \n\n" - "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " - "{cmd_channel_mention} for more information) or use a pasting service like: " - f"\n\n{URLs.site_schema}{URLs.site_paste}" -) - -DISALLOWED_EMBED_DESCRIPTION = ( - "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " - "We currently allow the following file types: **{joined_whitelist}**.\n\n" - "Feel free to ask in {meta_channel_mention} if you think this is a mistake." -) - - -class AntiMalware(Cog): - """Delete messages which contain attachments with non-whitelisted file extensions.""" - - def __init__(self, bot: Bot): - self.bot = bot - - def _get_whitelisted_file_formats(self) -> list: - """Get the file formats currently on the whitelist.""" - return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() - - def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: - """Get an iterable containing all the disallowed extensions of attachments.""" - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} - extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) - return extensions_blocked - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Identify messages with prohibited attachments.""" - # Return when message don't have attachment and don't moderate DMs - if not message.attachments or not message.guild: - return - - # Check if user is staff, if is, return - # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance - if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): - return - - embed = Embed() - extensions_blocked = self._get_disallowed_extensions(message) - blocked_extensions_str = ', '.join(extensions_blocked) - if ".py" in extensions_blocked: - # Short-circuit on *.py files to provide a pastebin link - embed.description = PY_EMBED_DESCRIPTION - elif ".txt" in extensions_blocked: - # Work around Discord AutoConversion of messages longer than 2000 chars to .txt - cmd_channel = self.bot.get_channel(Channels.bot_commands) - embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) - elif extensions_blocked: - meta_channel = self.bot.get_channel(Channels.meta) - embed.description = DISALLOWED_EMBED_DESCRIPTION.format( - joined_whitelist=', '.join(self._get_whitelisted_file_formats()), - blocked_extensions_str=blocked_extensions_str, - meta_channel_mention=meta_channel.mention, - ) - - if embed.description: - log.info( - f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", - extra={"attachment_list": [attachment.filename for attachment in message.attachments]} - ) - - await message.channel.send(f"Hey {message.author.mention}!", embed=embed) - - # Delete the offending message: - try: - await message.delete() - except NotFound: - log.info(f"Tried to delete message `{message.id}`, but message could not be found.") - - -def setup(bot: Bot) -> None: - """Load the AntiMalware cog.""" - bot.add_cog(AntiMalware(bot)) diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py deleted file mode 100644 index 0bcca578d..000000000 --- a/bot/cogs/antispam.py +++ /dev/null @@ -1,288 +0,0 @@ -import asyncio -import logging -from collections.abc import Mapping -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from operator import itemgetter -from typing import Dict, Iterable, List, Set - -from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Cog - -from bot import rules -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import ( - AntiSpam as AntiSpamConfig, Channels, - Colours, DEBUG_MODE, Event, Filter, - Guild as GuildConfig, Icons, - STAFF_ROLES, -) -from bot.converters import Duration -from bot.utils.messages import send_attachments - - -log = logging.getLogger(__name__) - -RULE_FUNCTION_MAPPING = { - 'attachments': rules.apply_attachments, - 'burst': rules.apply_burst, - 'burst_shared': rules.apply_burst_shared, - 'chars': rules.apply_chars, - 'discord_emojis': rules.apply_discord_emojis, - 'duplicates': rules.apply_duplicates, - 'links': rules.apply_links, - 'mentions': rules.apply_mentions, - 'newlines': rules.apply_newlines, - 'role_mentions': rules.apply_role_mentions -} - - -@dataclass -class DeletionContext: - """Represents a Deletion Context for a single spam event.""" - - channel: TextChannel - members: Dict[int, Member] = field(default_factory=dict) - rules: Set[str] = field(default_factory=set) - messages: Dict[int, Message] = field(default_factory=dict) - attachments: List[List[str]] = field(default_factory=list) - - async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: - """Adds new rule violation events to the deletion context.""" - self.rules.add(rule_name) - - for member in members: - if member.id not in self.members: - self.members[member.id] = member - - for message in messages: - if message.id not in self.messages: - self.messages[message.id] = message - - # Re-upload attachments - destination = message.guild.get_channel(Channels.attachment_log) - urls = await send_attachments(message, destination, link_large=False) - self.attachments.append(urls) - - async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: - """Method that takes care of uploading the queue and posting modlog alert.""" - triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values()) - - mod_alert_message = ( - f"**Triggered by:** {triggered_by_users}\n" - f"**Channel:** {self.channel.mention}\n" - f"**Rules:** {', '.join(rule for rule in self.rules)}\n" - ) - - # For multiple messages or those with excessive newlines, use the logs API - if len(self.messages) > 1 or 'newlines' in self.rules: - url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) - mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" - else: - mod_alert_message += "Message:\n" - [message] = self.messages.values() - content = message.clean_content - remaining_chars = 2040 - len(mod_alert_message) - - if len(content) > remaining_chars: - content = content[:remaining_chars] + "..." - - mod_alert_message += f"{content}" - - *_, last_message = self.messages.values() - await modlog.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title="Spam detected!", - text=mod_alert_message, - thumbnail=last_message.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=AntiSpamConfig.ping_everyone - ) - - -class AntiSpam(Cog): - """Cog that controls our anti-spam measures.""" - - def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: - self.bot = bot - self.validation_errors = validation_errors - role_id = AntiSpamConfig.punishment['role_id'] - self.muted_role = Object(role_id) - self.expiration_date_converter = Duration() - - self.message_deletion_queue = dict() - - self.bot.loop.create_task(self.alert_on_validation_error()) - - @property - def mod_log(self) -> ModLog: - """Allows for easy access of the ModLog cog.""" - return self.bot.get_cog("ModLog") - - async def alert_on_validation_error(self) -> None: - """Unloads the cog and alerts admins if configuration validation failed.""" - await self.bot.wait_until_guild_available() - if self.validation_errors: - body = "**The following errors were encountered:**\n" - body += "\n".join(f"- {error}" for error in self.validation_errors.values()) - body += "\n\n**The cog has been unloaded.**" - - await self.mod_log.send_log_message( - title="Error: AntiSpam configuration validation failed!", - text=body, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Colour.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Applies the antispam rules to each received message.""" - if ( - not message.guild - or message.guild.id != GuildConfig.id - or message.author.bot - or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) - or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) - ): - return - - # Fetch the rule configuration with the highest rule interval. - max_interval_config = max( - AntiSpamConfig.rules.values(), - key=itemgetter('interval') - ) - max_interval = max_interval_config['interval'] - - # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. - earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) - relevant_messages = [ - msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) - if not msg.author.bot - ] - - for rule_name in AntiSpamConfig.rules: - rule_config = AntiSpamConfig.rules[rule_name] - rule_function = RULE_FUNCTION_MAPPING[rule_name] - - # Create a list of messages that were sent in the interval that the rule cares about. - latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) - messages_for_rule = [ - msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp - ] - result = await rule_function(message, messages_for_rule, rule_config) - - # If the rule returns `None`, that means the message didn't violate it. - # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` - # which contains the reason for why the message violated the rule and - # an iterable of all members that violated the rule. - if result is not None: - self.bot.stats.incr(f"mod_alerts.{rule_name}") - reason, members, relevant_messages = result - full_reason = f"`{rule_name}` rule: {reason}" - - # If there's no spam event going on for this channel, start a new Message Deletion Context - channel = message.channel - if channel.id not in self.message_deletion_queue: - log.trace(f"Creating queue for channel `{channel.id}`") - self.message_deletion_queue[message.channel.id] = DeletionContext(channel) - self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) - - # Add the relevant of this trigger to the Deletion Context - await self.message_deletion_queue[message.channel.id].add( - rule_name=rule_name, - members=members, - messages=relevant_messages - ) - - for member in members: - - # Fire it off as a background task to ensure - # that the sleep doesn't block further tasks - self.bot.loop.create_task( - self.punish(message, member, full_reason) - ) - - await self.maybe_delete_messages(channel, relevant_messages) - break - - async def punish(self, msg: Message, member: Member, reason: str) -> None: - """Punishes the given member for triggering an antispam rule.""" - if not any(role.id == self.muted_role.id for role in member.roles): - remove_role_after = AntiSpamConfig.punishment['remove_after'] - - # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes - context = await self.bot.get_context(msg) - context.author = self.bot.user - context.message.author = self.bot.user - - # Since we're going to invoke the tempmute command directly, we need to manually call the converter. - dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S") - await context.invoke( - self.bot.get_command('tempmute'), - member, - dt_remove_role_after, - reason=reason - ) - - async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: - """Cleans the messages if cleaning is configured.""" - if AntiSpamConfig.clean_offending: - # If we have more than one message, we can use bulk delete. - if len(messages) > 1: - message_ids = [message.id for message in messages] - self.mod_log.ignore(Event.message_delete, *message_ids) - await channel.delete_messages(messages) - - # Otherwise, the bulk delete endpoint will throw up. - # Delete the message directly instead. - else: - self.mod_log.ignore(Event.message_delete, messages[0].id) - try: - await messages[0].delete() - except NotFound: - log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") - - async def _process_deletion_context(self, context_id: int) -> None: - """Processes the Deletion Context queue.""" - log.trace("Sleeping before processing message deletion queue.") - await asyncio.sleep(10) - - if context_id not in self.message_deletion_queue: - log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") - return - - deletion_context = self.message_deletion_queue.pop(context_id) - await deletion_context.upload_messages(self.bot.user.id, self.mod_log) - - -def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: - """Validates the antispam configs.""" - validation_errors = {} - for name, config in rules_.items(): - if name not in RULE_FUNCTION_MAPPING: - log.error( - f"Unrecognized antispam rule `{name}`. " - f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" - ) - validation_errors[name] = f"`{name}` is not recognized as an antispam rule." - continue - for required_key in ('interval', 'max'): - if required_key not in config: - log.error( - f"`{required_key}` is required but was not " - f"set in rule `{name}`'s configuration." - ) - validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" - return validation_errors - - -def setup(bot: Bot) -> None: - """Validate the AntiSpam configs and load the AntiSpam cog.""" - validation_errors = validate_config() - bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/cogs/backend/__init__.py b/bot/cogs/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/backend/config_verifier.py b/bot/cogs/backend/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/cogs/backend/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): + """Verify config on startup.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + + async def verify_channels(self) -> None: + """ + Verify channels. + + If any channels in config aren't present in server, log them in a warning. + """ + await self.bot.wait_until_guild_available() + server = self.bot.get_guild(constants.Guild.id) + + server_channel_ids = {channel.id for channel in server.channels} + invalid_channels = [ + channel_name for channel_name, channel_id in constants.Channels + if channel_id not in server_channel_ids + ] + + if invalid_channels: + log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: + """Load the ConfigVerifier cog.""" + bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/backend/error_handler.py b/bot/cogs/backend/error_handler.py new file mode 100644 index 000000000..f9d4de638 --- /dev/null +++ b/bot/cogs/backend/error_handler.py @@ -0,0 +1,287 @@ +import contextlib +import logging +import typing as t + +from discord import Embed +from discord.ext.commands import Cog, Context, errors +from sentry_sdk import push_scope + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Colours +from bot.converters import TagNameConverter +from bot.utils.checks import InWhitelistCheckFailure + +log = logging.getLogger(__name__) + + +class ErrorHandler(Cog): + """Handles errors emitted from commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_error_embed(self, title: str, body: str) -> Embed: + """Return an embed that contains the exception.""" + return Embed( + title=title, + colour=Colours.soft_red, + description=body + ) + + @Cog.listener() + async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: + """ + Provide generic command error handling. + + Error handling is deferred to any local error handler, if present. This is done by + checking for the presence of a `handled` attribute on the error. + + Error handling emits a single error message in the invoking context `ctx` and a log message, + prioritised as follows: + + 1. If the name fails to match a command: + * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. + Otherwise if it matches a tag, the tag is invoked + * If CommandNotFound is raised when invoking the tag (determined by the presence of the + `invoked_from_error_handler` attribute), this error is treated as being unexpected + and therefore sends an error message + * Commands in the verification channel are ignored + 2. UserInputError: see `handle_user_input_error` + 3. CheckFailure: see `handle_check_failure` + 4. CommandOnCooldown: send an error message in the invoking context + 5. ResponseCodeError: see `handle_api_error` + 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` + """ + command = ctx.command + + if hasattr(e, "handled"): + log.trace(f"Command {command} had its error already handled locally; ignoring.") + return + + if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + if await self.try_silence(ctx): + return + if ctx.channel.id != Channels.verification: + # Try to look for a tag with the command's name + await self.try_get_tag(ctx) + return # Exit early to avoid logging. + elif isinstance(e, errors.UserInputError): + await self.handle_user_input_error(ctx, e) + elif isinstance(e, errors.CheckFailure): + await self.handle_check_failure(ctx, e) + elif isinstance(e, errors.CommandOnCooldown): + await ctx.send(e) + elif isinstance(e, errors.CommandInvokeError): + if isinstance(e.original, ResponseCodeError): + await self.handle_api_error(ctx, e.original) + else: + await self.handle_unexpected_error(ctx, e.original) + return # Exit early to avoid logging. + elif not isinstance(e, errors.DisabledCommand): + # ConversionError, MaxConcurrencyReached, ExtensionError + await self.handle_unexpected_error(ctx, e) + return # Exit early to avoid logging. + + log.debug( + f"Command {command} invoked by {ctx.message.author} with error " + f"{e.__class__.__name__}: {e}" + ) + + @staticmethod + def get_help_command(ctx: Context) -> t.Coroutine: + """Return a prepared `help` command invocation coroutine.""" + if ctx.command: + return ctx.send_help(ctx.command) + + return ctx.send_help() + + async def try_silence(self, ctx: Context) -> bool: + """ + Attempt to invoke the silence or unsilence command if invoke with matches a pattern. + + Respecting the checks if: + * invoked with `shh+` silence channel for amount of h's*2 with max of 15. + * invoked with `unshh+` unsilence channel + Return bool depending on success of command. + """ + command = ctx.invoked_with.lower() + silence_command = self.bot.get_command("silence") + ctx.invoked_from_error_handler = True + try: + if not await silence_command.can_run(ctx): + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + except errors.CommandError: + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + if command.startswith("shh"): + await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) + return True + elif command.startswith("unshh"): + await ctx.invoke(self.bot.get_command("unsilence")) + return True + return False + + async def try_get_tag(self, ctx: Context) -> None: + """ + Attempt to display a tag by interpreting the command name as a tag name. + + The invocation of tags get respects its checks. Any CommandErrors raised will be handled + by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to + the context to prevent infinite recursion in the case of a CommandNotFound exception. + """ + tags_get_command = self.bot.get_command("tags get") + ctx.invoked_from_error_handler = True + + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except errors.CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) + return + + try: + tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) + except errors.BadArgument: + log.debug( + f"{ctx.author} tried to use an invalid command " + f"and the fallback tag failed validation in TagNameConverter." + ) + else: + with contextlib.suppress(ResponseCodeError): + await ctx.invoke(tags_get_command, tag_name=tag_name) + # Return to not raise the exception + return + + async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: + """ + Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. + + * MissingRequiredArgument: send an error message with arg name and the help command + * TooManyArguments: send an error message and the help command + * BadArgument: send an error message and the help command + * BadUnionArgument: send an error message including the error produced by the last converter + * ArgumentParsingError: send an error message + * Other: send an error message and the help command + """ + prepared_help_command = self.get_help_command(ctx) + + if isinstance(e, errors.MissingRequiredArgument): + embed = self._get_error_embed("Missing required argument", e.param.name) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.missing_required_argument") + elif isinstance(e, errors.TooManyArguments): + embed = self._get_error_embed("Too many arguments", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.too_many_arguments") + elif isinstance(e, errors.BadArgument): + embed = self._get_error_embed("Bad argument", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.bad_argument") + elif isinstance(e, errors.BadUnionArgument): + embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") + await ctx.send(embed=embed) + self.bot.stats.incr("errors.bad_union_argument") + elif isinstance(e, errors.ArgumentParsingError): + embed = self._get_error_embed("Argument parsing error", str(e)) + await ctx.send(embed=embed) + self.bot.stats.incr("errors.argument_parsing_error") + else: + embed = self._get_error_embed( + "Input error", + "Something about your input seems off. Check the arguments and try again." + ) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.other_user_input_error") + + @staticmethod + async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: + """ + Send an error message in `ctx` for certain types of CheckFailure. + + The following types are handled: + + * BotMissingPermissions + * BotMissingRole + * BotMissingAnyRole + * NoPrivateMessage + * InWhitelistCheckFailure + """ + bot_missing_errors = ( + errors.BotMissingPermissions, + errors.BotMissingRole, + errors.BotMissingAnyRole + ) + + if isinstance(e, bot_missing_errors): + ctx.bot.stats.incr("errors.bot_permission_error") + await ctx.send( + "Sorry, it looks like I don't have the permissions or roles I need to do that." + ) + elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): + ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") + await ctx.send(e) + + @staticmethod + async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: + """Send an error message in `ctx` for ResponseCodeError and log it.""" + if e.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + log.debug(f"API responded with 404 for command {ctx.command}") + ctx.bot.stats.incr("errors.api_error_404") + elif e.status == 400: + content = await e.response.json() + log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + await ctx.send("According to the API, your request is malformed.") + ctx.bot.stats.incr("errors.api_error_400") + elif 500 <= e.status < 600: + await ctx.send("Sorry, there seems to be an internal issue with the API.") + log.warning(f"API responded with {e.status} for command {ctx.command}") + ctx.bot.stats.incr("errors.api_internal_server_error") + else: + await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") + log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") + ctx.bot.stats.incr(f"errors.api_error_{e.status}") + + @staticmethod + async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: + """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" + await ctx.send( + f"Sorry, an unexpected error occurred. Please let us know!\n\n" + f"```{e.__class__.__name__}: {e}```" + ) + + ctx.bot.stats.incr("errors.unexpected") + + with push_scope() as scope: + scope.user = { + "id": ctx.author.id, + "username": str(ctx.author) + } + + scope.set_tag("command", ctx.command.qualified_name) + scope.set_tag("message_id", ctx.message.id) + scope.set_tag("channel_id", ctx.channel.id) + + scope.set_extra("full_message", ctx.message.content) + + if ctx.guild is not None: + scope.set_extra( + "jump_to", + f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" + ) + + log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) + + +def setup(bot: Bot) -> None: + """Load the ErrorHandler cog.""" + bot.add_cog(ErrorHandler(bot)) diff --git a/bot/cogs/backend/logging.py b/bot/cogs/backend/logging.py new file mode 100644 index 000000000..94fa2b139 --- /dev/null +++ b/bot/cogs/backend/logging.py @@ -0,0 +1,42 @@ +import logging + +from discord import Embed +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, DEBUG_MODE + + +log = logging.getLogger(__name__) + + +class Logging(Cog): + """Debug logging module.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.bot.loop.create_task(self.startup_greeting()) + + async def startup_greeting(self) -> None: + """Announce our presence to the configured devlog channel.""" + await self.bot.wait_until_guild_available() + log.info("Bot connected!") + + embed = Embed(description="Connected!") + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=( + "https://raw.githubusercontent.com/" + "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" + ) + ) + + if not DEBUG_MODE: + await self.bot.get_channel(Channels.dev_log).send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the Logging cog.""" + bot.add_cog(Logging(bot)) diff --git a/bot/cogs/backend/sync/__init__.py b/bot/cogs/backend/sync/__init__.py new file mode 100644 index 000000000..fe7df4e9b --- /dev/null +++ b/bot/cogs/backend/sync/__init__.py @@ -0,0 +1,7 @@ +from bot.bot import Bot +from .cog import Sync + + +def setup(bot: Bot) -> None: + """Load the Sync cog.""" + bot.add_cog(Sync(bot)) diff --git a/bot/cogs/backend/sync/cog.py b/bot/cogs/backend/sync/cog.py new file mode 100644 index 000000000..274845a50 --- /dev/null +++ b/bot/cogs/backend/sync/cog.py @@ -0,0 +1,180 @@ +import logging +from typing import Any, Dict + +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from . import syncers + +log = logging.getLogger(__name__) + + +class Sync(Cog): + """Captures relevant events and sends them to the site.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.role_syncer = syncers.RoleSyncer(self.bot) + self.user_syncer = syncers.UserSyncer(self.bot) + + self.bot.loop.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Syncs the roles/users of the guild with the database.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: + """Send a PATCH request to partially update a user in the database.""" + try: + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) + except ResponseCodeError as e: + if e.response.status != 404: + raise + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") + + @Cog.listener() + async def on_guild_role_create(self, role: Role) -> None: + """Adds newly create role to the database table over the API.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.post( + 'bot/roles', + json={ + 'colour': role.colour.value, + 'id': role.id, + 'name': role.name, + 'permissions': role.permissions.value, + 'position': role.position, + } + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: Role) -> None: + """Deletes role from the database when it's deleted from the guild.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + @Cog.listener() + async def on_guild_role_update(self, before: Role, after: Role) -> None: + """Syncs role with the database if any of the stored attributes were updated.""" + if after.guild.id != constants.Guild.id: + return + + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: + await self.bot.api_client.put( + f'bot/roles/{after.id}', + json={ + 'colour': after.colour.value, + 'id': after.id, + 'name': after.name, + 'permissions': after.permissions.value, + 'position': after.position, + } + ) + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """ + Adds a new user or updates existing user to the database when a member joins the guild. + + If the joining member is a user that is already known to the database (i.e., a user that + previously left), it will update the user's information. If the user is not yet known by + the database, the user is added. + """ + if member.guild.id != constants.Guild.id: + return + + packed = { + 'discriminator': int(member.discriminator), + 'id': member.id, + 'in_guild': True, + 'name': member.name, + 'roles': sorted(role.id for role in member.roles) + } + + got_error = False + + try: + # First try an update of the user to set the `in_guild` field and other + # fields that may have changed since the last time we've seen them. + await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) + + except ResponseCodeError as e: + # If we didn't get 404, something else broke - propagate it up. + if e.response.status != 404: + raise + + got_error = True # yikes + + if got_error: + # If we got `404`, the user is new. Create them. + await self.bot.api_client.post('bot/users', json=packed) + + @Cog.listener() + async def on_member_remove(self, member: Member) -> None: + """Set the in_guild field to False when a member leaves the guild.""" + if member.guild.id != constants.Guild.id: + return + + await self.patch_user(member.id, json={"in_guild": False}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Update the roles of the member in the database if a change is detected.""" + if after.guild.id != constants.Guild.id: + return + + if before.roles != after.roles: + updated_information = {"roles": sorted(role.id for role in after.roles)} + await self.patch_user(after.id, json=updated_information) + + @Cog.listener() + async def on_user_update(self, before: User, after: User) -> None: + """Update the user information in the database if a relevant change is detected.""" + attrs = ("name", "discriminator") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): + updated_information = { + "name": after.name, + "discriminator": int(after.discriminator), + } + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) + + @commands.group(name='sync') + @commands.has_permissions(administrator=True) + async def sync_group(self, ctx: Context) -> None: + """Run synchronizations between the bot and site manually.""" + + @sync_group.command(name='roles') + @commands.has_permissions(administrator=True) + async def sync_roles_command(self, ctx: Context) -> None: + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) + + @sync_group.command(name='users') + @commands.has_permissions(administrator=True) + async def sync_users_command(self, ctx: Context) -> None: + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/backend/sync/syncers.py b/bot/cogs/backend/sync/syncers.py new file mode 100644 index 000000000..f7ba811bc --- /dev/null +++ b/bot/cogs/backend/sync/syncers.py @@ -0,0 +1,347 @@ +import abc +import asyncio +import logging +import typing as t +from collections import namedtuple +from functools import partial + +import discord +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot + +log = logging.getLogger(__name__) + +# These objects are declared as namedtuples because tuples are hashable, +# something that we make use of when diffing site roles against guild roles. +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) + + +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' + ) + + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + allowed_roles = [discord.Object(constants.Roles.core_developers)] + message = await channel.send( + f"{self._CORE_DEV_MENTION}{msg_content}", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS + ) + + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except asyncio.TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.info(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py deleted file mode 100644 index 79510739c..000000000 --- a/bot/cogs/bot.py +++ /dev/null @@ -1,385 +0,0 @@ -import ast -import logging -import re -import time -from typing import Optional, Tuple - -from discord import Embed, Message, RawMessageUpdateEvent, TextChannel -from discord.ext.commands import Cog, Context, command, group - -from bot.bot import Bot -from bot.cogs.token_remover import TokenRemover -from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs -from bot.decorators import with_role -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -RE_MARKDOWN = re.compile(r'([*_~`|>])') - - -class BotCog(Cog, name="Bot"): - """Bot information commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - # Stores allowed channels plus epoch time since last call. - self.channel_cooldowns = { - Channels.python_discussion: 0, - } - - # These channels will also work, but will not be subject to cooldown - self.channel_whitelist = ( - Channels.bot_commands, - ) - - # Stores improperly formatted Python codeblock message ids and the corresponding bot message - self.codeblock_message_ids = {} - - @group(invoke_without_command=True, name="bot", hidden=True) - @with_role(Roles.verified) - async def botinfo_group(self, ctx: Context) -> None: - """Bot informational commands.""" - await ctx.send_help(ctx.command) - - @botinfo_group.command(name='about', aliases=('info',), hidden=True) - @with_role(Roles.verified) - async def about_command(self, ctx: Context) -> None: - """Get information about the bot.""" - embed = Embed( - description="A utility bot designed just for the Python server! Try `!help` for more info.", - url="https://github.com/python-discord/bot" - ) - - embed.add_field(name="Total Users", value=str(len(self.bot.get_guild(Guild.id).members))) - embed.set_author( - name="Python Bot", - url="https://github.com/python-discord/bot", - icon_url=URLs.bot_avatar - ) - - await ctx.send(embed=embed) - - @command(name='echo', aliases=('print',)) - @with_role(*MODERATION_ROLES) - async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: - """Repeat the given message in either a specified channel or the current channel.""" - if channel is None: - await ctx.send(text) - else: - await channel.send(text) - - @command(name='embed') - @with_role(*MODERATION_ROLES) - async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: - """Send the input within an embed to either a specified channel or the current channel.""" - embed = Embed(description=text) - - if channel is None: - await ctx.send(embed=embed) - else: - await channel.send(embed=embed) - - def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: - """ - Strip msg in order to find Python code. - - Tries to strip out Python code out of msg and returns the stripped block or - None if the block is a valid Python codeblock. - """ - if msg.count("\n") >= 3: - # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. - if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: - log.trace( - "Someone wrote a message that was already a " - "valid Python syntax highlighted code block. No action taken." - ) - return None - - else: - # Stripping backticks from every line of the message. - log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") - content = "" - for line in msg.splitlines(keepends=True): - content += line.strip("`") - - content = content.strip() - - # Remove "Python" or "Py" from start of the message if it exists. - log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") - pycode = False - if content.lower().startswith("python"): - content = content[6:] - pycode = True - elif content.lower().startswith("py"): - content = content[2:] - pycode = True - - if pycode: - content = content.splitlines(keepends=True) - - # Check if there might be code in the first line, and preserve it. - first_line = content[0] - if " " in content[0]: - first_space = first_line.index(" ") - content[0] = first_line[first_space:] - content = "".join(content) - - # If there's no code we can just get rid of the first line. - else: - content = "".join(content[1:]) - - # Strip it again to remove any leading whitespace. This is neccessary - # if the first line of the message looked like ```python - old = content.strip() - - # Strips REPL code out of the message if there is any. - content, repl_code = self.repl_stripping(old) - if old != content: - return (content, old), repl_code - - # Try to apply indentation fixes to the code. - content = self.fix_indentation(content) - - # Check if the code contains backticks, if it does ignore the message. - if "`" in content: - log.trace("Detected ` inside the code, won't reply") - return None - else: - log.trace(f"Returning message.\n\n{content}\n\n") - return (content,), repl_code - - def fix_indentation(self, msg: str) -> str: - """Attempts to fix badly indented code.""" - def unindent(code: str, skip_spaces: int = 0) -> str: - """Unindents all code down to the number of spaces given in skip_spaces.""" - final = "" - current = code[0] - leading_spaces = 0 - - # Get numbers of spaces before code in the first line. - while current == " ": - current = code[leading_spaces + 1] - leading_spaces += 1 - leading_spaces -= skip_spaces - - # If there are any, remove that number of spaces from every line. - if leading_spaces > 0: - for line in code.splitlines(keepends=True): - line = line[leading_spaces:] - final += line - return final - else: - return code - - # Apply fix for "all lines are overindented" case. - msg = unindent(msg) - - # If the first line does not end with a colon, we can be - # certain the next line will be on the same indentation level. - # - # If it does end with a colon, we will need to indent all successive - # lines one additional level. - first_line = msg.splitlines()[0] - code = "".join(msg.splitlines(keepends=True)[1:]) - if not first_line.endswith(":"): - msg = f"{first_line}\n{unindent(code)}" - else: - msg = f"{first_line}\n{unindent(code, 4)}" - return msg - - def repl_stripping(self, msg: str) -> Tuple[str, bool]: - """ - Strip msg in order to extract Python code out of REPL output. - - Tries to strip out REPL Python code out of msg and returns the stripped msg. - - Returns True for the boolean if REPL code was found in the input msg. - """ - final = "" - for line in msg.splitlines(keepends=True): - if line.startswith(">>>") or line.startswith("..."): - final += line[4:] - log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") - if not final: - log.trace(f"Found no REPL code in \n\n{msg}\n\n") - return msg, False - else: - log.trace(f"Found REPL code in \n\n{msg}\n\n") - return final.rstrip(), True - - def has_bad_ticks(self, msg: Message) -> bool: - """Check to see if msg contains ticks that aren't '`'.""" - not_backticks = [ - "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", - "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", - "\u3003\u3003\u3003" - ] - - return msg.content[:3] in not_backticks - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """ - Detect poorly formatted Python code in new messages. - - If poorly formatted code is detected, send the user a helpful message explaining how to do - properly formatted Python syntax highlighting codeblocks. - """ - is_help_channel = ( - getattr(msg.channel, "category", None) - and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) - ) - parse_codeblock = ( - ( - is_help_channel - or msg.channel.id in self.channel_cooldowns - or msg.channel.id in self.channel_whitelist - ) - and not msg.author.bot - and len(msg.content.splitlines()) > 3 - and not TokenRemover.find_token_in_message(msg) - ) - - if parse_codeblock: # no token in the msg - on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 - if not on_cooldown or DEBUG_MODE: - try: - if self.has_bad_ticks(msg): - ticks = msg.content[:3] - content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) - if content is None: - return - - content, repl_code = content - - if len(content) == 2: - content = content[1] - else: - content = content[0] - - space_left = 204 - if len(content) >= space_left: - current_length = 0 - lines_walked = 0 - for line in content.splitlines(keepends=True): - if current_length + len(line) > space_left or lines_walked == 10: - break - current_length += len(line) - lines_walked += 1 - content = content[:current_length] + "#..." - content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) - howto = ( - "It looks like you are trying to paste code into this channel.\n\n" - "You seem to be using the wrong symbols to indicate where the codeblock should start. " - f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" - "**Here is an example of how it should look:**\n" - f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" - "**This will result in the following:**\n" - f"```python\n{content}\n```" - ) - - else: - howto = "" - content = self.codeblock_stripping(msg.content, False) - if content is None: - return - - content, repl_code = content - # Attempts to parse the message into an AST node. - # Invalid Python code will raise a SyntaxError. - tree = ast.parse(content[0]) - - # Multiple lines of single words could be interpreted as expressions. - # This check is to avoid all nodes being parsed as expressions. - # (e.g. words over multiple lines) - if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: - # Shorten the code to 10 lines and/or 204 characters. - space_left = 204 - if content and repl_code: - content = content[1] - else: - content = content[0] - - if len(content) >= space_left: - current_length = 0 - lines_walked = 0 - for line in content.splitlines(keepends=True): - if current_length + len(line) > space_left or lines_walked == 10: - break - current_length += len(line) - lines_walked += 1 - content = content[:current_length] + "#..." - - content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) - howto += ( - "It looks like you're trying to paste code into this channel.\n\n" - "Discord has support for Markdown, which allows you to post code with full " - "syntax highlighting. Please use these whenever you paste code, as this " - "helps improve the legibility and makes it easier for us to help you.\n\n" - f"**To do this, use the following method:**\n" - f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" - "**This will result in the following:**\n" - f"```python\n{content}\n```" - ) - - log.debug(f"{msg.author} posted something that needed to be put inside python code " - "blocks. Sending the user some instructions.") - else: - log.trace("The code consists only of expressions, not sending instructions") - - if howto != "": - # Increase amount of codeblock correction in stats - self.bot.stats.incr("codeblock_corrections") - howto_embed = Embed(description=howto) - bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) - self.codeblock_message_ids[msg.id] = bot_message.id - - self.bot.loop.create_task( - wait_for_deletion(bot_message, user_ids=(msg.author.id,), client=self.bot) - ) - else: - return - - if msg.channel.id not in self.channel_whitelist: - self.channel_cooldowns[msg.channel.id] = time.time() - - except SyntaxError: - log.trace( - f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " - "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " - f"The message that was posted was:\n\n{msg.content}\n\n" - ) - - @Cog.listener() - async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: - """Check to see if an edited message (previously called out) still contains poorly formatted code.""" - if ( - # Checks to see if the message was called out by the bot - payload.message_id not in self.codeblock_message_ids - # Makes sure that there is content in the message - or payload.data.get("content") is None - # Makes sure there's a channel id in the message payload - or payload.data.get("channel_id") is None - ): - return - - # Retrieve channel and message objects for use later - channel = self.bot.get_channel(int(payload.data.get("channel_id"))) - user_message = await channel.fetch_message(payload.message_id) - - # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None - has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) - - # If the message is fixed, delete the bot message and the entry from the id dictionary - if has_fixed_codeblock is None: - bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) - await bot_message.delete() - del self.codeblock_message_ids[payload.message_id] - log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") - - -def setup(bot: Bot) -> None: - """Load the Bot cog.""" - bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py deleted file mode 100644 index f436e531a..000000000 --- a/bot/cogs/clean.py +++ /dev/null @@ -1,272 +0,0 @@ -import logging -import random -import re -from typing import Iterable, Optional - -from discord import Colour, Embed, Message, TextChannel, User -from discord.ext import commands -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import ( - Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES -) -from bot.decorators import with_role - -log = logging.getLogger(__name__) - - -class Clean(Cog): - """ - A cog that allows messages to be deleted in bulk, while applying various filters. - - You can delete messages sent by a specific user, messages sent by bots, all messages, or messages that match a - specific regular expression. - - The deleted messages are saved and uploaded to the database via an API endpoint, and a URL is returned which can be - used to view the messages in the Discord dark theme style. - """ - - def __init__(self, bot: Bot): - self.bot = bot - self.cleaning = False - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def _clean_messages( - self, - amount: int, - ctx: Context, - channels: Iterable[TextChannel], - bots_only: bool = False, - user: User = None, - regex: Optional[str] = None, - until_message: Optional[Message] = None, - ) -> None: - """A helper function that does the actual message cleaning.""" - def predicate_bots_only(message: Message) -> bool: - """Return True if the message was sent by a bot.""" - return message.author.bot - - def predicate_specific_user(message: Message) -> bool: - """Return True if the message was sent by the user provided in the _clean_messages call.""" - return message.author == user - - def predicate_regex(message: Message) -> bool: - """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" - content = [message.content] - - # Add the content for all embed attributes - for embed in message.embeds: - content.append(embed.title) - content.append(embed.description) - content.append(embed.footer.text) - content.append(embed.author.name) - for field in embed.fields: - content.append(field.name) - content.append(field.value) - - # Get rid of empty attributes and turn it into a string - content = [attr for attr in content if attr] - content = "\n".join(content) - - # Now let's see if there's a regex match - if not content: - return False - else: - return bool(re.search(regex.lower(), content.lower())) - - # Is this an acceptable amount of messages to clean? - if amount > CleanMessages.message_limit: - embed = Embed( - color=Colour(Colours.soft_red), - title=random.choice(NEGATIVE_REPLIES), - description=f"You cannot clean more than {CleanMessages.message_limit} messages." - ) - await ctx.send(embed=embed) - return - - # Are we already performing a clean? - if self.cleaning: - embed = Embed( - color=Colour(Colours.soft_red), - title=random.choice(NEGATIVE_REPLIES), - description="Please wait for the currently ongoing clean operation to complete." - ) - await ctx.send(embed=embed) - return - - # Set up the correct predicate - if bots_only: - predicate = predicate_bots_only # Delete messages from bots - elif user: - predicate = predicate_specific_user # Delete messages from specific user - elif regex: - predicate = predicate_regex # Delete messages that match regex - else: - predicate = None # Delete all messages - - # Default to using the invoking context's channel - if not channels: - channels = [ctx.channel] - - # Delete the invocation first - self.mod_log.ignore(Event.message_delete, ctx.message.id) - await ctx.message.delete() - - messages = [] - message_ids = [] - self.cleaning = True - - # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. - for channel in channels: - async for message in channel.history(limit=amount): - - # If at any point the cancel command is invoked, we should stop. - if not self.cleaning: - return - - # If we are looking for specific message. - if until_message: - - # we could use ID's here however in case if the message we are looking for gets deleted, - # we won't have a way to figure that out thus checking for datetime should be more reliable - if message.created_at < until_message.created_at: - # means we have found the message until which we were supposed to be deleting. - break - - # Since we will be using `delete_messages` method of a TextChannel and we need message objects to - # use it as well as to send logs we will start appending messages here instead adding them from - # purge. - messages.append(message) - - # If the message passes predicate, let's save it. - if predicate is None or predicate(message): - message_ids.append(message.id) - - self.cleaning = False - - # Now let's delete the actual messages with purge. - self.mod_log.ignore(Event.message_delete, *message_ids) - for channel in channels: - if until_message: - for i in range(0, len(messages), 100): - # while purge automatically handles the amount of messages - # delete_messages only allows for up to 100 messages at once - # thus we need to paginate the amount to always be <= 100 - await channel.delete_messages(messages[i:i + 100]) - else: - messages += await channel.purge(limit=amount, check=predicate) - - # Reverse the list to restore chronological order - if messages: - messages = reversed(messages) - log_url = await self.mod_log.upload_log(messages, ctx.author.id) - else: - # Can't build an embed, nothing to clean! - embed = Embed( - color=Colour(Colours.soft_red), - description="No matching messages could be found." - ) - await ctx.send(embed=embed, delete_after=10) - return - - # Build the embed and send it - target_channels = ", ".join(channel.mention for channel in channels) - - message = ( - f"**{len(message_ids)}** messages deleted in {target_channels} by **{ctx.author.name}**\n\n" - f"A log of the deleted messages can be found [here]({log_url})." - ) - - await self.mod_log.send_log_message( - icon_url=Icons.message_bulk_delete, - colour=Colour(Colours.soft_red), - title="Bulk message delete", - text=message, - channel_id=Channels.mod_log, - ) - - @group(invoke_without_command=True, name="clean", aliases=["purge"]) - @with_role(*MODERATION_ROLES) - async def clean_group(self, ctx: Context) -> None: - """Commands for cleaning messages in channels.""" - await ctx.send_help(ctx.command) - - @clean_group.command(name="user", aliases=["users"]) - @with_role(*MODERATION_ROLES) - async def clean_user( - self, - ctx: Context, - user: User, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, user=user, channels=channels) - - @clean_group.command(name="all", aliases=["everything"]) - @with_role(*MODERATION_ROLES) - async def clean_all( - self, - ctx: Context, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, channels=channels) - - @clean_group.command(name="bots", aliases=["bot"]) - @with_role(*MODERATION_ROLES) - async def clean_bots( - self, - ctx: Context, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, bots_only=True, channels=channels) - - @clean_group.command(name="regex", aliases=["word", "expression"]) - @with_role(*MODERATION_ROLES) - async def clean_regex( - self, - ctx: Context, - regex: str, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, regex=regex, channels=channels) - - @clean_group.command(name="message", aliases=["messages"]) - @with_role(*MODERATION_ROLES) - async def clean_message(self, ctx: Context, message: Message) -> None: - """Delete all messages until certain message, stop cleaning after hitting the `message`.""" - await self._clean_messages( - CleanMessages.message_limit, - ctx, - channels=[message.channel], - until_message=message - ) - - @clean_group.command(name="stop", aliases=["cancel", "abort"]) - @with_role(*MODERATION_ROLES) - async def clean_cancel(self, ctx: Context) -> None: - """If there is an ongoing cleaning process, attempt to immediately cancel it.""" - self.cleaning = False - - embed = Embed( - color=Colour.blurple(), - description="Clean interrupted." - ) - await ctx.send(embed=embed, delete_after=10) - - -def setup(bot: Bot) -> None: - """Load the Clean cog.""" - bot.add_cog(Clean(bot)) diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py deleted file mode 100644 index d72c6c22e..000000000 --- a/bot/cogs/config_verifier.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging - -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot - - -log = logging.getLogger(__name__) - - -class ConfigVerifier(Cog): - """Verify config on startup.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) - - async def verify_channels(self) -> None: - """ - Verify channels. - - If any channels in config aren't present in server, log them in a warning. - """ - await self.bot.wait_until_guild_available() - server = self.bot.get_guild(constants.Guild.id) - - server_channel_ids = {channel.id for channel in server.channels} - invalid_channels = [ - channel_name for channel_name, channel_id in constants.Channels - if channel_id not in server_channel_ids - ] - - if invalid_channels: - log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") - - -def setup(bot: Bot) -> None: - """Load the ConfigVerifier cog.""" - bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py deleted file mode 100644 index 4c0ad5914..000000000 --- a/bot/cogs/defcon.py +++ /dev/null @@ -1,258 +0,0 @@ -from __future__ import annotations - -import logging -from collections import namedtuple -from datetime import datetime, timedelta -from enum import Enum - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles -from bot.decorators import with_role - -log = logging.getLogger(__name__) - -REJECTION_MESSAGE = """ -Hi, {user} - Thanks for your interest in our server! - -Due to a current (or detected) cyberattack on our community, we've limited access to the server for new accounts. Since -your account is relatively new, we're unable to provide access to the server at this time. - -Even so, thanks for joining! We're very excited at the possibility of having you here, and we hope that this situation -will be resolved soon. In the meantime, please feel free to peruse the resources on our site at -, and have a nice day! -""" - -BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism" - - -class Action(Enum): - """Defcon Action.""" - - ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template']) - - ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n") - DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "") - UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n") - - -class Defcon(Cog): - """Time-sensitive server defense mechanisms.""" - - days = None # type: timedelta - enabled = False # type: bool - - def __init__(self, bot: Bot): - self.bot = bot - self.channel = None - self.days = timedelta(days=0) - - self.bot.loop.create_task(self.sync_settings()) - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def sync_settings(self) -> None: - """On cog load, try to synchronize DEFCON settings to the API.""" - await self.bot.wait_until_guild_available() - self.channel = await self.bot.fetch_channel(Channels.defcon) - - try: - response = await self.bot.api_client.get('bot/bot-settings/defcon') - data = response['data'] - - except Exception: # Yikes! - log.exception("Unable to get DEFCON settings!") - await self.bot.get_channel(Channels.dev_log).send( - f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" - ) - - else: - if data["enabled"]: - self.enabled = True - self.days = timedelta(days=data["days"]) - log.info(f"DEFCON enabled: {self.days.days} days") - - else: - self.enabled = False - self.days = timedelta(days=0) - log.info("DEFCON disabled") - - await self.update_channel_topic() - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """If DEFCON is enabled, check newly joining users to see if they meet the account age threshold.""" - if self.enabled and self.days.days > 0: - now = datetime.utcnow() - - if now - member.created_at < self.days: - log.info(f"Rejecting user {member}: Account is too new and DEFCON is enabled") - - message_sent = False - - try: - await member.send(REJECTION_MESSAGE.format(user=member.mention)) - - message_sent = True - except Exception: - log.exception(f"Unable to send rejection message to user: {member}") - - await member.kick(reason="DEFCON active, user is too new") - self.bot.stats.incr("defcon.leaves") - - message = ( - f"{member} (`{member.id}`) was denied entry because their account is too new." - ) - - if not message_sent: - message = f"{message}\n\nUnable to send rejection message via DM; they probably have DMs disabled." - - await self.mod_log.send_log_message( - Icons.defcon_denied, Colours.soft_red, "Entry denied", - message, member.avatar_url_as(static_format="png") - ) - - @group(name='defcon', aliases=('dc',), invoke_without_command=True) - @with_role(Roles.admins, Roles.owners) - async def defcon_group(self, ctx: Context) -> None: - """Check the DEFCON status or run a subcommand.""" - await ctx.send_help(ctx.command) - - async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: - """Providing a structured way to do an defcon action.""" - try: - response = await self.bot.api_client.get('bot/bot-settings/defcon') - data = response['data'] - - if "enable_date" in data and action is Action.DISABLED: - enabled = datetime.fromisoformat(data["enable_date"]) - - delta = datetime.now() - enabled - - self.bot.stats.timing("defcon.enabled", delta) - except Exception: - pass - - error = None - try: - await self.bot.api_client.put( - 'bot/bot-settings/defcon', - json={ - 'name': 'defcon', - 'data': { - # TODO: retrieve old days count - 'days': days, - 'enabled': action is not Action.DISABLED, - 'enable_date': datetime.now().isoformat() - } - } - ) - except Exception as err: - log.exception("Unable to update DEFCON settings.") - error = err - finally: - await ctx.send(self.build_defcon_msg(action, error)) - await self.send_defcon_log(action, ctx.author, error) - - self.bot.stats.gauge("defcon.threshold", days) - - @defcon_group.command(name='enable', aliases=('on', 'e')) - @with_role(Roles.admins, Roles.owners) - async def enable_command(self, ctx: Context) -> None: - """ - Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! - - Currently, this just adds an account age requirement. Use !defcon days to set how old an account must be, - in days. - """ - self.enabled = True - await self._defcon_action(ctx, days=0, action=Action.ENABLED) - await self.update_channel_topic() - - @defcon_group.command(name='disable', aliases=('off', 'd')) - @with_role(Roles.admins, Roles.owners) - async def disable_command(self, ctx: Context) -> None: - """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" - self.enabled = False - await self._defcon_action(ctx, days=0, action=Action.DISABLED) - await self.update_channel_topic() - - @defcon_group.command(name='status', aliases=('s',)) - @with_role(Roles.admins, Roles.owners) - async def status_command(self, ctx: Context) -> None: - """Check the current status of DEFCON mode.""" - embed = Embed( - colour=Colour.blurple(), title="DEFCON Status", - description=f"**Enabled:** {self.enabled}\n" - f"**Days:** {self.days.days}" - ) - - await ctx.send(embed=embed) - - @defcon_group.command(name='days') - @with_role(Roles.admins, Roles.owners) - async def days_command(self, ctx: Context, days: int) -> None: - """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" - self.days = timedelta(days=days) - self.enabled = True - await self._defcon_action(ctx, days=days, action=Action.UPDATED) - await self.update_channel_topic() - - async def update_channel_topic(self) -> None: - """Update the #defcon channel topic with the current DEFCON status.""" - if self.enabled: - day_str = "days" if self.days.days > 1 else "day" - new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Enabled, Threshold: {self.days.days} {day_str})" - else: - new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Disabled)" - - self.mod_log.ignore(Event.guild_channel_update, Channels.defcon) - await self.channel.edit(topic=new_topic) - - def build_defcon_msg(self, action: Action, e: Exception = None) -> str: - """Build in-channel response string for DEFCON action.""" - if action is Action.ENABLED: - msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n" - elif action is Action.DISABLED: - msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n" - elif action is Action.UPDATED: - msg = ( - f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} " - f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n" - ) - - if e: - msg += ( - "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" - f"```py\n{e}\n```" - ) - - return msg - - async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None: - """Send log message for DEFCON action.""" - info = action.value - log_msg: str = ( - f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n" - f"{info.template.format(days=self.days.days)}" - ) - status_msg = f"DEFCON {action.name.lower()}" - - if e: - log_msg += ( - "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" - f"```py\n{e}\n```" - ) - - await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg) - - -def setup(bot: Bot) -> None: - """Load the Defcon cog.""" - bot.add_cog(Defcon(bot)) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py deleted file mode 100644 index 204cffb37..000000000 --- a/bot/cogs/doc.py +++ /dev/null @@ -1,511 +0,0 @@ -import asyncio -import functools -import logging -import re -import textwrap -from collections import OrderedDict -from contextlib import suppress -from types import SimpleNamespace -from typing import Any, Callable, Optional, Tuple - -import discord -from bs4 import BeautifulSoup -from bs4.element import PageElement, Tag -from discord.errors import NotFound -from discord.ext import commands -from markdownify import MarkdownConverter -from requests import ConnectTimeout, ConnectionError, HTTPError -from sphinx.ext import intersphinx -from urllib3.exceptions import ProtocolError - -from bot.bot import Bot -from bot.constants import MODERATION_ROLES, RedirectOutput -from bot.converters import ValidPythonIdentifier, ValidURL -from bot.decorators import with_role -from bot.pagination import LinePaginator - - -log = logging.getLogger(__name__) -logging.getLogger('urllib3').setLevel(logging.WARNING) - -# Since Intersphinx is intended to be used with Sphinx, -# we need to mock its configuration. -SPHINX_MOCK_APP = SimpleNamespace( - config=SimpleNamespace( - intersphinx_timeout=3, - tls_verify=True, - user_agent="python3:python-discord/bot:1.0.0" - ) -) - -NO_OVERRIDE_GROUPS = ( - "2to3fixer", - "token", - "label", - "pdbcommand", - "term", -) -NO_OVERRIDE_PACKAGES = ( - "python", -) - -SEARCH_END_TAG_ATTRS = ( - "data", - "function", - "class", - "exception", - "seealso", - "section", - "rubric", - "sphinxsidebar", -) -UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") -WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") - -FAILED_REQUEST_RETRY_AMOUNT = 3 -NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay - - -def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: - """ - LRU cache implementation for coroutines. - - Once the cache exceeds the maximum size, keys are deleted in FIFO order. - - An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. - """ - # Assign the cache to the function itself so we can clear it from outside. - async_cache.cache = OrderedDict() - - def decorator(function: Callable) -> Callable: - """Define the async_cache decorator.""" - @functools.wraps(function) - async def wrapper(*args) -> Any: - """Decorator wrapper for the caching logic.""" - key = ':'.join(args[arg_offset:]) - - value = async_cache.cache.get(key) - if value is None: - if len(async_cache.cache) > max_size: - async_cache.cache.popitem(last=False) - - async_cache.cache[key] = await function(*args) - return async_cache.cache[key] - return wrapper - return decorator - - -class DocMarkdownConverter(MarkdownConverter): - """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" - - def convert_code(self, el: PageElement, text: str) -> str: - """Undo `markdownify`s underscore escaping.""" - return f"`{text}`".replace('\\', '') - - def convert_pre(self, el: PageElement, text: str) -> str: - """Wrap any codeblocks in `py` for syntax highlighting.""" - code = ''.join(el.strings) - return f"```py\n{code}```" - - -def markdownify(html: str) -> DocMarkdownConverter: - """Create a DocMarkdownConverter object from the input html.""" - return DocMarkdownConverter(bullets='•').convert(html) - - -class InventoryURL(commands.Converter): - """ - Represents an Intersphinx inventory URL. - - This converter checks whether intersphinx accepts the given inventory URL, and raises - `BadArgument` if that is not the case. - - Otherwise, it simply passes through the given URL. - """ - - @staticmethod - async def convert(ctx: commands.Context, url: str) -> str: - """Convert url to Intersphinx inventory URL.""" - try: - intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) - except AttributeError: - raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") - except ConnectionError: - if url.startswith('https'): - raise commands.BadArgument( - f"Cannot establish a connection to `{url}`. Does it support HTTPS?" - ) - raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") - except ValueError: - raise commands.BadArgument( - f"Failed to read Intersphinx inventory from URL `{url}`. " - "Are you sure that it's a valid inventory file?" - ) - return url - - -class Doc(commands.Cog): - """A set of commands for querying & displaying documentation.""" - - def __init__(self, bot: Bot): - self.base_urls = {} - self.bot = bot - self.inventories = {} - self.renamed_symbols = set() - - self.bot.loop.create_task(self.init_refresh_inventory()) - - async def init_refresh_inventory(self) -> None: - """Refresh documentation inventory on cog initialization.""" - await self.bot.wait_until_guild_available() - await self.refresh_inventory() - - async def update_single( - self, package_name: str, base_url: str, inventory_url: str - ) -> None: - """ - Rebuild the inventory for a single package. - - Where: - * `package_name` is the package name to use, appears in the log - * `base_url` is the root documentation URL for the specified package, used to build - absolute paths that link to specific symbols - * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running - `intersphinx.fetch_inventory` in an executor on the bot's event loop - """ - self.base_urls[package_name] = base_url - - package = await self._fetch_inventory(inventory_url) - if not package: - return None - - for group, value in package.items(): - for symbol, (package_name, _version, relative_doc_url, _) in value.items(): - absolute_doc_url = base_url + relative_doc_url - - if symbol in self.inventories: - group_name = group.split(":")[1] - symbol_base_url = self.inventories[symbol].split("/", 3)[2] - if ( - group_name in NO_OVERRIDE_GROUPS - or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) - ): - - symbol = f"{group_name}.{symbol}" - # If renamed `symbol` already exists, add library name in front to differentiate between them. - if symbol in self.renamed_symbols: - # Split `package_name` because of packages like Pillow that have spaces in them. - symbol = f"{package_name.split()[0]}.{symbol}" - - self.inventories[symbol] = absolute_doc_url - self.renamed_symbols.add(symbol) - continue - - self.inventories[symbol] = absolute_doc_url - - log.trace(f"Fetched inventory for {package_name}.") - - async def refresh_inventory(self) -> None: - """Refresh internal documentation inventory.""" - log.debug("Refreshing documentation inventory...") - - # Clear the old base URLS and inventories to ensure - # that we start from a fresh local dataset. - # Also, reset the cache used for fetching documentation. - self.base_urls.clear() - self.inventories.clear() - self.renamed_symbols.clear() - async_cache.cache = OrderedDict() - - # Run all coroutines concurrently - since each of them performs a HTTP - # request, this speeds up fetching the inventory data heavily. - coros = [ - self.update_single( - package["package"], package["base_url"], package["inventory_url"] - ) for package in await self.bot.api_client.get('bot/documentation-links') - ] - await asyncio.gather(*coros) - - async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: - """ - Given a Python symbol, return its signature and description. - - The first tuple element is the signature of the given symbol as a markup-free string, and - the second tuple element is the description of the given symbol with HTML markup included. - - If the given symbol is a module, returns a tuple `(None, str)` - else if the symbol could not be found, returns `None`. - """ - url = self.inventories.get(symbol) - if url is None: - return None - - async with self.bot.http_session.get(url) as response: - html = await response.text(encoding='utf-8') - - # Find the signature header and parse the relevant parts. - symbol_id = url.split('#')[-1] - soup = BeautifulSoup(html, 'lxml') - symbol_heading = soup.find(id=symbol_id) - search_html = str(soup) - - if symbol_heading is None: - return None - - if symbol_id == f"module-{symbol}": - # Get page content from the module headerlink to the - # first tag that has its class in `SEARCH_END_TAG_ATTRS` - start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) - if start_tag is None: - return [], "" - - end_tag = start_tag.find_next(self._match_end_tag) - if end_tag is None: - return [], "" - - description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) - description_end_index = search_html.find(str(end_tag)) - description = search_html[description_start_index:description_end_index] - signatures = None - - else: - signatures = [] - description = str(symbol_heading.find_next_sibling("dd")) - description_pos = search_html.find(description) - # Get text of up to 3 signatures, remove unwanted symbols - for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): - signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) - if signature and search_html.find(str(element)) < description_pos: - signatures.append(signature) - - return signatures, description.replace('¶', '') - - @async_cache(arg_offset=1) - async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: - """ - Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. - - If the symbol is known, an Embed with documentation about it is returned. - """ - scraped_html = await self.get_symbol_html(symbol) - if scraped_html is None: - return None - - signatures = scraped_html[0] - permalink = self.inventories[symbol] - description = markdownify(scraped_html[1]) - - # Truncate the description of the embed to the last occurrence - # of a double newline (interpreted as a paragraph) before index 1000. - if len(description) > 1000: - shortened = description[:1000] - description_cutoff = shortened.rfind('\n\n', 100) - if description_cutoff == -1: - # Search the shortened version for cutoff points in decreasing desirability, - # cutoff at 1000 if none are found. - for string in (". ", ", ", ",", " "): - description_cutoff = shortened.rfind(string) - if description_cutoff != -1: - break - else: - description_cutoff = 1000 - description = description[:description_cutoff] - - # If there is an incomplete code block, cut it out - if description.count("```") % 2: - codeblock_start = description.rfind('```py') - description = description[:codeblock_start].rstrip() - description += f"... [read more]({permalink})" - - description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) - if signatures is None: - # If symbol is a module, don't show signature. - embed_description = description - - elif not signatures: - # It's some "meta-page", for example: - # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views - embed_description = "This appears to be a generic page not tied to a specific symbol." - - else: - embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) - embed_description += f"\n{description}" - - embed = discord.Embed( - title=f'`{symbol}`', - url=permalink, - description=embed_description - ) - # Show all symbols with the same name that were renamed in the footer. - embed.set_footer( - text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) - ) - return embed - - @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) - async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: - """Lookup documentation for Python symbols.""" - await ctx.invoke(self.get_command, symbol) - - @docs_group.command(name='get', aliases=('g',)) - async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: - """ - Return a documentation embed for a given symbol. - - If no symbol is given, return a list of all available inventories. - - Examples: - !docs - !docs aiohttp - !docs aiohttp.ClientSession - !docs get aiohttp.ClientSession - """ - if symbol is None: - inventory_embed = discord.Embed( - title=f"All inventories (`{len(self.base_urls)}` total)", - colour=discord.Colour.blue() - ) - - lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) - if self.base_urls: - await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) - - else: - inventory_embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=inventory_embed) - - else: - # Fetching documentation for a symbol (at least for the first time, since - # caching is used) takes quite some time, so let's send typing to indicate - # that we got the command, but are still working on it. - async with ctx.typing(): - doc_embed = await self.get_symbol_embed(symbol) - - if doc_embed is None: - error_embed = discord.Embed( - description=f"Sorry, I could not find any documentation for `{symbol}`.", - colour=discord.Colour.red() - ) - error_message = await ctx.send(embed=error_embed) - with suppress(NotFound): - await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) - await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) - else: - await ctx.send(embed=doc_embed) - - @docs_group.command(name='set', aliases=('s',)) - @with_role(*MODERATION_ROLES) - async def set_command( - self, ctx: commands.Context, package_name: ValidPythonIdentifier, - base_url: ValidURL, inventory_url: InventoryURL - ) -> None: - """ - Adds a new documentation metadata object to the site's database. - - The database will update the object, should an existing item with the specified `package_name` already exist. - - Example: - !docs set \ - python \ - https://docs.python.org/3/ \ - https://docs.python.org/3/objects.inv - """ - body = { - 'package': package_name, - 'base_url': base_url, - 'inventory_url': inventory_url - } - await self.bot.api_client.post('bot/documentation-links', json=body) - - log.info( - f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" - f"Package name: {package_name}\n" - f"Base url: {base_url}\n" - f"Inventory URL: {inventory_url}" - ) - - # Rebuilding the inventory can take some time, so lets send out a - # typing event to show that the Bot is still working. - async with ctx.typing(): - await self.refresh_inventory() - await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") - - @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) - @with_role(*MODERATION_ROLES) - async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: - """ - Removes the specified package from the database. - - Examples: - !docs delete aiohttp - """ - await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') - - async with ctx.typing(): - # Rebuild the inventory to ensure that everything - # that was from this package is properly deleted. - await self.refresh_inventory() - await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") - - @docs_group.command(name="refresh", aliases=("rfsh", "r")) - @with_role(*MODERATION_ROLES) - async def refresh_command(self, ctx: commands.Context) -> None: - """Refresh inventories and send differences to channel.""" - old_inventories = set(self.base_urls) - with ctx.typing(): - await self.refresh_inventory() - # Get differences of added and removed inventories - added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) - if added: - added = f"+ {added}" - - removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) - if removed: - removed = f"- {removed}" - - embed = discord.Embed( - title="Inventories refreshed", - description=f"```diff\n{added}\n{removed}```" if added or removed else "" - ) - await ctx.send(embed=embed) - - async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: - """Get and return inventory from `inventory_url`. If fetching fails, return None.""" - fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) - for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): - try: - package = await self.bot.loop.run_in_executor(None, fetch_func) - except ConnectTimeout: - log.error( - f"Fetching of inventory {inventory_url} timed out," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" - ) - except ProtocolError: - log.error( - f"Connection lost while fetching inventory {inventory_url}," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" - ) - except HTTPError as e: - log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") - return None - except ConnectionError: - log.error(f"Couldn't establish connection to inventory {inventory_url}.") - return None - else: - return package - log.error(f"Fetching of inventory {inventory_url} failed.") - return None - - @staticmethod - def _match_end_tag(tag: Tag) -> bool: - """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" - for attr in SEARCH_END_TAG_ATTRS: - if attr in tag.get("class", ()): - return True - - return tag.name == "table" - - -def setup(bot: Bot) -> None: - """Load the Doc cog.""" - bot.add_cog(Doc(bot)) diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py deleted file mode 100644 index f9d4de638..000000000 --- a/bot/cogs/error_handler.py +++ /dev/null @@ -1,287 +0,0 @@ -import contextlib -import logging -import typing as t - -from discord import Embed -from discord.ext.commands import Cog, Context, errors -from sentry_sdk import push_scope - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, Colours -from bot.converters import TagNameConverter -from bot.utils.checks import InWhitelistCheckFailure - -log = logging.getLogger(__name__) - - -class ErrorHandler(Cog): - """Handles errors emitted from commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - def _get_error_embed(self, title: str, body: str) -> Embed: - """Return an embed that contains the exception.""" - return Embed( - title=title, - colour=Colours.soft_red, - description=body - ) - - @Cog.listener() - async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: - """ - Provide generic command error handling. - - Error handling is deferred to any local error handler, if present. This is done by - checking for the presence of a `handled` attribute on the error. - - Error handling emits a single error message in the invoking context `ctx` and a log message, - prioritised as follows: - - 1. If the name fails to match a command: - * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. - Otherwise if it matches a tag, the tag is invoked - * If CommandNotFound is raised when invoking the tag (determined by the presence of the - `invoked_from_error_handler` attribute), this error is treated as being unexpected - and therefore sends an error message - * Commands in the verification channel are ignored - 2. UserInputError: see `handle_user_input_error` - 3. CheckFailure: see `handle_check_failure` - 4. CommandOnCooldown: send an error message in the invoking context - 5. ResponseCodeError: see `handle_api_error` - 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` - """ - command = ctx.command - - if hasattr(e, "handled"): - log.trace(f"Command {command} had its error already handled locally; ignoring.") - return - - if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): - if await self.try_silence(ctx): - return - if ctx.channel.id != Channels.verification: - # Try to look for a tag with the command's name - await self.try_get_tag(ctx) - return # Exit early to avoid logging. - elif isinstance(e, errors.UserInputError): - await self.handle_user_input_error(ctx, e) - elif isinstance(e, errors.CheckFailure): - await self.handle_check_failure(ctx, e) - elif isinstance(e, errors.CommandOnCooldown): - await ctx.send(e) - elif isinstance(e, errors.CommandInvokeError): - if isinstance(e.original, ResponseCodeError): - await self.handle_api_error(ctx, e.original) - else: - await self.handle_unexpected_error(ctx, e.original) - return # Exit early to avoid logging. - elif not isinstance(e, errors.DisabledCommand): - # ConversionError, MaxConcurrencyReached, ExtensionError - await self.handle_unexpected_error(ctx, e) - return # Exit early to avoid logging. - - log.debug( - f"Command {command} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" - ) - - @staticmethod - def get_help_command(ctx: Context) -> t.Coroutine: - """Return a prepared `help` command invocation coroutine.""" - if ctx.command: - return ctx.send_help(ctx.command) - - return ctx.send_help() - - async def try_silence(self, ctx: Context) -> bool: - """ - Attempt to invoke the silence or unsilence command if invoke with matches a pattern. - - Respecting the checks if: - * invoked with `shh+` silence channel for amount of h's*2 with max of 15. - * invoked with `unshh+` unsilence channel - Return bool depending on success of command. - """ - command = ctx.invoked_with.lower() - silence_command = self.bot.get_command("silence") - ctx.invoked_from_error_handler = True - try: - if not await silence_command.can_run(ctx): - log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") - return False - except errors.CommandError: - log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") - return False - if command.startswith("shh"): - await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) - return True - elif command.startswith("unshh"): - await ctx.invoke(self.bot.get_command("unsilence")) - return True - return False - - async def try_get_tag(self, ctx: Context) -> None: - """ - Attempt to display a tag by interpreting the command name as a tag name. - - The invocation of tags get respects its checks. Any CommandErrors raised will be handled - by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to - the context to prevent infinite recursion in the case of a CommandNotFound exception. - """ - tags_get_command = self.bot.get_command("tags get") - ctx.invoked_from_error_handler = True - - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." - try: - if not await tags_get_command.can_run(ctx): - log.debug(log_msg) - return - except errors.CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - - try: - tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) - except errors.BadArgument: - log.debug( - f"{ctx.author} tried to use an invalid command " - f"and the fallback tag failed validation in TagNameConverter." - ) - else: - with contextlib.suppress(ResponseCodeError): - await ctx.invoke(tags_get_command, tag_name=tag_name) - # Return to not raise the exception - return - - async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: - """ - Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. - - * MissingRequiredArgument: send an error message with arg name and the help command - * TooManyArguments: send an error message and the help command - * BadArgument: send an error message and the help command - * BadUnionArgument: send an error message including the error produced by the last converter - * ArgumentParsingError: send an error message - * Other: send an error message and the help command - """ - prepared_help_command = self.get_help_command(ctx) - - if isinstance(e, errors.MissingRequiredArgument): - embed = self._get_error_embed("Missing required argument", e.param.name) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.missing_required_argument") - elif isinstance(e, errors.TooManyArguments): - embed = self._get_error_embed("Too many arguments", str(e)) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.too_many_arguments") - elif isinstance(e, errors.BadArgument): - embed = self._get_error_embed("Bad argument", str(e)) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.bad_argument") - elif isinstance(e, errors.BadUnionArgument): - embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") - await ctx.send(embed=embed) - self.bot.stats.incr("errors.bad_union_argument") - elif isinstance(e, errors.ArgumentParsingError): - embed = self._get_error_embed("Argument parsing error", str(e)) - await ctx.send(embed=embed) - self.bot.stats.incr("errors.argument_parsing_error") - else: - embed = self._get_error_embed( - "Input error", - "Something about your input seems off. Check the arguments and try again." - ) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.other_user_input_error") - - @staticmethod - async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: - """ - Send an error message in `ctx` for certain types of CheckFailure. - - The following types are handled: - - * BotMissingPermissions - * BotMissingRole - * BotMissingAnyRole - * NoPrivateMessage - * InWhitelistCheckFailure - """ - bot_missing_errors = ( - errors.BotMissingPermissions, - errors.BotMissingRole, - errors.BotMissingAnyRole - ) - - if isinstance(e, bot_missing_errors): - ctx.bot.stats.incr("errors.bot_permission_error") - await ctx.send( - "Sorry, it looks like I don't have the permissions or roles I need to do that." - ) - elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): - ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") - await ctx.send(e) - - @staticmethod - async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: - """Send an error message in `ctx` for ResponseCodeError and log it.""" - if e.status == 404: - await ctx.send("There does not seem to be anything matching your query.") - log.debug(f"API responded with 404 for command {ctx.command}") - ctx.bot.stats.incr("errors.api_error_404") - elif e.status == 400: - content = await e.response.json() - log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) - await ctx.send("According to the API, your request is malformed.") - ctx.bot.stats.incr("errors.api_error_400") - elif 500 <= e.status < 600: - await ctx.send("Sorry, there seems to be an internal issue with the API.") - log.warning(f"API responded with {e.status} for command {ctx.command}") - ctx.bot.stats.incr("errors.api_internal_server_error") - else: - await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") - log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") - ctx.bot.stats.incr(f"errors.api_error_{e.status}") - - @staticmethod - async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: - """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" - await ctx.send( - f"Sorry, an unexpected error occurred. Please let us know!\n\n" - f"```{e.__class__.__name__}: {e}```" - ) - - ctx.bot.stats.incr("errors.unexpected") - - with push_scope() as scope: - scope.user = { - "id": ctx.author.id, - "username": str(ctx.author) - } - - scope.set_tag("command", ctx.command.qualified_name) - scope.set_tag("message_id", ctx.message.id) - scope.set_tag("channel_id", ctx.channel.id) - - scope.set_extra("full_message", ctx.message.content) - - if ctx.guild is not None: - scope.set_extra( - "jump_to", - f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" - ) - - log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) - - -def setup(bot: Bot) -> None: - """Load the ErrorHandler cog.""" - bot.add_cog(ErrorHandler(bot)) diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py deleted file mode 100644 index eb8bfb1cf..000000000 --- a/bot/cogs/eval.py +++ /dev/null @@ -1,202 +0,0 @@ -import contextlib -import inspect -import logging -import pprint -import re -import textwrap -import traceback -from io import StringIO -from typing import Any, Optional, Tuple - -import discord -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import Roles -from bot.decorators import with_role -from bot.interpreter import Interpreter - -log = logging.getLogger(__name__) - - -class CodeEval(Cog): - """Owner and admin feature that evaluates code and returns the result to the channel.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.env = {} - self.ln = 0 - self.stdout = StringIO() - - self.interpreter = Interpreter(bot) - - def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: - """Format the eval output into a string & attempt to format it into an Embed.""" - self._ = out - - res = "" - - # Erase temp input we made - if inp.startswith("_ = "): - inp = inp[4:] - - # Get all non-empty lines - lines = [line for line in inp.split("\n") if line.strip()] - if len(lines) != 1: - lines += [""] - - # Create the input dialog - for i, line in enumerate(lines): - if i == 0: - # Start dialog - start = f"In [{self.ln}]: " - - else: - # Indent the 3 dots correctly; - # Normally, it's something like - # In [X]: - # ...: - # - # But if it's - # In [XX]: - # ...: - # - # You can see it doesn't look right. - # This code simply indents the dots - # far enough to align them. - # we first `str()` the line number - # then we get the length - # and use `str.rjust()` - # to indent it. - start = "...: ".rjust(len(str(self.ln)) + 7) - - if i == len(lines) - 2: - if line.startswith("return"): - line = line[6:].strip() - - # Combine everything - res += (start + line + "\n") - - self.stdout.seek(0) - text = self.stdout.read() - self.stdout.close() - self.stdout = StringIO() - - if text: - res += (text + "\n") - - if out is None: - # No output, return the input statement - return (res, None) - - res += f"Out[{self.ln}]: " - - if isinstance(out, discord.Embed): - # We made an embed? Send that as embed - res += "" - res = (res, out) - - else: - if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): - # Leave out the traceback message - out = "\n" + "\n".join(out.split("\n")[1:]) - - if isinstance(out, str): - pretty = out - else: - pretty = pprint.pformat(out, compact=True, width=60) - - if pretty != str(out): - # We're using the pretty version, start on the next line - res += "\n" - - if pretty.count("\n") > 20: - # Text too long, shorten - li = pretty.split("\n") - - pretty = ("\n".join(li[:3]) # First 3 lines - + "\n ...\n" # Ellipsis to indicate removed lines - + "\n".join(li[-3:])) # last 3 lines - - # Add the output - res += pretty - res = (res, None) - - return res # Return (text, embed) - - async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: - """Eval the input code string & send an embed to the invoking context.""" - self.ln += 1 - - if code.startswith("exit"): - self.ln = 0 - self.env = {} - return await ctx.send("```Reset history!```") - - env = { - "message": ctx.message, - "author": ctx.message.author, - "channel": ctx.channel, - "guild": ctx.guild, - "ctx": ctx, - "self": self, - "bot": self.bot, - "inspect": inspect, - "discord": discord, - "contextlib": contextlib - } - - self.env.update(env) - - # Ignore this code, it works - code_ = """ -async def func(): # (None,) -> Any - try: - with contextlib.redirect_stdout(self.stdout): -{0} - if '_' in locals(): - if inspect.isawaitable(_): - _ = await _ - return _ - finally: - self.env.update(locals()) -""".format(textwrap.indent(code, ' ')) - - try: - exec(code_, self.env) # noqa: B102,S102 - func = self.env['func'] - res = await func() - - except Exception: - res = traceback.format_exc() - - out, embed = self._format(code, res) - await ctx.send(f"```py\n{out}```", embed=embed) - - @group(name='internal', aliases=('int',)) - @with_role(Roles.owners, Roles.admins) - async def internal_group(self, ctx: Context) -> None: - """Internal commands. Top secret!""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @internal_group.command(name='eval', aliases=('e',)) - @with_role(Roles.admins, Roles.owners) - async def eval(self, ctx: Context, *, code: str) -> None: - """Run eval in a REPL-like format.""" - code = code.strip("`") - if re.match('py(thon)?\n', code): - code = "\n".join(code.split("\n")[1:]) - - if not re.search( # Check if it's an expression - r"^(return|import|for|while|def|class|" - r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( - code.split("\n")) == 1: - code = "_ = " + code - - await self._eval(ctx, code) - - -def setup(bot: Bot) -> None: - """Load the CodeEval cog.""" - bot.add_cog(CodeEval(bot)) diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py deleted file mode 100644 index 365f198ff..000000000 --- a/bot/cogs/extensions.py +++ /dev/null @@ -1,236 +0,0 @@ -import functools -import logging -import typing as t -from enum import Enum -from pkgutil import iter_modules - -from discord import Colour, Embed -from discord.ext import commands -from discord.ext.commands import Context, group - -from bot.bot import Bot -from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - -UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"} -EXTENSIONS = frozenset( - ext.name - for ext in iter_modules(("bot/cogs",), "bot.cogs.") - if ext.name[-1] != "_" -) - - -class Action(Enum): - """Represents an action to perform on an extension.""" - - # Need to be partial otherwise they are considered to be function definitions. - LOAD = functools.partial(Bot.load_extension) - UNLOAD = functools.partial(Bot.unload_extension) - RELOAD = functools.partial(Bot.reload_extension) - - -class Extension(commands.Converter): - """ - Fully qualify the name of an extension and ensure it exists. - - The * and ** values bypass this when used with the reload command. - """ - - async def convert(self, ctx: Context, argument: str) -> str: - """Fully qualify the name of an extension and ensure it exists.""" - # Special values to reload all extensions - if argument == "*" or argument == "**": - return argument - - argument = argument.lower() - - if "." not in argument: - argument = f"bot.cogs.{argument}" - - if argument in EXTENSIONS: - return argument - else: - raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") - - -class Extensions(commands.Cog): - """Extension management commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) - async def extensions_group(self, ctx: Context) -> None: - """Load, unload, reload, and list loaded extensions.""" - await ctx.send_help(ctx.command) - - @extensions_group.command(name="load", aliases=("l",)) - async def load_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Load extensions given their fully qualified or unqualified names. - - If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - if "*" in extensions or "**" in extensions: - extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) - - msg = self.batch_manage(Action.LOAD, *extensions) - await ctx.send(msg) - - @extensions_group.command(name="unload", aliases=("ul",)) - async def unload_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Unload currently loaded extensions given their fully qualified or unqualified names. - - If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) - - if blacklisted: - msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" - else: - if "*" in extensions or "**" in extensions: - extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST - - msg = self.batch_manage(Action.UNLOAD, *extensions) - - await ctx.send(msg) - - @extensions_group.command(name="reload", aliases=("r",)) - async def reload_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Reload extensions given their fully qualified or unqualified names. - - If an extension fails to be reloaded, it will be rolled-back to the prior working state. - - If '\*' is given as the name, all currently loaded extensions will be reloaded. - If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - if "**" in extensions: - extensions = EXTENSIONS - elif "*" in extensions: - extensions = set(self.bot.extensions.keys()) | set(extensions) - extensions.remove("*") - - msg = self.batch_manage(Action.RELOAD, *extensions) - - await ctx.send(msg) - - @extensions_group.command(name="list", aliases=("all",)) - async def list_command(self, ctx: Context) -> None: - """ - Get a list of all extensions, including their loaded status. - - Grey indicates that the extension is unloaded. - Green indicates that the extension is currently loaded. - """ - embed = Embed() - lines = [] - - embed.colour = Colour.blurple() - embed.set_author( - name="Extensions List", - url=URLs.github_bot_repo, - icon_url=URLs.bot_avatar - ) - - for ext in sorted(list(EXTENSIONS)): - if ext in self.bot.extensions: - status = Emojis.status_online - else: - status = Emojis.status_offline - - ext = ext.rsplit(".", 1)[1] - lines.append(f"{status} {ext}") - - log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") - await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) - - def batch_manage(self, action: Action, *extensions: str) -> str: - """ - Apply an action to multiple extensions and return a message with the results. - - If only one extension is given, it is deferred to `manage()`. - """ - if len(extensions) == 1: - msg, _ = self.manage(action, extensions[0]) - return msg - - verb = action.name.lower() - failures = {} - - for extension in extensions: - _, error = self.manage(action, extension) - if error: - failures[extension] = error - - emoji = ":x:" if failures else ":ok_hand:" - msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." - - if failures: - failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) - msg += f"\nFailures:```{failures}```" - - log.debug(f"Batch {verb}ed extensions.") - - return msg - - def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: - """Apply an action to an extension and return the status message and any error message.""" - verb = action.name.lower() - error_msg = None - - try: - action.value(self.bot, ext) - except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): - if action is Action.RELOAD: - # When reloading, just load the extension if it was not loaded. - return self.manage(Action.LOAD, ext) - - msg = f":x: Extension `{ext}` is already {verb}ed." - log.debug(msg[4:]) - except Exception as e: - if hasattr(e, "original"): - e = e.original - - log.exception(f"Extension '{ext}' failed to {verb}.") - - error_msg = f"{e.__class__.__name__}: {e}" - msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" - else: - msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." - log.debug(msg[10:]) - - return msg, error_msg - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators and core developers to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Handle BadArgument errors locally to prevent the help command from showing.""" - if isinstance(error, commands.BadArgument): - await ctx.send(str(error)) - error.handled = True - - -def setup(bot: Bot) -> None: - """Load the Extensions cog.""" - bot.add_cog(Extensions(bot)) diff --git a/bot/cogs/filter_lists.py b/bot/cogs/filter_lists.py deleted file mode 100644 index c15adc461..000000000 --- a/bot/cogs/filter_lists.py +++ /dev/null @@ -1,273 +0,0 @@ -import logging -from typing import Optional - -from discord import Colour, Embed -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.converters import ValidDiscordServerInvite, ValidFilterListType -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - - -class FilterLists(Cog): - """Commands for blacklisting and whitelisting things.""" - - methods_with_filterlist_types = [ - "allow_add", - "allow_delete", - "allow_get", - "deny_add", - "deny_delete", - "deny_get", - ] - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.bot.loop.create_task(self._amend_docstrings()) - - async def _amend_docstrings(self) -> None: - """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" - await self.bot.wait_until_guild_available() - - # Add valid filterlist types to the docstrings - valid_types = await ValidFilterListType.get_valid_types(self.bot) - valid_types = [f"`{type_.lower()}`" for type_ in valid_types] - - for method_name in self.methods_with_filterlist_types: - command = getattr(self, method_name) - command.help = ( - f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." - ) - - async def _add_data( - self, - ctx: Context, - allowed: bool, - list_type: ValidFilterListType, - content: str, - comment: Optional[str] = None, - ) -> None: - """Add an item to a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we gotta validate it. - if list_type == "GUILD_INVITE": - guild_data = await self._validate_guild_invite(ctx, content) - content = guild_data.get("id") - - # Unless the user has specified another comment, let's - # use the server name as the comment so that the list - # of guild IDs will be more easily readable when we - # display it. - if not comment: - comment = guild_data.get("name") - - # If it's a file format, let's make sure it has a leading dot. - elif list_type == "FILE_FORMAT" and not content.startswith("."): - content = f".{content}" - - # Try to add the item to the database - log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") - payload = { - "allowed": allowed, - "type": list_type, - "content": content, - "comment": comment, - } - - try: - item = await self.bot.api_client.post( - "bot/filter-lists", - json=payload - ) - except ResponseCodeError as e: - if e.status == 400: - await ctx.message.add_reaction("❌") - log.debug( - f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " - "probably because the request violated the UniqueConstraint." - ) - raise BadArgument( - f"Unable to add the item to the {allow_type}. " - "The item probably already exists. Keep in mind that a " - "blacklist and a whitelist for the same item cannot co-exist, " - "and we do not permit any duplicates." - ) - raise - - # Insert the item into the cache - self.bot.insert_item_into_filter_list_cache(item) - await ctx.message.add_reaction("✅") - - async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we need to convert it. - if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): - guild_data = await self._validate_guild_invite(ctx, content) - content = guild_data.get("id") - - # If it's a file format, let's make sure it has a leading dot. - elif list_type == "FILE_FORMAT" and not content.startswith("."): - content = f".{content}" - - # Find the content and delete it. - log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") - item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) - - if item is not None: - try: - await self.bot.api_client.delete( - f"bot/filter-lists/{item['id']}" - ) - del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] - await ctx.message.add_reaction("✅") - except ResponseCodeError as e: - log.debug( - f"{ctx.author} tried to delete an item with the id {item['id']}, but " - f"the API raised an unexpected error: {e}" - ) - await ctx.message.add_reaction("❌") - else: - await ctx.message.add_reaction("❌") - - async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: - """Paginate and display all items in a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] - - # Build a list of lines we want to show in the paginator - lines = [] - for content, metadata in result.items(): - line = f"• `{content}`" - - if comment := metadata.get("comment"): - line += f" - {comment}" - - lines.append(line) - lines = sorted(lines) - - # Build the embed - list_type_plural = list_type.lower().replace("_", " ").title() + "s" - embed = Embed( - title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", - colour=Colour.blue() - ) - log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") - - if result: - await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - await ctx.message.add_reaction("❌") - - async def _sync_data(self, ctx: Context) -> None: - """Syncs the filterlists with the API.""" - try: - log.trace("Attempting to sync FilterList cache with data from the API.") - await self.bot.cache_filter_list_data() - await ctx.message.add_reaction("✅") - except ResponseCodeError as e: - log.debug( - f"{ctx.author} tried to sync FilterList cache data but " - f"the API raised an unexpected error: {e}" - ) - await ctx.message.add_reaction("❌") - - @staticmethod - async def _validate_guild_invite(ctx: Context, invite: str) -> dict: - """ - Validates a guild invite, and returns the guild info as a dict. - - Will raise a BadArgument if the guild invite is invalid. - """ - log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") - validator = ValidDiscordServerInvite() - guild_data = await validator.convert(ctx, invite) - - # If we make it this far without raising a BadArgument, the invite is - # valid. Let's return a dict of guild information. - log.trace(f"{invite} validated as server invite. Converting to ID.") - return guild_data - - @group(aliases=("allowlist", "allow", "al", "wl")) - async def whitelist(self, ctx: Context) -> None: - """Group for whitelisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @group(aliases=("denylist", "deny", "bl", "dl")) - async def blacklist(self, ctx: Context) -> None: - """Group for blacklisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @whitelist.command(name="add", aliases=("a", "set")) - async def allow_add( - self, - ctx: Context, - list_type: ValidFilterListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified allowlist.""" - await self._add_data(ctx, True, list_type, content, comment) - - @blacklist.command(name="add", aliases=("a", "set")) - async def deny_add( - self, - ctx: Context, - list_type: ValidFilterListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified denylist.""" - await self._add_data(ctx, False, list_type, content, comment) - - @whitelist.command(name="remove", aliases=("delete", "rm",)) - async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from the specified allowlist.""" - await self._delete_data(ctx, True, list_type, content) - - @blacklist.command(name="remove", aliases=("delete", "rm",)) - async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from the specified denylist.""" - await self._delete_data(ctx, False, list_type, content) - - @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: - """Get the contents of a specified allowlist.""" - await self._list_all_data(ctx, True, list_type) - - @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: - """Get the contents of a specified denylist.""" - await self._list_all_data(ctx, False, list_type) - - @whitelist.command(name="sync", aliases=("s",)) - async def allow_sync(self, ctx: Context) -> None: - """Syncs both allowlists and denylists with the API.""" - await self._sync_data(ctx) - - @blacklist.command(name="sync", aliases=("s",)) - async def deny_sync(self, ctx: Context) -> None: - """Syncs both allowlists and denylists with the API.""" - await self._sync_data(ctx) - - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the FilterLists cog.""" - bot.add_cog(FilterLists(bot)) diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py deleted file mode 100644 index 93cc1c655..000000000 --- a/bot/cogs/filtering.py +++ /dev/null @@ -1,575 +0,0 @@ -import asyncio -import logging -import re -from datetime import datetime, timedelta -from typing import List, Mapping, Optional, Tuple, Union - -import dateutil -import discord.errors -from dateutil.relativedelta import relativedelta -from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel -from discord.ext.commands import Cog -from discord.utils import escape_markdown - -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import ( - Channels, Colours, - Filter, Icons, URLs -) -from bot.utils.redis_cache import RedisCache -from bot.utils.regex import INVITE_RE -from bot.utils.scheduling import Scheduler - -log = logging.getLogger(__name__) - -# Regular expressions -SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) -URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) -ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") - -# Other constants. -DAYS_BETWEEN_ALERTS = 3 -OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) - - -class Filtering(Cog): - """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" - - # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent - name_alerts = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - self.name_lock = asyncio.Lock() - - staff_mistake_str = "If you believe this was a mistake, please let staff know!" - self.filters = { - "filter_zalgo": { - "enabled": Filter.filter_zalgo, - "function": self._has_zalgo, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_zalgo, - "notification_msg": ( - "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " - f"{staff_mistake_str}" - ), - "schedule_deletion": False - }, - "filter_invites": { - "enabled": Filter.filter_invites, - "function": self._has_invites, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_invites, - "notification_msg": ( - f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" - r"Our server rules can be found here: " - ), - "schedule_deletion": False - }, - "filter_domains": { - "enabled": Filter.filter_domains, - "function": self._has_urls, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_domains, - "notification_msg": ( - f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" - ), - "schedule_deletion": False - }, - "watch_regex": { - "enabled": Filter.watch_regex, - "function": self._has_watch_regex_match, - "type": "watchlist", - "content_only": True, - "schedule_deletion": True - }, - "watch_rich_embeds": { - "enabled": Filter.watch_rich_embeds, - "function": self._has_rich_embed, - "type": "watchlist", - "content_only": False, - "schedule_deletion": False - } - } - - self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: - """Fetch items from the filter_list_cache.""" - return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() - - @staticmethod - def _expand_spoilers(text: str) -> str: - """Return a string containing all interpretations of a spoilered message.""" - split_text = SPOILER_RE.split(text) - return ''.join( - split_text[0::2] + split_text[1::2] + split_text - ) - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Invoke message filter for new messages.""" - await self._filter_message(msg) - - # Ignore webhook messages. - if msg.webhook_id is None: - await self.check_bad_words_in_name(msg.author) - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """ - Invoke message filter for message edits. - - If there have been multiple edits, calculate the time delta from the previous edit. - """ - if not before.edited_at: - delta = relativedelta(after.edited_at, before.created_at).microseconds - else: - delta = relativedelta(after.edited_at, before.edited_at).microseconds - await self._filter_message(after, delta) - - def get_name_matches(self, name: str) -> List[re.Match]: - """Check bad words from passed string (name). Return list of matches.""" - matches = [] - watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) - for pattern in watchlist_patterns: - if match := re.search(pattern, name, flags=re.IGNORECASE): - matches.append(match) - return matches - - async def check_send_alert(self, member: Member) -> bool: - """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" - if last_alert := await self.name_alerts.get(member.id): - last_alert = datetime.utcfromtimestamp(last_alert) - if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: - log.trace(f"Last alert was too recent for {member}'s nickname.") - return False - - return True - - async def check_bad_words_in_name(self, member: Member) -> None: - """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" - # Use lock to avoid race conditions - async with self.name_lock: - # Check whether the users display name contains any words in our blacklist - matches = self.get_name_matches(member.display_name) - - if not matches or not await self.check_send_alert(member): - return - - log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") - - log_string = ( - f"**User:** {member.mention} (`{member.id}`)\n" - f"**Display Name:** {member.display_name}\n" - f"**Bad Matches:** {', '.join(match.group() for match in matches)}" - ) - - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colours.soft_red, - title="Username filtering alert", - text=log_string, - channel_id=Channels.mod_alerts, - thumbnail=member.avatar_url - ) - - # Update time when alert sent - await self.name_alerts.set(member.id, datetime.utcnow().timestamp()) - - async def filter_eval(self, result: str, msg: Message) -> bool: - """ - Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. - - Also requires the original message, to check whether to filter and for mod logs. - Returns whether a filter was triggered or not. - """ - filter_triggered = False - # Should we filter this message? - if self._check_filter(msg): - for filter_name, _filter in self.filters.items(): - # Is this specific filter enabled in the config? - # We also do not need to worry about filters that take the full message, - # since all we have is an arbitrary string. - if _filter["enabled"] and _filter["content_only"]: - match = await _filter["function"](result) - - if match: - # If this is a filter (not a watchlist), we set the variable so we know - # that it has been triggered - if _filter["type"] == "filter": - filter_triggered = True - - # We do not have to check against DM channels since !eval cannot be used there. - channel_str = f"in {msg.channel.mention}" - - message_content, additional_embeds, additional_embeds_msg = self._add_stats( - filter_name, match, result - ) - - message = ( - f"The {filter_name} {_filter['type']} was triggered " - f"by **{msg.author}** " - f"(`{msg.author.id}`) {channel_str} using !eval with " - f"[the following message]({msg.jump_url}):\n\n" - f"{message_content}" - ) - - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title=f"{_filter['type'].title()} triggered!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone, - additional_embeds=additional_embeds, - additional_embeds_msg=additional_embeds_msg - ) - - break # We don't want multiple filters to trigger - - return filter_triggered - - async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: - """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" - # Should we filter this message? - if self._check_filter(msg): - for filter_name, _filter in self.filters.items(): - # Is this specific filter enabled in the config? - if _filter["enabled"]: - # Double trigger check for the embeds filter - if filter_name == "watch_rich_embeds": - # If the edit delta is less than 0.001 seconds, then we're probably dealing - # with a double filter trigger. - if delta is not None and delta < 100: - continue - - # Does the filter only need the message content or the full message? - if _filter["content_only"]: - match = await _filter["function"](msg.content) - else: - match = await _filter["function"](msg) - - if match: - is_private = msg.channel.type is discord.ChannelType.private - - # If this is a filter (not a watchlist) and not in a DM, delete the message. - if _filter["type"] == "filter" and not is_private: - try: - # Embeds (can?) trigger both the `on_message` and `on_message_edit` - # event handlers, triggering filtering twice for the same message. - # - # If `on_message`-triggered filtering already deleted the message - # then `on_message_edit`-triggered filtering will raise exception - # since the message no longer exists. - # - # In addition, to avoid sending two notifications to the user, the - # logs, and mod_alert, we return if the message no longer exists. - await msg.delete() - except discord.errors.NotFound: - return - - # Notify the user if the filter specifies - if _filter["user_notification"]: - await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) - - # If the message is classed as offensive, we store it in the site db and - # it will be deleted it after one week. - if _filter["schedule_deletion"] and not is_private: - delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() - data = { - 'id': msg.id, - 'channel_id': msg.channel.id, - 'delete_date': delete_date - } - - await self.bot.api_client.post('bot/offensive-messages', json=data) - self.schedule_msg_delete(data) - log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") - - if is_private: - channel_str = "via DM" - else: - channel_str = f"in {msg.channel.mention}" - - message_content, additional_embeds, additional_embeds_msg = self._add_stats( - filter_name, match, msg.content - ) - - message = ( - f"The {filter_name} {_filter['type']} was triggered " - f"by **{msg.author}** " - f"(`{msg.author.id}`) {channel_str} with [the " - f"following message]({msg.jump_url}):\n\n" - f"{message_content}" - ) - - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title=f"{_filter['type'].title()} triggered!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone if not is_private else False, - additional_embeds=additional_embeds, - additional_embeds_msg=additional_embeds_msg - ) - - break # We don't want multiple filters to trigger - - def _add_stats(self, name: str, match: Union[re.Match, dict, bool, List[discord.Embed]], content: str) -> Tuple[ - str, Optional[List[discord.Embed]], Optional[str] - ]: - """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" - # Word and match stats for watch_regex - if name == "watch_regex": - surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] - message_content = ( - f"**Match:** '{match[0]}'\n" - f"**Location:** '...{escape_markdown(surroundings)}...'\n" - f"\n**Original Message:**\n{escape_markdown(content)}" - ) - else: # Use original content - message_content = content - - additional_embeds = None - additional_embeds_msg = None - - self.bot.stats.incr(f"filters.{name}") - - # The function returns True for invalid invites. - # They have no data so additional embeds can't be created for them. - if name == "filter_invites" and match is not True: - additional_embeds = [] - for _, data in match.items(): - embed = discord.Embed(description=( - f"**Members:**\n{data['members']}\n" - f"**Active:**\n{data['active']}" - )) - embed.set_author(name=data["name"]) - embed.set_thumbnail(url=data["icon"]) - embed.set_footer(text=f"Guild ID: {data['id']}") - additional_embeds.append(embed) - additional_embeds_msg = "For the following guild(s):" - - elif name == "watch_rich_embeds": - additional_embeds = match - additional_embeds_msg = "With the following embed(s):" - - return message_content, additional_embeds, additional_embeds_msg - - @staticmethod - def _check_filter(msg: Message) -> bool: - """Check whitelists to see if we should filter this message.""" - role_whitelisted = False - - if type(msg.author) is Member: # Only Member has roles, not User. - for role in msg.author.roles: - if role.id in Filter.role_whitelist: - role_whitelisted = True - - return ( - msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist - and not role_whitelisted # Role not in whitelist - and not msg.author.bot # Author not a bot - ) - - async def _has_watch_regex_match(self, text: str) -> Union[bool, re.Match]: - """ - Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. - - `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is - matched as-is. Spoilers are expanded, if any, and URLs are ignored. - """ - if SPOILER_RE.search(text): - text = self._expand_spoilers(text) - - # Make sure it's not a URL - if URL_RE.search(text): - return False - - watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) - for pattern in watchlist_patterns: - match = re.search(pattern, text, flags=re.IGNORECASE) - if match: - return match - - async def _has_urls(self, text: str) -> bool: - """Returns True if the text contains one of the blacklisted URLs from the config file.""" - if not URL_RE.search(text): - return False - - text = text.lower() - domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) - - for url in domain_blacklist: - if url.lower() in text: - return True - - return False - - @staticmethod - async def _has_zalgo(text: str) -> bool: - """ - Returns True if the text contains zalgo characters. - - Zalgo range is \u0300 – \u036F and \u0489. - """ - return bool(ZALGO_RE.search(text)) - - async def _has_invites(self, text: str) -> Union[dict, bool]: - """ - Checks if there's any invites in the text content that aren't in the guild whitelist. - - If any are detected, a dictionary of invite data is returned, with a key per invite. - If none are detected, False is returned. - - Attempts to catch some of common ways to try to cheat the system. - """ - # Remove backslashes to prevent escape character aroundfuckery like - # discord\.gg/gdudes-pony-farm - text = text.replace("\\", "") - - invites = INVITE_RE.findall(text) - invite_data = dict() - for invite in invites: - if invite in invite_data: - continue - - response = await self.bot.http_session.get( - f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} - ) - response = await response.json() - guild = response.get("guild") - if guild is None: - # Lack of a "guild" key in the JSON response indicates either an group DM invite, an - # expired invite, or an invalid invite. The API does not currently differentiate - # between invalid and expired invites - return True - - guild_id = guild.get("id") - guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) - guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) - - # Is this invite allowed? - guild_partnered_or_verified = ( - 'PARTNERED' in guild.get("features", []) - or 'VERIFIED' in guild.get("features", []) - ) - invite_not_allowed = ( - guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted. - or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted. - and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered. - ) - - if invite_not_allowed: - guild_icon_hash = guild["icon"] - guild_icon = ( - "https://cdn.discordapp.com/icons/" - f"{guild_id}/{guild_icon_hash}.png?size=512" - ) - - invite_data[invite] = { - "name": guild["name"], - "id": guild['id'], - "icon": guild_icon, - "members": response["approximate_member_count"], - "active": response["approximate_presence_count"] - } - - return invite_data if invite_data else False - - @staticmethod - async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: - """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" - if msg.embeds: - for embed in msg.embeds: - if embed.type == "rich": - urls = URL_RE.findall(msg.content) - if not embed.url or embed.url not in urls: - # If `embed.url` does not exist or if `embed.url` is not part of the content - # of the message, it's unlikely to be an auto-generated embed by Discord. - return msg.embeds - else: - log.trace( - "Found a rich embed sent by a regular user account, " - "but it was likely just an automatic URL embed." - ) - return False - return False - - async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: - """ - Notify filtered_member about a moderation action with the reason str. - - First attempts to DM the user, fall back to in-channel notification if user has DMs disabled - """ - try: - await filtered_member.send(reason) - except discord.errors.Forbidden: - await channel.send(f"{filtered_member.mention} {reason}") - - def schedule_msg_delete(self, msg: dict) -> None: - """Delete an offensive message once its deletion date is reached.""" - delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) - - async def reschedule_offensive_msg_deletion(self) -> None: - """Get all the pending message deletion from the API and reschedule them.""" - await self.bot.wait_until_ready() - response = await self.bot.api_client.get('bot/offensive-messages',) - - now = datetime.utcnow() - - for msg in response: - delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - - if delete_at < now: - await self.delete_offensive_msg(msg) - else: - self.schedule_msg_delete(msg) - - async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: - """Delete an offensive message, and then delete it from the db.""" - try: - channel = self.bot.get_channel(msg['channel_id']) - if channel: - msg_obj = await channel.fetch_message(msg['id']) - await msg_obj.delete() - except NotFound: - log.info( - f"Tried to delete message {msg['id']}, but the message can't be found " - f"(it has been probably already deleted)." - ) - except HTTPException as e: - log.warning(f"Failed to delete message {msg['id']}: status {e.status}") - - await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') - log.info(f"Deleted the offensive message with id {msg['id']}.") - - -def setup(bot: Bot) -> None: - """Load the Filtering cog.""" - bot.add_cog(Filtering(bot)) diff --git a/bot/cogs/filters/__init__.py b/bot/cogs/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/filters/antimalware.py b/bot/cogs/filters/antimalware.py new file mode 100644 index 000000000..c76bd2c60 --- /dev/null +++ b/bot/cogs/filters/antimalware.py @@ -0,0 +1,98 @@ +import logging +import typing as t +from os.path import splitext + +from discord import Embed, Message, NotFound +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, STAFF_ROLES, URLs + +log = logging.getLogger(__name__) + +PY_EMBED_DESCRIPTION = ( + "It looks like you tried to attach a Python file - " + f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" +) + +TXT_EMBED_DESCRIPTION = ( + "**Uh-oh!** It looks like your message got zapped by our spam filter. " + "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" + "• If you attempted to send a message longer than 2000 characters, try shortening your message " + "to fit within the character limit or use a pasting service (see below) \n\n" + "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " + "{cmd_channel_mention} for more information) or use a pasting service like: " + f"\n\n{URLs.site_schema}{URLs.site_paste}" +) + +DISALLOWED_EMBED_DESCRIPTION = ( + "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " + "We currently allow the following file types: **{joined_whitelist}**.\n\n" + "Feel free to ask in {meta_channel_mention} if you think this is a mistake." +) + + +class AntiMalware(Cog): + """Delete messages which contain attachments with non-whitelisted file extensions.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_whitelisted_file_formats(self) -> list: + """Get the file formats currently on the whitelist.""" + return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() + + def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: + """Get an iterable containing all the disallowed extensions of attachments.""" + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) + return extensions_blocked + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Identify messages with prohibited attachments.""" + # Return when message don't have attachment and don't moderate DMs + if not message.attachments or not message.guild: + return + + # Check if user is staff, if is, return + # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance + if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): + return + + embed = Embed() + extensions_blocked = self._get_disallowed_extensions(message) + blocked_extensions_str = ', '.join(extensions_blocked) + if ".py" in extensions_blocked: + # Short-circuit on *.py files to provide a pastebin link + embed.description = PY_EMBED_DESCRIPTION + elif ".txt" in extensions_blocked: + # Work around Discord AutoConversion of messages longer than 2000 chars to .txt + cmd_channel = self.bot.get_channel(Channels.bot_commands) + embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) + elif extensions_blocked: + meta_channel = self.bot.get_channel(Channels.meta) + embed.description = DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=', '.join(self._get_whitelisted_file_formats()), + blocked_extensions_str=blocked_extensions_str, + meta_channel_mention=meta_channel.mention, + ) + + if embed.description: + log.info( + f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", + extra={"attachment_list": [attachment.filename for attachment in message.attachments]} + ) + + await message.channel.send(f"Hey {message.author.mention}!", embed=embed) + + # Delete the offending message: + try: + await message.delete() + except NotFound: + log.info(f"Tried to delete message `{message.id}`, but message could not be found.") + + +def setup(bot: Bot) -> None: + """Load the AntiMalware cog.""" + bot.add_cog(AntiMalware(bot)) diff --git a/bot/cogs/filters/antispam.py b/bot/cogs/filters/antispam.py new file mode 100644 index 000000000..0bcca578d --- /dev/null +++ b/bot/cogs/filters/antispam.py @@ -0,0 +1,288 @@ +import asyncio +import logging +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from operator import itemgetter +from typing import Dict, Iterable, List, Set + +from discord import Colour, Member, Message, NotFound, Object, TextChannel +from discord.ext.commands import Cog + +from bot import rules +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import ( + AntiSpam as AntiSpamConfig, Channels, + Colours, DEBUG_MODE, Event, Filter, + Guild as GuildConfig, Icons, + STAFF_ROLES, +) +from bot.converters import Duration +from bot.utils.messages import send_attachments + + +log = logging.getLogger(__name__) + +RULE_FUNCTION_MAPPING = { + 'attachments': rules.apply_attachments, + 'burst': rules.apply_burst, + 'burst_shared': rules.apply_burst_shared, + 'chars': rules.apply_chars, + 'discord_emojis': rules.apply_discord_emojis, + 'duplicates': rules.apply_duplicates, + 'links': rules.apply_links, + 'mentions': rules.apply_mentions, + 'newlines': rules.apply_newlines, + 'role_mentions': rules.apply_role_mentions +} + + +@dataclass +class DeletionContext: + """Represents a Deletion Context for a single spam event.""" + + channel: TextChannel + members: Dict[int, Member] = field(default_factory=dict) + rules: Set[str] = field(default_factory=set) + messages: Dict[int, Message] = field(default_factory=dict) + attachments: List[List[str]] = field(default_factory=list) + + async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: + """Adds new rule violation events to the deletion context.""" + self.rules.add(rule_name) + + for member in members: + if member.id not in self.members: + self.members[member.id] = member + + for message in messages: + if message.id not in self.messages: + self.messages[message.id] = message + + # Re-upload attachments + destination = message.guild.get_channel(Channels.attachment_log) + urls = await send_attachments(message, destination, link_large=False) + self.attachments.append(urls) + + async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: + """Method that takes care of uploading the queue and posting modlog alert.""" + triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values()) + + mod_alert_message = ( + f"**Triggered by:** {triggered_by_users}\n" + f"**Channel:** {self.channel.mention}\n" + f"**Rules:** {', '.join(rule for rule in self.rules)}\n" + ) + + # For multiple messages or those with excessive newlines, use the logs API + if len(self.messages) > 1 or 'newlines' in self.rules: + url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) + mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" + else: + mod_alert_message += "Message:\n" + [message] = self.messages.values() + content = message.clean_content + remaining_chars = 2040 - len(mod_alert_message) + + if len(content) > remaining_chars: + content = content[:remaining_chars] + "..." + + mod_alert_message += f"{content}" + + *_, last_message = self.messages.values() + await modlog.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title="Spam detected!", + text=mod_alert_message, + thumbnail=last_message.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=AntiSpamConfig.ping_everyone + ) + + +class AntiSpam(Cog): + """Cog that controls our anti-spam measures.""" + + def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: + self.bot = bot + self.validation_errors = validation_errors + role_id = AntiSpamConfig.punishment['role_id'] + self.muted_role = Object(role_id) + self.expiration_date_converter = Duration() + + self.message_deletion_queue = dict() + + self.bot.loop.create_task(self.alert_on_validation_error()) + + @property + def mod_log(self) -> ModLog: + """Allows for easy access of the ModLog cog.""" + return self.bot.get_cog("ModLog") + + async def alert_on_validation_error(self) -> None: + """Unloads the cog and alerts admins if configuration validation failed.""" + await self.bot.wait_until_guild_available() + if self.validation_errors: + body = "**The following errors were encountered:**\n" + body += "\n".join(f"- {error}" for error in self.validation_errors.values()) + body += "\n\n**The cog has been unloaded.**" + + await self.mod_log.send_log_message( + title="Error: AntiSpam configuration validation failed!", + text=body, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Colour.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Applies the antispam rules to each received message.""" + if ( + not message.guild + or message.guild.id != GuildConfig.id + or message.author.bot + or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) + or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) + ): + return + + # Fetch the rule configuration with the highest rule interval. + max_interval_config = max( + AntiSpamConfig.rules.values(), + key=itemgetter('interval') + ) + max_interval = max_interval_config['interval'] + + # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. + earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) + relevant_messages = [ + msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) + if not msg.author.bot + ] + + for rule_name in AntiSpamConfig.rules: + rule_config = AntiSpamConfig.rules[rule_name] + rule_function = RULE_FUNCTION_MAPPING[rule_name] + + # Create a list of messages that were sent in the interval that the rule cares about. + latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) + messages_for_rule = [ + msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp + ] + result = await rule_function(message, messages_for_rule, rule_config) + + # If the rule returns `None`, that means the message didn't violate it. + # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` + # which contains the reason for why the message violated the rule and + # an iterable of all members that violated the rule. + if result is not None: + self.bot.stats.incr(f"mod_alerts.{rule_name}") + reason, members, relevant_messages = result + full_reason = f"`{rule_name}` rule: {reason}" + + # If there's no spam event going on for this channel, start a new Message Deletion Context + channel = message.channel + if channel.id not in self.message_deletion_queue: + log.trace(f"Creating queue for channel `{channel.id}`") + self.message_deletion_queue[message.channel.id] = DeletionContext(channel) + self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) + + # Add the relevant of this trigger to the Deletion Context + await self.message_deletion_queue[message.channel.id].add( + rule_name=rule_name, + members=members, + messages=relevant_messages + ) + + for member in members: + + # Fire it off as a background task to ensure + # that the sleep doesn't block further tasks + self.bot.loop.create_task( + self.punish(message, member, full_reason) + ) + + await self.maybe_delete_messages(channel, relevant_messages) + break + + async def punish(self, msg: Message, member: Member, reason: str) -> None: + """Punishes the given member for triggering an antispam rule.""" + if not any(role.id == self.muted_role.id for role in member.roles): + remove_role_after = AntiSpamConfig.punishment['remove_after'] + + # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes + context = await self.bot.get_context(msg) + context.author = self.bot.user + context.message.author = self.bot.user + + # Since we're going to invoke the tempmute command directly, we need to manually call the converter. + dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S") + await context.invoke( + self.bot.get_command('tempmute'), + member, + dt_remove_role_after, + reason=reason + ) + + async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: + """Cleans the messages if cleaning is configured.""" + if AntiSpamConfig.clean_offending: + # If we have more than one message, we can use bulk delete. + if len(messages) > 1: + message_ids = [message.id for message in messages] + self.mod_log.ignore(Event.message_delete, *message_ids) + await channel.delete_messages(messages) + + # Otherwise, the bulk delete endpoint will throw up. + # Delete the message directly instead. + else: + self.mod_log.ignore(Event.message_delete, messages[0].id) + try: + await messages[0].delete() + except NotFound: + log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") + + async def _process_deletion_context(self, context_id: int) -> None: + """Processes the Deletion Context queue.""" + log.trace("Sleeping before processing message deletion queue.") + await asyncio.sleep(10) + + if context_id not in self.message_deletion_queue: + log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") + return + + deletion_context = self.message_deletion_queue.pop(context_id) + await deletion_context.upload_messages(self.bot.user.id, self.mod_log) + + +def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: + """Validates the antispam configs.""" + validation_errors = {} + for name, config in rules_.items(): + if name not in RULE_FUNCTION_MAPPING: + log.error( + f"Unrecognized antispam rule `{name}`. " + f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" + ) + validation_errors[name] = f"`{name}` is not recognized as an antispam rule." + continue + for required_key in ('interval', 'max'): + if required_key not in config: + log.error( + f"`{required_key}` is required but was not " + f"set in rule `{name}`'s configuration." + ) + validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" + return validation_errors + + +def setup(bot: Bot) -> None: + """Validate the AntiSpam configs and load the AntiSpam cog.""" + validation_errors = validate_config() + bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/cogs/filters/filter_lists.py b/bot/cogs/filters/filter_lists.py new file mode 100644 index 000000000..c15adc461 --- /dev/null +++ b/bot/cogs/filters/filter_lists.py @@ -0,0 +1,273 @@ +import logging +from typing import Optional + +from discord import Colour, Embed +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.converters import ValidDiscordServerInvite, ValidFilterListType +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + + +class FilterLists(Cog): + """Commands for blacklisting and whitelisting things.""" + + methods_with_filterlist_types = [ + "allow_add", + "allow_delete", + "allow_get", + "deny_add", + "deny_delete", + "deny_get", + ] + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.bot.loop.create_task(self._amend_docstrings()) + + async def _amend_docstrings(self) -> None: + """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" + await self.bot.wait_until_guild_available() + + # Add valid filterlist types to the docstrings + valid_types = await ValidFilterListType.get_valid_types(self.bot) + valid_types = [f"`{type_.lower()}`" for type_ in valid_types] + + for method_name in self.methods_with_filterlist_types: + command = getattr(self, method_name) + command.help = ( + f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." + ) + + async def _add_data( + self, + ctx: Context, + allowed: bool, + list_type: ValidFilterListType, + content: str, + comment: Optional[str] = None, + ) -> None: + """Add an item to a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we gotta validate it. + if list_type == "GUILD_INVITE": + guild_data = await self._validate_guild_invite(ctx, content) + content = guild_data.get("id") + + # Unless the user has specified another comment, let's + # use the server name as the comment so that the list + # of guild IDs will be more easily readable when we + # display it. + if not comment: + comment = guild_data.get("name") + + # If it's a file format, let's make sure it has a leading dot. + elif list_type == "FILE_FORMAT" and not content.startswith("."): + content = f".{content}" + + # Try to add the item to the database + log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") + payload = { + "allowed": allowed, + "type": list_type, + "content": content, + "comment": comment, + } + + try: + item = await self.bot.api_client.post( + "bot/filter-lists", + json=payload + ) + except ResponseCodeError as e: + if e.status == 400: + await ctx.message.add_reaction("❌") + log.debug( + f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " + "probably because the request violated the UniqueConstraint." + ) + raise BadArgument( + f"Unable to add the item to the {allow_type}. " + "The item probably already exists. Keep in mind that a " + "blacklist and a whitelist for the same item cannot co-exist, " + "and we do not permit any duplicates." + ) + raise + + # Insert the item into the cache + self.bot.insert_item_into_filter_list_cache(item) + await ctx.message.add_reaction("✅") + + async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we need to convert it. + if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): + guild_data = await self._validate_guild_invite(ctx, content) + content = guild_data.get("id") + + # If it's a file format, let's make sure it has a leading dot. + elif list_type == "FILE_FORMAT" and not content.startswith("."): + content = f".{content}" + + # Find the content and delete it. + log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") + item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) + + if item is not None: + try: + await self.bot.api_client.delete( + f"bot/filter-lists/{item['id']}" + ) + del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] + await ctx.message.add_reaction("✅") + except ResponseCodeError as e: + log.debug( + f"{ctx.author} tried to delete an item with the id {item['id']}, but " + f"the API raised an unexpected error: {e}" + ) + await ctx.message.add_reaction("❌") + else: + await ctx.message.add_reaction("❌") + + async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: + """Paginate and display all items in a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] + + # Build a list of lines we want to show in the paginator + lines = [] + for content, metadata in result.items(): + line = f"• `{content}`" + + if comment := metadata.get("comment"): + line += f" - {comment}" + + lines.append(line) + lines = sorted(lines) + + # Build the embed + list_type_plural = list_type.lower().replace("_", " ").title() + "s" + embed = Embed( + title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", + colour=Colour.blue() + ) + log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") + + if result: + await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + await ctx.message.add_reaction("❌") + + async def _sync_data(self, ctx: Context) -> None: + """Syncs the filterlists with the API.""" + try: + log.trace("Attempting to sync FilterList cache with data from the API.") + await self.bot.cache_filter_list_data() + await ctx.message.add_reaction("✅") + except ResponseCodeError as e: + log.debug( + f"{ctx.author} tried to sync FilterList cache data but " + f"the API raised an unexpected error: {e}" + ) + await ctx.message.add_reaction("❌") + + @staticmethod + async def _validate_guild_invite(ctx: Context, invite: str) -> dict: + """ + Validates a guild invite, and returns the guild info as a dict. + + Will raise a BadArgument if the guild invite is invalid. + """ + log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") + validator = ValidDiscordServerInvite() + guild_data = await validator.convert(ctx, invite) + + # If we make it this far without raising a BadArgument, the invite is + # valid. Let's return a dict of guild information. + log.trace(f"{invite} validated as server invite. Converting to ID.") + return guild_data + + @group(aliases=("allowlist", "allow", "al", "wl")) + async def whitelist(self, ctx: Context) -> None: + """Group for whitelisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @group(aliases=("denylist", "deny", "bl", "dl")) + async def blacklist(self, ctx: Context) -> None: + """Group for blacklisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @whitelist.command(name="add", aliases=("a", "set")) + async def allow_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified allowlist.""" + await self._add_data(ctx, True, list_type, content, comment) + + @blacklist.command(name="add", aliases=("a", "set")) + async def deny_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified denylist.""" + await self._add_data(ctx, False, list_type, content, comment) + + @whitelist.command(name="remove", aliases=("delete", "rm",)) + async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified allowlist.""" + await self._delete_data(ctx, True, list_type, content) + + @blacklist.command(name="remove", aliases=("delete", "rm",)) + async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified denylist.""" + await self._delete_data(ctx, False, list_type, content) + + @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified allowlist.""" + await self._list_all_data(ctx, True, list_type) + + @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified denylist.""" + await self._list_all_data(ctx, False, list_type) + + @whitelist.command(name="sync", aliases=("s",)) + async def allow_sync(self, ctx: Context) -> None: + """Syncs both allowlists and denylists with the API.""" + await self._sync_data(ctx) + + @blacklist.command(name="sync", aliases=("s",)) + async def deny_sync(self, ctx: Context) -> None: + """Syncs both allowlists and denylists with the API.""" + await self._sync_data(ctx) + + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the FilterLists cog.""" + bot.add_cog(FilterLists(bot)) diff --git a/bot/cogs/filters/filtering.py b/bot/cogs/filters/filtering.py new file mode 100644 index 000000000..93cc1c655 --- /dev/null +++ b/bot/cogs/filters/filtering.py @@ -0,0 +1,575 @@ +import asyncio +import logging +import re +from datetime import datetime, timedelta +from typing import List, Mapping, Optional, Tuple, Union + +import dateutil +import discord.errors +from dateutil.relativedelta import relativedelta +from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel +from discord.ext.commands import Cog +from discord.utils import escape_markdown + +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import ( + Channels, Colours, + Filter, Icons, URLs +) +from bot.utils.redis_cache import RedisCache +from bot.utils.regex import INVITE_RE +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + +# Regular expressions +SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) +URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) +ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") + +# Other constants. +DAYS_BETWEEN_ALERTS = 3 +OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) + + +class Filtering(Cog): + """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" + + # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent + name_alerts = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + self.name_lock = asyncio.Lock() + + staff_mistake_str = "If you believe this was a mistake, please let staff know!" + self.filters = { + "filter_zalgo": { + "enabled": Filter.filter_zalgo, + "function": self._has_zalgo, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_zalgo, + "notification_msg": ( + "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " + f"{staff_mistake_str}" + ), + "schedule_deletion": False + }, + "filter_invites": { + "enabled": Filter.filter_invites, + "function": self._has_invites, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_invites, + "notification_msg": ( + f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" + r"Our server rules can be found here: " + ), + "schedule_deletion": False + }, + "filter_domains": { + "enabled": Filter.filter_domains, + "function": self._has_urls, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_domains, + "notification_msg": ( + f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" + ), + "schedule_deletion": False + }, + "watch_regex": { + "enabled": Filter.watch_regex, + "function": self._has_watch_regex_match, + "type": "watchlist", + "content_only": True, + "schedule_deletion": True + }, + "watch_rich_embeds": { + "enabled": Filter.watch_rich_embeds, + "function": self._has_rich_embed, + "type": "watchlist", + "content_only": False, + "schedule_deletion": False + } + } + + self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: + """Fetch items from the filter_list_cache.""" + return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() + + @staticmethod + def _expand_spoilers(text: str) -> str: + """Return a string containing all interpretations of a spoilered message.""" + split_text = SPOILER_RE.split(text) + return ''.join( + split_text[0::2] + split_text[1::2] + split_text + ) + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Invoke message filter for new messages.""" + await self._filter_message(msg) + + # Ignore webhook messages. + if msg.webhook_id is None: + await self.check_bad_words_in_name(msg.author) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Invoke message filter for message edits. + + If there have been multiple edits, calculate the time delta from the previous edit. + """ + if not before.edited_at: + delta = relativedelta(after.edited_at, before.created_at).microseconds + else: + delta = relativedelta(after.edited_at, before.edited_at).microseconds + await self._filter_message(after, delta) + + def get_name_matches(self, name: str) -> List[re.Match]: + """Check bad words from passed string (name). Return list of matches.""" + matches = [] + watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) + for pattern in watchlist_patterns: + if match := re.search(pattern, name, flags=re.IGNORECASE): + matches.append(match) + return matches + + async def check_send_alert(self, member: Member) -> bool: + """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" + if last_alert := await self.name_alerts.get(member.id): + last_alert = datetime.utcfromtimestamp(last_alert) + if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: + log.trace(f"Last alert was too recent for {member}'s nickname.") + return False + + return True + + async def check_bad_words_in_name(self, member: Member) -> None: + """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" + # Use lock to avoid race conditions + async with self.name_lock: + # Check whether the users display name contains any words in our blacklist + matches = self.get_name_matches(member.display_name) + + if not matches or not await self.check_send_alert(member): + return + + log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") + + log_string = ( + f"**User:** {member.mention} (`{member.id}`)\n" + f"**Display Name:** {member.display_name}\n" + f"**Bad Matches:** {', '.join(match.group() for match in matches)}" + ) + + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colours.soft_red, + title="Username filtering alert", + text=log_string, + channel_id=Channels.mod_alerts, + thumbnail=member.avatar_url + ) + + # Update time when alert sent + await self.name_alerts.set(member.id, datetime.utcnow().timestamp()) + + async def filter_eval(self, result: str, msg: Message) -> bool: + """ + Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. + + Also requires the original message, to check whether to filter and for mod logs. + Returns whether a filter was triggered or not. + """ + filter_triggered = False + # Should we filter this message? + if self._check_filter(msg): + for filter_name, _filter in self.filters.items(): + # Is this specific filter enabled in the config? + # We also do not need to worry about filters that take the full message, + # since all we have is an arbitrary string. + if _filter["enabled"] and _filter["content_only"]: + match = await _filter["function"](result) + + if match: + # If this is a filter (not a watchlist), we set the variable so we know + # that it has been triggered + if _filter["type"] == "filter": + filter_triggered = True + + # We do not have to check against DM channels since !eval cannot be used there. + channel_str = f"in {msg.channel.mention}" + + message_content, additional_embeds, additional_embeds_msg = self._add_stats( + filter_name, match, result + ) + + message = ( + f"The {filter_name} {_filter['type']} was triggered " + f"by **{msg.author}** " + f"(`{msg.author.id}`) {channel_str} using !eval with " + f"[the following message]({msg.jump_url}):\n\n" + f"{message_content}" + ) + + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title=f"{_filter['type'].title()} triggered!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=Filter.ping_everyone, + additional_embeds=additional_embeds, + additional_embeds_msg=additional_embeds_msg + ) + + break # We don't want multiple filters to trigger + + return filter_triggered + + async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: + """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" + # Should we filter this message? + if self._check_filter(msg): + for filter_name, _filter in self.filters.items(): + # Is this specific filter enabled in the config? + if _filter["enabled"]: + # Double trigger check for the embeds filter + if filter_name == "watch_rich_embeds": + # If the edit delta is less than 0.001 seconds, then we're probably dealing + # with a double filter trigger. + if delta is not None and delta < 100: + continue + + # Does the filter only need the message content or the full message? + if _filter["content_only"]: + match = await _filter["function"](msg.content) + else: + match = await _filter["function"](msg) + + if match: + is_private = msg.channel.type is discord.ChannelType.private + + # If this is a filter (not a watchlist) and not in a DM, delete the message. + if _filter["type"] == "filter" and not is_private: + try: + # Embeds (can?) trigger both the `on_message` and `on_message_edit` + # event handlers, triggering filtering twice for the same message. + # + # If `on_message`-triggered filtering already deleted the message + # then `on_message_edit`-triggered filtering will raise exception + # since the message no longer exists. + # + # In addition, to avoid sending two notifications to the user, the + # logs, and mod_alert, we return if the message no longer exists. + await msg.delete() + except discord.errors.NotFound: + return + + # Notify the user if the filter specifies + if _filter["user_notification"]: + await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) + + # If the message is classed as offensive, we store it in the site db and + # it will be deleted it after one week. + if _filter["schedule_deletion"] and not is_private: + delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() + data = { + 'id': msg.id, + 'channel_id': msg.channel.id, + 'delete_date': delete_date + } + + await self.bot.api_client.post('bot/offensive-messages', json=data) + self.schedule_msg_delete(data) + log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") + + if is_private: + channel_str = "via DM" + else: + channel_str = f"in {msg.channel.mention}" + + message_content, additional_embeds, additional_embeds_msg = self._add_stats( + filter_name, match, msg.content + ) + + message = ( + f"The {filter_name} {_filter['type']} was triggered " + f"by **{msg.author}** " + f"(`{msg.author.id}`) {channel_str} with [the " + f"following message]({msg.jump_url}):\n\n" + f"{message_content}" + ) + + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title=f"{_filter['type'].title()} triggered!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=Filter.ping_everyone if not is_private else False, + additional_embeds=additional_embeds, + additional_embeds_msg=additional_embeds_msg + ) + + break # We don't want multiple filters to trigger + + def _add_stats(self, name: str, match: Union[re.Match, dict, bool, List[discord.Embed]], content: str) -> Tuple[ + str, Optional[List[discord.Embed]], Optional[str] + ]: + """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" + # Word and match stats for watch_regex + if name == "watch_regex": + surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] + message_content = ( + f"**Match:** '{match[0]}'\n" + f"**Location:** '...{escape_markdown(surroundings)}...'\n" + f"\n**Original Message:**\n{escape_markdown(content)}" + ) + else: # Use original content + message_content = content + + additional_embeds = None + additional_embeds_msg = None + + self.bot.stats.incr(f"filters.{name}") + + # The function returns True for invalid invites. + # They have no data so additional embeds can't be created for them. + if name == "filter_invites" and match is not True: + additional_embeds = [] + for _, data in match.items(): + embed = discord.Embed(description=( + f"**Members:**\n{data['members']}\n" + f"**Active:**\n{data['active']}" + )) + embed.set_author(name=data["name"]) + embed.set_thumbnail(url=data["icon"]) + embed.set_footer(text=f"Guild ID: {data['id']}") + additional_embeds.append(embed) + additional_embeds_msg = "For the following guild(s):" + + elif name == "watch_rich_embeds": + additional_embeds = match + additional_embeds_msg = "With the following embed(s):" + + return message_content, additional_embeds, additional_embeds_msg + + @staticmethod + def _check_filter(msg: Message) -> bool: + """Check whitelists to see if we should filter this message.""" + role_whitelisted = False + + if type(msg.author) is Member: # Only Member has roles, not User. + for role in msg.author.roles: + if role.id in Filter.role_whitelist: + role_whitelisted = True + + return ( + msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist + and not role_whitelisted # Role not in whitelist + and not msg.author.bot # Author not a bot + ) + + async def _has_watch_regex_match(self, text: str) -> Union[bool, re.Match]: + """ + Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. + + `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is + matched as-is. Spoilers are expanded, if any, and URLs are ignored. + """ + if SPOILER_RE.search(text): + text = self._expand_spoilers(text) + + # Make sure it's not a URL + if URL_RE.search(text): + return False + + watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) + for pattern in watchlist_patterns: + match = re.search(pattern, text, flags=re.IGNORECASE) + if match: + return match + + async def _has_urls(self, text: str) -> bool: + """Returns True if the text contains one of the blacklisted URLs from the config file.""" + if not URL_RE.search(text): + return False + + text = text.lower() + domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) + + for url in domain_blacklist: + if url.lower() in text: + return True + + return False + + @staticmethod + async def _has_zalgo(text: str) -> bool: + """ + Returns True if the text contains zalgo characters. + + Zalgo range is \u0300 – \u036F and \u0489. + """ + return bool(ZALGO_RE.search(text)) + + async def _has_invites(self, text: str) -> Union[dict, bool]: + """ + Checks if there's any invites in the text content that aren't in the guild whitelist. + + If any are detected, a dictionary of invite data is returned, with a key per invite. + If none are detected, False is returned. + + Attempts to catch some of common ways to try to cheat the system. + """ + # Remove backslashes to prevent escape character aroundfuckery like + # discord\.gg/gdudes-pony-farm + text = text.replace("\\", "") + + invites = INVITE_RE.findall(text) + invite_data = dict() + for invite in invites: + if invite in invite_data: + continue + + response = await self.bot.http_session.get( + f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} + ) + response = await response.json() + guild = response.get("guild") + if guild is None: + # Lack of a "guild" key in the JSON response indicates either an group DM invite, an + # expired invite, or an invalid invite. The API does not currently differentiate + # between invalid and expired invites + return True + + guild_id = guild.get("id") + guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) + guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) + + # Is this invite allowed? + guild_partnered_or_verified = ( + 'PARTNERED' in guild.get("features", []) + or 'VERIFIED' in guild.get("features", []) + ) + invite_not_allowed = ( + guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted. + or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted. + and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered. + ) + + if invite_not_allowed: + guild_icon_hash = guild["icon"] + guild_icon = ( + "https://cdn.discordapp.com/icons/" + f"{guild_id}/{guild_icon_hash}.png?size=512" + ) + + invite_data[invite] = { + "name": guild["name"], + "id": guild['id'], + "icon": guild_icon, + "members": response["approximate_member_count"], + "active": response["approximate_presence_count"] + } + + return invite_data if invite_data else False + + @staticmethod + async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: + """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" + if msg.embeds: + for embed in msg.embeds: + if embed.type == "rich": + urls = URL_RE.findall(msg.content) + if not embed.url or embed.url not in urls: + # If `embed.url` does not exist or if `embed.url` is not part of the content + # of the message, it's unlikely to be an auto-generated embed by Discord. + return msg.embeds + else: + log.trace( + "Found a rich embed sent by a regular user account, " + "but it was likely just an automatic URL embed." + ) + return False + return False + + async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: + """ + Notify filtered_member about a moderation action with the reason str. + + First attempts to DM the user, fall back to in-channel notification if user has DMs disabled + """ + try: + await filtered_member.send(reason) + except discord.errors.Forbidden: + await channel.send(f"{filtered_member.mention} {reason}") + + def schedule_msg_delete(self, msg: dict) -> None: + """Delete an offensive message once its deletion date is reached.""" + delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) + self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) + + async def reschedule_offensive_msg_deletion(self) -> None: + """Get all the pending message deletion from the API and reschedule them.""" + await self.bot.wait_until_ready() + response = await self.bot.api_client.get('bot/offensive-messages',) + + now = datetime.utcnow() + + for msg in response: + delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) + + if delete_at < now: + await self.delete_offensive_msg(msg) + else: + self.schedule_msg_delete(msg) + + async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: + """Delete an offensive message, and then delete it from the db.""" + try: + channel = self.bot.get_channel(msg['channel_id']) + if channel: + msg_obj = await channel.fetch_message(msg['id']) + await msg_obj.delete() + except NotFound: + log.info( + f"Tried to delete message {msg['id']}, but the message can't be found " + f"(it has been probably already deleted)." + ) + except HTTPException as e: + log.warning(f"Failed to delete message {msg['id']}: status {e.status}") + + await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') + log.info(f"Deleted the offensive message with id {msg['id']}.") + + +def setup(bot: Bot) -> None: + """Load the Filtering cog.""" + bot.add_cog(Filtering(bot)) diff --git a/bot/cogs/filters/security.py b/bot/cogs/filters/security.py new file mode 100644 index 000000000..c680c5e27 --- /dev/null +++ b/bot/cogs/filters/security.py @@ -0,0 +1,31 @@ +import logging + +from discord.ext.commands import Cog, Context, NoPrivateMessage + +from bot.bot import Bot + +log = logging.getLogger(__name__) + + +class Security(Cog): + """Security-related helpers.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.bot.check(self.check_not_bot) # Global commands check - no bots can run any commands at all + self.bot.check(self.check_on_guild) # Global commands check - commands can't be run in a DM + + def check_not_bot(self, ctx: Context) -> bool: + """Check if the context is a bot user.""" + return not ctx.author.bot + + def check_on_guild(self, ctx: Context) -> bool: + """Check if the context is in a guild.""" + if ctx.guild is None: + raise NoPrivateMessage("This command cannot be used in private messages.") + return True + + +def setup(bot: Bot) -> None: + """Load the Security cog.""" + bot.add_cog(Security(bot)) diff --git a/bot/cogs/filters/token_remover.py b/bot/cogs/filters/token_remover.py new file mode 100644 index 000000000..ef979f222 --- /dev/null +++ b/bot/cogs/filters/token_remover.py @@ -0,0 +1,182 @@ +import base64 +import binascii +import logging +import re +import typing as t + +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog + +from bot import utils +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import Channels, Colours, Event, Icons + +log = logging.getLogger(__name__) + +LOG_MESSAGE = ( + "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, " + "token was `{user_id}.{timestamp}.{hmac}`" +) +DELETION_MESSAGE_TEMPLATE = ( + "Hey {mention}! I noticed you posted a seemingly valid Discord API " + "token in your message and have removed your message. " + "This means that your token has been **compromised**. " + "Please change your token **immediately** at: " + "\n\n" + "Feel free to re-post it with the token removed. " + "If you believe this was a mistake, please let us know!" +) +DISCORD_EPOCH = 1_420_070_400 +TOKEN_EPOCH = 1_293_840_000 + +# Three parts delimited by dots: user ID, creation timestamp, HMAC. +# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string. +# Each part only matches base64 URL-safe characters. +# Padding has never been observed, but the padding character '=' is matched just in case. +TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII) + + +class Token(t.NamedTuple): + """A Discord Bot token.""" + + user_id: str + timestamp: str + hmac: str + + +class TokenRemover(Cog): + """Scans messages for potential discord.py bot tokens and removes them.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """ + Check each message for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return + + found_token = self.find_token_in_message(msg) + if found_token: + await self.take_action(msg, found_token) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + await self.on_message(after) + + async def take_action(self, msg: Message, found_token: Token) -> None: + """Remove the `msg` containing the `found_token` and send a mod log message.""" + self.mod_log.ignore(Event.message_delete, msg.id) + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") + return + + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + log_message = self.format_log_message(msg, found_token) + log.debug(log_message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=log_message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) + + self.bot.stats.incr("tokens.removed_tokens") + + @staticmethod + def format_log_message(msg: Message, token: Token) -> str: + """Return the log message to send for `token` being censored in `msg`.""" + return LOG_MESSAGE.format( + author=msg.author, + author_id=msg.author.id, + channel=msg.channel.mention, + user_id=token.user_id, + timestamp=token.timestamp, + hmac='x' * len(token.hmac), + ) + + @classmethod + def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: + """Return a seemingly valid token found in `msg` or `None` if no token is found.""" + # Use finditer rather than search to guard against method calls prematurely returning the + # token check (e.g. `message.channel.send` also matches our token pattern) + for match in TOKEN_RE.finditer(msg.content): + token = Token(*match.groups()) + if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp): + # Short-circuit on first match + return token + + # No matching substring + return + + @staticmethod + def is_valid_user_id(b64_content: str) -> bool: + """ + Check potential token to see if it contains a valid Discord user ID. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + b64_content = utils.pad_base64(b64_content) + + try: + decoded_bytes = base64.urlsafe_b64decode(b64_content) + string = decoded_bytes.decode('utf-8') + + # isdigit on its own would match a lot of other Unicode characters, hence the isascii. + return string.isascii() and string.isdigit() + except (binascii.Error, ValueError): + return False + + @staticmethod + def is_valid_timestamp(b64_content: str) -> bool: + """ + Return True if `b64_content` decodes to a valid timestamp. + + If the timestamp is greater than the Discord epoch, it's probably valid. + See: https://i.imgur.com/7WdehGn.png + """ + b64_content = utils.pad_base64(b64_content) + + try: + decoded_bytes = base64.urlsafe_b64decode(b64_content) + timestamp = int.from_bytes(decoded_bytes, byteorder="big") + except (binascii.Error, ValueError) as e: + log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") + return False + + # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound + # is not checked. + if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH: + return True + else: + log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch") + return False + + +def setup(bot: Bot) -> None: + """Load the TokenRemover cog.""" + bot.add_cog(TokenRemover(bot)) diff --git a/bot/cogs/filters/webhook_remover.py b/bot/cogs/filters/webhook_remover.py new file mode 100644 index 000000000..5812da87c --- /dev/null +++ b/bot/cogs/filters/webhook_remover.py @@ -0,0 +1,84 @@ +import logging +import re + +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.cogs.moderation.modlog import ModLog +from bot.constants import Channels, Colours, Event, Icons + +WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) + +ALERT_MESSAGE_TEMPLATE = ( + "{user}, looks like you posted a Discord webhook URL. Therefore, your " + "message has been removed. Your webhook may have been **compromised** so " + "please re-create the webhook **immediately**. If you believe this was " + "mistake, please let us know." +) + +log = logging.getLogger(__name__) + + +class WebhookRemover(Cog): + """Scan messages to detect Discord webhooks links.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get current instance of `ModLog`.""" + return self.bot.get_cog("ModLog") + + async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: + """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" + # Don't log this, due internal delete, not by user. Will make different entry. + self.mod_log.ignore(Event.message_delete, msg.id) + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") + return + + await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) + + message = ( + f"{msg.author} (`{msg.author.id}`) posted a Discord webhook URL " + f"to #{msg.channel}. Webhook URL was `{redacted_url}`" + ) + log.debug(message) + + # Send entry to moderation alerts. + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Discord webhook URL removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts + ) + + self.bot.stats.incr("tokens.removed_webhooks") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Check if a Discord webhook URL is in `message`.""" + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return + + matches = WEBHOOK_URL_RE.search(msg.content) + if matches: + await self.delete_and_respond(msg, matches[1] + "xxx") + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """Check if a Discord webhook URL is in the edited message `after`.""" + await self.on_message(after) + + +def setup(bot: Bot) -> None: + """Load `WebhookRemover` cog.""" + bot.add_cog(WebhookRemover(bot)) diff --git a/bot/cogs/help.py b/bot/cogs/help.py deleted file mode 100644 index 3d1d6fd10..000000000 --- a/bot/cogs/help.py +++ /dev/null @@ -1,375 +0,0 @@ -import itertools -import logging -from asyncio import TimeoutError -from collections import namedtuple -from contextlib import suppress -from typing import List, Union - -from discord import Colour, Embed, Member, Message, NotFound, Reaction, User -from discord.ext.commands import Bot, Cog, Command, Context, Group, HelpCommand -from fuzzywuzzy import fuzz, process -from fuzzywuzzy.utils import full_process - -from bot import constants -from bot.constants import Channels, Emojis, STAFF_ROLES -from bot.decorators import redirect_output -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -COMMANDS_PER_PAGE = 8 -DELETE_EMOJI = Emojis.trashcan -PREFIX = constants.Bot.prefix - -Category = namedtuple("Category", ["name", "description", "cogs"]) - - -async def help_cleanup(bot: Bot, author: Member, message: Message) -> None: - """ - Runs the cleanup for the help command. - - Adds the :trashcan: reaction that, when clicked, will delete the help message. - After a 300 second timeout, the reaction will be removed. - """ - def check(reaction: Reaction, user: User) -> bool: - """Checks the reaction is :trashcan:, the author is original author and messages are the same.""" - return str(reaction) == DELETE_EMOJI and user.id == author.id and reaction.message.id == message.id - - await message.add_reaction(DELETE_EMOJI) - - with suppress(NotFound): - try: - await bot.wait_for("reaction_add", check=check, timeout=300) - await message.delete() - except TimeoutError: - await message.remove_reaction(DELETE_EMOJI, bot.user) - - -class HelpQueryNotFound(ValueError): - """ - Raised when a HelpSession Query doesn't match a command or cog. - - Contains the custom attribute of ``possible_matches``. - - Instances of this object contain a dictionary of any command(s) that were close to matching the - query, where keys are the possible matched command names and values are the likeness match scores. - """ - - def __init__(self, arg: str, possible_matches: dict = None): - super().__init__(arg) - self.possible_matches = possible_matches - - -class CustomHelpCommand(HelpCommand): - """ - An interactive instance for the bot help command. - - Cogs can be grouped into custom categories. All cogs with the same category will be displayed - under a single category name in the help output. Custom categories are defined inside the cogs - as a class attribute named `category`. A description can also be specified with the attribute - `category_description`. If a description is not found in at least one cog, the default will be - the regular description (class docstring) of the first cog found in the category. - """ - - def __init__(self): - super().__init__(command_attrs={"help": "Shows help for bot commands"}) - - @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) - async def command_callback(self, ctx: Context, *, command: str = None) -> None: - """Attempts to match the provided query with a valid command or cog.""" - # the only reason we need to tamper with this is because d.py does not support "categories", - # so we need to deal with them ourselves. - - bot = ctx.bot - - if command is None: - # quick and easy, send bot help if command is none - mapping = self.get_bot_mapping() - await self.send_bot_help(mapping) - return - - cog_matches = [] - description = None - for cog in bot.cogs.values(): - if hasattr(cog, "category") and cog.category == command: - cog_matches.append(cog) - if hasattr(cog, "category_description"): - description = cog.category_description - - if cog_matches: - category = Category(name=command, description=description, cogs=cog_matches) - await self.send_category_help(category) - return - - # it's either a cog, group, command or subcommand; let the parent class deal with it - await super().command_callback(ctx, command=command) - - async def get_all_help_choices(self) -> set: - """ - Get all the possible options for getting help in the bot. - - This will only display commands the author has permission to run. - - These include: - - Category names - - Cog names - - Group command names (and aliases) - - Command names (and aliases) - - Subcommand names (with parent group and aliases for subcommand, but not including aliases for group) - - Options and choices are case sensitive. - """ - # first get all commands including subcommands and full command name aliases - choices = set() - for command in await self.filter_commands(self.context.bot.walk_commands()): - # the the command or group name - choices.add(str(command)) - - if isinstance(command, Command): - # all aliases if it's just a command - choices.update(command.aliases) - else: - # otherwise we need to add the parent name in - choices.update(f"{command.full_parent_name} {alias}" for alias in command.aliases) - - # all cog names - choices.update(self.context.bot.cogs) - - # all category names - choices.update(cog.category for cog in self.context.bot.cogs.values() if hasattr(cog, "category")) - return choices - - async def command_not_found(self, string: str) -> "HelpQueryNotFound": - """ - Handles when a query does not match a valid command, group, cog or category. - - Will return an instance of the `HelpQueryNotFound` exception with the error message and possible matches. - """ - choices = await self.get_all_help_choices() - - # Run fuzzywuzzy's processor beforehand, and avoid matching if processed string is empty - # This avoids fuzzywuzzy from raising a warning on inputs with only non-alphanumeric characters - if (processed := full_process(string)): - result = process.extractBests(processed, choices, scorer=fuzz.ratio, score_cutoff=60, processor=None) - else: - result = [] - - return HelpQueryNotFound(f'Query "{string}" not found.', dict(result)) - - async def subcommand_not_found(self, command: Command, string: str) -> "HelpQueryNotFound": - """ - Redirects the error to `command_not_found`. - - `command_not_found` deals with searching and getting best choices for both commands and subcommands. - """ - return await self.command_not_found(f"{command.qualified_name} {string}") - - async def send_error_message(self, error: HelpQueryNotFound) -> None: - """Send the error message to the channel.""" - embed = Embed(colour=Colour.red(), title=str(error)) - - if getattr(error, "possible_matches", None): - matches = "\n".join(f"`{match}`" for match in error.possible_matches) - embed.description = f"**Did you mean:**\n{matches}" - - await self.context.send(embed=embed) - - async def command_formatting(self, command: Command) -> Embed: - """ - Takes a command and turns it into an embed. - - It will add an author, command signature + help, aliases and a note if the user can't run the command. - """ - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - parent = command.full_parent_name - - name = str(command) if not parent else f"{parent} {command.name}" - command_details = f"**```{PREFIX}{name} {command.signature}```**\n" - - # show command aliases - aliases = ", ".join(f"`{alias}`" if not parent else f"`{parent} {alias}`" for alias in command.aliases) - if aliases: - command_details += f"**Can also use:** {aliases}\n\n" - - # check if the user is allowed to run this command - if not await command.can_run(self.context): - command_details += "***You cannot run this command.***\n\n" - - command_details += f"*{command.help or 'No details provided.'}*\n" - embed.description = command_details - - return embed - - async def send_command_help(self, command: Command) -> None: - """Send help for a single command.""" - embed = await self.command_formatting(command) - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - @staticmethod - def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: - """ - Formats the prefix, command name and signature, and short doc for an iterable of commands. - - return_as_list is helpful for passing these command details into the paginator as a list of command details. - """ - details = [] - for command in commands_: - signature = f" {command.signature}" if command.signature else "" - details.append( - f"\n**`{PREFIX}{command.qualified_name}{signature}`**\n*{command.short_doc or 'No details provided'}*" - ) - if return_as_list: - return details - else: - return "".join(details) - - async def send_group_help(self, group: Group) -> None: - """Sends help for a group command.""" - subcommands = group.commands - - if len(subcommands) == 0: - # no subcommands, just treat it like a regular command - await self.send_command_help(group) - return - - # remove commands that the user can't run and are hidden, and sort by name - commands_ = await self.filter_commands(subcommands, sort=True) - - embed = await self.command_formatting(group) - - command_details = self.get_commands_brief_details(commands_) - if command_details: - embed.description += f"\n**Subcommands:**\n{command_details}" - - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - async def send_cog_help(self, cog: Cog) -> None: - """Send help for a cog.""" - # sort commands by name, and remove any the user cant run or are hidden. - commands_ = await self.filter_commands(cog.get_commands(), sort=True) - - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - embed.description = f"**{cog.qualified_name}**\n*{cog.description}*" - - command_details = self.get_commands_brief_details(commands_) - if command_details: - embed.description += f"\n\n**Commands:**\n{command_details}" - - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - @staticmethod - def _category_key(command: Command) -> str: - """ - Returns a cog name of a given command for use as a key for `sorted` and `groupby`. - - A zero width space is used as a prefix for results with no cogs to force them last in ordering. - """ - if command.cog: - with suppress(AttributeError): - if command.cog.category: - return f"**{command.cog.category}**" - return f"**{command.cog_name}**" - else: - return "**\u200bNo Category:**" - - async def send_category_help(self, category: Category) -> None: - """ - Sends help for a bot category. - - This sends a brief help for all commands in all cogs registered to the category. - """ - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - all_commands = [] - for cog in category.cogs: - all_commands.extend(cog.get_commands()) - - filtered_commands = await self.filter_commands(all_commands, sort=True) - - command_detail_lines = self.get_commands_brief_details(filtered_commands, return_as_list=True) - description = f"**{category.name}**\n*{category.description}*" - - if command_detail_lines: - description += "\n\n**Commands:**" - - await LinePaginator.paginate( - command_detail_lines, - self.context, - embed, - prefix=description, - max_lines=COMMANDS_PER_PAGE, - max_size=2000, - ) - - async def send_bot_help(self, mapping: dict) -> None: - """Sends help for all bot commands and cogs.""" - bot = self.context.bot - - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - filter_commands = await self.filter_commands(bot.commands, sort=True, key=self._category_key) - - cog_or_category_pages = [] - - for cog_or_category, _commands in itertools.groupby(filter_commands, key=self._category_key): - sorted_commands = sorted(_commands, key=lambda c: c.name) - - if len(sorted_commands) == 0: - continue - - command_detail_lines = self.get_commands_brief_details(sorted_commands, return_as_list=True) - - # Split cogs or categories which have too many commands to fit in one page. - # The length of commands is included for later use when aggregating into pages for the paginator. - for index in range(0, len(sorted_commands), COMMANDS_PER_PAGE): - truncated_lines = command_detail_lines[index:index + COMMANDS_PER_PAGE] - joined_lines = "".join(truncated_lines) - cog_or_category_pages.append((f"**{cog_or_category}**{joined_lines}", len(truncated_lines))) - - pages = [] - counter = 0 - page = "" - for page_details, length in cog_or_category_pages: - counter += length - if counter > COMMANDS_PER_PAGE: - # force a new page on paginator even if it falls short of the max pages - # since we still want to group categories/cogs. - counter = length - pages.append(page) - page = f"{page_details}\n\n" - else: - page += f"{page_details}\n\n" - - if page: - # add any remaining command help that didn't get added in the last iteration above. - pages.append(page) - - await LinePaginator.paginate(pages, self.context, embed=embed, max_lines=1, max_size=2000) - - -class Help(Cog): - """Custom Embed Pagination Help feature.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.old_help_command = bot.help_command - bot.help_command = CustomHelpCommand() - bot.help_command.cog = self - - def cog_unload(self) -> None: - """Reset the help command when the cog is unloaded.""" - self.bot.help_command = self.old_help_command - - -def setup(bot: Bot) -> None: - """Load the Help cog.""" - bot.add_cog(Help(bot)) - log.info("Cog loaded: Help") diff --git a/bot/cogs/info/__init__.py b/bot/cogs/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/info/doc.py b/bot/cogs/info/doc.py new file mode 100644 index 000000000..204cffb37 --- /dev/null +++ b/bot/cogs/info/doc.py @@ -0,0 +1,511 @@ +import asyncio +import functools +import logging +import re +import textwrap +from collections import OrderedDict +from contextlib import suppress +from types import SimpleNamespace +from typing import Any, Callable, Optional, Tuple + +import discord +from bs4 import BeautifulSoup +from bs4.element import PageElement, Tag +from discord.errors import NotFound +from discord.ext import commands +from markdownify import MarkdownConverter +from requests import ConnectTimeout, ConnectionError, HTTPError +from sphinx.ext import intersphinx +from urllib3.exceptions import ProtocolError + +from bot.bot import Bot +from bot.constants import MODERATION_ROLES, RedirectOutput +from bot.converters import ValidPythonIdentifier, ValidURL +from bot.decorators import with_role +from bot.pagination import LinePaginator + + +log = logging.getLogger(__name__) +logging.getLogger('urllib3').setLevel(logging.WARNING) + +# Since Intersphinx is intended to be used with Sphinx, +# we need to mock its configuration. +SPHINX_MOCK_APP = SimpleNamespace( + config=SimpleNamespace( + intersphinx_timeout=3, + tls_verify=True, + user_agent="python3:python-discord/bot:1.0.0" + ) +) + +NO_OVERRIDE_GROUPS = ( + "2to3fixer", + "token", + "label", + "pdbcommand", + "term", +) +NO_OVERRIDE_PACKAGES = ( + "python", +) + +SEARCH_END_TAG_ATTRS = ( + "data", + "function", + "class", + "exception", + "seealso", + "section", + "rubric", + "sphinxsidebar", +) +UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") +WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") + +FAILED_REQUEST_RETRY_AMOUNT = 3 +NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay + + +def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: + """ + LRU cache implementation for coroutines. + + Once the cache exceeds the maximum size, keys are deleted in FIFO order. + + An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. + """ + # Assign the cache to the function itself so we can clear it from outside. + async_cache.cache = OrderedDict() + + def decorator(function: Callable) -> Callable: + """Define the async_cache decorator.""" + @functools.wraps(function) + async def wrapper(*args) -> Any: + """Decorator wrapper for the caching logic.""" + key = ':'.join(args[arg_offset:]) + + value = async_cache.cache.get(key) + if value is None: + if len(async_cache.cache) > max_size: + async_cache.cache.popitem(last=False) + + async_cache.cache[key] = await function(*args) + return async_cache.cache[key] + return wrapper + return decorator + + +class DocMarkdownConverter(MarkdownConverter): + """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" + + def convert_code(self, el: PageElement, text: str) -> str: + """Undo `markdownify`s underscore escaping.""" + return f"`{text}`".replace('\\', '') + + def convert_pre(self, el: PageElement, text: str) -> str: + """Wrap any codeblocks in `py` for syntax highlighting.""" + code = ''.join(el.strings) + return f"```py\n{code}```" + + +def markdownify(html: str) -> DocMarkdownConverter: + """Create a DocMarkdownConverter object from the input html.""" + return DocMarkdownConverter(bullets='•').convert(html) + + +class InventoryURL(commands.Converter): + """ + Represents an Intersphinx inventory URL. + + This converter checks whether intersphinx accepts the given inventory URL, and raises + `BadArgument` if that is not the case. + + Otherwise, it simply passes through the given URL. + """ + + @staticmethod + async def convert(ctx: commands.Context, url: str) -> str: + """Convert url to Intersphinx inventory URL.""" + try: + intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) + except AttributeError: + raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") + except ConnectionError: + if url.startswith('https'): + raise commands.BadArgument( + f"Cannot establish a connection to `{url}`. Does it support HTTPS?" + ) + raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") + except ValueError: + raise commands.BadArgument( + f"Failed to read Intersphinx inventory from URL `{url}`. " + "Are you sure that it's a valid inventory file?" + ) + return url + + +class Doc(commands.Cog): + """A set of commands for querying & displaying documentation.""" + + def __init__(self, bot: Bot): + self.base_urls = {} + self.bot = bot + self.inventories = {} + self.renamed_symbols = set() + + self.bot.loop.create_task(self.init_refresh_inventory()) + + async def init_refresh_inventory(self) -> None: + """Refresh documentation inventory on cog initialization.""" + await self.bot.wait_until_guild_available() + await self.refresh_inventory() + + async def update_single( + self, package_name: str, base_url: str, inventory_url: str + ) -> None: + """ + Rebuild the inventory for a single package. + + Where: + * `package_name` is the package name to use, appears in the log + * `base_url` is the root documentation URL for the specified package, used to build + absolute paths that link to specific symbols + * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running + `intersphinx.fetch_inventory` in an executor on the bot's event loop + """ + self.base_urls[package_name] = base_url + + package = await self._fetch_inventory(inventory_url) + if not package: + return None + + for group, value in package.items(): + for symbol, (package_name, _version, relative_doc_url, _) in value.items(): + absolute_doc_url = base_url + relative_doc_url + + if symbol in self.inventories: + group_name = group.split(":")[1] + symbol_base_url = self.inventories[symbol].split("/", 3)[2] + if ( + group_name in NO_OVERRIDE_GROUPS + or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) + ): + + symbol = f"{group_name}.{symbol}" + # If renamed `symbol` already exists, add library name in front to differentiate between them. + if symbol in self.renamed_symbols: + # Split `package_name` because of packages like Pillow that have spaces in them. + symbol = f"{package_name.split()[0]}.{symbol}" + + self.inventories[symbol] = absolute_doc_url + self.renamed_symbols.add(symbol) + continue + + self.inventories[symbol] = absolute_doc_url + + log.trace(f"Fetched inventory for {package_name}.") + + async def refresh_inventory(self) -> None: + """Refresh internal documentation inventory.""" + log.debug("Refreshing documentation inventory...") + + # Clear the old base URLS and inventories to ensure + # that we start from a fresh local dataset. + # Also, reset the cache used for fetching documentation. + self.base_urls.clear() + self.inventories.clear() + self.renamed_symbols.clear() + async_cache.cache = OrderedDict() + + # Run all coroutines concurrently - since each of them performs a HTTP + # request, this speeds up fetching the inventory data heavily. + coros = [ + self.update_single( + package["package"], package["base_url"], package["inventory_url"] + ) for package in await self.bot.api_client.get('bot/documentation-links') + ] + await asyncio.gather(*coros) + + async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: + """ + Given a Python symbol, return its signature and description. + + The first tuple element is the signature of the given symbol as a markup-free string, and + the second tuple element is the description of the given symbol with HTML markup included. + + If the given symbol is a module, returns a tuple `(None, str)` + else if the symbol could not be found, returns `None`. + """ + url = self.inventories.get(symbol) + if url is None: + return None + + async with self.bot.http_session.get(url) as response: + html = await response.text(encoding='utf-8') + + # Find the signature header and parse the relevant parts. + symbol_id = url.split('#')[-1] + soup = BeautifulSoup(html, 'lxml') + symbol_heading = soup.find(id=symbol_id) + search_html = str(soup) + + if symbol_heading is None: + return None + + if symbol_id == f"module-{symbol}": + # Get page content from the module headerlink to the + # first tag that has its class in `SEARCH_END_TAG_ATTRS` + start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) + if start_tag is None: + return [], "" + + end_tag = start_tag.find_next(self._match_end_tag) + if end_tag is None: + return [], "" + + description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) + description_end_index = search_html.find(str(end_tag)) + description = search_html[description_start_index:description_end_index] + signatures = None + + else: + signatures = [] + description = str(symbol_heading.find_next_sibling("dd")) + description_pos = search_html.find(description) + # Get text of up to 3 signatures, remove unwanted symbols + for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): + signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) + if signature and search_html.find(str(element)) < description_pos: + signatures.append(signature) + + return signatures, description.replace('¶', '') + + @async_cache(arg_offset=1) + async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: + """ + Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. + + If the symbol is known, an Embed with documentation about it is returned. + """ + scraped_html = await self.get_symbol_html(symbol) + if scraped_html is None: + return None + + signatures = scraped_html[0] + permalink = self.inventories[symbol] + description = markdownify(scraped_html[1]) + + # Truncate the description of the embed to the last occurrence + # of a double newline (interpreted as a paragraph) before index 1000. + if len(description) > 1000: + shortened = description[:1000] + description_cutoff = shortened.rfind('\n\n', 100) + if description_cutoff == -1: + # Search the shortened version for cutoff points in decreasing desirability, + # cutoff at 1000 if none are found. + for string in (". ", ", ", ",", " "): + description_cutoff = shortened.rfind(string) + if description_cutoff != -1: + break + else: + description_cutoff = 1000 + description = description[:description_cutoff] + + # If there is an incomplete code block, cut it out + if description.count("```") % 2: + codeblock_start = description.rfind('```py') + description = description[:codeblock_start].rstrip() + description += f"... [read more]({permalink})" + + description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) + if signatures is None: + # If symbol is a module, don't show signature. + embed_description = description + + elif not signatures: + # It's some "meta-page", for example: + # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views + embed_description = "This appears to be a generic page not tied to a specific symbol." + + else: + embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) + embed_description += f"\n{description}" + + embed = discord.Embed( + title=f'`{symbol}`', + url=permalink, + description=embed_description + ) + # Show all symbols with the same name that were renamed in the footer. + embed.set_footer( + text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) + ) + return embed + + @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) + async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: + """Lookup documentation for Python symbols.""" + await ctx.invoke(self.get_command, symbol) + + @docs_group.command(name='get', aliases=('g',)) + async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: + """ + Return a documentation embed for a given symbol. + + If no symbol is given, return a list of all available inventories. + + Examples: + !docs + !docs aiohttp + !docs aiohttp.ClientSession + !docs get aiohttp.ClientSession + """ + if symbol is None: + inventory_embed = discord.Embed( + title=f"All inventories (`{len(self.base_urls)}` total)", + colour=discord.Colour.blue() + ) + + lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) + if self.base_urls: + await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) + + else: + inventory_embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=inventory_embed) + + else: + # Fetching documentation for a symbol (at least for the first time, since + # caching is used) takes quite some time, so let's send typing to indicate + # that we got the command, but are still working on it. + async with ctx.typing(): + doc_embed = await self.get_symbol_embed(symbol) + + if doc_embed is None: + error_embed = discord.Embed( + description=f"Sorry, I could not find any documentation for `{symbol}`.", + colour=discord.Colour.red() + ) + error_message = await ctx.send(embed=error_embed) + with suppress(NotFound): + await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) + await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) + else: + await ctx.send(embed=doc_embed) + + @docs_group.command(name='set', aliases=('s',)) + @with_role(*MODERATION_ROLES) + async def set_command( + self, ctx: commands.Context, package_name: ValidPythonIdentifier, + base_url: ValidURL, inventory_url: InventoryURL + ) -> None: + """ + Adds a new documentation metadata object to the site's database. + + The database will update the object, should an existing item with the specified `package_name` already exist. + + Example: + !docs set \ + python \ + https://docs.python.org/3/ \ + https://docs.python.org/3/objects.inv + """ + body = { + 'package': package_name, + 'base_url': base_url, + 'inventory_url': inventory_url + } + await self.bot.api_client.post('bot/documentation-links', json=body) + + log.info( + f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" + f"Package name: {package_name}\n" + f"Base url: {base_url}\n" + f"Inventory URL: {inventory_url}" + ) + + # Rebuilding the inventory can take some time, so lets send out a + # typing event to show that the Bot is still working. + async with ctx.typing(): + await self.refresh_inventory() + await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") + + @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) + @with_role(*MODERATION_ROLES) + async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: + """ + Removes the specified package from the database. + + Examples: + !docs delete aiohttp + """ + await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') + + async with ctx.typing(): + # Rebuild the inventory to ensure that everything + # that was from this package is properly deleted. + await self.refresh_inventory() + await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") + + @docs_group.command(name="refresh", aliases=("rfsh", "r")) + @with_role(*MODERATION_ROLES) + async def refresh_command(self, ctx: commands.Context) -> None: + """Refresh inventories and send differences to channel.""" + old_inventories = set(self.base_urls) + with ctx.typing(): + await self.refresh_inventory() + # Get differences of added and removed inventories + added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) + if added: + added = f"+ {added}" + + removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) + if removed: + removed = f"- {removed}" + + embed = discord.Embed( + title="Inventories refreshed", + description=f"```diff\n{added}\n{removed}```" if added or removed else "" + ) + await ctx.send(embed=embed) + + async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: + """Get and return inventory from `inventory_url`. If fetching fails, return None.""" + fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) + for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): + try: + package = await self.bot.loop.run_in_executor(None, fetch_func) + except ConnectTimeout: + log.error( + f"Fetching of inventory {inventory_url} timed out," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except ProtocolError: + log.error( + f"Connection lost while fetching inventory {inventory_url}," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except HTTPError as e: + log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") + return None + except ConnectionError: + log.error(f"Couldn't establish connection to inventory {inventory_url}.") + return None + else: + return package + log.error(f"Fetching of inventory {inventory_url} failed.") + return None + + @staticmethod + def _match_end_tag(tag: Tag) -> bool: + """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" + for attr in SEARCH_END_TAG_ATTRS: + if attr in tag.get("class", ()): + return True + + return tag.name == "table" + + +def setup(bot: Bot) -> None: + """Load the Doc cog.""" + bot.add_cog(Doc(bot)) diff --git a/bot/cogs/info/help.py b/bot/cogs/info/help.py new file mode 100644 index 000000000..3d1d6fd10 --- /dev/null +++ b/bot/cogs/info/help.py @@ -0,0 +1,375 @@ +import itertools +import logging +from asyncio import TimeoutError +from collections import namedtuple +from contextlib import suppress +from typing import List, Union + +from discord import Colour, Embed, Member, Message, NotFound, Reaction, User +from discord.ext.commands import Bot, Cog, Command, Context, Group, HelpCommand +from fuzzywuzzy import fuzz, process +from fuzzywuzzy.utils import full_process + +from bot import constants +from bot.constants import Channels, Emojis, STAFF_ROLES +from bot.decorators import redirect_output +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + +COMMANDS_PER_PAGE = 8 +DELETE_EMOJI = Emojis.trashcan +PREFIX = constants.Bot.prefix + +Category = namedtuple("Category", ["name", "description", "cogs"]) + + +async def help_cleanup(bot: Bot, author: Member, message: Message) -> None: + """ + Runs the cleanup for the help command. + + Adds the :trashcan: reaction that, when clicked, will delete the help message. + After a 300 second timeout, the reaction will be removed. + """ + def check(reaction: Reaction, user: User) -> bool: + """Checks the reaction is :trashcan:, the author is original author and messages are the same.""" + return str(reaction) == DELETE_EMOJI and user.id == author.id and reaction.message.id == message.id + + await message.add_reaction(DELETE_EMOJI) + + with suppress(NotFound): + try: + await bot.wait_for("reaction_add", check=check, timeout=300) + await message.delete() + except TimeoutError: + await message.remove_reaction(DELETE_EMOJI, bot.user) + + +class HelpQueryNotFound(ValueError): + """ + Raised when a HelpSession Query doesn't match a command or cog. + + Contains the custom attribute of ``possible_matches``. + + Instances of this object contain a dictionary of any command(s) that were close to matching the + query, where keys are the possible matched command names and values are the likeness match scores. + """ + + def __init__(self, arg: str, possible_matches: dict = None): + super().__init__(arg) + self.possible_matches = possible_matches + + +class CustomHelpCommand(HelpCommand): + """ + An interactive instance for the bot help command. + + Cogs can be grouped into custom categories. All cogs with the same category will be displayed + under a single category name in the help output. Custom categories are defined inside the cogs + as a class attribute named `category`. A description can also be specified with the attribute + `category_description`. If a description is not found in at least one cog, the default will be + the regular description (class docstring) of the first cog found in the category. + """ + + def __init__(self): + super().__init__(command_attrs={"help": "Shows help for bot commands"}) + + @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) + async def command_callback(self, ctx: Context, *, command: str = None) -> None: + """Attempts to match the provided query with a valid command or cog.""" + # the only reason we need to tamper with this is because d.py does not support "categories", + # so we need to deal with them ourselves. + + bot = ctx.bot + + if command is None: + # quick and easy, send bot help if command is none + mapping = self.get_bot_mapping() + await self.send_bot_help(mapping) + return + + cog_matches = [] + description = None + for cog in bot.cogs.values(): + if hasattr(cog, "category") and cog.category == command: + cog_matches.append(cog) + if hasattr(cog, "category_description"): + description = cog.category_description + + if cog_matches: + category = Category(name=command, description=description, cogs=cog_matches) + await self.send_category_help(category) + return + + # it's either a cog, group, command or subcommand; let the parent class deal with it + await super().command_callback(ctx, command=command) + + async def get_all_help_choices(self) -> set: + """ + Get all the possible options for getting help in the bot. + + This will only display commands the author has permission to run. + + These include: + - Category names + - Cog names + - Group command names (and aliases) + - Command names (and aliases) + - Subcommand names (with parent group and aliases for subcommand, but not including aliases for group) + + Options and choices are case sensitive. + """ + # first get all commands including subcommands and full command name aliases + choices = set() + for command in await self.filter_commands(self.context.bot.walk_commands()): + # the the command or group name + choices.add(str(command)) + + if isinstance(command, Command): + # all aliases if it's just a command + choices.update(command.aliases) + else: + # otherwise we need to add the parent name in + choices.update(f"{command.full_parent_name} {alias}" for alias in command.aliases) + + # all cog names + choices.update(self.context.bot.cogs) + + # all category names + choices.update(cog.category for cog in self.context.bot.cogs.values() if hasattr(cog, "category")) + return choices + + async def command_not_found(self, string: str) -> "HelpQueryNotFound": + """ + Handles when a query does not match a valid command, group, cog or category. + + Will return an instance of the `HelpQueryNotFound` exception with the error message and possible matches. + """ + choices = await self.get_all_help_choices() + + # Run fuzzywuzzy's processor beforehand, and avoid matching if processed string is empty + # This avoids fuzzywuzzy from raising a warning on inputs with only non-alphanumeric characters + if (processed := full_process(string)): + result = process.extractBests(processed, choices, scorer=fuzz.ratio, score_cutoff=60, processor=None) + else: + result = [] + + return HelpQueryNotFound(f'Query "{string}" not found.', dict(result)) + + async def subcommand_not_found(self, command: Command, string: str) -> "HelpQueryNotFound": + """ + Redirects the error to `command_not_found`. + + `command_not_found` deals with searching and getting best choices for both commands and subcommands. + """ + return await self.command_not_found(f"{command.qualified_name} {string}") + + async def send_error_message(self, error: HelpQueryNotFound) -> None: + """Send the error message to the channel.""" + embed = Embed(colour=Colour.red(), title=str(error)) + + if getattr(error, "possible_matches", None): + matches = "\n".join(f"`{match}`" for match in error.possible_matches) + embed.description = f"**Did you mean:**\n{matches}" + + await self.context.send(embed=embed) + + async def command_formatting(self, command: Command) -> Embed: + """ + Takes a command and turns it into an embed. + + It will add an author, command signature + help, aliases and a note if the user can't run the command. + """ + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + parent = command.full_parent_name + + name = str(command) if not parent else f"{parent} {command.name}" + command_details = f"**```{PREFIX}{name} {command.signature}```**\n" + + # show command aliases + aliases = ", ".join(f"`{alias}`" if not parent else f"`{parent} {alias}`" for alias in command.aliases) + if aliases: + command_details += f"**Can also use:** {aliases}\n\n" + + # check if the user is allowed to run this command + if not await command.can_run(self.context): + command_details += "***You cannot run this command.***\n\n" + + command_details += f"*{command.help or 'No details provided.'}*\n" + embed.description = command_details + + return embed + + async def send_command_help(self, command: Command) -> None: + """Send help for a single command.""" + embed = await self.command_formatting(command) + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + @staticmethod + def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: + """ + Formats the prefix, command name and signature, and short doc for an iterable of commands. + + return_as_list is helpful for passing these command details into the paginator as a list of command details. + """ + details = [] + for command in commands_: + signature = f" {command.signature}" if command.signature else "" + details.append( + f"\n**`{PREFIX}{command.qualified_name}{signature}`**\n*{command.short_doc or 'No details provided'}*" + ) + if return_as_list: + return details + else: + return "".join(details) + + async def send_group_help(self, group: Group) -> None: + """Sends help for a group command.""" + subcommands = group.commands + + if len(subcommands) == 0: + # no subcommands, just treat it like a regular command + await self.send_command_help(group) + return + + # remove commands that the user can't run and are hidden, and sort by name + commands_ = await self.filter_commands(subcommands, sort=True) + + embed = await self.command_formatting(group) + + command_details = self.get_commands_brief_details(commands_) + if command_details: + embed.description += f"\n**Subcommands:**\n{command_details}" + + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + async def send_cog_help(self, cog: Cog) -> None: + """Send help for a cog.""" + # sort commands by name, and remove any the user cant run or are hidden. + commands_ = await self.filter_commands(cog.get_commands(), sort=True) + + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + embed.description = f"**{cog.qualified_name}**\n*{cog.description}*" + + command_details = self.get_commands_brief_details(commands_) + if command_details: + embed.description += f"\n\n**Commands:**\n{command_details}" + + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + @staticmethod + def _category_key(command: Command) -> str: + """ + Returns a cog name of a given command for use as a key for `sorted` and `groupby`. + + A zero width space is used as a prefix for results with no cogs to force them last in ordering. + """ + if command.cog: + with suppress(AttributeError): + if command.cog.category: + return f"**{command.cog.category}**" + return f"**{command.cog_name}**" + else: + return "**\u200bNo Category:**" + + async def send_category_help(self, category: Category) -> None: + """ + Sends help for a bot category. + + This sends a brief help for all commands in all cogs registered to the category. + """ + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + all_commands = [] + for cog in category.cogs: + all_commands.extend(cog.get_commands()) + + filtered_commands = await self.filter_commands(all_commands, sort=True) + + command_detail_lines = self.get_commands_brief_details(filtered_commands, return_as_list=True) + description = f"**{category.name}**\n*{category.description}*" + + if command_detail_lines: + description += "\n\n**Commands:**" + + await LinePaginator.paginate( + command_detail_lines, + self.context, + embed, + prefix=description, + max_lines=COMMANDS_PER_PAGE, + max_size=2000, + ) + + async def send_bot_help(self, mapping: dict) -> None: + """Sends help for all bot commands and cogs.""" + bot = self.context.bot + + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + filter_commands = await self.filter_commands(bot.commands, sort=True, key=self._category_key) + + cog_or_category_pages = [] + + for cog_or_category, _commands in itertools.groupby(filter_commands, key=self._category_key): + sorted_commands = sorted(_commands, key=lambda c: c.name) + + if len(sorted_commands) == 0: + continue + + command_detail_lines = self.get_commands_brief_details(sorted_commands, return_as_list=True) + + # Split cogs or categories which have too many commands to fit in one page. + # The length of commands is included for later use when aggregating into pages for the paginator. + for index in range(0, len(sorted_commands), COMMANDS_PER_PAGE): + truncated_lines = command_detail_lines[index:index + COMMANDS_PER_PAGE] + joined_lines = "".join(truncated_lines) + cog_or_category_pages.append((f"**{cog_or_category}**{joined_lines}", len(truncated_lines))) + + pages = [] + counter = 0 + page = "" + for page_details, length in cog_or_category_pages: + counter += length + if counter > COMMANDS_PER_PAGE: + # force a new page on paginator even if it falls short of the max pages + # since we still want to group categories/cogs. + counter = length + pages.append(page) + page = f"{page_details}\n\n" + else: + page += f"{page_details}\n\n" + + if page: + # add any remaining command help that didn't get added in the last iteration above. + pages.append(page) + + await LinePaginator.paginate(pages, self.context, embed=embed, max_lines=1, max_size=2000) + + +class Help(Cog): + """Custom Embed Pagination Help feature.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.old_help_command = bot.help_command + bot.help_command = CustomHelpCommand() + bot.help_command.cog = self + + def cog_unload(self) -> None: + """Reset the help command when the cog is unloaded.""" + self.bot.help_command = self.old_help_command + + +def setup(bot: Bot) -> None: + """Load the Help cog.""" + bot.add_cog(Help(bot)) + log.info("Cog loaded: Help") diff --git a/bot/cogs/info/information.py b/bot/cogs/info/information.py new file mode 100644 index 000000000..8982196d1 --- /dev/null +++ b/bot/cogs/info/information.py @@ -0,0 +1,422 @@ +import colorsys +import logging +import pprint +import textwrap +from collections import Counter, defaultdict +from string import Template +from typing import Any, Mapping, Optional, Union + +from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils +from discord.abc import GuildChannel +from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group +from discord.utils import escape_markdown + +from bot import constants +from bot.bot import Bot +from bot.decorators import in_whitelist, with_role +from bot.pagination import LinePaginator +from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + + +class Information(Cog): + """A cog with commands for generating embeds with server info, such as server stats and user info.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @staticmethod + def role_can_read(channel: GuildChannel, role: Role) -> bool: + """Return True if `role` can read messages in `channel`.""" + overwrites = channel.overwrites_for(role) + return overwrites.read_messages is True + + def get_staff_channel_count(self, guild: Guild) -> int: + """ + Get the number of channels that are staff-only. + + We need to know two things about a channel: + - Does the @everyone role have explicit read deny permissions? + - Do staff roles have explicit read allow permissions? + + If the answer to both of these questions is yes, it's a staff channel. + """ + channel_ids = set() + for channel in guild.channels: + if channel.type is ChannelType.category: + continue + + everyone_can_read = self.role_can_read(channel, guild.default_role) + + for role in constants.STAFF_ROLES: + role_can_read = self.role_can_read(channel, guild.get_role(role)) + if role_can_read and not everyone_can_read: + channel_ids.add(channel.id) + break + + return len(channel_ids) + + @staticmethod + def get_channel_type_counts(guild: Guild) -> str: + """Return the total amounts of the various types of channels in `guild`.""" + channel_counter = Counter(c.type for c in guild.channels) + channel_type_list = [] + for channel, count in channel_counter.items(): + channel_type = str(channel).title() + channel_type_list.append(f"{channel_type} channels: {count}") + + channel_type_list = sorted(channel_type_list) + return "\n".join(channel_type_list) + + @with_role(*constants.MODERATION_ROLES) + @command(name="roles") + async def roles_info(self, ctx: Context) -> None: + """Returns a list of all roles and their corresponding IDs.""" + # Sort the roles alphabetically and remove the @everyone role + roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) + + # Build a list + role_list = [] + for role in roles: + role_list.append(f"`{role.id}` - {role.mention}") + + # Build an embed + embed = Embed( + title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", + colour=Colour.blurple() + ) + + await LinePaginator.paginate(role_list, ctx, embed, empty=False) + + @with_role(*constants.MODERATION_ROLES) + @command(name="role") + async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: + """ + Return information on a role or list of roles. + + To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. + """ + parsed_roles = [] + failed_roles = [] + + for role_name in roles: + if isinstance(role_name, Role): + # Role conversion has already succeeded + parsed_roles.append(role_name) + continue + + role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) + + if not role: + failed_roles.append(role_name) + continue + + parsed_roles.append(role) + + if failed_roles: + await ctx.send(f":x: Could not retrieve the following roles: {', '.join(failed_roles)}") + + for role in parsed_roles: + h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) + + embed = Embed( + title=f"{role.name} info", + colour=role.colour, + ) + embed.add_field(name="ID", value=role.id, inline=True) + embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) + embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) + embed.add_field(name="Member count", value=len(role.members), inline=True) + embed.add_field(name="Position", value=role.position) + embed.add_field(name="Permission code", value=role.permissions.value, inline=True) + + await ctx.send(embed=embed) + + @command(name="server", aliases=["server_info", "guild", "guild_info"]) + async def server_info(self, ctx: Context) -> None: + """Returns an embed full of server information.""" + created = time_since(ctx.guild.created_at, precision="days") + features = ", ".join(ctx.guild.features) + region = ctx.guild.region + + roles = len(ctx.guild.roles) + member_count = ctx.guild.member_count + channel_counts = self.get_channel_type_counts(ctx.guild) + + # How many of each user status? + statuses = Counter(member.status for member in ctx.guild.members) + embed = Embed(colour=Colour.blurple()) + + # How many staff members and staff channels do we have? + staff_member_count = len(ctx.guild.get_role(constants.Roles.helpers).members) + staff_channel_count = self.get_staff_channel_count(ctx.guild) + + # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the + # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting + # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts + # after the dedent is made. + embed.description = Template( + textwrap.dedent(f""" + **Server information** + Created: {created} + Voice region: {region} + Features: {features} + + **Channel counts** + $channel_counts + Staff channels: {staff_channel_count} + + **Member counts** + Members: {member_count:,} + Staff members: {staff_member_count} + Roles: {roles} + + **Member statuses** + {constants.Emojis.status_online} {statuses[Status.online]:,} + {constants.Emojis.status_idle} {statuses[Status.idle]:,} + {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} + {constants.Emojis.status_offline} {statuses[Status.offline]:,} + """) + ).substitute({"channel_counts": channel_counts}) + embed.set_thumbnail(url=ctx.guild.icon_url) + + await ctx.send(embed=embed) + + @command(name="user", aliases=["user_info", "member", "member_info"]) + async def user_info(self, ctx: Context, user: Member = None) -> None: + """Returns info about a user.""" + if user is None: + user = ctx.author + + # Do a role check if this is being executed on someone other than the caller + elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): + await ctx.send("You may not use this command on users other than yourself.") + return + + # Non-staff may only do this in #bot-commands + if not with_role_check(ctx, *constants.STAFF_ROLES): + if not ctx.channel.id == constants.Channels.bot_commands: + raise InWhitelistCheckFailure(constants.Channels.bot_commands) + + embed = await self.create_user_embed(ctx, user) + + await ctx.send(embed=embed) + + async def create_user_embed(self, ctx: Context, user: Member) -> Embed: + """Creates an embed containing information on the `user`.""" + created = time_since(user.created_at, max_units=3) + + # Custom status + custom_status = '' + for activity in user.activities: + # Check activity.state for None value if user has a custom status set + # This guards against a custom status with an emoji but no text, which will cause + # escape_markdown to raise an exception + # This can be reworked after a move to d.py 1.3.0+, which adds a CustomActivity class + if activity.name == 'Custom Status' and activity.state: + state = escape_markdown(activity.state) + custom_status = f'Status: {state}\n' + + name = str(user) + if user.nick: + name = f"{user.nick} ({name})" + + joined = time_since(user.joined_at, max_units=3) + roles = ", ".join(role.mention for role in user.roles[1:]) + + description = [ + textwrap.dedent(f""" + **User Information** + Created: {created} + Profile: {user.mention} + ID: {user.id} + {custom_status} + **Member Information** + Joined: {joined} + Roles: {roles or None} + """).strip() + ] + + # Show more verbose output in moderation channels for infractions and nominations + if ctx.channel.id in constants.MODERATION_CHANNELS: + description.append(await self.expanded_user_infraction_counts(user)) + description.append(await self.user_nomination_counts(user)) + else: + description.append(await self.basic_user_infraction_counts(user)) + + # Let's build the embed now + embed = Embed( + title=name, + description="\n\n".join(description) + ) + + embed.set_thumbnail(url=user.avatar_url_as(static_format="png")) + embed.colour = user.top_role.colour if roles else Colour.blurple() + + return embed + + async def basic_user_infraction_counts(self, member: Member) -> str: + """Gets the total and active infraction counts for the given `member`.""" + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'hidden': 'False', + 'user__id': str(member.id) + } + ) + + total_infractions = len(infractions) + active_infractions = sum(infraction['active'] for infraction in infractions) + + infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}" + + return infraction_output + + async def expanded_user_infraction_counts(self, member: Member) -> str: + """ + Gets expanded infraction counts for the given `member`. + + The counts will be split by infraction type and the number of active infractions for each type will indicated + in the output as well. + """ + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'user__id': str(member.id) + } + ) + + infraction_output = ["**Infractions**"] + if not infractions: + infraction_output.append("This user has never received an infraction.") + else: + # Count infractions split by `type` and `active` status for this user + infraction_types = set() + infraction_counter = defaultdict(int) + for infraction in infractions: + infraction_type = infraction["type"] + infraction_active = 'active' if infraction["active"] else 'inactive' + + infraction_types.add(infraction_type) + infraction_counter[f"{infraction_active} {infraction_type}"] += 1 + + # Format the output of the infraction counts + for infraction_type in sorted(infraction_types): + active_count = infraction_counter[f"active {infraction_type}"] + total_count = active_count + infraction_counter[f"inactive {infraction_type}"] + + line = f"{infraction_type.capitalize()}s: {total_count}" + if active_count: + line += f" ({active_count} active)" + + infraction_output.append(line) + + return "\n".join(infraction_output) + + async def user_nomination_counts(self, member: Member) -> str: + """Gets the active and historical nomination counts for the given `member`.""" + nominations = await self.bot.api_client.get( + 'bot/nominations', + params={ + 'user__id': str(member.id) + } + ) + + output = ["**Nominations**"] + + if not nominations: + output.append("This user has never been nominated.") + else: + count = len(nominations) + is_currently_nominated = any(nomination["active"] for nomination in nominations) + nomination_noun = "nomination" if count == 1 else "nominations" + + if is_currently_nominated: + output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).") + else: + output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.") + + return "\n".join(output) + + def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str: + """Format a mapping to be readable to a human.""" + # sorting is technically superfluous but nice if you want to look for a specific field + fields = sorted(mapping.items(), key=lambda item: item[0]) + + if field_width is None: + field_width = len(max(mapping.keys(), key=len)) + + out = '' + + for key, val in fields: + if isinstance(val, dict): + # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries + inner_width = int(field_width * 1.6) + val = '\n' + self.format_fields(val, field_width=inner_width) + + elif isinstance(val, str): + # split up text since it might be long + text = textwrap.fill(val, width=100, replace_whitespace=False) + + # indent it, I guess you could do this with `wrap` and `join` but this is nicer + val = textwrap.indent(text, ' ' * (field_width + len(': '))) + + # the first line is already indented so we `str.lstrip` it + val = val.lstrip() + + if key == 'color': + # makes the base 10 representation of a hex number readable to humans + val = hex(val) + + out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width) + + # remove trailing whitespace + return out.rstrip() + + @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) + @group(invoke_without_command=True) + @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) + async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: + """Shows information about the raw API response.""" + # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling + # doing this extra request is also much easier than trying to convert everything back into a dictionary again + raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) + + paginator = Paginator() + + def add_content(title: str, content: str) -> None: + paginator.add_line(f'== {title} ==\n') + # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution. + # we hope it's not close to 2000 + paginator.add_line(content.replace('```', '`` `')) + paginator.close_page() + + if message.content: + add_content('Raw message', message.content) + + transformer = pprint.pformat if json else self.format_fields + for field_name in ('embeds', 'attachments'): + data = raw_data[field_name] + + if not data: + continue + + total = len(data) + for current, item in enumerate(data, start=1): + title = f'Raw {field_name} ({current}/{total})' + add_content(title, transformer(item)) + + for page in paginator.pages: + await ctx.send(page) + + @raw.command() + async def json(self, ctx: Context, message: Message) -> None: + """Shows information about the raw API response in a copy-pasteable Python format.""" + await ctx.invoke(self.raw, message=message, json=True) + + +def setup(bot: Bot) -> None: + """Load the Information cog.""" + bot.add_cog(Information(bot)) diff --git a/bot/cogs/info/python_news.py b/bot/cogs/info/python_news.py new file mode 100644 index 000000000..0ab5738a4 --- /dev/null +++ b/bot/cogs/info/python_news.py @@ -0,0 +1,232 @@ +import logging +import typing as t +from datetime import date, datetime + +import discord +import feedparser +from bs4 import BeautifulSoup +from discord.ext.commands import Cog +from discord.ext.tasks import loop + +from bot import constants +from bot.bot import Bot +from bot.utils.webhooks import send_webhook + +PEPS_RSS_URL = "https://www.python.org/dev/peps/peps.rss/" + +RECENT_THREADS_TEMPLATE = "https://mail.python.org/archives/list/{name}@python.org/recent-threads" +THREAD_TEMPLATE_URL = "https://mail.python.org/archives/api/list/{name}@python.org/thread/{id}/" +MAILMAN_PROFILE_URL = "https://mail.python.org/archives/users/{id}/" +THREAD_URL = "https://mail.python.org/archives/list/{list}@python.org/thread/{id}/" + +AVATAR_URL = "https://www.python.org/static/opengraph-icon-200x200.png" + +log = logging.getLogger(__name__) + + +class PythonNews(Cog): + """Post new PEPs and Python News to `#python-news`.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_names = {} + self.webhook: t.Optional[discord.Webhook] = None + + self.bot.loop.create_task(self.get_webhook_names()) + self.bot.loop.create_task(self.get_webhook_and_channel()) + + async def start_tasks(self) -> None: + """Start the tasks for fetching new PEPs and mailing list messages.""" + self.fetch_new_media.start() + + @loop(minutes=20) + async def fetch_new_media(self) -> None: + """Fetch new mailing list messages and then new PEPs.""" + await self.post_maillist_news() + await self.post_pep_news() + + async def sync_maillists(self) -> None: + """Sync currently in-use maillists with API.""" + # Wait until guild is available to avoid running before everything is ready + await self.bot.wait_until_guild_available() + + response = await self.bot.api_client.get("bot/bot-settings/news") + for mail in constants.PythonNews.mail_lists: + if mail not in response["data"]: + response["data"][mail] = [] + + # Because we are handling PEPs differently, we don't include it to mail lists + if "pep" not in response["data"]: + response["data"]["pep"] = [] + + await self.bot.api_client.put("bot/bot-settings/news", json=response) + + async def get_webhook_names(self) -> None: + """Get webhook author names from maillist API.""" + await self.bot.wait_until_guild_available() + + async with self.bot.http_session.get("https://mail.python.org/archives/api/lists") as resp: + lists = await resp.json() + + for mail in lists: + if mail["name"].split("@")[0] in constants.PythonNews.mail_lists: + self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] + + async def post_pep_news(self) -> None: + """Fetch new PEPs and when they don't have announcement in #python-news, create it.""" + # Wait until everything is ready and http_session available + await self.bot.wait_until_guild_available() + await self.sync_maillists() + + async with self.bot.http_session.get(PEPS_RSS_URL) as resp: + data = feedparser.parse(await resp.text("utf-8")) + + news_listing = await self.bot.api_client.get("bot/bot-settings/news") + payload = news_listing.copy() + pep_numbers = news_listing["data"]["pep"] + + # Reverse entries to send oldest first + data["entries"].reverse() + for new in data["entries"]: + try: + new_datetime = datetime.strptime(new["published"], "%a, %d %b %Y %X %Z") + except ValueError: + log.warning(f"Wrong datetime format passed in PEP new: {new['published']}") + continue + pep_nr = new["title"].split(":")[0].split()[1] + if ( + pep_nr in pep_numbers + or new_datetime.date() < date.today() + ): + continue + + # Build an embed and send a webhook + embed = discord.Embed( + title=new["title"], + description=new["summary"], + timestamp=new_datetime, + url=new["link"], + colour=constants.Colours.soft_green + ) + embed.set_footer(text=data["feed"]["title"], icon_url=AVATAR_URL) + msg = await send_webhook( + webhook=self.webhook, + username=data["feed"]["title"], + embed=embed, + avatar_url=AVATAR_URL, + wait=True, + ) + payload["data"]["pep"].append(pep_nr) + + # Increase overall PEP new stat + self.bot.stats.incr("python_news.posted.pep") + + if msg.channel.is_news(): + log.trace("Publishing PEP annnouncement because it was in a news channel") + await msg.publish() + + # Apply new sent news to DB to avoid duplicate sending + await self.bot.api_client.put("bot/bot-settings/news", json=payload) + + async def post_maillist_news(self) -> None: + """Send new maillist threads to #python-news that is listed in configuration.""" + await self.bot.wait_until_guild_available() + await self.sync_maillists() + existing_news = await self.bot.api_client.get("bot/bot-settings/news") + payload = existing_news.copy() + + for maillist in constants.PythonNews.mail_lists: + async with self.bot.http_session.get(RECENT_THREADS_TEMPLATE.format(name=maillist)) as resp: + recents = BeautifulSoup(await resp.text(), features="lxml") + + # When a

element is present in the response then the mailing list + # has not had any activity during the current month, so therefore it + # can be ignored. + if recents.p: + continue + + for thread in recents.html.body.div.find_all("a", href=True): + # We want only these threads that have identifiers + if "latest" in thread["href"]: + continue + + thread_information, email_information = await self.get_thread_and_first_mail( + maillist, thread["href"].split("/")[-2] + ) + + try: + new_date = datetime.strptime(email_information["date"], "%Y-%m-%dT%X%z") + except ValueError: + log.warning(f"Invalid datetime from Thread email: {email_information['date']}") + continue + + if ( + thread_information["thread_id"] in existing_news["data"][maillist] + or 'Re: ' in thread_information["subject"] + or new_date.date() < date.today() + ): + continue + + content = email_information["content"] + link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) + + # Build an embed and send a message to the webhook + embed = discord.Embed( + title=thread_information["subject"], + description=content[:500] + f"... [continue reading]({link})" if len(content) > 500 else content, + timestamp=new_date, + url=link, + colour=constants.Colours.soft_green + ) + embed.set_author( + name=f"{email_information['sender_name']} ({email_information['sender']['address']})", + url=MAILMAN_PROFILE_URL.format(id=email_information["sender"]["mailman_id"]), + ) + embed.set_footer( + text=f"Posted to {self.webhook_names[maillist]}", + icon_url=AVATAR_URL, + ) + msg = await send_webhook( + webhook=self.webhook, + username=self.webhook_names[maillist], + embed=embed, + avatar_url=AVATAR_URL, + wait=True, + ) + payload["data"][maillist].append(thread_information["thread_id"]) + + # Increase this specific maillist counter in stats + self.bot.stats.incr(f"python_news.posted.{maillist.replace('-', '_')}") + + if msg.channel.is_news(): + log.trace("Publishing mailing list message because it was in a news channel") + await msg.publish() + + await self.bot.api_client.put("bot/bot-settings/news", json=payload) + + async def get_thread_and_first_mail(self, maillist: str, thread_identifier: str) -> t.Tuple[t.Any, t.Any]: + """Get mail thread and first mail from mail.python.org based on `maillist` and `thread_identifier`.""" + async with self.bot.http_session.get( + THREAD_TEMPLATE_URL.format(name=maillist, id=thread_identifier) + ) as resp: + thread_information = await resp.json() + + async with self.bot.http_session.get(thread_information["starting_email"]) as resp: + email_information = await resp.json() + return thread_information, email_information + + async def get_webhook_and_channel(self) -> None: + """Storage #python-news channel Webhook and `TextChannel` to `News.webhook` and `channel`.""" + await self.bot.wait_until_guild_available() + self.webhook = await self.bot.fetch_webhook(constants.PythonNews.webhook) + + await self.start_tasks() + + def cog_unload(self) -> None: + """Stop news posting tasks on cog unload.""" + self.fetch_new_media.cancel() + + +def setup(bot: Bot) -> None: + """Add `News` cog.""" + bot.add_cog(PythonNews(bot)) diff --git a/bot/cogs/info/reddit.py b/bot/cogs/info/reddit.py new file mode 100644 index 000000000..d853ab2ea --- /dev/null +++ b/bot/cogs/info/reddit.py @@ -0,0 +1,304 @@ +import asyncio +import logging +import random +import textwrap +from collections import namedtuple +from datetime import datetime, timedelta +from typing import List + +from aiohttp import BasicAuth, ClientError +from discord import Colour, Embed, TextChannel +from discord.ext.commands import Cog, Context, group +from discord.ext.tasks import loop + +from bot.bot import Bot +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks +from bot.converters import Subreddit +from bot.decorators import with_role +from bot.pagination import LinePaginator +from bot.utils.messages import sub_clyde + +log = logging.getLogger(__name__) + +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + + +class Reddit(Cog): + """Track subreddit posts and show detailed statistics about them.""" + + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} + URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 + + def __init__(self, bot: Bot): + self.bot = bot + + self.webhook = None + self.access_token = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + bot.loop.create_task(self.init_reddit_ready()) + self.auto_poster_loop.start() + + def cog_unload(self) -> None: + """Stop the loop task and revoke the access token when the cog is unloaded.""" + self.auto_poster_loop.cancel() + if self.access_token and self.access_token.expires_at > datetime.utcnow(): + asyncio.create_task(self.revoke_access_token()) + + async def init_reddit_ready(self) -> None: + """Sets the reddit webhook when the cog is loaded.""" + await self.bot.wait_until_guild_available() + if not self.webhook: + self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) + + @property + def channel(self) -> TextChannel: + """Get the #reddit channel object from the bot's cache.""" + return self.bot.get_channel(Channels.reddit) + + async def get_access_token(self) -> None: + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded and a ClientError raised if retrieval was still unsuccessful. + """ + for i in range(1, self.MAX_RETRIES + 1): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(seconds=expiration) + ) + + log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") + return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) + + await asyncio.sleep(3) + + self.bot.remove_cog(self.qualified_name) + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + + async def revoke_access_token(self) -> None: + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } + ) + + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + else: + log.warning(f"Unable to revoke access token: status {response.status}.") + + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: + """A helper method to fetch a certain amount of Reddit posts at a given route.""" + # Reddit's JSON responses only provide 25 posts at most. + if not 25 >= amount > 0: + raise ValueError("Invalid amount of subreddit posts requested.") + + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): + response = await self.bot.http_session.get( + url=url, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, + params=params + ) + if response.status == 200 and response.content_type == 'application/json': + # Got appropriate response - process and return. + content = await response.json() + posts = content["data"]["children"] + return posts[:amount] + + await asyncio.sleep(3) + + log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") + return list() # Failed to get appropriate response within allowed number of retries. + + async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: + """ + Get the top amount of posts for a given subreddit within a specified timeframe. + + A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top + weekly posts. + + The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. + """ + embed = Embed(description="") + + posts = await self.fetch_posts( + route=f"{subreddit}/top", + amount=amount, + params={"t": time} + ) + + if not posts: + embed.title = random.choice(ERROR_REPLIES) + embed.colour = Colour.red() + embed.description = ( + "Sorry! We couldn't find any posts from that subreddit. " + "If this problem persists, please let us know." + ) + + return embed + + for post in posts: + data = post["data"] + + text = data["selftext"] + if text: + text = textwrap.shorten(text, width=128, placeholder="...") + text += "\n" # Add newline to separate embed info + + ups = data["ups"] + comments = data["num_comments"] + author = data["author"] + + title = textwrap.shorten(data["title"], width=64, placeholder="...") + link = self.URL + data["permalink"] + + embed.description += ( + f"**[{title}]({link})**\n" + f"{text}" + f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" + ) + + embed.colour = Colour.blurple() + return embed + + @loop() + async def auto_poster_loop(self) -> None: + """Post the top 5 posts daily, and the top 5 posts weekly.""" + # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter + now = datetime.utcnow() + tomorrow = now + timedelta(days=1) + midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) + seconds_until = (midnight_tomorrow - now).total_seconds() + + await asyncio.sleep(seconds_until) + + await self.bot.wait_until_guild_available() + if not self.webhook: + await self.bot.fetch_webhook(Webhooks.reddit) + + if datetime.utcnow().weekday() == 0: + await self.top_weekly_posts() + # if it's a monday send the top weekly posts + + for subreddit in RedditConfig.subreddits: + top_posts = await self.get_top_posts(subreddit=subreddit, time="day") + username = sub_clyde(f"{subreddit} Top Daily Posts") + message = await self.webhook.send(username=username, embed=top_posts, wait=True) + + if message.channel.is_news(): + await message.publish() + + async def top_weekly_posts(self) -> None: + """Post a summary of the top posts.""" + for subreddit in RedditConfig.subreddits: + # Send and pin the new weekly posts. + top_posts = await self.get_top_posts(subreddit=subreddit, time="week") + username = sub_clyde(f"{subreddit} Top Weekly Posts") + message = await self.webhook.send(wait=True, username=username, embed=top_posts) + + if subreddit.lower() == "r/python": + if not self.channel: + log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") + return + + # Remove the oldest pins so that only 12 remain at most. + pins = await self.channel.pins() + + while len(pins) >= 12: + await pins[-1].unpin() + del pins[-1] + + await message.pin() + + if message.channel.is_news(): + await message.publish() + + @group(name="reddit", invoke_without_command=True) + async def reddit_group(self, ctx: Context) -> None: + """View the top posts from various subreddits.""" + await ctx.send_help(ctx.command) + + @reddit_group.command(name="top") + async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of all time from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="all") + + await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) + + @reddit_group.command(name="daily") + async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of today from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="day") + + await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) + + @reddit_group.command(name="weekly") + async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of this week from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="week") + + await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) + + @with_role(*STAFF_ROLES) + @reddit_group.command(name="subreddits", aliases=("subs",)) + async def subreddits_command(self, ctx: Context) -> None: + """Send a paginated embed of all the subreddits we're relaying.""" + embed = Embed() + embed.title = "Relayed subreddits." + embed.colour = Colour.blurple() + + await LinePaginator.paginate( + RedditConfig.subreddits, + ctx, embed, + footer_text="Use the reddit commands along with these to view their posts.", + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Reddit cog.""" + if not RedditConfig.secret or not RedditConfig.client_id: + log.error("Credentials not provided, cog not loaded.") + return + bot.add_cog(Reddit(bot)) diff --git a/bot/cogs/info/site.py b/bot/cogs/info/site.py new file mode 100644 index 000000000..ac29daa1d --- /dev/null +++ b/bot/cogs/info/site.py @@ -0,0 +1,146 @@ +import logging + +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import URLs +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + +PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" + + +class Site(Cog): + """Commands for linking to different parts of the site.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="site", aliases=("s",), invoke_without_command=True) + async def site_group(self, ctx: Context) -> None: + """Commands for getting info about our website.""" + await ctx.send_help(ctx.command) + + @site_group.command(name="home", aliases=("about",)) + async def site_main(self, ctx: Context) -> None: + """Info about the website itself.""" + url = f"{URLs.site_schema}{URLs.site}/" + + embed = Embed(title="Python Discord website") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + f"[Our official website]({url}) is an open-source community project " + "created with Python and Django. It contains information about the server " + "itself, lets you sign up for upcoming events, has its own wiki, contains " + "a list of valuable learning resources, and much more." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="resources") + async def site_resources(self, ctx: Context) -> None: + """Info about the site's Resources page.""" + learning_url = f"{PAGES_URL}/resources" + + embed = Embed(title="Resources") + embed.set_footer(text=f"{learning_url}") + embed.colour = Colour.blurple() + embed.description = ( + f"The [Resources page]({learning_url}) on our website contains a " + "list of hand-selected learning resources that we regularly recommend " + f"to both beginners and experts." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="tools") + async def site_tools(self, ctx: Context) -> None: + """Info about the site's Tools page.""" + tools_url = f"{PAGES_URL}/resources/tools" + + embed = Embed(title="Tools") + embed.set_footer(text=f"{tools_url}") + embed.colour = Colour.blurple() + embed.description = ( + f"The [Tools page]({tools_url}) on our website contains a " + f"couple of the most popular tools for programming in Python." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="help") + async def site_help(self, ctx: Context) -> None: + """Info about the site's Getting Help page.""" + url = f"{PAGES_URL}/resources/guides/asking-good-questions" + + embed = Embed(title="Asking Good Questions") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + "Asking the right question about something that's new to you can sometimes be tricky. " + f"To help with this, we've created a [guide to asking good questions]({url}) on our website. " + "It contains everything you need to get the very best help from our community." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="faq") + async def site_faq(self, ctx: Context) -> None: + """Info about the site's FAQ page.""" + url = f"{PAGES_URL}/frequently-asked-questions" + + embed = Embed(title="FAQ") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + "As the largest Python community on Discord, we get hundreds of questions every day. " + "Many of these questions have been asked before. We've compiled a list of the most " + "frequently asked questions along with their answers, which can be found on " + f"our [FAQ page]({url})." + ) + + await ctx.send(embed=embed) + + @site_group.command(aliases=['r', 'rule'], name='rules') + async def site_rules(self, ctx: Context, *rules: int) -> None: + """Provides a link to all rules or, if specified, displays specific rule(s).""" + rules_embed = Embed(title='Rules', color=Colour.blurple()) + rules_embed.url = f"{PAGES_URL}/rules" + + if not rules: + # Rules were not submitted. Return the default description. + rules_embed.description = ( + "The rules and guidelines that apply to this community can be found on" + f" our [rules page]({PAGES_URL}/rules). We expect" + " all members of the community to have read and understood these." + ) + + await ctx.send(embed=rules_embed) + return + + full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) + invalid_indices = tuple( + pick + for pick in rules + if pick < 1 or pick > len(full_rules) + ) + + if invalid_indices: + indices = ', '.join(map(str, invalid_indices)) + await ctx.send(f":x: Invalid rule indices: {indices}") + return + + for rule in rules: + self.bot.stats.incr(f"rule_uses.{rule}") + + final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) + + await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) + + +def setup(bot: Bot) -> None: + """Load the Site cog.""" + bot.add_cog(Site(bot)) diff --git a/bot/cogs/info/source.py b/bot/cogs/info/source.py new file mode 100644 index 000000000..205e0ba81 --- /dev/null +++ b/bot/cogs/info/source.py @@ -0,0 +1,141 @@ +import inspect +from pathlib import Path +from typing import Optional, Tuple, Union + +from discord import Embed +from discord.ext import commands + +from bot.bot import Bot +from bot.constants import URLs + +SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] + + +class SourceConverter(commands.Converter): + """Convert an argument into a help command, tag, command, or cog.""" + + async def convert(self, ctx: commands.Context, argument: str) -> SourceType: + """Convert argument into source object.""" + if argument.lower().startswith("help"): + return ctx.bot.help_command + + cog = ctx.bot.get_cog(argument) + if cog: + return cog + + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd + + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + elif argument.lower() in tags_cog._cache: + return argument.lower() + + raise commands.BadArgument( + f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog." + ) + + +class BotSource(commands.Cog): + """Displays information about the bot's source code.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @commands.command(name="source", aliases=("src",)) + async def source_command(self, ctx: commands.Context, *, source_item: SourceConverter = None) -> None: + """Display information and a GitHub link to the source code of a command, tag, or cog.""" + if not source_item: + embed = Embed(title="Bot's GitHub Repository") + embed.add_field(name="Repository", value=f"[Go to GitHub]({URLs.github_bot_repo})") + embed.set_thumbnail(url="https://avatars1.githubusercontent.com/u/9919") + await ctx.send(embed=embed) + return + + embed = await self.build_embed(source_item) + await ctx.send(embed=embed) + + def get_source_link(self, source_item: SourceType) -> Tuple[str, str, Optional[int]]: + """ + Build GitHub link of source item, return this link, file location and first line number. + + Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). + """ + if isinstance(source_item, commands.Command): + if source_item.cog_name == "Alias": + cmd_name = source_item.callback.__name__.replace("_alias", "") + cmd = self.bot.get_command(cmd_name.replace("_", " ")) + src = cmd.callback.__code__ + filename = src.co_filename + else: + src = source_item.callback.__code__ + filename = src.co_filename + elif isinstance(source_item, str): + tags_cog = self.bot.get_cog("Tags") + filename = tags_cog._cache[source_item]["location"] + else: + src = type(source_item) + try: + filename = inspect.getsourcefile(src) + except TypeError: + raise commands.BadArgument("Cannot get source for a dynamically-created object.") + + if not isinstance(source_item, str): + try: + lines, first_line_no = inspect.getsourcelines(src) + except OSError: + raise commands.BadArgument("Cannot get source for a dynamically-created object.") + + lines_extension = f"#L{first_line_no}-L{first_line_no+len(lines)-1}" + else: + first_line_no = None + lines_extension = "" + + # Handle tag file location differently than others to avoid errors in some cases + if not first_line_no: + file_location = Path(filename).relative_to("/bot/") + else: + file_location = Path(filename).relative_to(Path.cwd()).as_posix() + + url = f"{URLs.github_bot_repo}/blob/master/{file_location}{lines_extension}" + + return url, file_location, first_line_no or None + + async def build_embed(self, source_object: SourceType) -> Optional[Embed]: + """Build embed based on source object.""" + url, location, first_line = self.get_source_link(source_object) + + if isinstance(source_object, commands.HelpCommand): + title = "Help Command" + description = source_object.__doc__.splitlines()[1] + elif isinstance(source_object, commands.Command): + if source_object.cog_name == "Alias": + cmd_name = source_object.callback.__name__.replace("_alias", "") + cmd = self.bot.get_command(cmd_name.replace("_", " ")) + description = cmd.short_doc + else: + description = source_object.short_doc + + title = f"Command: {source_object.qualified_name}" + elif isinstance(source_object, str): + title = f"Tag: {source_object}" + description = "" + else: + title = f"Cog: {source_object.qualified_name}" + description = source_object.description.splitlines()[0] + + embed = Embed(title=title, description=description) + embed.add_field(name="Source Code", value=f"[Go to GitHub]({url})") + line_text = f":{first_line}" if first_line else "" + embed.set_footer(text=f"{location}{line_text}") + + return embed + + +def setup(bot: Bot) -> None: + """Load the BotSource cog.""" + bot.add_cog(BotSource(bot)) diff --git a/bot/cogs/info/stats.py b/bot/cogs/info/stats.py new file mode 100644 index 000000000..d42f55466 --- /dev/null +++ b/bot/cogs/info/stats.py @@ -0,0 +1,129 @@ +import string +from datetime import datetime + +from discord import Member, Message, Status +from discord.ext.commands import Cog, Context +from discord.ext.tasks import loop + +from bot.bot import Bot +from bot.constants import Categories, Channels, Guild, Stats as StatConf + + +CHANNEL_NAME_OVERRIDES = { + Channels.off_topic_0: "off_topic_0", + Channels.off_topic_1: "off_topic_1", + Channels.off_topic_2: "off_topic_2", + Channels.staff_lounge: "staff_lounge" +} + +ALLOWED_CHARS = string.ascii_letters + string.digits + "_" + + +class Stats(Cog): + """A cog which provides a way to hook onto Discord events and forward to stats.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.last_presence_update = None + self.update_guild_boost.start() + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Report message events in the server to statsd.""" + if message.guild is None: + return + + if message.guild.id != Guild.id: + return + + cat = getattr(message.channel, "category", None) + if cat is not None and cat.id == Categories.modmail: + if message.channel.id != Channels.incidents: + # Do not report modmail channels to stats, there are too many + # of them for interesting statistics to be drawn out of this. + return + + reformatted_name = message.channel.name.replace('-', '_') + + if CHANNEL_NAME_OVERRIDES.get(message.channel.id): + reformatted_name = CHANNEL_NAME_OVERRIDES.get(message.channel.id) + + reformatted_name = "".join(char for char in reformatted_name if char in ALLOWED_CHARS) + + stat_name = f"channels.{reformatted_name}" + self.bot.stats.incr(stat_name) + + # Increment the total message count + self.bot.stats.incr("messages") + + @Cog.listener() + async def on_command_completion(self, ctx: Context) -> None: + """Report completed commands to statsd.""" + command_name = ctx.command.qualified_name.replace(" ", "_") + + self.bot.stats.incr(f"commands.{command_name}") + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Update member count stat on member join.""" + if member.guild.id != Guild.id: + return + + self.bot.stats.gauge("guild.total_members", len(member.guild.members)) + + @Cog.listener() + async def on_member_leave(self, member: Member) -> None: + """Update member count stat on member leave.""" + if member.guild.id != Guild.id: + return + + self.bot.stats.gauge("guild.total_members", len(member.guild.members)) + + @Cog.listener() + async def on_member_update(self, _before: Member, after: Member) -> None: + """Update presence estimates on member update.""" + if after.guild.id != Guild.id: + return + + if self.last_presence_update: + if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: + return + + self.last_presence_update = datetime.now() + + online = 0 + idle = 0 + dnd = 0 + offline = 0 + + for member in after.guild.members: + if member.status is Status.online: + online += 1 + elif member.status is Status.dnd: + dnd += 1 + elif member.status is Status.idle: + idle += 1 + elif member.status is Status.offline: + offline += 1 + + self.bot.stats.gauge("guild.status.online", online) + self.bot.stats.gauge("guild.status.idle", idle) + self.bot.stats.gauge("guild.status.do_not_disturb", dnd) + self.bot.stats.gauge("guild.status.offline", offline) + + @loop(hours=1) + async def update_guild_boost(self) -> None: + """Post the server boost level and tier every hour.""" + await self.bot.wait_until_guild_available() + g = self.bot.get_guild(Guild.id) + self.bot.stats.gauge("boost.amount", g.premium_subscription_count) + self.bot.stats.gauge("boost.tier", g.premium_tier) + + def cog_unload(self) -> None: + """Stop the boost statistic task on unload of the Cog.""" + self.update_guild_boost.stop() + + +def setup(bot: Bot) -> None: + """Load the stats cog.""" + bot.add_cog(Stats(bot)) diff --git a/bot/cogs/info/tags.py b/bot/cogs/info/tags.py new file mode 100644 index 000000000..3d76c5c08 --- /dev/null +++ b/bot/cogs/info/tags.py @@ -0,0 +1,277 @@ +import logging +import re +import time +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot import constants +from bot.bot import Bot +from bot.converters import TagNameConverter +from bot.pagination import LinePaginator +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +TEST_CHANNELS = ( + constants.Channels.bot_commands, + constants.Channels.helpers +) + +REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) +FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." + + +class Tags(Cog): + """Save new tags and fetch existing tags.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.tag_cooldowns = {} + self._cache = self.get_tags() + + @staticmethod + def get_tags() -> dict: + """Get all tags.""" + cache = {} + + base_path = Path("bot", "resources", "tags") + for file in base_path.glob("**/*"): + if file.is_file(): + tag_title = file.stem + tag = { + "title": tag_title, + "embed": { + "description": file.read_text(encoding="utf8"), + }, + "restricted_to": "developers", + "location": f"/bot/{file}" + } + + # Convert to a list to allow negative indexing. + parents = list(file.relative_to(base_path).parents) + if len(parents) > 1: + # -1 would be '.' hence -2 is used as the index. + tag["restricted_to"] = parents[-2].name + + cache[tag_title] = tag + + return cache + + @staticmethod + def check_accessibility(user: Member, tag: dict) -> bool: + """Check if user can access a tag.""" + return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] + + @staticmethod + def _fuzzy_search(search: str, target: str) -> float: + """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" + current, index = 0, 0 + _search = REGEX_NON_ALPHABET.sub('', search.lower()) + _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) + _target = next(_targets) + try: + while True: + while index < len(_target) and _search[current] == _target[index]: + current += 1 + index += 1 + index, _target = 0, next(_targets) + except (StopIteration, IndexError): + pass + return current / len(_search) * 100 + + def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: + """Return a list of suggested tags.""" + scores: Dict[str, int] = { + tag_title: Tags._fuzzy_search(tag_name, tag['title']) + for tag_title, tag in self._cache.items() + } + + thresholds = thresholds or [100, 90, 80, 70, 60] + + for threshold in thresholds: + suggestions = [ + self._cache[tag_title] + for tag_title, matching_score in scores.items() + if matching_score >= threshold + ] + if suggestions: + return suggestions + + return [] + + def _get_tag(self, tag_name: str) -> list: + """Get a specific tag.""" + found = [self._cache.get(tag_name.lower(), None)] + if not found[0]: + return self._get_suggestions(tag_name) + return found + + def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: + """ + Search for tags via contents. + + `predicate` will be the built-in any, all, or a custom callable. Must return a bool. + """ + keywords_processed: List[str] = [] + for keyword in keywords.split(','): + keyword_sanitized = keyword.strip().casefold() + if not keyword_sanitized: + # this happens when there are leading / trailing / consecutive comma. + continue + keywords_processed.append(keyword_sanitized) + + if not keywords_processed: + # after sanitizing, we can end up with an empty list, for example when keywords is ',' + # in that case, we simply want to search for such keywords directly instead. + keywords_processed = [keywords] + + matching_tags = [] + for tag in self._cache.values(): + matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) + if self.check_accessibility(user, tag) and check(matches): + matching_tags.append(tag) + + return matching_tags + + async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: + """Send the result of matching tags to user.""" + if not matching_tags: + pass + elif len(matching_tags) == 1: + await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) + else: + is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 + embed = Embed( + title=f"Here are the tags containing the given keyword{'s' * is_plural}:", + description='\n'.join(tag['title'] for tag in matching_tags[:10]) + ) + await LinePaginator.paginate( + sorted(f"**»** {tag['title']}" for tag in matching_tags), + ctx, + embed, + footer_text=FOOTER_TEXT, + empty=False, + max_lines=15 + ) + + @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) + async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + """Show all known tags, a single tag, or run a subcommand.""" + await ctx.invoke(self.get_command, tag_name=tag_name) + + @tags_group.group(name='search', invoke_without_command=True) + async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Only search for tags that has ALL the keywords. + """ + matching_tags = self._get_tags_via_content(all, keywords, ctx.author) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @search_tag_content.command(name='any') + async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Search for tags that has ANY of the keywords. + """ + matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @tags_group.command(name='get', aliases=('show', 'g')) + async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + """Get a specified tag, or a list of all tags if no tag is specified.""" + + def _command_on_cooldown(tag_name: str) -> bool: + """ + Check if the command is currently on cooldown, on a per-tag, per-channel basis. + + The cooldown duration is set in constants.py. + """ + now = time.time() + + cooldown_conditions = ( + tag_name + and tag_name in self.tag_cooldowns + and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags + and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id + ) + + if cooldown_conditions: + return True + return False + + if _command_on_cooldown(tag_name): + time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] + time_left = constants.Cooldowns.tags - time_elapsed + log.info( + f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " + f"Cooldown ends in {time_left:.1f} seconds." + ) + return + + if tag_name is not None: + temp_founds = self._get_tag(tag_name) + + founds = [] + + for found_tag in temp_founds: + if self.check_accessibility(ctx.author, found_tag): + founds.append(found_tag) + + if len(founds) == 1: + tag = founds[0] + if ctx.channel.id not in TEST_CHANNELS: + self.tag_cooldowns[tag_name] = { + "time": time.time(), + "channel": ctx.channel.id + } + + self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") + + await wait_for_deletion( + await ctx.send(embed=Embed.from_dict(tag['embed'])), + [ctx.author.id], + client=self.bot + ) + elif founds and len(tag_name) >= 3: + await wait_for_deletion( + await ctx.send( + embed=Embed( + title='Did you mean ...', + description='\n'.join(tag['title'] for tag in founds[:10]) + ) + ), + [ctx.author.id], + client=self.bot + ) + + else: + tags = self._cache.values() + if not tags: + await ctx.send(embed=Embed( + description="**There are no tags in the database!**", + colour=Colour.red() + )) + else: + embed: Embed = Embed(title="**Current tags**") + await LinePaginator.paginate( + sorted( + f"**»** {tag['title']}" for tag in tags + if self.check_accessibility(ctx.author, tag) + ), + ctx, + embed, + footer_text=FOOTER_TEXT, + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Tags cog.""" + bot.add_cog(Tags(bot)) diff --git a/bot/cogs/info/wolfram.py b/bot/cogs/info/wolfram.py new file mode 100644 index 000000000..e6cae3bb8 --- /dev/null +++ b/bot/cogs/info/wolfram.py @@ -0,0 +1,280 @@ +import logging +from io import BytesIO +from typing import Callable, List, Optional, Tuple +from urllib import parse + +import discord +from dateutil.relativedelta import relativedelta +from discord import Embed +from discord.ext import commands +from discord.ext.commands import BucketType, Cog, Context, check, group + +from bot.bot import Bot +from bot.constants import Colours, STAFF_ROLES, Wolfram +from bot.pagination import ImagePaginator +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +APPID = Wolfram.key +DEFAULT_OUTPUT_FORMAT = "JSON" +QUERY = "http://api.wolframalpha.com/v2/{request}?{data}" +WOLF_IMAGE = "https://www.symbols.com/gi.php?type=1&id=2886&i=1" + +MAX_PODS = 20 + +# Allows for 10 wolfram calls pr user pr day +usercd = commands.CooldownMapping.from_cooldown(Wolfram.user_limit_day, 60*60*24, BucketType.user) + +# Allows for max api requests / days in month per day for the entire guild (Temporary) +guildcd = commands.CooldownMapping.from_cooldown(Wolfram.guild_limit_day, 60*60*24, BucketType.guild) + + +async def send_embed( + ctx: Context, + message_txt: str, + colour: int = Colours.soft_red, + footer: str = None, + img_url: str = None, + f: discord.File = None +) -> None: + """Generate & send a response embed with Wolfram as the author.""" + embed = Embed(colour=colour) + embed.description = message_txt + embed.set_author(name="Wolfram Alpha", + icon_url=WOLF_IMAGE, + url="https://www.wolframalpha.com/") + if footer: + embed.set_footer(text=footer) + + if img_url: + embed.set_image(url=img_url) + + await ctx.send(embed=embed, file=f) + + +def custom_cooldown(*ignore: List[int]) -> Callable: + """ + Implement per-user and per-guild cooldowns for requests to the Wolfram API. + + A list of roles may be provided to ignore the per-user cooldown + """ + async def predicate(ctx: Context) -> bool: + if ctx.invoked_with == 'help': + # if the invoked command is help we don't want to increase the ratelimits since it's not actually + # invoking the command/making a request, so instead just check if the user/guild are on cooldown. + guild_cooldown = not guildcd.get_bucket(ctx.message).get_tokens() == 0 # if guild is on cooldown + if not any(r.id in ignore for r in ctx.author.roles): # check user bucket if user is not ignored + return guild_cooldown and not usercd.get_bucket(ctx.message).get_tokens() == 0 + return guild_cooldown + + user_bucket = usercd.get_bucket(ctx.message) + + if all(role.id not in ignore for role in ctx.author.roles): + user_rate = user_bucket.update_rate_limit() + + if user_rate: + # Can't use api; cause: member limit + delta = relativedelta(seconds=int(user_rate)) + cooldown = humanize_delta(delta) + message = ( + "You've used up your limit for Wolfram|Alpha requests.\n" + f"Cooldown: {cooldown}" + ) + await send_embed(ctx, message) + return False + + guild_bucket = guildcd.get_bucket(ctx.message) + guild_rate = guild_bucket.update_rate_limit() + + # Repr has a token attribute to read requests left + log.debug(guild_bucket) + + if guild_rate: + # Can't use api; cause: guild limit + message = ( + "The max limit of requests for the server has been reached for today.\n" + f"Cooldown: {int(guild_rate)}" + ) + await send_embed(ctx, message) + return False + + return True + return check(predicate) + + +async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tuple]]: + """Get the Wolfram API pod pages for the provided query.""" + async with ctx.channel.typing(): + url_str = parse.urlencode({ + "input": query, + "appid": APPID, + "output": DEFAULT_OUTPUT_FORMAT, + "format": "image,plaintext" + }) + request_url = QUERY.format(request="query", data=url_str) + + async with bot.http_session.get(request_url) as response: + json = await response.json(content_type='text/plain') + + result = json["queryresult"] + + if result["error"]: + # API key not set up correctly + if result["error"]["msg"] == "Invalid appid": + message = "Wolfram API key is invalid or missing." + log.warning( + "API key seems to be missing, or invalid when " + f"processing a wolfram request: {url_str}, Response: {json}" + ) + await send_embed(ctx, message) + return + + message = "Something went wrong internally with your request, please notify staff!" + log.warning(f"Something went wrong getting a response from wolfram: {url_str}, Response: {json}") + await send_embed(ctx, message) + return + + if not result["success"]: + message = f"I couldn't find anything for {query}." + await send_embed(ctx, message) + return + + if not result["numpods"]: + message = "Could not find any results." + await send_embed(ctx, message) + return + + pods = result["pods"] + pages = [] + for pod in pods[:MAX_PODS]: + subs = pod.get("subpods") + + for sub in subs: + title = sub.get("title") or sub.get("plaintext") or sub.get("id", "") + img = sub["img"]["src"] + pages.append((title, img)) + return pages + + +class Wolfram(Cog): + """Commands for interacting with the Wolfram|Alpha API.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_command(self, ctx: Context, *, query: str) -> None: + """Requests all answers on a single image, sends an image of all related pods.""" + url_str = parse.urlencode({ + "i": query, + "appid": APPID, + }) + query = QUERY.format(request="simple", data=url_str) + + # Give feedback that the bot is working. + async with ctx.channel.typing(): + async with self.bot.http_session.get(query) as response: + status = response.status + image_bytes = await response.read() + + f = discord.File(BytesIO(image_bytes), filename="image.png") + image_url = "attachment://image.png" + + if status == 501: + message = "Failed to get response" + footer = "" + color = Colours.soft_red + elif status == 400: + message = "No input found" + footer = "" + color = Colours.soft_red + elif status == 403: + message = "Wolfram API key is invalid or missing." + footer = "" + color = Colours.soft_red + else: + message = "" + footer = "View original for a bigger picture." + color = Colours.soft_orange + + # Sends a "blank" embed if no request is received, unsure how to fix + await send_embed(ctx, message, color, footer=footer, img_url=image_url, f=f) + + @wolfram_command.command(name="page", aliases=("pa", "p")) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_page_command(self, ctx: Context, *, query: str) -> None: + """ + Requests a drawn image of given query. + + Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. + """ + pages = await get_pod_pages(ctx, self.bot, query) + + if not pages: + return + + embed = Embed() + embed.set_author(name="Wolfram Alpha", + icon_url=WOLF_IMAGE, + url="https://www.wolframalpha.com/") + embed.colour = Colours.soft_orange + + await ImagePaginator.paginate(pages, ctx, embed) + + @wolfram_command.command(name="cut", aliases=("c",)) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_cut_command(self, ctx: Context, *, query: str) -> None: + """ + Requests a drawn image of given query. + + Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. + """ + pages = await get_pod_pages(ctx, self.bot, query) + + if not pages: + return + + if len(pages) >= 2: + page = pages[1] + else: + page = pages[0] + + await send_embed(ctx, page[0], colour=Colours.soft_orange, img_url=page[1]) + + @wolfram_command.command(name="short", aliases=("sh", "s")) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_short_command(self, ctx: Context, *, query: str) -> None: + """Requests an answer to a simple question.""" + url_str = parse.urlencode({ + "i": query, + "appid": APPID, + }) + query = QUERY.format(request="result", data=url_str) + + # Give feedback that the bot is working. + async with ctx.channel.typing(): + async with self.bot.http_session.get(query) as response: + status = response.status + response_text = await response.text() + + if status == 501: + message = "Failed to get response" + color = Colours.soft_red + elif status == 400: + message = "No input found" + color = Colours.soft_red + elif response_text == "Error 1: Invalid appid": + message = "Wolfram API key is invalid or missing." + color = Colours.soft_red + else: + message = response_text + color = Colours.soft_orange + + await send_embed(ctx, message, color) + + +def setup(bot: Bot) -> None: + """Load the Wolfram cog.""" + bot.add_cog(Wolfram(bot)) diff --git a/bot/cogs/information.py b/bot/cogs/information.py deleted file mode 100644 index 8982196d1..000000000 --- a/bot/cogs/information.py +++ /dev/null @@ -1,422 +0,0 @@ -import colorsys -import logging -import pprint -import textwrap -from collections import Counter, defaultdict -from string import Template -from typing import Any, Mapping, Optional, Union - -from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils -from discord.abc import GuildChannel -from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group -from discord.utils import escape_markdown - -from bot import constants -from bot.bot import Bot -from bot.decorators import in_whitelist, with_role -from bot.pagination import LinePaginator -from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - - -class Information(Cog): - """A cog with commands for generating embeds with server info, such as server stats and user info.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @staticmethod - def role_can_read(channel: GuildChannel, role: Role) -> bool: - """Return True if `role` can read messages in `channel`.""" - overwrites = channel.overwrites_for(role) - return overwrites.read_messages is True - - def get_staff_channel_count(self, guild: Guild) -> int: - """ - Get the number of channels that are staff-only. - - We need to know two things about a channel: - - Does the @everyone role have explicit read deny permissions? - - Do staff roles have explicit read allow permissions? - - If the answer to both of these questions is yes, it's a staff channel. - """ - channel_ids = set() - for channel in guild.channels: - if channel.type is ChannelType.category: - continue - - everyone_can_read = self.role_can_read(channel, guild.default_role) - - for role in constants.STAFF_ROLES: - role_can_read = self.role_can_read(channel, guild.get_role(role)) - if role_can_read and not everyone_can_read: - channel_ids.add(channel.id) - break - - return len(channel_ids) - - @staticmethod - def get_channel_type_counts(guild: Guild) -> str: - """Return the total amounts of the various types of channels in `guild`.""" - channel_counter = Counter(c.type for c in guild.channels) - channel_type_list = [] - for channel, count in channel_counter.items(): - channel_type = str(channel).title() - channel_type_list.append(f"{channel_type} channels: {count}") - - channel_type_list = sorted(channel_type_list) - return "\n".join(channel_type_list) - - @with_role(*constants.MODERATION_ROLES) - @command(name="roles") - async def roles_info(self, ctx: Context) -> None: - """Returns a list of all roles and their corresponding IDs.""" - # Sort the roles alphabetically and remove the @everyone role - roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) - - # Build a list - role_list = [] - for role in roles: - role_list.append(f"`{role.id}` - {role.mention}") - - # Build an embed - embed = Embed( - title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", - colour=Colour.blurple() - ) - - await LinePaginator.paginate(role_list, ctx, embed, empty=False) - - @with_role(*constants.MODERATION_ROLES) - @command(name="role") - async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: - """ - Return information on a role or list of roles. - - To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. - """ - parsed_roles = [] - failed_roles = [] - - for role_name in roles: - if isinstance(role_name, Role): - # Role conversion has already succeeded - parsed_roles.append(role_name) - continue - - role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) - - if not role: - failed_roles.append(role_name) - continue - - parsed_roles.append(role) - - if failed_roles: - await ctx.send(f":x: Could not retrieve the following roles: {', '.join(failed_roles)}") - - for role in parsed_roles: - h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) - - embed = Embed( - title=f"{role.name} info", - colour=role.colour, - ) - embed.add_field(name="ID", value=role.id, inline=True) - embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) - embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) - embed.add_field(name="Member count", value=len(role.members), inline=True) - embed.add_field(name="Position", value=role.position) - embed.add_field(name="Permission code", value=role.permissions.value, inline=True) - - await ctx.send(embed=embed) - - @command(name="server", aliases=["server_info", "guild", "guild_info"]) - async def server_info(self, ctx: Context) -> None: - """Returns an embed full of server information.""" - created = time_since(ctx.guild.created_at, precision="days") - features = ", ".join(ctx.guild.features) - region = ctx.guild.region - - roles = len(ctx.guild.roles) - member_count = ctx.guild.member_count - channel_counts = self.get_channel_type_counts(ctx.guild) - - # How many of each user status? - statuses = Counter(member.status for member in ctx.guild.members) - embed = Embed(colour=Colour.blurple()) - - # How many staff members and staff channels do we have? - staff_member_count = len(ctx.guild.get_role(constants.Roles.helpers).members) - staff_channel_count = self.get_staff_channel_count(ctx.guild) - - # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the - # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting - # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts - # after the dedent is made. - embed.description = Template( - textwrap.dedent(f""" - **Server information** - Created: {created} - Voice region: {region} - Features: {features} - - **Channel counts** - $channel_counts - Staff channels: {staff_channel_count} - - **Member counts** - Members: {member_count:,} - Staff members: {staff_member_count} - Roles: {roles} - - **Member statuses** - {constants.Emojis.status_online} {statuses[Status.online]:,} - {constants.Emojis.status_idle} {statuses[Status.idle]:,} - {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} - {constants.Emojis.status_offline} {statuses[Status.offline]:,} - """) - ).substitute({"channel_counts": channel_counts}) - embed.set_thumbnail(url=ctx.guild.icon_url) - - await ctx.send(embed=embed) - - @command(name="user", aliases=["user_info", "member", "member_info"]) - async def user_info(self, ctx: Context, user: Member = None) -> None: - """Returns info about a user.""" - if user is None: - user = ctx.author - - # Do a role check if this is being executed on someone other than the caller - elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): - await ctx.send("You may not use this command on users other than yourself.") - return - - # Non-staff may only do this in #bot-commands - if not with_role_check(ctx, *constants.STAFF_ROLES): - if not ctx.channel.id == constants.Channels.bot_commands: - raise InWhitelistCheckFailure(constants.Channels.bot_commands) - - embed = await self.create_user_embed(ctx, user) - - await ctx.send(embed=embed) - - async def create_user_embed(self, ctx: Context, user: Member) -> Embed: - """Creates an embed containing information on the `user`.""" - created = time_since(user.created_at, max_units=3) - - # Custom status - custom_status = '' - for activity in user.activities: - # Check activity.state for None value if user has a custom status set - # This guards against a custom status with an emoji but no text, which will cause - # escape_markdown to raise an exception - # This can be reworked after a move to d.py 1.3.0+, which adds a CustomActivity class - if activity.name == 'Custom Status' and activity.state: - state = escape_markdown(activity.state) - custom_status = f'Status: {state}\n' - - name = str(user) - if user.nick: - name = f"{user.nick} ({name})" - - joined = time_since(user.joined_at, max_units=3) - roles = ", ".join(role.mention for role in user.roles[1:]) - - description = [ - textwrap.dedent(f""" - **User Information** - Created: {created} - Profile: {user.mention} - ID: {user.id} - {custom_status} - **Member Information** - Joined: {joined} - Roles: {roles or None} - """).strip() - ] - - # Show more verbose output in moderation channels for infractions and nominations - if ctx.channel.id in constants.MODERATION_CHANNELS: - description.append(await self.expanded_user_infraction_counts(user)) - description.append(await self.user_nomination_counts(user)) - else: - description.append(await self.basic_user_infraction_counts(user)) - - # Let's build the embed now - embed = Embed( - title=name, - description="\n\n".join(description) - ) - - embed.set_thumbnail(url=user.avatar_url_as(static_format="png")) - embed.colour = user.top_role.colour if roles else Colour.blurple() - - return embed - - async def basic_user_infraction_counts(self, member: Member) -> str: - """Gets the total and active infraction counts for the given `member`.""" - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'hidden': 'False', - 'user__id': str(member.id) - } - ) - - total_infractions = len(infractions) - active_infractions = sum(infraction['active'] for infraction in infractions) - - infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}" - - return infraction_output - - async def expanded_user_infraction_counts(self, member: Member) -> str: - """ - Gets expanded infraction counts for the given `member`. - - The counts will be split by infraction type and the number of active infractions for each type will indicated - in the output as well. - """ - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'user__id': str(member.id) - } - ) - - infraction_output = ["**Infractions**"] - if not infractions: - infraction_output.append("This user has never received an infraction.") - else: - # Count infractions split by `type` and `active` status for this user - infraction_types = set() - infraction_counter = defaultdict(int) - for infraction in infractions: - infraction_type = infraction["type"] - infraction_active = 'active' if infraction["active"] else 'inactive' - - infraction_types.add(infraction_type) - infraction_counter[f"{infraction_active} {infraction_type}"] += 1 - - # Format the output of the infraction counts - for infraction_type in sorted(infraction_types): - active_count = infraction_counter[f"active {infraction_type}"] - total_count = active_count + infraction_counter[f"inactive {infraction_type}"] - - line = f"{infraction_type.capitalize()}s: {total_count}" - if active_count: - line += f" ({active_count} active)" - - infraction_output.append(line) - - return "\n".join(infraction_output) - - async def user_nomination_counts(self, member: Member) -> str: - """Gets the active and historical nomination counts for the given `member`.""" - nominations = await self.bot.api_client.get( - 'bot/nominations', - params={ - 'user__id': str(member.id) - } - ) - - output = ["**Nominations**"] - - if not nominations: - output.append("This user has never been nominated.") - else: - count = len(nominations) - is_currently_nominated = any(nomination["active"] for nomination in nominations) - nomination_noun = "nomination" if count == 1 else "nominations" - - if is_currently_nominated: - output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).") - else: - output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.") - - return "\n".join(output) - - def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str: - """Format a mapping to be readable to a human.""" - # sorting is technically superfluous but nice if you want to look for a specific field - fields = sorted(mapping.items(), key=lambda item: item[0]) - - if field_width is None: - field_width = len(max(mapping.keys(), key=len)) - - out = '' - - for key, val in fields: - if isinstance(val, dict): - # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries - inner_width = int(field_width * 1.6) - val = '\n' + self.format_fields(val, field_width=inner_width) - - elif isinstance(val, str): - # split up text since it might be long - text = textwrap.fill(val, width=100, replace_whitespace=False) - - # indent it, I guess you could do this with `wrap` and `join` but this is nicer - val = textwrap.indent(text, ' ' * (field_width + len(': '))) - - # the first line is already indented so we `str.lstrip` it - val = val.lstrip() - - if key == 'color': - # makes the base 10 representation of a hex number readable to humans - val = hex(val) - - out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width) - - # remove trailing whitespace - return out.rstrip() - - @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) - @group(invoke_without_command=True) - @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) - async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: - """Shows information about the raw API response.""" - # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling - # doing this extra request is also much easier than trying to convert everything back into a dictionary again - raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) - - paginator = Paginator() - - def add_content(title: str, content: str) -> None: - paginator.add_line(f'== {title} ==\n') - # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution. - # we hope it's not close to 2000 - paginator.add_line(content.replace('```', '`` `')) - paginator.close_page() - - if message.content: - add_content('Raw message', message.content) - - transformer = pprint.pformat if json else self.format_fields - for field_name in ('embeds', 'attachments'): - data = raw_data[field_name] - - if not data: - continue - - total = len(data) - for current, item in enumerate(data, start=1): - title = f'Raw {field_name} ({current}/{total})' - add_content(title, transformer(item)) - - for page in paginator.pages: - await ctx.send(page) - - @raw.command() - async def json(self, ctx: Context, message: Message) -> None: - """Shows information about the raw API response in a copy-pasteable Python format.""" - await ctx.invoke(self.raw, message=message, json=True) - - -def setup(bot: Bot) -> None: - """Load the Information cog.""" - bot.add_cog(Information(bot)) diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py deleted file mode 100644 index b3102db2f..000000000 --- a/bot/cogs/jams.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -import typing as t - -from discord import CategoryChannel, Guild, Member, PermissionOverwrite, Role -from discord.ext import commands -from more_itertools import unique_everseen - -from bot.bot import Bot -from bot.constants import Roles -from bot.decorators import with_role - -log = logging.getLogger(__name__) - -MAX_CHANNELS = 50 -CATEGORY_NAME = "Code Jam" - - -class CodeJams(commands.Cog): - """Manages the code-jam related parts of our server.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @commands.command() - @with_role(Roles.admins) - async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: - """ - Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. - - The first user passed will always be the team leader. - """ - # Ignore duplicate members - members = list(unique_everseen(members)) - - # We had a little issue during Code Jam 4 here, the greedy converter did it's job - # and ignored anything which wasn't a valid argument which left us with teams of - # two members or at some times even 1 member. This fixes that by checking that there - # are always 3 members in the members list. - if len(members) < 3: - await ctx.send( - ":no_entry_sign: One of your arguments was invalid\n" - f"There must be a minimum of 3 valid members in your team. Found: {len(members)}" - " members" - ) - return - - team_channel = await self.create_channels(ctx.guild, team_name, members) - await self.add_roles(ctx.guild, members) - - await ctx.send( - f":ok_hand: Team created: {team_channel}\n" - f"**Team Leader:** {members[0].mention}\n" - f"**Team Members:** {' '.join(member.mention for member in members[1:])}" - ) - - async def get_category(self, guild: Guild) -> CategoryChannel: - """ - Return a code jam category. - - If all categories are full or none exist, create a new category. - """ - for category in guild.categories: - # Need 2 available spaces: one for the text channel and one for voice. - if category.name == CATEGORY_NAME and MAX_CHANNELS - len(category.channels) >= 2: - return category - - return await self.create_category(guild) - - @staticmethod - async def create_category(guild: Guild) -> CategoryChannel: - """Create a new code jam category and return it.""" - log.info("Creating a new code jam category.") - - category_overwrites = { - guild.default_role: PermissionOverwrite(read_messages=False), - guild.me: PermissionOverwrite(read_messages=True) - } - - return await guild.create_category_channel( - CATEGORY_NAME, - overwrites=category_overwrites, - reason="It's code jam time!" - ) - - @staticmethod - def get_overwrites(members: t.List[Member], guild: Guild) -> t.Dict[t.Union[Member, Role], PermissionOverwrite]: - """Get code jam team channels permission overwrites.""" - # First member is always the team leader - team_channel_overwrites = { - members[0]: PermissionOverwrite( - manage_messages=True, - read_messages=True, - manage_webhooks=True, - connect=True - ), - guild.default_role: PermissionOverwrite(read_messages=False, connect=False), - guild.get_role(Roles.verified): PermissionOverwrite( - read_messages=False, - connect=False - ) - } - - # Rest of members should just have read_messages - for member in members[1:]: - team_channel_overwrites[member] = PermissionOverwrite( - read_messages=True, - connect=True - ) - - return team_channel_overwrites - - async def create_channels(self, guild: Guild, team_name: str, members: t.List[Member]) -> str: - """Create team text and voice channels. Return the mention for the text channel.""" - # Get permission overwrites and category - team_channel_overwrites = self.get_overwrites(members, guild) - code_jam_category = await self.get_category(guild) - - # Create a text channel for the team - team_channel = await guild.create_text_channel( - team_name, - overwrites=team_channel_overwrites, - category=code_jam_category - ) - - # Create a voice channel for the team - team_voice_name = " ".join(team_name.split("-")).title() - - await guild.create_voice_channel( - team_voice_name, - overwrites=team_channel_overwrites, - category=code_jam_category - ) - - return team_channel.mention - - @staticmethod - async def add_roles(guild: Guild, members: t.List[Member]) -> None: - """Assign team leader and jammer roles.""" - # Assign team leader role - await members[0].add_roles(guild.get_role(Roles.team_leaders)) - - # Assign rest of roles - jammer_role = guild.get_role(Roles.jammers) - for member in members: - await member.add_roles(jammer_role) - - -def setup(bot: Bot) -> None: - """Load the CodeJams cog.""" - bot.add_cog(CodeJams(bot)) diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py deleted file mode 100644 index 94fa2b139..000000000 --- a/bot/cogs/logging.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging - -from discord import Embed -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, DEBUG_MODE - - -log = logging.getLogger(__name__) - - -class Logging(Cog): - """Debug logging module.""" - - def __init__(self, bot: Bot): - self.bot = bot - - self.bot.loop.create_task(self.startup_greeting()) - - async def startup_greeting(self) -> None: - """Announce our presence to the configured devlog channel.""" - await self.bot.wait_until_guild_available() - log.info("Bot connected!") - - embed = Embed(description="Connected!") - embed.set_author( - name="Python Bot", - url="https://github.com/python-discord/bot", - icon_url=( - "https://raw.githubusercontent.com/" - "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" - ) - ) - - if not DEBUG_MODE: - await self.bot.get_channel(Channels.dev_log).send(embed=embed) - - -def setup(bot: Bot) -> None: - """Load the Logging cog.""" - bot.add_cog(Logging(bot)) diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index 995187ef0..aad1f3c26 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,11 +1,11 @@ from bot.bot import Bot from .incidents import Incidents -from .infractions import Infractions -from .management import ModManagement +from .infraction.infractions import Infractions +from .infraction.management import ModManagement +from .infraction.superstarify import Superstarify from .modlog import ModLog from .silence import Silence from .slowmode import Slowmode -from .superstarify import Superstarify def setup(bot: Bot) -> None: diff --git a/bot/cogs/moderation/defcon.py b/bot/cogs/moderation/defcon.py new file mode 100644 index 000000000..4c0ad5914 --- /dev/null +++ b/bot/cogs/moderation/defcon.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import logging +from collections import namedtuple +from datetime import datetime, timedelta +from enum import Enum + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles +from bot.decorators import with_role + +log = logging.getLogger(__name__) + +REJECTION_MESSAGE = """ +Hi, {user} - Thanks for your interest in our server! + +Due to a current (or detected) cyberattack on our community, we've limited access to the server for new accounts. Since +your account is relatively new, we're unable to provide access to the server at this time. + +Even so, thanks for joining! We're very excited at the possibility of having you here, and we hope that this situation +will be resolved soon. In the meantime, please feel free to peruse the resources on our site at +, and have a nice day! +""" + +BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism" + + +class Action(Enum): + """Defcon Action.""" + + ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template']) + + ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n") + DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "") + UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n") + + +class Defcon(Cog): + """Time-sensitive server defense mechanisms.""" + + days = None # type: timedelta + enabled = False # type: bool + + def __init__(self, bot: Bot): + self.bot = bot + self.channel = None + self.days = timedelta(days=0) + + self.bot.loop.create_task(self.sync_settings()) + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def sync_settings(self) -> None: + """On cog load, try to synchronize DEFCON settings to the API.""" + await self.bot.wait_until_guild_available() + self.channel = await self.bot.fetch_channel(Channels.defcon) + + try: + response = await self.bot.api_client.get('bot/bot-settings/defcon') + data = response['data'] + + except Exception: # Yikes! + log.exception("Unable to get DEFCON settings!") + await self.bot.get_channel(Channels.dev_log).send( + f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" + ) + + else: + if data["enabled"]: + self.enabled = True + self.days = timedelta(days=data["days"]) + log.info(f"DEFCON enabled: {self.days.days} days") + + else: + self.enabled = False + self.days = timedelta(days=0) + log.info("DEFCON disabled") + + await self.update_channel_topic() + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """If DEFCON is enabled, check newly joining users to see if they meet the account age threshold.""" + if self.enabled and self.days.days > 0: + now = datetime.utcnow() + + if now - member.created_at < self.days: + log.info(f"Rejecting user {member}: Account is too new and DEFCON is enabled") + + message_sent = False + + try: + await member.send(REJECTION_MESSAGE.format(user=member.mention)) + + message_sent = True + except Exception: + log.exception(f"Unable to send rejection message to user: {member}") + + await member.kick(reason="DEFCON active, user is too new") + self.bot.stats.incr("defcon.leaves") + + message = ( + f"{member} (`{member.id}`) was denied entry because their account is too new." + ) + + if not message_sent: + message = f"{message}\n\nUnable to send rejection message via DM; they probably have DMs disabled." + + await self.mod_log.send_log_message( + Icons.defcon_denied, Colours.soft_red, "Entry denied", + message, member.avatar_url_as(static_format="png") + ) + + @group(name='defcon', aliases=('dc',), invoke_without_command=True) + @with_role(Roles.admins, Roles.owners) + async def defcon_group(self, ctx: Context) -> None: + """Check the DEFCON status or run a subcommand.""" + await ctx.send_help(ctx.command) + + async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: + """Providing a structured way to do an defcon action.""" + try: + response = await self.bot.api_client.get('bot/bot-settings/defcon') + data = response['data'] + + if "enable_date" in data and action is Action.DISABLED: + enabled = datetime.fromisoformat(data["enable_date"]) + + delta = datetime.now() - enabled + + self.bot.stats.timing("defcon.enabled", delta) + except Exception: + pass + + error = None + try: + await self.bot.api_client.put( + 'bot/bot-settings/defcon', + json={ + 'name': 'defcon', + 'data': { + # TODO: retrieve old days count + 'days': days, + 'enabled': action is not Action.DISABLED, + 'enable_date': datetime.now().isoformat() + } + } + ) + except Exception as err: + log.exception("Unable to update DEFCON settings.") + error = err + finally: + await ctx.send(self.build_defcon_msg(action, error)) + await self.send_defcon_log(action, ctx.author, error) + + self.bot.stats.gauge("defcon.threshold", days) + + @defcon_group.command(name='enable', aliases=('on', 'e')) + @with_role(Roles.admins, Roles.owners) + async def enable_command(self, ctx: Context) -> None: + """ + Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! + + Currently, this just adds an account age requirement. Use !defcon days to set how old an account must be, + in days. + """ + self.enabled = True + await self._defcon_action(ctx, days=0, action=Action.ENABLED) + await self.update_channel_topic() + + @defcon_group.command(name='disable', aliases=('off', 'd')) + @with_role(Roles.admins, Roles.owners) + async def disable_command(self, ctx: Context) -> None: + """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" + self.enabled = False + await self._defcon_action(ctx, days=0, action=Action.DISABLED) + await self.update_channel_topic() + + @defcon_group.command(name='status', aliases=('s',)) + @with_role(Roles.admins, Roles.owners) + async def status_command(self, ctx: Context) -> None: + """Check the current status of DEFCON mode.""" + embed = Embed( + colour=Colour.blurple(), title="DEFCON Status", + description=f"**Enabled:** {self.enabled}\n" + f"**Days:** {self.days.days}" + ) + + await ctx.send(embed=embed) + + @defcon_group.command(name='days') + @with_role(Roles.admins, Roles.owners) + async def days_command(self, ctx: Context, days: int) -> None: + """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" + self.days = timedelta(days=days) + self.enabled = True + await self._defcon_action(ctx, days=days, action=Action.UPDATED) + await self.update_channel_topic() + + async def update_channel_topic(self) -> None: + """Update the #defcon channel topic with the current DEFCON status.""" + if self.enabled: + day_str = "days" if self.days.days > 1 else "day" + new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Enabled, Threshold: {self.days.days} {day_str})" + else: + new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Disabled)" + + self.mod_log.ignore(Event.guild_channel_update, Channels.defcon) + await self.channel.edit(topic=new_topic) + + def build_defcon_msg(self, action: Action, e: Exception = None) -> str: + """Build in-channel response string for DEFCON action.""" + if action is Action.ENABLED: + msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n" + elif action is Action.DISABLED: + msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n" + elif action is Action.UPDATED: + msg = ( + f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} " + f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n" + ) + + if e: + msg += ( + "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" + f"```py\n{e}\n```" + ) + + return msg + + async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None: + """Send log message for DEFCON action.""" + info = action.value + log_msg: str = ( + f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n" + f"{info.template.format(days=self.days.days)}" + ) + status_msg = f"DEFCON {action.name.lower()}" + + if e: + log_msg += ( + "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" + f"```py\n{e}\n```" + ) + + await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg) + + +def setup(bot: Bot) -> None: + """Load the Defcon cog.""" + bot.add_cog(Defcon(bot)) diff --git a/bot/cogs/moderation/infraction/__init__.py b/bot/cogs/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/moderation/infraction/infractions.py b/bot/cogs/moderation/infraction/infractions.py new file mode 100644 index 000000000..8df642428 --- /dev/null +++ b/bot/cogs/moderation/infraction/infractions.py @@ -0,0 +1,370 @@ +import logging +import textwrap +import typing as t + +import discord +from discord import Member +from discord.ext import commands +from discord.ext.commands import Context, command + +from bot import constants +from bot.bot import Bot +from bot.constants import Event +from bot.converters import Expiry, FetchedMember +from bot.decorators import respect_role_hierarchy +from bot.utils.checks import with_role_check +from . import utils +from .scheduler import InfractionScheduler +from .utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class Infractions(InfractionScheduler, commands.Cog): + """Apply and pardon infractions on users for moderation purposes.""" + + category = "Moderation" + category_description = "Server moderation tools." + + def __init__(self, bot: Bot): + super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) + + self.category = "Moderation" + self._muted_role = discord.Object(constants.Roles.muted) + + @commands.Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Reapply active mute infractions for returning members.""" + active_mutes = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "mute", + "user__id": member.id + } + ) + + if active_mutes: + reason = f"Re-applying active mute: {active_mutes[0]['id']}" + action = member.add_roles(self._muted_role, reason=reason) + + await self.reapply_infraction(active_mutes[0], action) + + # region: Permanent infractions + + @command() + async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Warn a user for the given reason.""" + infraction = await utils.post_infraction(ctx, user, "warning", reason, active=False) + if infraction is None: + return + + await self.apply_infraction(ctx, infraction, user) + + @command() + async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Kick a user for the given reason.""" + await self.apply_kick(ctx, user, reason) + + @command() + async def ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Permanently ban a user for the given reason and stop watching them with Big Brother.""" + await self.apply_ban(ctx, user, reason) + + # endregion + # region: Temporary infractions + + @command(aliases=["mute"]) + async def tempmute(self, ctx: Context, user: Member, duration: Expiry, *, reason: t.Optional[str] = None) -> None: + """ + Temporarily mute a user for the given reason and duration. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_mute(ctx, user, reason, expires_at=duration) + + @command() + async def tempban( + self, + ctx: Context, + user: FetchedMember, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily ban a user for the given reason and duration. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_ban(ctx, user, reason, expires_at=duration) + + # endregion + # region: Permanent shadow infractions + + @command(hidden=True) + async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Create a private note for a user with the given reason without notifying the user.""" + infraction = await utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) + if infraction is None: + return + + await self.apply_infraction(ctx, infraction, user) + + @command(hidden=True, aliases=['shadowkick', 'skick']) + async def shadow_kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Kick a user for the given reason without notifying the user.""" + await self.apply_kick(ctx, user, reason, hidden=True) + + @command(hidden=True, aliases=['shadowban', 'sban']) + async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Permanently ban a user for the given reason without notifying the user.""" + await self.apply_ban(ctx, user, reason, hidden=True) + + # endregion + # region: Temporary shadow infractions + + @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) + async def shadow_tempmute( + self, ctx: Context, + user: Member, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily mute a user for the given reason and duration without notifying the user. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True) + + @command(hidden=True, aliases=["shadowtempban, stempban"]) + async def shadow_tempban( + self, + ctx: Context, + user: FetchedMember, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily ban a user for the given reason and duration without notifying the user. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) + + # endregion + # region: Remove infractions (un- commands) + + @command() + async def unmute(self, ctx: Context, user: FetchedMember) -> None: + """Prematurely end the active mute infraction for the user.""" + await self.pardon_infraction(ctx, "mute", user) + + @command() + async def unban(self, ctx: Context, user: FetchedMember) -> None: + """Prematurely end the active ban infraction for the user.""" + await self.pardon_infraction(ctx, "ban", user) + + # endregion + # region: Base apply functions + + async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: + """Apply a mute infraction with kwargs passed to `post_infraction`.""" + if await utils.get_active_infraction(ctx, user, "mute"): + return + + infraction = await utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_update, user.id) + + async def action() -> None: + await user.add_roles(self._muted_role, reason=reason) + + log.trace(f"Attempting to kick {user} from voice because they've been muted.") + await user.move_to(None, reason=reason) + + await self.apply_infraction(ctx, infraction, user, action()) + + @respect_role_hierarchy() + async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: + """Apply a kick infraction with kwargs passed to `post_infraction`.""" + infraction = await utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_remove, user.id) + + if reason: + reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = user.kick(reason=reason) + await self.apply_infraction(ctx, infraction, user, action) + + @respect_role_hierarchy() + async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: + """ + Apply a ban infraction with kwargs passed to `post_infraction`. + + Will also remove the banned user from the Big Brother watch list if applicable. + """ + # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active + is_temporary = kwargs.get("expires_at") is not None + active_infraction = await utils.get_active_infraction(ctx, user, "ban", is_temporary) + + if active_infraction: + if is_temporary: + log.trace("Tempban ignored as it cannot overwrite an active ban.") + return + + if active_infraction.get('expires_at') is None: + log.trace("Permaban already exists, notify.") + await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).") + return + + log.trace("Old tempban is being replaced by new permaban.") + await self.pardon_infraction(ctx, "ban", user, is_temporary) + + infraction = await utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_remove, user.id) + + if reason: + reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = ctx.guild.ban(user, reason=reason, delete_message_days=0) + await self.apply_infraction(ctx, infraction, user, action) + + if infraction.get('expires_at') is not None: + log.trace(f"Ban isn't permanent; user {user} won't be unwatched by Big Brother.") + return + + bb_cog = self.bot.get_cog("Big Brother") + if not bb_cog: + log.error(f"Big Brother cog not loaded; perma-banned user {user} won't be unwatched.") + return + + log.trace(f"Big Brother cog loaded; attempting to unwatch perma-banned user {user}.") + + bb_reason = "User has been permanently banned from the server. Automatically removed." + await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) + + # endregion + # region: Base pardon functions + + async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + """Remove a user's muted role, DM them a notification, and return a log dict.""" + user = guild.get_member(user_id) + log_text = {} + + if user: + # Remove the muted role. + self.mod_log.ignore(Event.member_update, user.id) + await user.remove_roles(self._muted_role, reason=reason) + + # DM the user about the expiration. + notified = await utils.notify_pardon( + user=user, + title="You have been unmuted", + content="You may now send messages in the server.", + icon_url=utils.INFRACTION_ICONS["mute"][1] + ) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["DM"] = "Sent" if notified else "**Failed**" + else: + log.info(f"Failed to unmute user {user_id}: user not found") + log_text["Failure"] = "User was not found in the guild." + + return log_text + + async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + """Remove a user's ban on the Discord guild and return a log dict.""" + user = discord.Object(user_id) + log_text = {} + + self.mod_log.ignore(Event.member_unban, user_id) + + try: + await guild.unban(user, reason=reason) + except discord.NotFound: + log.info(f"Failed to unban user {user_id}: no active ban found on Discord") + log_text["Note"] = "No active ban found on Discord." + + return log_text + + async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + guild = self.bot.get_guild(constants.Guild.id) + user_id = infraction["user"] + reason = f"Infraction #{infraction['id']} expired or was pardoned." + + if infraction["type"] == "mute": + return await self.pardon_mute(user_id, guild, reason) + elif infraction["type"] == "ban": + return await self.pardon_ban(user_id, guild, reason) + + # endregion + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Send a notification to the invoking context on a Union failure.""" + if isinstance(error, commands.BadUnionArgument): + if discord.User in error.converters or discord.Member in error.converters: + await ctx.send(str(error.errors[0])) + error.handled = True diff --git a/bot/cogs/moderation/infraction/management.py b/bot/cogs/moderation/infraction/management.py new file mode 100644 index 000000000..791585b6e --- /dev/null +++ b/bot/cogs/moderation/infraction/management.py @@ -0,0 +1,305 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext import commands +from discord.ext.commands import Context + +from bot import constants +from bot.bot import Bot +from bot.cogs.moderation.modlog import ModLog +from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user +from bot.pagination import LinePaginator +from bot.utils import time +from bot.utils.checks import in_whitelist_check, with_role_check +from . import utils +from .infractions import Infractions + +log = logging.getLogger(__name__) + + +class ModManagement(commands.Cog): + """Management of infractions.""" + + category = "Moderation" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @property + def infractions_cog(self) -> Infractions: + """Get currently loaded Infractions cog instance.""" + return self.bot.get_cog("Infractions") + + # region: Edit infraction commands + + @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) + async def infraction_group(self, ctx: Context) -> None: + """Infraction manipulation commands.""" + await ctx.send_help(ctx.command) + + @infraction_group.command(name='edit') + async def infraction_edit( + self, + ctx: Context, + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821 + duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821 + *, + reason: str = None + ) -> None: + """ + Edit the duration and/or the reason of an infraction. + + Durations are relative to the time of updating and should be appended with a unit of time. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. + + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. + """ + if duration is None and reason is None: + # Unlike UserInputError, the error handler will show a specified message for BadArgument + raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") + + # Retrieve the previous infraction for its information. + if isinstance(infraction_id, str): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get("bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + ":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") + + request_data = {} + confirm_messages = [] + log_text = "" + + if duration is not None and not old_infraction['active']: + if reason is None: + await ctx.send(":x: Cannot edit the expiration of an expired infraction.") + return + confirm_messages.append("expiry unchanged (infraction already expired)") + elif isinstance(duration, str): + request_data['expires_at'] = None + confirm_messages.append("marked as permanent") + elif duration is not None: + request_data['expires_at'] = duration.isoformat() + expiry = time.format_infraction_with_duration(request_data['expires_at']) + confirm_messages.append(f"set to expire on {expiry}") + else: + confirm_messages.append("expiry unchanged") + + if reason: + request_data['reason'] = reason + confirm_messages.append("set a new reason") + log_text += f""" + Previous reason: {old_infraction['reason']} + New reason: {reason} + """.rstrip() + else: + confirm_messages.append("reason unchanged") + + # Update the infraction + new_infraction = await self.bot.api_client.patch( + f'bot/infractions/{infraction_id}', + json=request_data, + ) + + # Re-schedule infraction if the expiration has been updated + if 'expires_at' in request_data: + # A scheduled task should only exist if the old infraction wasn't permanent + if old_infraction['expires_at']: + self.infractions_cog.scheduler.cancel(new_infraction['id']) + + # If the infraction was not marked as permanent, schedule a new expiration task + if request_data['expires_at']: + self.infractions_cog.schedule_expiration(new_infraction) + + log_text += f""" + Previous expiry: {old_infraction['expires_at'] or "Permanent"} + New expiry: {new_infraction['expires_at'] or "Permanent"} + """.rstrip() + + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") + + # Get information about the infraction's user + user_id = new_infraction['user'] + user = ctx.guild.get_member(user_id) + + if user: + user_text = f"{user.mention} (`{user.id}`)" + thumbnail = user.avatar_url_as(static_format="png") + else: + user_text = f"`{user_id}`" + thumbnail = None + + # The infraction's actor + actor_id = new_infraction['actor'] + actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`" + + await self.mod_log.send_log_message( + icon_url=constants.Icons.pencil, + colour=discord.Colour.blurple(), + title="Infraction edited", + thumbnail=thumbnail, + text=textwrap.dedent(f""" + Member: {user_text} + Actor: {actor} + Edited by: {ctx.message.author}{log_text} + """) + ) + + # endregion + # region: Search infractions + + @infraction_group.group(name="search", invoke_without_command=True) + async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None: + """Searches for infractions in the database.""" + if isinstance(query, discord.User): + await ctx.invoke(self.search_user, query) + else: + await ctx.invoke(self.search_reason, query) + + @infraction_search_group.command(name="user", aliases=("member", "id")) + async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: + """Search for infractions by member.""" + infraction_list = await self.bot.api_client.get( + 'bot/infractions', + params={'user__id': str(user.id)} + ) + embed = discord.Embed( + title=f"Infractions for {user} ({len(infraction_list)} total)", + colour=discord.Colour.orange() + ) + await self.send_infraction_list(ctx, embed, infraction_list) + + @infraction_search_group.command(name="reason", aliases=("match", "regex", "re")) + async def search_reason(self, ctx: Context, reason: str) -> None: + """Search for infractions by their reason. Use Re2 for matching.""" + infraction_list = await self.bot.api_client.get( + 'bot/infractions', + params={'search': reason} + ) + embed = discord.Embed( + title=f"Infractions matching `{reason}` ({len(infraction_list)} total)", + colour=discord.Colour.orange() + ) + await self.send_infraction_list(ctx, embed, infraction_list) + + # endregion + # region: Utility functions + + async def send_infraction_list( + self, + ctx: Context, + embed: discord.Embed, + infractions: t.Iterable[utils.Infraction] + ) -> None: + """Send a paginated embed of infractions for the specified user.""" + if not infractions: + await ctx.send(":warning: No infractions could be found for that query.") + return + + lines = tuple( + self.infraction_to_string(infraction) + for infraction in infractions + ) + + await LinePaginator.paginate( + lines, + ctx=ctx, + embed=embed, + empty=True, + max_lines=3, + max_size=1000 + ) + + def infraction_to_string(self, infraction: utils.Infraction) -> str: + """Convert the infraction object to a string representation.""" + actor_id = infraction["actor"] + guild = self.bot.get_guild(constants.Guild.id) + actor = guild.get_member(actor_id) + active = infraction["active"] + user_id = infraction["user"] + hidden = infraction["hidden"] + created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" + else: + remaining = "Inactive" + + if infraction["expires_at"] is None: + expires = "*Permanent*" + else: + date_from = datetime.strptime(created, time.INFRACTION_FORMAT) + expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) + + lines = textwrap.dedent(f""" + {"**===============**" if active else "==============="} + Status: {"__**Active**__" if active else "Inactive"} + User: {self.bot.get_user(user_id)} (`{user_id}`) + Type: **{infraction["type"]}** + Shadow: {hidden} + Created: {created} + Expires: {expires} + Remaining: {remaining} + Actor: {actor.mention if actor else actor_id} + ID: `{infraction["id"]}` + Reason: {infraction["reason"] or "*None*"} + {"**===============**" if active else "==============="} + """) + + return lines.strip() + + # endregion + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators inside moderator channels to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_whitelist_check( + ctx, + channels=constants.MODERATION_CHANNELS, + categories=[constants.Categories.modmail], + redirect=None, + fail_silently=True, + ) + ] + return all(checks) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Send a notification to the invoking context on a Union failure.""" + if isinstance(error, commands.BadUnionArgument): + if discord.User in error.converters: + await ctx.send(str(error.errors[0])) + error.handled = True diff --git a/bot/cogs/moderation/infraction/scheduler.py b/bot/cogs/moderation/infraction/scheduler.py new file mode 100644 index 000000000..b3d27fe76 --- /dev/null +++ b/bot/cogs/moderation/infraction/scheduler.py @@ -0,0 +1,463 @@ +import logging +import textwrap +import typing as t +from abc import abstractmethod +from datetime import datetime +from gettext import ngettext + +import dateutil.parser +import discord +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.cogs.moderation.modlog import ModLog +from bot.constants import Colours, STAFF_CHANNELS +from bot.utils import time +from bot.utils.scheduling import Scheduler +from . import utils +from .utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class InfractionScheduler: + """Handles the application, pardoning, and expiration of infractions.""" + + def __init__(self, bot: Bot, supported_infractions: t.Container[str]): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + @property + def mod_log(self) -> ModLog: + """Get the currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: + """Schedule expiration for previous infractions.""" + await self.bot.wait_until_guild_available() + + log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") + + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={'active': 'true'} + ) + for infraction in infractions: + if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: + self.schedule_expiration(infraction) + + async def reapply_infraction( + self, + infraction: utils.Infraction, + apply_coro: t.Optional[t.Awaitable] + ) -> None: + """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" + # Calculate the time remaining, in seconds, for the mute. + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + delta = (expiry - datetime.utcnow()).total_seconds() + + # Mark as inactive if less than a minute remains. + if delta < 60: + log.info( + "Infraction will be deactivated instead of re-applied " + "because less than 1 minute remains." + ) + await self.deactivate_infraction(infraction) + return + + # Allowing mod log since this is a passive action that should be logged. + await apply_coro + log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") + + async def apply_infraction( + self, + ctx: Context, + infraction: utils.Infraction, + user: UserSnowflake, + action_coro: t.Optional[t.Awaitable] = None + ) -> None: + """Apply an infraction to the user, log the infraction, and optionally notify the user.""" + infr_type = infraction["type"] + icon = utils.INFRACTION_ICONS[infr_type][0] + reason = infraction["reason"] + expiry = time.format_infraction_with_duration(infraction["expires_at"]) + id_ = infraction['id'] + + log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") + + # Default values for the confirmation message and mod log. + confirm_msg = ":ok_hand: applied" + + # Specifying an expiry for a note or warning makes no sense. + if infr_type in ("note", "warning"): + expiry_msg = "" + else: + expiry_msg = f" until {expiry}" if expiry else " permanently" + + dm_result = "" + dm_log_text = "" + expiry_log_text = f"\nExpires: {expiry}" if expiry else "" + log_title = "applied" + log_content = None + failed = False + + # DM the user about the infraction if it's not a shadow/hidden infraction. + # This needs to happen before we apply the infraction, as the bot cannot + # send DMs to user that it doesn't share a guild with. If we were to + # apply kick/ban infractions first, this would mean that we'd make it + # impossible for us to deliver a DM. See python-discord/bot#982. + if not infraction["hidden"]: + dm_result = f"{constants.Emojis.failmail} " + dm_log_text = "\nDM: **Failed**" + + # Sometimes user is a discord.Object; make it a proper user. + try: + if not isinstance(user, (discord.Member, discord.User)): + user = await self.bot.fetch_user(user.id) + except discord.HTTPException as e: + log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + else: + # Accordingly display whether the user was successfully notified via DM. + if await utils.notify_infraction(user, infr_type, expiry, reason, icon): + dm_result = ":incoming_envelope: " + dm_log_text = "\nDM: Sent" + + end_msg = "" + if infraction["actor"] == self.bot.user.id: + log.trace( + f"Infraction #{id_} actor is bot; including the reason in the confirmation message." + ) + if reason: + end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" + elif ctx.channel.id not in STAFF_CHANNELS: + log.trace( + f"Infraction #{id_} context is not in a staff channel; omitting infraction count." + ) + else: + log.trace(f"Fetching total infraction count for {user}.") + + infractions = await self.bot.api_client.get( + "bot/infractions", + params={"user__id": str(user.id)} + ) + total = len(infractions) + end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" + + # Execute the necessary actions to apply the infraction on Discord. + if action_coro: + log.trace(f"Awaiting the infraction #{id_} application action coroutine.") + try: + await action_coro + if expiry: + # Schedule the expiration of the infraction. + self.schedule_expiration(infraction) + except discord.HTTPException as e: + # Accordingly display that applying the infraction failed. + confirm_msg = ":x: failed to apply" + expiry_msg = "" + log_content = ctx.author.mention + log_title = "failed to apply" + + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) + failed = True + + if failed: + log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") + try: + await self.bot.api_client.delete(f"bot/infractions/{id_}") + except ResponseCodeError as e: + confirm_msg += " and failed to delete" + log_title += " and failed to delete" + log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") + infr_message = "" + else: + infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" + + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{id_} confirmation message.") + await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") + + # Send a log message to the mod log. + log.trace(f"Sending apply mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=icon, + colour=Colours.soft_red, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {user.mention} (`{user.id}`) + Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} + Reason: {reason} + """), + content=log_content, + footer=f"ID {infraction['id']}" + ) + + log.info(f"Applied {infr_type} infraction #{id_} to {user}.") + + async def pardon_infraction( + self, + ctx: Context, + infr_type: str, + user: UserSnowflake, + send_msg: bool = True + ) -> None: + """ + Prematurely end an infraction for a user and log the action in the mod log. + + If `send_msg` is True, then a pardoning confirmation message will be sent to + the context channel. Otherwise, no such message will be sent. + """ + log.trace(f"Pardoning {infr_type} infraction for {user}.") + + # Check the current active infraction + log.trace(f"Fetching active {infr_type} infractions for {user}.") + response = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': user.id + } + ) + + if not response: + log.debug(f"No active {infr_type} infraction found for {user}.") + await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") + return + + # Deactivate the infraction and cancel its scheduled expiration task. + log_text = await self.deactivate_infraction(response[0], send_log=False) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["Actor"] = str(ctx.message.author) + log_content = None + id_ = response[0]['id'] + footer = f"ID: {id_}" + + # If multiple active infractions were found, mark them as inactive in the database + # and cancel their expiration tasks. + if len(response) > 1: + log.info( + f"Found more than one active {infr_type} infraction for user {user.id}; " + "deactivating the extra active infractions too." + ) + + footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" + + log_note = f"Found multiple **active** {infr_type} infractions in the database." + if "Note" in log_text: + log_text["Note"] = f" {log_note}" + else: + log_text["Note"] = log_note + + # deactivate_infraction() is not called again because: + # 1. Discord cannot store multiple active bans or assign multiples of the same role + # 2. It would send a pardon DM for each active infraction, which is redundant + for infraction in response[1:]: + id_ = infraction['id'] + try: + # Mark infraction as inactive in the database. + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError: + log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") + # This is simpler and cleaner than trying to concatenate all the errors. + log_text["Failure"] = "See bot's logs for details." + + # Cancel pending expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Accordingly display whether the user was successfully notified via DM. + dm_emoji = "" + if log_text.get("DM") == "Sent": + dm_emoji = ":incoming_envelope: " + elif "DM" in log_text: + dm_emoji = f"{constants.Emojis.failmail} " + + # Accordingly display whether the pardon failed. + if "Failure" in log_text: + confirm_msg = ":x: failed to pardon" + log_title = "pardon failed" + log_content = ctx.author.mention + + log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") + else: + confirm_msg = ":ok_hand: pardoned" + log_title = "pardoned" + + log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") + + # Send a confirmation message to the invoking context. + if send_msg: + log.trace(f"Sending infraction #{id_} pardon confirmation message.") + await ctx.send( + f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " + f"{log_text.get('Failure', '')}" + ) + + # Move reason to end of entry to avoid cutting out some keys + log_text["Reason"] = log_text.pop("Reason") + + # Send a log message to the mod log. + await self.mod_log.send_log_message( + icon_url=utils.INFRACTION_ICONS[infr_type][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=footer, + content=log_content, + ) + + async def deactivate_infraction( + self, + infraction: utils.Infraction, + send_log: bool = True + ) -> t.Dict[str, str]: + """ + Deactivate an active infraction and return a dictionary of lines to send in a mod log. + + The infraction is removed from Discord, marked as inactive in the database, and has its + expiration task cancelled. If `send_log` is True, a mod log is sent for the + deactivation of the infraction. + + Infractions of unsupported types will raise a ValueError. + """ + guild = self.bot.get_guild(constants.Guild.id) + mod_role = guild.get_role(constants.Roles.moderators) + user_id = infraction["user"] + actor = infraction["actor"] + type_ = infraction["type"] + id_ = infraction["id"] + inserted_at = infraction["inserted_at"] + expiry = infraction["expires_at"] + + log.info(f"Marking infraction #{id_} as inactive (expired).") + + expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None + created = time.format_infraction_with_duration(inserted_at, expiry) + + log_content = None + log_text = { + "Member": f"<@{user_id}>", + "Actor": str(self.bot.get_user(actor) or actor), + "Reason": infraction["reason"], + "Created": created, + } + + try: + log.trace("Awaiting the pardon action coroutine.") + returned_log = await self._pardon_action(infraction) + + if returned_log is not None: + log_text = {**log_text, **returned_log} # Merge the logs together + else: + raise ValueError( + f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" + ) + except discord.Forbidden: + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") + log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" + log_content = mod_role.mention + except discord.HTTPException as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." + log_content = mod_role.mention + + # Check if the user is currently being watched by Big Brother. + try: + log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") + + active_watch = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "watch", + "user__id": user_id + } + ) + + log_text["Watching"] = "Yes" if active_watch else "No" + except ResponseCodeError: + log.exception(f"Failed to fetch watch status for user {user_id}") + log_text["Watching"] = "Unknown - failed to fetch watch status." + + try: + # Mark infraction as inactive in the database. + log.trace(f"Marking infraction #{id_} as inactive in the database.") + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_line = f"API request failed with code {e.status}." + log_content = mod_role.mention + + # Append to an existing failure message if possible + if "Failure" in log_text: + log_text["Failure"] += f" {log_line}" + else: + log_text["Failure"] = log_line + + # Cancel the expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Send a log message to the mod log. + if send_log: + log_title = "expiration failed" if "Failure" in log_text else "expired" + + user = self.bot.get_user(user_id) + avatar = user.avatar_url_as(static_format="png") if user else None + + # Move reason to end so when reason is too long, this is not gonna cut out required items. + log_text["Reason"] = log_text.pop("Reason") + + log.trace(f"Sending deactivation mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=utils.INFRACTION_ICONS[type_][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {type_}", + thumbnail=avatar, + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=f"ID: {id_}", + content=log_content, + ) + + return log_text + + @abstractmethod + async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + raise NotImplementedError + + def schedule_expiration(self, infraction: utils.Infraction) -> None: + """ + Marks an infraction expired after the delay from time of scheduling to time of expiration. + + At the time of expiration, the infraction is marked as inactive on the website and the + expiration task is cancelled. + """ + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/infraction/superstarify.py b/bot/cogs/moderation/infraction/superstarify.py new file mode 100644 index 000000000..867de815a --- /dev/null +++ b/bot/cogs/moderation/infraction/superstarify.py @@ -0,0 +1,239 @@ +import json +import logging +import random +import textwrap +import typing as t +from pathlib import Path + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, command + +from bot import constants +from bot.bot import Bot +from bot.converters import Expiry +from bot.utils.checks import with_role_check +from bot.utils.time import format_infraction +from . import utils +from .scheduler import InfractionScheduler + +log = logging.getLogger(__name__) +NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" + +with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file: + STAR_NAMES = json.load(stars_file) + + +class Superstarify(InfractionScheduler, Cog): + """A set of commands to moderate terrible nicknames.""" + + def __init__(self, bot: Bot): + super().__init__(bot, supported_infractions={"superstar"}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Revert nickname edits if the user has an active superstarify infraction.""" + if before.display_name == after.display_name: + return # User didn't change their nickname. Abort! + + log.trace( + f"{before} ({before.display_name}) is trying to change their nickname to " + f"{after.display_name}. Checking if the user is in superstar-prison..." + ) + + active_superstarifies = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "superstar", + "user__id": str(before.id) + } + ) + + if not active_superstarifies: + log.trace(f"{before} has no active superstar infractions.") + return + + infraction = active_superstarifies[0] + forced_nick = self.get_nick(infraction["id"], before.id) + if after.display_name == forced_nick: + return # Nick change was triggered by this event. Ignore. + + log.info( + f"{after.display_name} ({after.id}) tried to escape superstar prison. " + f"Changing the nick back to {before.display_name}." + ) + await after.edit( + nick=forced_nick, + reason=f"Superstarified member tried to escape the prison: {infraction['id']}" + ) + + notified = await utils.notify_infraction( + user=after, + infr_type="Superstarify", + expires_at=format_infraction(infraction["expires_at"]), + reason=( + "You have tried to change your nickname on the **Python Discord** server " + f"from **{before.display_name}** to **{after.display_name}**, but as you " + "are currently in superstar-prison, you do not have permission to do so." + ), + icon_url=utils.INFRACTION_ICONS["superstar"][0] + ) + + if not notified: + log.info("Failed to DM user about why they cannot change their nickname.") + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Reapply active superstar infractions for returning members.""" + active_superstarifies = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "superstar", + "user__id": member.id + } + ) + + if active_superstarifies: + infraction = active_superstarifies[0] + action = member.edit( + nick=self.get_nick(infraction["id"], member.id), + reason=f"Superstarified member tried to escape the prison: {infraction['id']}" + ) + + await self.reapply_infraction(infraction, action) + + @command(name="superstarify", aliases=("force_nick", "star")) + async def superstarify( + self, + ctx: Context, + member: Member, + duration: Expiry, + *, + reason: str = None, + ) -> None: + """ + Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + + An optional reason can be provided. If no reason is given, the original name will be shown + in a generated reason. + """ + if await utils.get_active_infraction(ctx, member, "superstar"): + return + + # Post the infraction to the API + reason = reason or f"old nick: {member.display_name}" + infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) + id_ = infraction["id"] + + old_nick = member.display_name + forced_nick = self.get_nick(id_, member.id) + expiry_str = format_infraction(infraction["expires_at"]) + + # Apply the infraction and schedule the expiration task. + log.debug(f"Changing nickname of {member} to {forced_nick}.") + self.mod_log.ignore(constants.Event.member_update, member.id) + await member.edit(nick=forced_nick, reason=reason) + self.schedule_expiration(infraction) + + # Send a DM to the user to notify them of their new infraction. + await utils.notify_infraction( + user=member, + infr_type="Superstarify", + expires_at=expiry_str, + icon_url=utils.INFRACTION_ICONS["superstar"][0], + reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." + ) + + # Send an embed with the infraction information to the invoking context. + log.trace(f"Sending superstar #{id_} embed.") + embed = Embed( + title="Congratulations!", + colour=constants.Colours.soft_orange, + description=( + f"Your previous nickname, **{old_nick}**, " + f"was so bad that we have decided to change it. " + f"Your new nickname will be **{forced_nick}**.\n\n" + f"You will be unable to change your nickname until **{expiry_str}**.\n\n" + "If you're confused by this, please read our " + f"[official nickname policy]({NICKNAME_POLICY_URL})." + ) + ) + await ctx.send(embed=embed) + + # Log to the mod log channel. + log.trace(f"Sending apply mod log for superstar #{id_}.") + await self.mod_log.send_log_message( + icon_url=utils.INFRACTION_ICONS["superstar"][0], + colour=Colour.gold(), + title="Member achieved superstardom", + thumbnail=member.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {member.mention} (`{member.id}`) + Actor: {ctx.message.author} + Expires: {expiry_str} + Old nickname: `{old_nick}` + New nickname: `{forced_nick}` + Reason: {reason} + """), + footer=f"ID {id_}" + ) + + @command(name="unsuperstarify", aliases=("release_nick", "unstar")) + async def unsuperstarify(self, ctx: Context, member: Member) -> None: + """Remove the superstarify infraction and allow the user to change their nickname.""" + await self.pardon_infraction(ctx, "superstar", member) + + async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """Pardon a superstar infraction and return a log dict.""" + if infraction["type"] != "superstar": + return + + guild = self.bot.get_guild(constants.Guild.id) + user = guild.get_member(infraction["user"]) + + # Don't bother sending a notification if the user left the guild. + if not user: + log.debug( + "User left the guild and therefore won't be notified about superstar " + f"{infraction['id']} pardon." + ) + return {} + + # DM the user about the expiration. + notified = await utils.notify_pardon( + user=user, + title="You are no longer superstarified", + content="You may now change your nickname on the server.", + icon_url=utils.INFRACTION_ICONS["superstar"][1] + ) + + return { + "Member": f"{user.mention}(`{user.id}`)", + "DM": "Sent" if notified else "**Failed**" + } + + @staticmethod + def get_nick(infraction_id: int, member_id: int) -> str: + """Randomly select a nickname from the Superstarify nickname list.""" + log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") + + rng = random.Random(str(infraction_id) + str(member_id)) + return rng.choice(STAR_NAMES) + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) diff --git a/bot/cogs/moderation/infraction/utils.py b/bot/cogs/moderation/infraction/utils.py new file mode 100644 index 000000000..fb55287b6 --- /dev/null +++ b/bot/cogs/moderation/infraction/utils.py @@ -0,0 +1,201 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext.commands import Context + +from bot.api import ResponseCodeError +from bot.constants import Colours, Icons + +log = logging.getLogger(__name__) + +# apply icon, pardon icon +INFRACTION_ICONS = { + "ban": (Icons.user_ban, Icons.user_unban), + "kick": (Icons.sign_out, None), + "mute": (Icons.user_mute, Icons.user_unmute), + "note": (Icons.user_warn, None), + "superstar": (Icons.superstarify, Icons.unsuperstarify), + "warning": (Icons.user_warn, None), +} +RULES_URL = "https://pythondiscord.com/pages/rules" +APPEALABLE_INFRACTIONS = ("ban", "mute") + +# Type aliases +UserObject = t.Union[discord.Member, discord.User] +UserSnowflake = t.Union[UserObject, discord.Object] +Infraction = t.Dict[str, t.Union[str, int, bool]] + + +async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: + """ + Create a new user in the database. + + Used when an infraction needs to be applied on a user absent in the guild. + """ + log.trace(f"Attempting to add user {user.id} to the database.") + + if not isinstance(user, (discord.Member, discord.User)): + log.debug("The user being added to the DB is not a Member or User object.") + + payload = { + 'discriminator': int(getattr(user, 'discriminator', 0)), + 'id': user.id, + 'in_guild': False, + 'name': getattr(user, 'name', 'Name unknown'), + 'roles': [] + } + + try: + response = await ctx.bot.api_client.post('bot/users', json=payload) + log.info(f"User {user.id} added to the DB.") + return response + except ResponseCodeError as e: + log.error(f"Failed to add user {user.id} to the DB. {e}") + await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") + + +async def post_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + reason: str, + expires_at: datetime = None, + hidden: bool = False, + active: bool = True +) -> t.Optional[dict]: + """Posts an infraction to the API.""" + log.trace(f"Posting {infr_type} infraction for {user} to the API.") + + payload = { + "actor": ctx.message.author.id, + "hidden": hidden, + "reason": reason, + "type": infr_type, + "user": user.id, + "active": active + } + if expires_at: + payload['expires_at'] = expires_at.isoformat() + + # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. + for should_post_user in (True, False): + try: + response = await ctx.bot.api_client.post('bot/infractions', json=payload) + return response + except ResponseCodeError as e: + if e.status == 400 and 'user' in e.response_json: + # Only one attempt to add the user to the database, not two: + if not should_post_user or await post_user(ctx, user) is None: + return + else: + log.exception(f"Unexpected error while adding an infraction for {user}:") + await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") + return + + +async def get_active_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + send_msg: bool = True +) -> t.Optional[dict]: + """ + Retrieves an active infraction of the given type for the user. + + If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, + then a message for the moderator will be sent to the context channel letting them know. + Otherwise, no message will be sent. + """ + log.trace(f"Checking if {user} has active infractions of type {infr_type}.") + + active_infractions = await ctx.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': str(user.id) + } + ) + if active_infractions: + # Checks to see if the moderator should be told there is an active infraction + if send_msg: + log.trace(f"{user} has active infractions of type {infr_type}.") + await ctx.send( + f":x: According to my records, this user already has a {infr_type} infraction. " + f"See infraction **#{active_infractions[0]['id']}**." + ) + return active_infractions[0] + else: + log.trace(f"{user} does not have active infractions of type {infr_type}.") + + +async def notify_infraction( + user: UserObject, + infr_type: str, + expires_at: t.Optional[str] = None, + reason: t.Optional[str] = None, + icon_url: str = Icons.token_removed +) -> bool: + """DM a user about their new infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + + text = textwrap.dedent(f""" + **Type:** {infr_type.capitalize()} + **Expires:** {expires_at or "N/A"} + **Reason:** {reason or "No reason provided."} + """) + + embed = discord.Embed( + description=textwrap.shorten(text, width=2048, placeholder="..."), + colour=Colours.soft_red + ) + + embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) + embed.title = f"Please review our rules over at {RULES_URL}" + embed.url = RULES_URL + + if infr_type in APPEALABLE_INFRACTIONS: + embed.set_footer( + text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" + ) + + return await send_private_embed(user, embed) + + +async def notify_pardon( + user: UserObject, + title: str, + content: str, + icon_url: str = Icons.user_verified +) -> bool: + """DM a user about their pardoned infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their pardoned infraction.") + + embed = discord.Embed( + description=content, + colour=Colours.soft_green + ) + + embed.set_author(name=title, icon_url=icon_url) + + return await send_private_embed(user, embed) + + +async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: + """ + A helper method for sending an embed to a user's DMs. + + Returns a boolean indicator of DM success. + """ + try: + await user.send(embed=embed) + return True + except (discord.HTTPException, discord.Forbidden, discord.NotFound): + log.debug( + f"Infraction-related information could not be sent to user {user} ({user.id}). " + "The user either could not be retrieved or probably disabled their DMs." + ) + return False diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py deleted file mode 100644 index 8df642428..000000000 --- a/bot/cogs/moderation/infractions.py +++ /dev/null @@ -1,370 +0,0 @@ -import logging -import textwrap -import typing as t - -import discord -from discord import Member -from discord.ext import commands -from discord.ext.commands import Context, command - -from bot import constants -from bot.bot import Bot -from bot.constants import Event -from bot.converters import Expiry, FetchedMember -from bot.decorators import respect_role_hierarchy -from bot.utils.checks import with_role_check -from . import utils -from .scheduler import InfractionScheduler -from .utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class Infractions(InfractionScheduler, commands.Cog): - """Apply and pardon infractions on users for moderation purposes.""" - - category = "Moderation" - category_description = "Server moderation tools." - - def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) - - self.category = "Moderation" - self._muted_role = discord.Object(constants.Roles.muted) - - @commands.Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Reapply active mute infractions for returning members.""" - active_mutes = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "mute", - "user__id": member.id - } - ) - - if active_mutes: - reason = f"Re-applying active mute: {active_mutes[0]['id']}" - action = member.add_roles(self._muted_role, reason=reason) - - await self.reapply_infraction(active_mutes[0], action) - - # region: Permanent infractions - - @command() - async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Warn a user for the given reason.""" - infraction = await utils.post_infraction(ctx, user, "warning", reason, active=False) - if infraction is None: - return - - await self.apply_infraction(ctx, infraction, user) - - @command() - async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Kick a user for the given reason.""" - await self.apply_kick(ctx, user, reason) - - @command() - async def ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Permanently ban a user for the given reason and stop watching them with Big Brother.""" - await self.apply_ban(ctx, user, reason) - - # endregion - # region: Temporary infractions - - @command(aliases=["mute"]) - async def tempmute(self, ctx: Context, user: Member, duration: Expiry, *, reason: t.Optional[str] = None) -> None: - """ - Temporarily mute a user for the given reason and duration. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_mute(ctx, user, reason, expires_at=duration) - - @command() - async def tempban( - self, - ctx: Context, - user: FetchedMember, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily ban a user for the given reason and duration. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_ban(ctx, user, reason, expires_at=duration) - - # endregion - # region: Permanent shadow infractions - - @command(hidden=True) - async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Create a private note for a user with the given reason without notifying the user.""" - infraction = await utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) - if infraction is None: - return - - await self.apply_infraction(ctx, infraction, user) - - @command(hidden=True, aliases=['shadowkick', 'skick']) - async def shadow_kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Kick a user for the given reason without notifying the user.""" - await self.apply_kick(ctx, user, reason, hidden=True) - - @command(hidden=True, aliases=['shadowban', 'sban']) - async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Permanently ban a user for the given reason without notifying the user.""" - await self.apply_ban(ctx, user, reason, hidden=True) - - # endregion - # region: Temporary shadow infractions - - @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) - async def shadow_tempmute( - self, ctx: Context, - user: Member, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily mute a user for the given reason and duration without notifying the user. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True) - - @command(hidden=True, aliases=["shadowtempban, stempban"]) - async def shadow_tempban( - self, - ctx: Context, - user: FetchedMember, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily ban a user for the given reason and duration without notifying the user. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) - - # endregion - # region: Remove infractions (un- commands) - - @command() - async def unmute(self, ctx: Context, user: FetchedMember) -> None: - """Prematurely end the active mute infraction for the user.""" - await self.pardon_infraction(ctx, "mute", user) - - @command() - async def unban(self, ctx: Context, user: FetchedMember) -> None: - """Prematurely end the active ban infraction for the user.""" - await self.pardon_infraction(ctx, "ban", user) - - # endregion - # region: Base apply functions - - async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: - """Apply a mute infraction with kwargs passed to `post_infraction`.""" - if await utils.get_active_infraction(ctx, user, "mute"): - return - - infraction = await utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_update, user.id) - - async def action() -> None: - await user.add_roles(self._muted_role, reason=reason) - - log.trace(f"Attempting to kick {user} from voice because they've been muted.") - await user.move_to(None, reason=reason) - - await self.apply_infraction(ctx, infraction, user, action()) - - @respect_role_hierarchy() - async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: - """Apply a kick infraction with kwargs passed to `post_infraction`.""" - infraction = await utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_remove, user.id) - - if reason: - reason = textwrap.shorten(reason, width=512, placeholder="...") - - action = user.kick(reason=reason) - await self.apply_infraction(ctx, infraction, user, action) - - @respect_role_hierarchy() - async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: - """ - Apply a ban infraction with kwargs passed to `post_infraction`. - - Will also remove the banned user from the Big Brother watch list if applicable. - """ - # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active - is_temporary = kwargs.get("expires_at") is not None - active_infraction = await utils.get_active_infraction(ctx, user, "ban", is_temporary) - - if active_infraction: - if is_temporary: - log.trace("Tempban ignored as it cannot overwrite an active ban.") - return - - if active_infraction.get('expires_at') is None: - log.trace("Permaban already exists, notify.") - await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).") - return - - log.trace("Old tempban is being replaced by new permaban.") - await self.pardon_infraction(ctx, "ban", user, is_temporary) - - infraction = await utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_remove, user.id) - - if reason: - reason = textwrap.shorten(reason, width=512, placeholder="...") - - action = ctx.guild.ban(user, reason=reason, delete_message_days=0) - await self.apply_infraction(ctx, infraction, user, action) - - if infraction.get('expires_at') is not None: - log.trace(f"Ban isn't permanent; user {user} won't be unwatched by Big Brother.") - return - - bb_cog = self.bot.get_cog("Big Brother") - if not bb_cog: - log.error(f"Big Brother cog not loaded; perma-banned user {user} won't be unwatched.") - return - - log.trace(f"Big Brother cog loaded; attempting to unwatch perma-banned user {user}.") - - bb_reason = "User has been permanently banned from the server. Automatically removed." - await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) - - # endregion - # region: Base pardon functions - - async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's muted role, DM them a notification, and return a log dict.""" - user = guild.get_member(user_id) - log_text = {} - - if user: - # Remove the muted role. - self.mod_log.ignore(Event.member_update, user.id) - await user.remove_roles(self._muted_role, reason=reason) - - # DM the user about the expiration. - notified = await utils.notify_pardon( - user=user, - title="You have been unmuted", - content="You may now send messages in the server.", - icon_url=utils.INFRACTION_ICONS["mute"][1] - ) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["DM"] = "Sent" if notified else "**Failed**" - else: - log.info(f"Failed to unmute user {user_id}: user not found") - log_text["Failure"] = "User was not found in the guild." - - return log_text - - async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's ban on the Discord guild and return a log dict.""" - user = discord.Object(user_id) - log_text = {} - - self.mod_log.ignore(Event.member_unban, user_id) - - try: - await guild.unban(user, reason=reason) - except discord.NotFound: - log.info(f"Failed to unban user {user_id}: no active ban found on Discord") - log_text["Note"] = "No active ban found on Discord." - - return log_text - - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - guild = self.bot.get_guild(constants.Guild.id) - user_id = infraction["user"] - reason = f"Infraction #{infraction['id']} expired or was pardoned." - - if infraction["type"] == "mute": - return await self.pardon_mute(user_id, guild, reason) - elif infraction["type"] == "ban": - return await self.pardon_ban(user_id, guild, reason) - - # endregion - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Send a notification to the invoking context on a Union failure.""" - if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters or discord.Member in error.converters: - await ctx.send(str(error.errors[0])) - error.handled = True diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py deleted file mode 100644 index 672bb0e9c..000000000 --- a/bot/cogs/moderation/management.py +++ /dev/null @@ -1,305 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext import commands -from discord.ext.commands import Context - -from bot import constants -from bot.bot import Bot -from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user -from bot.pagination import LinePaginator -from bot.utils import time -from bot.utils.checks import in_whitelist_check, with_role_check -from . import utils -from .infractions import Infractions -from .modlog import ModLog - -log = logging.getLogger(__name__) - - -class ModManagement(commands.Cog): - """Management of infractions.""" - - category = "Moderation" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @property - def infractions_cog(self) -> Infractions: - """Get currently loaded Infractions cog instance.""" - return self.bot.get_cog("Infractions") - - # region: Edit infraction commands - - @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) - async def infraction_group(self, ctx: Context) -> None: - """Infraction manipulation commands.""" - await ctx.send_help(ctx.command) - - @infraction_group.command(name='edit') - async def infraction_edit( - self, - ctx: Context, - infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821 - duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821 - *, - reason: str = None - ) -> None: - """ - Edit the duration and/or the reason of an infraction. - - Durations are relative to the time of updating and should be appended with a unit of time. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction - authored by the command invoker should be edited. - - Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 - timestamp can be provided for the duration. - """ - if duration is None and reason is None: - # Unlike UserInputError, the error handler will show a specified message for BadArgument - raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") - - # Retrieve the previous infraction for its information. - if isinstance(infraction_id, str): - params = { - "actor__id": ctx.author.id, - "ordering": "-inserted_at" - } - infractions = await self.bot.api_client.get("bot/infractions", params=params) - - if infractions: - old_infraction = infractions[0] - infraction_id = old_infraction["id"] - else: - await ctx.send( - ":x: Couldn't find most recent infraction; you have never given an infraction." - ) - return - else: - old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") - - request_data = {} - confirm_messages = [] - log_text = "" - - if duration is not None and not old_infraction['active']: - if reason is None: - await ctx.send(":x: Cannot edit the expiration of an expired infraction.") - return - confirm_messages.append("expiry unchanged (infraction already expired)") - elif isinstance(duration, str): - request_data['expires_at'] = None - confirm_messages.append("marked as permanent") - elif duration is not None: - request_data['expires_at'] = duration.isoformat() - expiry = time.format_infraction_with_duration(request_data['expires_at']) - confirm_messages.append(f"set to expire on {expiry}") - else: - confirm_messages.append("expiry unchanged") - - if reason: - request_data['reason'] = reason - confirm_messages.append("set a new reason") - log_text += f""" - Previous reason: {old_infraction['reason']} - New reason: {reason} - """.rstrip() - else: - confirm_messages.append("reason unchanged") - - # Update the infraction - new_infraction = await self.bot.api_client.patch( - f'bot/infractions/{infraction_id}', - json=request_data, - ) - - # Re-schedule infraction if the expiration has been updated - if 'expires_at' in request_data: - # A scheduled task should only exist if the old infraction wasn't permanent - if old_infraction['expires_at']: - self.infractions_cog.scheduler.cancel(new_infraction['id']) - - # If the infraction was not marked as permanent, schedule a new expiration task - if request_data['expires_at']: - self.infractions_cog.schedule_expiration(new_infraction) - - log_text += f""" - Previous expiry: {old_infraction['expires_at'] or "Permanent"} - New expiry: {new_infraction['expires_at'] or "Permanent"} - """.rstrip() - - changes = ' & '.join(confirm_messages) - await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") - - # Get information about the infraction's user - user_id = new_infraction['user'] - user = ctx.guild.get_member(user_id) - - if user: - user_text = f"{user.mention} (`{user.id}`)" - thumbnail = user.avatar_url_as(static_format="png") - else: - user_text = f"`{user_id}`" - thumbnail = None - - # The infraction's actor - actor_id = new_infraction['actor'] - actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`" - - await self.mod_log.send_log_message( - icon_url=constants.Icons.pencil, - colour=discord.Colour.blurple(), - title="Infraction edited", - thumbnail=thumbnail, - text=textwrap.dedent(f""" - Member: {user_text} - Actor: {actor} - Edited by: {ctx.message.author}{log_text} - """) - ) - - # endregion - # region: Search infractions - - @infraction_group.group(name="search", invoke_without_command=True) - async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None: - """Searches for infractions in the database.""" - if isinstance(query, discord.User): - await ctx.invoke(self.search_user, query) - else: - await ctx.invoke(self.search_reason, query) - - @infraction_search_group.command(name="user", aliases=("member", "id")) - async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: - """Search for infractions by member.""" - infraction_list = await self.bot.api_client.get( - 'bot/infractions', - params={'user__id': str(user.id)} - ) - embed = discord.Embed( - title=f"Infractions for {user} ({len(infraction_list)} total)", - colour=discord.Colour.orange() - ) - await self.send_infraction_list(ctx, embed, infraction_list) - - @infraction_search_group.command(name="reason", aliases=("match", "regex", "re")) - async def search_reason(self, ctx: Context, reason: str) -> None: - """Search for infractions by their reason. Use Re2 for matching.""" - infraction_list = await self.bot.api_client.get( - 'bot/infractions', - params={'search': reason} - ) - embed = discord.Embed( - title=f"Infractions matching `{reason}` ({len(infraction_list)} total)", - colour=discord.Colour.orange() - ) - await self.send_infraction_list(ctx, embed, infraction_list) - - # endregion - # region: Utility functions - - async def send_infraction_list( - self, - ctx: Context, - embed: discord.Embed, - infractions: t.Iterable[utils.Infraction] - ) -> None: - """Send a paginated embed of infractions for the specified user.""" - if not infractions: - await ctx.send(":warning: No infractions could be found for that query.") - return - - lines = tuple( - self.infraction_to_string(infraction) - for infraction in infractions - ) - - await LinePaginator.paginate( - lines, - ctx=ctx, - embed=embed, - empty=True, - max_lines=3, - max_size=1000 - ) - - def infraction_to_string(self, infraction: utils.Infraction) -> str: - """Convert the infraction object to a string representation.""" - actor_id = infraction["actor"] - guild = self.bot.get_guild(constants.Guild.id) - actor = guild.get_member(actor_id) - active = infraction["active"] - user_id = infraction["user"] - hidden = infraction["hidden"] - created = time.format_infraction(infraction["inserted_at"]) - - if active: - remaining = time.until_expiration(infraction["expires_at"]) or "Expired" - else: - remaining = "Inactive" - - if infraction["expires_at"] is None: - expires = "*Permanent*" - else: - date_from = datetime.strptime(created, time.INFRACTION_FORMAT) - expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) - - lines = textwrap.dedent(f""" - {"**===============**" if active else "==============="} - Status: {"__**Active**__" if active else "Inactive"} - User: {self.bot.get_user(user_id)} (`{user_id}`) - Type: **{infraction["type"]}** - Shadow: {hidden} - Created: {created} - Expires: {expires} - Remaining: {remaining} - Actor: {actor.mention if actor else actor_id} - ID: `{infraction["id"]}` - Reason: {infraction["reason"] or "*None*"} - {"**===============**" if active else "==============="} - """) - - return lines.strip() - - # endregion - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators inside moderator channels to invoke the commands in this cog.""" - checks = [ - with_role_check(ctx, *constants.MODERATION_ROLES), - in_whitelist_check( - ctx, - channels=constants.MODERATION_CHANNELS, - categories=[constants.Categories.modmail], - redirect=None, - fail_silently=True, - ) - ] - return all(checks) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Send a notification to the invoking context on a Union failure.""" - if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters: - await ctx.send(str(error.errors[0])) - error.handled = True diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py deleted file mode 100644 index 75028d851..000000000 --- a/bot/cogs/moderation/scheduler.py +++ /dev/null @@ -1,463 +0,0 @@ -import logging -import textwrap -import typing as t -from abc import abstractmethod -from datetime import datetime -from gettext import ngettext - -import dateutil.parser -import discord -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Colours, STAFF_CHANNELS -from bot.utils import time -from bot.utils.scheduling import Scheduler -from . import utils -from .modlog import ModLog -from .utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class InfractionScheduler: - """Handles the application, pardoning, and expiration of infractions.""" - - def __init__(self, bot: Bot, supported_infractions: t.Container[str]): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - @property - def mod_log(self) -> ModLog: - """Get the currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: - """Schedule expiration for previous infractions.""" - await self.bot.wait_until_guild_available() - - log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") - - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={'active': 'true'} - ) - for infraction in infractions: - if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_expiration(infraction) - - async def reapply_infraction( - self, - infraction: utils.Infraction, - apply_coro: t.Optional[t.Awaitable] - ) -> None: - """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" - # Calculate the time remaining, in seconds, for the mute. - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - delta = (expiry - datetime.utcnow()).total_seconds() - - # Mark as inactive if less than a minute remains. - if delta < 60: - log.info( - "Infraction will be deactivated instead of re-applied " - "because less than 1 minute remains." - ) - await self.deactivate_infraction(infraction) - return - - # Allowing mod log since this is a passive action that should be logged. - await apply_coro - log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") - - async def apply_infraction( - self, - ctx: Context, - infraction: utils.Infraction, - user: UserSnowflake, - action_coro: t.Optional[t.Awaitable] = None - ) -> None: - """Apply an infraction to the user, log the infraction, and optionally notify the user.""" - infr_type = infraction["type"] - icon = utils.INFRACTION_ICONS[infr_type][0] - reason = infraction["reason"] - expiry = time.format_infraction_with_duration(infraction["expires_at"]) - id_ = infraction['id'] - - log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") - - # Default values for the confirmation message and mod log. - confirm_msg = ":ok_hand: applied" - - # Specifying an expiry for a note or warning makes no sense. - if infr_type in ("note", "warning"): - expiry_msg = "" - else: - expiry_msg = f" until {expiry}" if expiry else " permanently" - - dm_result = "" - dm_log_text = "" - expiry_log_text = f"\nExpires: {expiry}" if expiry else "" - log_title = "applied" - log_content = None - failed = False - - # DM the user about the infraction if it's not a shadow/hidden infraction. - # This needs to happen before we apply the infraction, as the bot cannot - # send DMs to user that it doesn't share a guild with. If we were to - # apply kick/ban infractions first, this would mean that we'd make it - # impossible for us to deliver a DM. See python-discord/bot#982. - if not infraction["hidden"]: - dm_result = f"{constants.Emojis.failmail} " - dm_log_text = "\nDM: **Failed**" - - # Sometimes user is a discord.Object; make it a proper user. - try: - if not isinstance(user, (discord.Member, discord.User)): - user = await self.bot.fetch_user(user.id) - except discord.HTTPException as e: - log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") - else: - # Accordingly display whether the user was successfully notified via DM. - if await utils.notify_infraction(user, infr_type, expiry, reason, icon): - dm_result = ":incoming_envelope: " - dm_log_text = "\nDM: Sent" - - end_msg = "" - if infraction["actor"] == self.bot.user.id: - log.trace( - f"Infraction #{id_} actor is bot; including the reason in the confirmation message." - ) - if reason: - end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" - elif ctx.channel.id not in STAFF_CHANNELS: - log.trace( - f"Infraction #{id_} context is not in a staff channel; omitting infraction count." - ) - else: - log.trace(f"Fetching total infraction count for {user}.") - - infractions = await self.bot.api_client.get( - "bot/infractions", - params={"user__id": str(user.id)} - ) - total = len(infractions) - end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" - - # Execute the necessary actions to apply the infraction on Discord. - if action_coro: - log.trace(f"Awaiting the infraction #{id_} application action coroutine.") - try: - await action_coro - if expiry: - # Schedule the expiration of the infraction. - self.schedule_expiration(infraction) - except discord.HTTPException as e: - # Accordingly display that applying the infraction failed. - confirm_msg = ":x: failed to apply" - expiry_msg = "" - log_content = ctx.author.mention - log_title = "failed to apply" - - log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" - if isinstance(e, discord.Forbidden): - log.warning(f"{log_msg}: bot lacks permissions.") - else: - log.exception(log_msg) - failed = True - - if failed: - log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") - try: - await self.bot.api_client.delete(f"bot/infractions/{id_}") - except ResponseCodeError as e: - confirm_msg += " and failed to delete" - log_title += " and failed to delete" - log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") - infr_message = "" - else: - infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" - - # Send a confirmation message to the invoking context. - log.trace(f"Sending infraction #{id_} confirmation message.") - await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") - - # Send a log message to the mod log. - log.trace(f"Sending apply mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=icon, - colour=Colours.soft_red, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {user.mention} (`{user.id}`) - Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} - Reason: {reason} - """), - content=log_content, - footer=f"ID {infraction['id']}" - ) - - log.info(f"Applied {infr_type} infraction #{id_} to {user}.") - - async def pardon_infraction( - self, - ctx: Context, - infr_type: str, - user: UserSnowflake, - send_msg: bool = True - ) -> None: - """ - Prematurely end an infraction for a user and log the action in the mod log. - - If `send_msg` is True, then a pardoning confirmation message will be sent to - the context channel. Otherwise, no such message will be sent. - """ - log.trace(f"Pardoning {infr_type} infraction for {user}.") - - # Check the current active infraction - log.trace(f"Fetching active {infr_type} infractions for {user}.") - response = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': user.id - } - ) - - if not response: - log.debug(f"No active {infr_type} infraction found for {user}.") - await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") - return - - # Deactivate the infraction and cancel its scheduled expiration task. - log_text = await self.deactivate_infraction(response[0], send_log=False) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["Actor"] = str(ctx.message.author) - log_content = None - id_ = response[0]['id'] - footer = f"ID: {id_}" - - # If multiple active infractions were found, mark them as inactive in the database - # and cancel their expiration tasks. - if len(response) > 1: - log.info( - f"Found more than one active {infr_type} infraction for user {user.id}; " - "deactivating the extra active infractions too." - ) - - footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" - - log_note = f"Found multiple **active** {infr_type} infractions in the database." - if "Note" in log_text: - log_text["Note"] = f" {log_note}" - else: - log_text["Note"] = log_note - - # deactivate_infraction() is not called again because: - # 1. Discord cannot store multiple active bans or assign multiples of the same role - # 2. It would send a pardon DM for each active infraction, which is redundant - for infraction in response[1:]: - id_ = infraction['id'] - try: - # Mark infraction as inactive in the database. - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError: - log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") - # This is simpler and cleaner than trying to concatenate all the errors. - log_text["Failure"] = "See bot's logs for details." - - # Cancel pending expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Accordingly display whether the user was successfully notified via DM. - dm_emoji = "" - if log_text.get("DM") == "Sent": - dm_emoji = ":incoming_envelope: " - elif "DM" in log_text: - dm_emoji = f"{constants.Emojis.failmail} " - - # Accordingly display whether the pardon failed. - if "Failure" in log_text: - confirm_msg = ":x: failed to pardon" - log_title = "pardon failed" - log_content = ctx.author.mention - - log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") - else: - confirm_msg = ":ok_hand: pardoned" - log_title = "pardoned" - - log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") - - # Send a confirmation message to the invoking context. - if send_msg: - log.trace(f"Sending infraction #{id_} pardon confirmation message.") - await ctx.send( - f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " - f"{log_text.get('Failure', '')}" - ) - - # Move reason to end of entry to avoid cutting out some keys - log_text["Reason"] = log_text.pop("Reason") - - # Send a log message to the mod log. - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[infr_type][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=footer, - content=log_content, - ) - - async def deactivate_infraction( - self, - infraction: utils.Infraction, - send_log: bool = True - ) -> t.Dict[str, str]: - """ - Deactivate an active infraction and return a dictionary of lines to send in a mod log. - - The infraction is removed from Discord, marked as inactive in the database, and has its - expiration task cancelled. If `send_log` is True, a mod log is sent for the - deactivation of the infraction. - - Infractions of unsupported types will raise a ValueError. - """ - guild = self.bot.get_guild(constants.Guild.id) - mod_role = guild.get_role(constants.Roles.moderators) - user_id = infraction["user"] - actor = infraction["actor"] - type_ = infraction["type"] - id_ = infraction["id"] - inserted_at = infraction["inserted_at"] - expiry = infraction["expires_at"] - - log.info(f"Marking infraction #{id_} as inactive (expired).") - - expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None - created = time.format_infraction_with_duration(inserted_at, expiry) - - log_content = None - log_text = { - "Member": f"<@{user_id}>", - "Actor": str(self.bot.get_user(actor) or actor), - "Reason": infraction["reason"], - "Created": created, - } - - try: - log.trace("Awaiting the pardon action coroutine.") - returned_log = await self._pardon_action(infraction) - - if returned_log is not None: - log_text = {**log_text, **returned_log} # Merge the logs together - else: - raise ValueError( - f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" - ) - except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") - log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" - log_content = mod_role.mention - except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." - log_content = mod_role.mention - - # Check if the user is currently being watched by Big Brother. - try: - log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") - - active_watch = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "watch", - "user__id": user_id - } - ) - - log_text["Watching"] = "Yes" if active_watch else "No" - except ResponseCodeError: - log.exception(f"Failed to fetch watch status for user {user_id}") - log_text["Watching"] = "Unknown - failed to fetch watch status." - - try: - # Mark infraction as inactive in the database. - log.trace(f"Marking infraction #{id_} as inactive in the database.") - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_line = f"API request failed with code {e.status}." - log_content = mod_role.mention - - # Append to an existing failure message if possible - if "Failure" in log_text: - log_text["Failure"] += f" {log_line}" - else: - log_text["Failure"] = log_line - - # Cancel the expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Send a log message to the mod log. - if send_log: - log_title = "expiration failed" if "Failure" in log_text else "expired" - - user = self.bot.get_user(user_id) - avatar = user.avatar_url_as(static_format="png") if user else None - - # Move reason to end so when reason is too long, this is not gonna cut out required items. - log_text["Reason"] = log_text.pop("Reason") - - log.trace(f"Sending deactivation mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[type_][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {type_}", - thumbnail=avatar, - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=f"ID: {id_}", - content=log_content, - ) - - return log_text - - @abstractmethod - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - raise NotImplementedError - - def schedule_expiration(self, infraction: utils.Infraction) -> None: - """ - Marks an infraction expired after the delay from time of scheduling to time of expiration. - - At the time of expiration, the infraction is marked as inactive on the website and the - expiration task is cancelled. - """ - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py deleted file mode 100644 index 867de815a..000000000 --- a/bot/cogs/moderation/superstarify.py +++ /dev/null @@ -1,239 +0,0 @@ -import json -import logging -import random -import textwrap -import typing as t -from pathlib import Path - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, command - -from bot import constants -from bot.bot import Bot -from bot.converters import Expiry -from bot.utils.checks import with_role_check -from bot.utils.time import format_infraction -from . import utils -from .scheduler import InfractionScheduler - -log = logging.getLogger(__name__) -NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" - -with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file: - STAR_NAMES = json.load(stars_file) - - -class Superstarify(InfractionScheduler, Cog): - """A set of commands to moderate terrible nicknames.""" - - def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"superstar"}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Revert nickname edits if the user has an active superstarify infraction.""" - if before.display_name == after.display_name: - return # User didn't change their nickname. Abort! - - log.trace( - f"{before} ({before.display_name}) is trying to change their nickname to " - f"{after.display_name}. Checking if the user is in superstar-prison..." - ) - - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "superstar", - "user__id": str(before.id) - } - ) - - if not active_superstarifies: - log.trace(f"{before} has no active superstar infractions.") - return - - infraction = active_superstarifies[0] - forced_nick = self.get_nick(infraction["id"], before.id) - if after.display_name == forced_nick: - return # Nick change was triggered by this event. Ignore. - - log.info( - f"{after.display_name} ({after.id}) tried to escape superstar prison. " - f"Changing the nick back to {before.display_name}." - ) - await after.edit( - nick=forced_nick, - reason=f"Superstarified member tried to escape the prison: {infraction['id']}" - ) - - notified = await utils.notify_infraction( - user=after, - infr_type="Superstarify", - expires_at=format_infraction(infraction["expires_at"]), - reason=( - "You have tried to change your nickname on the **Python Discord** server " - f"from **{before.display_name}** to **{after.display_name}**, but as you " - "are currently in superstar-prison, you do not have permission to do so." - ), - icon_url=utils.INFRACTION_ICONS["superstar"][0] - ) - - if not notified: - log.info("Failed to DM user about why they cannot change their nickname.") - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Reapply active superstar infractions for returning members.""" - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "superstar", - "user__id": member.id - } - ) - - if active_superstarifies: - infraction = active_superstarifies[0] - action = member.edit( - nick=self.get_nick(infraction["id"], member.id), - reason=f"Superstarified member tried to escape the prison: {infraction['id']}" - ) - - await self.reapply_infraction(infraction, action) - - @command(name="superstarify", aliases=("force_nick", "star")) - async def superstarify( - self, - ctx: Context, - member: Member, - duration: Expiry, - *, - reason: str = None, - ) -> None: - """ - Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - - An optional reason can be provided. If no reason is given, the original name will be shown - in a generated reason. - """ - if await utils.get_active_infraction(ctx, member, "superstar"): - return - - # Post the infraction to the API - reason = reason or f"old nick: {member.display_name}" - infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) - id_ = infraction["id"] - - old_nick = member.display_name - forced_nick = self.get_nick(id_, member.id) - expiry_str = format_infraction(infraction["expires_at"]) - - # Apply the infraction and schedule the expiration task. - log.debug(f"Changing nickname of {member} to {forced_nick}.") - self.mod_log.ignore(constants.Event.member_update, member.id) - await member.edit(nick=forced_nick, reason=reason) - self.schedule_expiration(infraction) - - # Send a DM to the user to notify them of their new infraction. - await utils.notify_infraction( - user=member, - infr_type="Superstarify", - expires_at=expiry_str, - icon_url=utils.INFRACTION_ICONS["superstar"][0], - reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." - ) - - # Send an embed with the infraction information to the invoking context. - log.trace(f"Sending superstar #{id_} embed.") - embed = Embed( - title="Congratulations!", - colour=constants.Colours.soft_orange, - description=( - f"Your previous nickname, **{old_nick}**, " - f"was so bad that we have decided to change it. " - f"Your new nickname will be **{forced_nick}**.\n\n" - f"You will be unable to change your nickname until **{expiry_str}**.\n\n" - "If you're confused by this, please read our " - f"[official nickname policy]({NICKNAME_POLICY_URL})." - ) - ) - await ctx.send(embed=embed) - - # Log to the mod log channel. - log.trace(f"Sending apply mod log for superstar #{id_}.") - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS["superstar"][0], - colour=Colour.gold(), - title="Member achieved superstardom", - thumbnail=member.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {member.mention} (`{member.id}`) - Actor: {ctx.message.author} - Expires: {expiry_str} - Old nickname: `{old_nick}` - New nickname: `{forced_nick}` - Reason: {reason} - """), - footer=f"ID {id_}" - ) - - @command(name="unsuperstarify", aliases=("release_nick", "unstar")) - async def unsuperstarify(self, ctx: Context, member: Member) -> None: - """Remove the superstarify infraction and allow the user to change their nickname.""" - await self.pardon_infraction(ctx, "superstar", member) - - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """Pardon a superstar infraction and return a log dict.""" - if infraction["type"] != "superstar": - return - - guild = self.bot.get_guild(constants.Guild.id) - user = guild.get_member(infraction["user"]) - - # Don't bother sending a notification if the user left the guild. - if not user: - log.debug( - "User left the guild and therefore won't be notified about superstar " - f"{infraction['id']} pardon." - ) - return {} - - # DM the user about the expiration. - notified = await utils.notify_pardon( - user=user, - title="You are no longer superstarified", - content="You may now change your nickname on the server.", - icon_url=utils.INFRACTION_ICONS["superstar"][1] - ) - - return { - "Member": f"{user.mention}(`{user.id}`)", - "DM": "Sent" if notified else "**Failed**" - } - - @staticmethod - def get_nick(infraction_id: int, member_id: int) -> str: - """Randomly select a nickname from the Superstarify nickname list.""" - log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") - - rng = random.Random(str(infraction_id) + str(member_id)) - return rng.choice(STAR_NAMES) - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py deleted file mode 100644 index fb55287b6..000000000 --- a/bot/cogs/moderation/utils.py +++ /dev/null @@ -1,201 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext.commands import Context - -from bot.api import ResponseCodeError -from bot.constants import Colours, Icons - -log = logging.getLogger(__name__) - -# apply icon, pardon icon -INFRACTION_ICONS = { - "ban": (Icons.user_ban, Icons.user_unban), - "kick": (Icons.sign_out, None), - "mute": (Icons.user_mute, Icons.user_unmute), - "note": (Icons.user_warn, None), - "superstar": (Icons.superstarify, Icons.unsuperstarify), - "warning": (Icons.user_warn, None), -} -RULES_URL = "https://pythondiscord.com/pages/rules" -APPEALABLE_INFRACTIONS = ("ban", "mute") - -# Type aliases -UserObject = t.Union[discord.Member, discord.User] -UserSnowflake = t.Union[UserObject, discord.Object] -Infraction = t.Dict[str, t.Union[str, int, bool]] - - -async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: - """ - Create a new user in the database. - - Used when an infraction needs to be applied on a user absent in the guild. - """ - log.trace(f"Attempting to add user {user.id} to the database.") - - if not isinstance(user, (discord.Member, discord.User)): - log.debug("The user being added to the DB is not a Member or User object.") - - payload = { - 'discriminator': int(getattr(user, 'discriminator', 0)), - 'id': user.id, - 'in_guild': False, - 'name': getattr(user, 'name', 'Name unknown'), - 'roles': [] - } - - try: - response = await ctx.bot.api_client.post('bot/users', json=payload) - log.info(f"User {user.id} added to the DB.") - return response - except ResponseCodeError as e: - log.error(f"Failed to add user {user.id} to the DB. {e}") - await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") - - -async def post_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - reason: str, - expires_at: datetime = None, - hidden: bool = False, - active: bool = True -) -> t.Optional[dict]: - """Posts an infraction to the API.""" - log.trace(f"Posting {infr_type} infraction for {user} to the API.") - - payload = { - "actor": ctx.message.author.id, - "hidden": hidden, - "reason": reason, - "type": infr_type, - "user": user.id, - "active": active - } - if expires_at: - payload['expires_at'] = expires_at.isoformat() - - # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. - for should_post_user in (True, False): - try: - response = await ctx.bot.api_client.post('bot/infractions', json=payload) - return response - except ResponseCodeError as e: - if e.status == 400 and 'user' in e.response_json: - # Only one attempt to add the user to the database, not two: - if not should_post_user or await post_user(ctx, user) is None: - return - else: - log.exception(f"Unexpected error while adding an infraction for {user}:") - await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") - return - - -async def get_active_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - send_msg: bool = True -) -> t.Optional[dict]: - """ - Retrieves an active infraction of the given type for the user. - - If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, - then a message for the moderator will be sent to the context channel letting them know. - Otherwise, no message will be sent. - """ - log.trace(f"Checking if {user} has active infractions of type {infr_type}.") - - active_infractions = await ctx.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': str(user.id) - } - ) - if active_infractions: - # Checks to see if the moderator should be told there is an active infraction - if send_msg: - log.trace(f"{user} has active infractions of type {infr_type}.") - await ctx.send( - f":x: According to my records, this user already has a {infr_type} infraction. " - f"See infraction **#{active_infractions[0]['id']}**." - ) - return active_infractions[0] - else: - log.trace(f"{user} does not have active infractions of type {infr_type}.") - - -async def notify_infraction( - user: UserObject, - infr_type: str, - expires_at: t.Optional[str] = None, - reason: t.Optional[str] = None, - icon_url: str = Icons.token_removed -) -> bool: - """DM a user about their new infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their {infr_type} infraction.") - - text = textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """) - - embed = discord.Embed( - description=textwrap.shorten(text, width=2048, placeholder="..."), - colour=Colours.soft_red - ) - - embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) - embed.title = f"Please review our rules over at {RULES_URL}" - embed.url = RULES_URL - - if infr_type in APPEALABLE_INFRACTIONS: - embed.set_footer( - text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" - ) - - return await send_private_embed(user, embed) - - -async def notify_pardon( - user: UserObject, - title: str, - content: str, - icon_url: str = Icons.user_verified -) -> bool: - """DM a user about their pardoned infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their pardoned infraction.") - - embed = discord.Embed( - description=content, - colour=Colours.soft_green - ) - - embed.set_author(name=title, icon_url=icon_url) - - return await send_private_embed(user, embed) - - -async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: - """ - A helper method for sending an embed to a user's DMs. - - Returns a boolean indicator of DM success. - """ - try: - await user.send(embed=embed) - return True - except (discord.HTTPException, discord.Forbidden, discord.NotFound): - log.debug( - f"Infraction-related information could not be sent to user {user} ({user.id}). " - "The user either could not be retrieved or probably disabled their DMs." - ) - return False diff --git a/bot/cogs/moderation/verification.py b/bot/cogs/moderation/verification.py new file mode 100644 index 000000000..ae156cf70 --- /dev/null +++ b/bot/cogs/moderation/verification.py @@ -0,0 +1,191 @@ +import logging +from contextlib import suppress + +from discord import Colour, Forbidden, Message, NotFound, Object +from discord.ext.commands import Cog, Context, command + +from bot import constants +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.decorators import in_whitelist, without_role +from bot.utils.checks import InWhitelistCheckFailure, without_role_check + +log = logging.getLogger(__name__) + +WELCOME_MESSAGE = f""" +Hello! Welcome to the server, and thanks for verifying yourself! + +For your records, these are the documents you accepted: + +`1)` Our rules, here: +`2)` Our privacy policy, here: - you can find information on how to have \ +your information removed here as well. + +Feel free to review them at any point! + +Additionally, if you'd like to receive notifications for the announcements \ +we post in <#{constants.Channels.announcements}> +from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ +to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. + +If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ +<#{constants.Channels.bot_commands}>. +""" + +BOT_MESSAGE_DELETE_DELAY = 10 + + +class Verification(Cog): + """User verification and role self-management.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Check new message event for messages to the checkpoint channel & process.""" + if message.channel.id != constants.Channels.verification: + return # Only listen for #checkpoint messages + + if message.author.bot: + # They're a bot, delete their message after the delay. + await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) + return + + # if a user mentions a role or guild member + # alert the mods in mod-alerts channel + if message.mentions or message.role_mentions: + log.debug( + f"{message.author} mentioned one or more users " + f"and/or roles in {message.channel.name}" + ) + + embed_text = ( + f"{message.author.mention} sent a message in " + f"{message.channel.mention} that contained user and/or role mentions." + f"\n\n**Original message:**\n>>> {message.content}" + ) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=constants.Icons.filtering, + colour=Colour(constants.Colours.soft_red), + title=f"User/Role mentioned in {message.channel.name}", + text=embed_text, + thumbnail=message.author.avatar_url_as(static_format="png"), + channel_id=constants.Channels.mod_alerts, + ) + + ctx: Context = await self.bot.get_context(message) + if ctx.command is not None and ctx.command.name == "accept": + return + + if any(r.id == constants.Roles.verified for r in ctx.author.roles): + log.info( + f"{ctx.author} posted '{ctx.message.content}' " + "in the verification channel, but is already verified." + ) + return + + log.debug( + f"{ctx.author} posted '{ctx.message.content}' in the verification " + "channel. We are providing instructions how to verify." + ) + await ctx.send( + f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " + f"and gain access to the rest of the server.", + delete_after=20 + ) + + log.trace(f"Deleting the message posted by {ctx.author}") + with suppress(NotFound): + await ctx.message.delete() + + @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) + @without_role(constants.Roles.verified) + @in_whitelist(channels=(constants.Channels.verification,)) + async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Accept our rules and gain access to the rest of the server.""" + log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") + await ctx.author.add_roles(Object(constants.Roles.verified), reason="Accepted the rules") + try: + await ctx.author.send(WELCOME_MESSAGE) + except Forbidden: + log.info(f"Sending welcome message failed for {ctx.author}.") + finally: + log.trace(f"Deleting accept message by {ctx.author}.") + with suppress(NotFound): + self.mod_log.ignore(constants.Event.message_delete, ctx.message.id) + await ctx.message.delete() + + @command(name='subscribe') + @in_whitelist(channels=(constants.Channels.bot_commands,)) + async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Subscribe to announcement notifications by assigning yourself the role.""" + has_role = False + + for role in ctx.author.roles: + if role.id == constants.Roles.announcements: + has_role = True + break + + if has_role: + await ctx.send(f"{ctx.author.mention} You're already subscribed!") + return + + log.debug(f"{ctx.author} called !subscribe. Assigning the 'Announcements' role.") + await ctx.author.add_roles(Object(constants.Roles.announcements), reason="Subscribed to announcements") + + log.trace(f"Deleting the message posted by {ctx.author}.") + + await ctx.send( + f"{ctx.author.mention} Subscribed to <#{constants.Channels.announcements}> notifications.", + ) + + @command(name='unsubscribe') + @in_whitelist(channels=(constants.Channels.bot_commands,)) + async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Unsubscribe from announcement notifications by removing the role from yourself.""" + has_role = False + + for role in ctx.author.roles: + if role.id == constants.Roles.announcements: + has_role = True + break + + if not has_role: + await ctx.send(f"{ctx.author.mention} You're already unsubscribed!") + return + + log.debug(f"{ctx.author} called !unsubscribe. Removing the 'Announcements' role.") + await ctx.author.remove_roles(Object(constants.Roles.announcements), reason="Unsubscribed from announcements") + + log.trace(f"Deleting the message posted by {ctx.author}.") + + await ctx.send( + f"{ctx.author.mention} Unsubscribed from <#{constants.Channels.announcements}> notifications." + ) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Check for & ignore any InWhitelistCheckFailure.""" + if isinstance(error, InWhitelistCheckFailure): + error.handled = True + + @staticmethod + def bot_check(ctx: Context) -> bool: + """Block any command within the verification channel that is not !accept.""" + if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): + return ctx.command.name == "accept" + else: + return True + + +def setup(bot: Bot) -> None: + """Load the Verification cog.""" + bot.add_cog(Verification(bot)) diff --git a/bot/cogs/moderation/watchchannels/__init__.py b/bot/cogs/moderation/watchchannels/__init__.py new file mode 100644 index 000000000..69d118df6 --- /dev/null +++ b/bot/cogs/moderation/watchchannels/__init__.py @@ -0,0 +1,9 @@ +from bot.bot import Bot +from .bigbrother import BigBrother +from .talentpool import TalentPool + + +def setup(bot: Bot) -> None: + """Load the BigBrother and TalentPool cogs.""" + bot.add_cog(BigBrother(bot)) + bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/moderation/watchchannels/bigbrother.py b/bot/cogs/moderation/watchchannels/bigbrother.py new file mode 100644 index 000000000..0c72e88f7 --- /dev/null +++ b/bot/cogs/moderation/watchchannels/bigbrother.py @@ -0,0 +1,165 @@ +import logging +import textwrap +from collections import ChainMap + +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.cogs.moderation.infraction.utils import post_infraction +from bot.constants import Channels, MODERATION_ROLES, Webhooks +from bot.converters import FetchedMember +from bot.decorators import with_role +from .watchchannel import WatchChannel + +log = logging.getLogger(__name__) + + +class BigBrother(WatchChannel, Cog, name="Big Brother"): + """Monitors users by relaying their messages to a watch channel to assist with moderation.""" + + def __init__(self, bot: Bot) -> None: + super().__init__( + bot, + destination=Channels.big_brother_logs, + webhook_id=Webhooks.big_brother, + api_endpoint='bot/infractions', + api_default_params={'active': 'true', 'type': 'watch', 'ordering': '-inserted_at'}, + logger=log + ) + + @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def bigbrother_group(self, ctx: Context) -> None: + """Monitors users by relaying their messages to the Big Brother watch channel.""" + await ctx.send_help(ctx.command) + + @bigbrother_group.command(name='watched', aliases=('all', 'list')) + @with_role(*MODERATION_ROLES) + async def watched_command( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Shows the users that are currently being monitored by Big Brother. + + The optional kwarg `oldest_first` can be used to order the list by oldest watched. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) + + @bigbrother_group.command(name='oldest') + @with_role(*MODERATION_ROLES) + async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: + """ + Shows Big Brother monitored users ordered by oldest watched. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) + + @bigbrother_group.command(name='watch', aliases=('w',)) + @with_role(*MODERATION_ROLES) + async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Relay messages sent by the given `user` to the `#big-brother` channel. + + A `reason` for adding the user to Big Brother is required and will be displayed + in the header when relaying messages of this user to the watchchannel. + """ + await self.apply_watch(ctx, user, reason) + + @bigbrother_group.command(name='unwatch', aliases=('uw',)) + @with_role(*MODERATION_ROLES) + async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Stop relaying messages by the given `user`.""" + await self.apply_unwatch(ctx, user, reason) + + async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: + """ + Add `user` to watched users and apply a watch infraction with `reason`. + + A message indicating the result of the operation is sent to `ctx`. + The message will include `user`'s previous watch infraction history, if it exists. + """ + if user.bot: + await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") + return + + if not await self.fetch_user_cache(): + await ctx.send(f":x: Updating the user cache failed, can't watch user {user}") + return + + if user.id in self.watched_users: + await ctx.send(f":x: {user} is already being watched.") + return + + response = await post_infraction(ctx, user, 'watch', reason, hidden=True, active=True) + + if response is not None: + self.watched_users[user.id] = response + msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother." + + history = await self.bot.api_client.get( + self.api_endpoint, + params={ + "user__id": str(user.id), + "active": "false", + 'type': 'watch', + 'ordering': '-inserted_at' + } + ) + + if len(history) > 1: + total = f"({len(history) // 2} previous infractions in total)" + end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") + start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" + msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" + else: + msg = ":x: Failed to post the infraction: response was empty." + + await ctx.send(msg) + + async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: + """ + Remove `user` from watched users and mark their infraction as inactive with `reason`. + + If `send_message` is True, a message indicating the result of the operation is sent to + `ctx`. + """ + active_watches = await self.bot.api_client.get( + self.api_endpoint, + params=ChainMap( + self.api_default_params, + {"user__id": str(user.id)} + ) + ) + if active_watches: + log.trace("Active watches for user found. Attempting to remove.") + [infraction] = active_watches + + await self.bot.api_client.patch( + f"{self.api_endpoint}/{infraction['id']}", + json={'active': False} + ) + + await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False) + + self._remove_user(user.id) + + if not send_message: # Prevents a message being sent to the channel if part of a permanent ban + log.debug(f"Perma-banned user {user} was unwatched.") + return + log.trace("User is not banned. Sending message to channel") + message = f":white_check_mark: Messages sent by {user} will no longer be relayed." + + else: + log.trace("No active watches found for user.") + if not send_message: # Prevents a message being sent to the channel if part of a permanent ban + log.debug(f"{user} was not on the watch list; no removal necessary.") + return + log.trace("User is not perma banned. Send the error message.") + message = ":x: The specified user is currently not being watched." + + await ctx.send(message) diff --git a/bot/cogs/moderation/watchchannels/talentpool.py b/bot/cogs/moderation/watchchannels/talentpool.py new file mode 100644 index 000000000..89256e92e --- /dev/null +++ b/bot/cogs/moderation/watchchannels/talentpool.py @@ -0,0 +1,264 @@ +import logging +import textwrap +from collections import ChainMap + +from discord import Color, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks +from bot.converters import FetchedMember +from bot.decorators import with_role +from bot.pagination import LinePaginator +from bot.utils import time +from .watchchannel import WatchChannel + +log = logging.getLogger(__name__) + + +class TalentPool(WatchChannel, Cog, name="Talentpool"): + """Relays messages of helper candidates to a watch channel to observe them.""" + + def __init__(self, bot: Bot) -> None: + super().__init__( + bot, + destination=Channels.talent_pool, + webhook_id=Webhooks.talent_pool, + api_endpoint='bot/nominations', + api_default_params={'active': 'true', 'ordering': '-inserted_at'}, + logger=log, + ) + + @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def nomination_group(self, ctx: Context) -> None: + """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" + await ctx.send_help(ctx.command) + + @nomination_group.command(name='watched', aliases=('all', 'list')) + @with_role(*MODERATION_ROLES) + async def watched_command( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Shows the users that are currently being monitored in the talent pool. + + The optional kwarg `oldest_first` can be used to order the list by oldest nomination. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) + + @nomination_group.command(name='oldest') + @with_role(*MODERATION_ROLES) + async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: + """ + Shows talent pool monitored users ordered by oldest nomination. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) + + @nomination_group.command(name='watch', aliases=('w', 'add', 'a')) + @with_role(*STAFF_ROLES) + async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Relay messages sent by the given `user` to the `#talent-pool` channel. + + A `reason` for adding the user to the talent pool is required and will be displayed + in the header when relaying messages of this user to the channel. + """ + if user.bot: + await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") + return + + if isinstance(user, Member) and any(role.id in STAFF_ROLES for role in user.roles): + await ctx.send(":x: Nominating staff members, eh? Here's a cookie :cookie:") + return + + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update the user cache; can't add {user}") + return + + if user.id in self.watched_users: + await ctx.send(f":x: {user} is already being watched in the talent pool") + return + + # Manual request with `raise_for_status` as False because we want the actual response + session = self.bot.api_client.session + url = self.bot.api_client._url_for(self.api_endpoint) + kwargs = { + 'json': { + 'actor': ctx.author.id, + 'reason': reason, + 'user': user.id + }, + 'raise_for_status': False, + } + async with session.post(url, **kwargs) as resp: + response_data = await resp.json() + + if resp.status == 400 and response_data.get('user', False): + await ctx.send(":x: The specified user can't be found in the database tables") + return + else: + resp.raise_for_status() + + self.watched_users[user.id] = response_data + msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel" + + history = await self.bot.api_client.get( + self.api_endpoint, + params={ + "user__id": str(user.id), + "active": "false", + "ordering": "-inserted_at" + } + ) + + if history: + total = f"({len(history)} previous nominations in total)" + start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" + end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" + msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" + + await ctx.send(msg) + + @nomination_group.command(name='history', aliases=('info', 'search')) + @with_role(*MODERATION_ROLES) + async def history_command(self, ctx: Context, user: FetchedMember) -> None: + """Shows the specified user's nomination history.""" + result = await self.bot.api_client.get( + self.api_endpoint, + params={ + 'user__id': str(user.id), + 'ordering': "-active,-inserted_at" + } + ) + if not result: + await ctx.send(":warning: This user has never been nominated") + return + + embed = Embed( + title=f"Nominations for {user.display_name} `({user.id})`", + color=Color.blue() + ) + lines = [self._nomination_to_string(nomination) for nomination in result] + await LinePaginator.paginate( + lines, + ctx=ctx, + embed=embed, + empty=True, + max_lines=3, + max_size=1000 + ) + + @nomination_group.command(name='unwatch', aliases=('end', )) + @with_role(*MODERATION_ROLES) + async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Ends the active nomination of the specified user with the given reason. + + Providing a `reason` is required. + """ + active_nomination = await self.bot.api_client.get( + self.api_endpoint, + params=ChainMap( + self.api_default_params, + {"user__id": str(user.id)} + ) + ) + + if not active_nomination: + await ctx.send(":x: The specified user does not have an active nomination") + return + + [nomination] = active_nomination + await self.bot.api_client.patch( + f"{self.api_endpoint}/{nomination['id']}", + json={'end_reason': reason, 'active': False} + ) + await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") + self._remove_user(user.id) + + @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def nomination_edit_group(self, ctx: Context) -> None: + """Commands to edit nominations.""" + await ctx.send_help(ctx.command) + + @nomination_edit_group.command(name='reason') + @with_role(*MODERATION_ROLES) + async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None: + """ + Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. + + If the nomination is active, the reason for nominating the user will be edited; + If the nomination is no longer active, the reason for ending the nomination will be edited instead. + """ + try: + nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") + except ResponseCodeError as e: + if e.response.status == 404: + self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") + await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") + return + else: + raise + + field = "reason" if nomination["active"] else "end_reason" + + self.log.trace(f"Changing {field} for nomination with id {nomination_id} to {reason}") + + await self.bot.api_client.patch( + f"{self.api_endpoint}/{nomination_id}", + json={field: reason} + ) + + await ctx.send(f":white_check_mark: Updated the {field} of the nomination!") + + def _nomination_to_string(self, nomination_object: dict) -> str: + """Creates a string representation of a nomination.""" + guild = self.bot.get_guild(Guild.id) + + actor_id = nomination_object["actor"] + actor = guild.get_member(actor_id) + + active = nomination_object["active"] + log.debug(active) + log.debug(type(nomination_object["inserted_at"])) + + start_date = time.format_infraction(nomination_object["inserted_at"]) + if active: + lines = textwrap.dedent( + f""" + =============== + Status: **Active** + Date: {start_date} + Actor: {actor.mention if actor else actor_id} + Reason: {nomination_object["reason"]} + Nomination ID: `{nomination_object["id"]}` + =============== + """ + ) + else: + end_date = time.format_infraction(nomination_object["ended_at"]) + lines = textwrap.dedent( + f""" + =============== + Status: Inactive + Date: {start_date} + Actor: {actor.mention if actor else actor_id} + Reason: {nomination_object["reason"]} + + End date: {end_date} + Unwatch reason: {nomination_object["end_reason"]} + Nomination ID: `{nomination_object["id"]}` + =============== + """ + ) + + return lines.strip() diff --git a/bot/cogs/moderation/watchchannels/watchchannel.py b/bot/cogs/moderation/watchchannels/watchchannel.py new file mode 100644 index 000000000..044077350 --- /dev/null +++ b/bot/cogs/moderation/watchchannels/watchchannel.py @@ -0,0 +1,348 @@ +import asyncio +import logging +import re +import textwrap +from abc import abstractmethod +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Optional + +import dateutil.parser +import discord +from discord import Color, DMChannel, Embed, HTTPException, Message, errors +from discord.ext.commands import Cog, Context + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons +from bot.pagination import LinePaginator +from bot.utils import CogABCMeta, messages +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + +URL_RE = re.compile(r"(https?://[^\s]+)") + + +@dataclass +class MessageHistory: + """Represents a watch channel's message history.""" + + last_author: Optional[int] = None + last_channel: Optional[int] = None + message_count: int = 0 + + +class WatchChannel(metaclass=CogABCMeta): + """ABC with functionality for relaying users' messages to a certain channel.""" + + @abstractmethod + def __init__( + self, + bot: Bot, + destination: int, + webhook_id: int, + api_endpoint: str, + api_default_params: dict, + logger: logging.Logger + ) -> None: + self.bot = bot + + self.destination = destination # E.g., Channels.big_brother_logs + self.webhook_id = webhook_id # E.g., Webhooks.big_brother + self.api_endpoint = api_endpoint # E.g., 'bot/infractions' + self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} + self.log = logger # Logger of the child cog for a correct name in the logs + + self._consume_task = None + self.watched_users = defaultdict(dict) + self.message_queue = defaultdict(lambda: defaultdict(deque)) + self.consumption_queue = {} + self.retries = 5 + self.retry_delay = 10 + self.channel = None + self.webhook = None + self.message_history = MessageHistory() + + self._start = self.bot.loop.create_task(self.start_watchchannel()) + + @property + def modlog(self) -> ModLog: + """Provides access to the ModLog cog for alert purposes.""" + return self.bot.get_cog("ModLog") + + @property + def consuming_messages(self) -> bool: + """Checks if a consumption task is currently running.""" + if self._consume_task is None: + return False + + if self._consume_task.done(): + exc = self._consume_task.exception() + if exc: + self.log.exception( + "The message queue consume task has failed with:", + exc_info=exc + ) + return False + + return True + + async def start_watchchannel(self) -> None: + """Starts the watch channel by getting the channel, webhook, and user cache ready.""" + await self.bot.wait_until_guild_available() + + try: + self.channel = await self.bot.fetch_channel(self.destination) + except HTTPException: + self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + if self.channel is None or self.webhook is None: + self.log.error("Failed to start the watch channel; unloading the cog.") + + message = textwrap.dedent( + f""" + An error occurred while loading the text channel or webhook. + + TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} + Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} + + The Cog has been unloaded. + """ + ) + + await self.modlog.send_log_message( + title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", + text=message, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + if not await self.fetch_user_cache(): + await self.modlog.send_log_message( + title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", + text="Could not retrieve the list of watched users from the API and messages will not be relayed.", + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + async def fetch_user_cache(self) -> bool: + """ + Fetches watched users from the API and updates the watched user cache accordingly. + + This function returns `True` if the update succeeded. + """ + try: + data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) + except ResponseCodeError as err: + self.log.exception("Failed to fetch the watched users from the API", exc_info=err) + return False + + self.watched_users = defaultdict(dict) + + for entry in data: + user_id = entry.pop('user') + self.watched_users[user_id] = entry + + return True + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Queues up messages sent by watched users.""" + if msg.author.id in self.watched_users: + if not self.consuming_messages: + self._consume_task = self.bot.loop.create_task(self.consume_messages()) + + self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") + self.message_queue[msg.author.id][msg.channel.id].append(msg) + + async def consume_messages(self, delay_consumption: bool = True) -> None: + """Consumes the message queues to log watched users' messages.""" + if delay_consumption: + self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") + await asyncio.sleep(BigBrotherConfig.log_delay) + + self.log.trace("Started consuming the message queue") + + # If the previous consumption Task failed, first consume the existing comsumption_queue + if not self.consumption_queue: + self.consumption_queue = self.message_queue.copy() + self.message_queue.clear() + + for user_channel_queues in self.consumption_queue.values(): + for channel_queue in user_channel_queues.values(): + while channel_queue: + msg = channel_queue.popleft() + + self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") + await self.relay_message(msg) + + self.consumption_queue.clear() + + if self.message_queue: + self.log.trace("Channel queue not empty: Continuing consuming queues") + self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) + else: + self.log.trace("Done consuming messages.") + + async def webhook_send( + self, + content: Optional[str] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + embed: Optional[Embed] = None, + ) -> None: + """Sends a message to the webhook with the specified kwargs.""" + username = messages.sub_clyde(username) + try: + await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send a message to the webhook", + exc_info=exc + ) + + async def relay_message(self, msg: Message) -> None: + """Relays the message to the relevant watch channel.""" + limit = BigBrotherConfig.header_message_limit + + if ( + msg.author.id != self.message_history.last_author + or msg.channel.id != self.message_history.last_channel + or self.message_history.message_count >= limit + ): + self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) + + await self.send_header(msg) + + cleaned_content = msg.clean_content + + if cleaned_content: + # Put all non-media URLs in a code block to prevent embeds + media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} + for url in URL_RE.findall(cleaned_content): + if url not in media_urls: + cleaned_content = cleaned_content.replace(url, f"`{url}`") + await self.webhook_send( + cleaned_content, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + + if msg.attachments: + try: + await messages.send_attachments(msg, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.webhook_send( + embed=e, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send an attachment to the webhook", + exc_info=exc + ) + + self.message_history.message_count += 1 + + async def send_header(self, msg: Message) -> None: + """Sends a header embed with information about the relayed messages to the watch channel.""" + user_id = msg.author.id + + guild = self.bot.get_guild(GuildConfig.id) + actor = guild.get_member(self.watched_users[user_id]['actor']) + actor = actor.display_name if actor else self.watched_users[user_id]['actor'] + + inserted_at = self.watched_users[user_id]['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + + reason = self.watched_users[user_id]['reason'] + + if isinstance(msg.channel, DMChannel): + # If a watched user DMs the bot there won't be a channel name or jump URL + # This could technically include a GroupChannel but bot's can't be in those + message_jump = "via DM" + else: + message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + + footer = f"Added {time_delta} by {actor} | Reason: {reason}" + embed = Embed(description=f"{msg.author.mention} {message_jump}") + embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) + + await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) + + async def list_watched_users( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Gives an overview of the watched user list for this channel. + + The optional kwarg `oldest_first` orders the list by oldest entry. + + The optional kwarg `update_cache` specifies whether the cache should + be refreshed by polling the API. + """ + if update_cache: + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") + update_cache = False + + lines = [] + for user_id, user_data in self.watched_users.items(): + inserted_at = user_data['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + lines.append(f"• <@{user_id}> (added {time_delta})") + + if oldest_first: + lines.reverse() + + lines = lines or ("There's nothing here yet.",) + + embed = Embed( + title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", + color=Color.blue() + ) + await LinePaginator.paginate(lines, ctx, embed, empty=False) + + @staticmethod + def _get_time_delta(time_string: str) -> str: + """Returns the time in human-readable time delta format.""" + date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) + time_delta = time_since(date_time, precision="minutes", max_units=1) + + return time_delta + + def _remove_user(self, user_id: int) -> None: + """Removes a user from a watch channel.""" + self.watched_users.pop(user_id, None) + self.message_queue.pop(user_id, None) + self.consumption_queue.pop(user_id, None) + + def cog_unload(self) -> None: + """Takes care of unloading the cog and canceling the consumption task.""" + self.log.trace("Unloading the cog") + if self._consume_task and not self._consume_task.done(): + self._consume_task.cancel() + try: + self._consume_task.result() + except asyncio.CancelledError as e: + self.log.exception( + "The consume task was canceled. Messages may be lost.", + exc_info=e + ) diff --git a/bot/cogs/python_news.py b/bot/cogs/python_news.py deleted file mode 100644 index 0ab5738a4..000000000 --- a/bot/cogs/python_news.py +++ /dev/null @@ -1,232 +0,0 @@ -import logging -import typing as t -from datetime import date, datetime - -import discord -import feedparser -from bs4 import BeautifulSoup -from discord.ext.commands import Cog -from discord.ext.tasks import loop - -from bot import constants -from bot.bot import Bot -from bot.utils.webhooks import send_webhook - -PEPS_RSS_URL = "https://www.python.org/dev/peps/peps.rss/" - -RECENT_THREADS_TEMPLATE = "https://mail.python.org/archives/list/{name}@python.org/recent-threads" -THREAD_TEMPLATE_URL = "https://mail.python.org/archives/api/list/{name}@python.org/thread/{id}/" -MAILMAN_PROFILE_URL = "https://mail.python.org/archives/users/{id}/" -THREAD_URL = "https://mail.python.org/archives/list/{list}@python.org/thread/{id}/" - -AVATAR_URL = "https://www.python.org/static/opengraph-icon-200x200.png" - -log = logging.getLogger(__name__) - - -class PythonNews(Cog): - """Post new PEPs and Python News to `#python-news`.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_names = {} - self.webhook: t.Optional[discord.Webhook] = None - - self.bot.loop.create_task(self.get_webhook_names()) - self.bot.loop.create_task(self.get_webhook_and_channel()) - - async def start_tasks(self) -> None: - """Start the tasks for fetching new PEPs and mailing list messages.""" - self.fetch_new_media.start() - - @loop(minutes=20) - async def fetch_new_media(self) -> None: - """Fetch new mailing list messages and then new PEPs.""" - await self.post_maillist_news() - await self.post_pep_news() - - async def sync_maillists(self) -> None: - """Sync currently in-use maillists with API.""" - # Wait until guild is available to avoid running before everything is ready - await self.bot.wait_until_guild_available() - - response = await self.bot.api_client.get("bot/bot-settings/news") - for mail in constants.PythonNews.mail_lists: - if mail not in response["data"]: - response["data"][mail] = [] - - # Because we are handling PEPs differently, we don't include it to mail lists - if "pep" not in response["data"]: - response["data"]["pep"] = [] - - await self.bot.api_client.put("bot/bot-settings/news", json=response) - - async def get_webhook_names(self) -> None: - """Get webhook author names from maillist API.""" - await self.bot.wait_until_guild_available() - - async with self.bot.http_session.get("https://mail.python.org/archives/api/lists") as resp: - lists = await resp.json() - - for mail in lists: - if mail["name"].split("@")[0] in constants.PythonNews.mail_lists: - self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] - - async def post_pep_news(self) -> None: - """Fetch new PEPs and when they don't have announcement in #python-news, create it.""" - # Wait until everything is ready and http_session available - await self.bot.wait_until_guild_available() - await self.sync_maillists() - - async with self.bot.http_session.get(PEPS_RSS_URL) as resp: - data = feedparser.parse(await resp.text("utf-8")) - - news_listing = await self.bot.api_client.get("bot/bot-settings/news") - payload = news_listing.copy() - pep_numbers = news_listing["data"]["pep"] - - # Reverse entries to send oldest first - data["entries"].reverse() - for new in data["entries"]: - try: - new_datetime = datetime.strptime(new["published"], "%a, %d %b %Y %X %Z") - except ValueError: - log.warning(f"Wrong datetime format passed in PEP new: {new['published']}") - continue - pep_nr = new["title"].split(":")[0].split()[1] - if ( - pep_nr in pep_numbers - or new_datetime.date() < date.today() - ): - continue - - # Build an embed and send a webhook - embed = discord.Embed( - title=new["title"], - description=new["summary"], - timestamp=new_datetime, - url=new["link"], - colour=constants.Colours.soft_green - ) - embed.set_footer(text=data["feed"]["title"], icon_url=AVATAR_URL) - msg = await send_webhook( - webhook=self.webhook, - username=data["feed"]["title"], - embed=embed, - avatar_url=AVATAR_URL, - wait=True, - ) - payload["data"]["pep"].append(pep_nr) - - # Increase overall PEP new stat - self.bot.stats.incr("python_news.posted.pep") - - if msg.channel.is_news(): - log.trace("Publishing PEP annnouncement because it was in a news channel") - await msg.publish() - - # Apply new sent news to DB to avoid duplicate sending - await self.bot.api_client.put("bot/bot-settings/news", json=payload) - - async def post_maillist_news(self) -> None: - """Send new maillist threads to #python-news that is listed in configuration.""" - await self.bot.wait_until_guild_available() - await self.sync_maillists() - existing_news = await self.bot.api_client.get("bot/bot-settings/news") - payload = existing_news.copy() - - for maillist in constants.PythonNews.mail_lists: - async with self.bot.http_session.get(RECENT_THREADS_TEMPLATE.format(name=maillist)) as resp: - recents = BeautifulSoup(await resp.text(), features="lxml") - - # When a

element is present in the response then the mailing list - # has not had any activity during the current month, so therefore it - # can be ignored. - if recents.p: - continue - - for thread in recents.html.body.div.find_all("a", href=True): - # We want only these threads that have identifiers - if "latest" in thread["href"]: - continue - - thread_information, email_information = await self.get_thread_and_first_mail( - maillist, thread["href"].split("/")[-2] - ) - - try: - new_date = datetime.strptime(email_information["date"], "%Y-%m-%dT%X%z") - except ValueError: - log.warning(f"Invalid datetime from Thread email: {email_information['date']}") - continue - - if ( - thread_information["thread_id"] in existing_news["data"][maillist] - or 'Re: ' in thread_information["subject"] - or new_date.date() < date.today() - ): - continue - - content = email_information["content"] - link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) - - # Build an embed and send a message to the webhook - embed = discord.Embed( - title=thread_information["subject"], - description=content[:500] + f"... [continue reading]({link})" if len(content) > 500 else content, - timestamp=new_date, - url=link, - colour=constants.Colours.soft_green - ) - embed.set_author( - name=f"{email_information['sender_name']} ({email_information['sender']['address']})", - url=MAILMAN_PROFILE_URL.format(id=email_information["sender"]["mailman_id"]), - ) - embed.set_footer( - text=f"Posted to {self.webhook_names[maillist]}", - icon_url=AVATAR_URL, - ) - msg = await send_webhook( - webhook=self.webhook, - username=self.webhook_names[maillist], - embed=embed, - avatar_url=AVATAR_URL, - wait=True, - ) - payload["data"][maillist].append(thread_information["thread_id"]) - - # Increase this specific maillist counter in stats - self.bot.stats.incr(f"python_news.posted.{maillist.replace('-', '_')}") - - if msg.channel.is_news(): - log.trace("Publishing mailing list message because it was in a news channel") - await msg.publish() - - await self.bot.api_client.put("bot/bot-settings/news", json=payload) - - async def get_thread_and_first_mail(self, maillist: str, thread_identifier: str) -> t.Tuple[t.Any, t.Any]: - """Get mail thread and first mail from mail.python.org based on `maillist` and `thread_identifier`.""" - async with self.bot.http_session.get( - THREAD_TEMPLATE_URL.format(name=maillist, id=thread_identifier) - ) as resp: - thread_information = await resp.json() - - async with self.bot.http_session.get(thread_information["starting_email"]) as resp: - email_information = await resp.json() - return thread_information, email_information - - async def get_webhook_and_channel(self) -> None: - """Storage #python-news channel Webhook and `TextChannel` to `News.webhook` and `channel`.""" - await self.bot.wait_until_guild_available() - self.webhook = await self.bot.fetch_webhook(constants.PythonNews.webhook) - - await self.start_tasks() - - def cog_unload(self) -> None: - """Stop news posting tasks on cog unload.""" - self.fetch_new_media.cancel() - - -def setup(bot: Bot) -> None: - """Add `News` cog.""" - bot.add_cog(PythonNews(bot)) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py deleted file mode 100644 index d853ab2ea..000000000 --- a/bot/cogs/reddit.py +++ /dev/null @@ -1,304 +0,0 @@ -import asyncio -import logging -import random -import textwrap -from collections import namedtuple -from datetime import datetime, timedelta -from typing import List - -from aiohttp import BasicAuth, ClientError -from discord import Colour, Embed, TextChannel -from discord.ext.commands import Cog, Context, group -from discord.ext.tasks import loop - -from bot.bot import Bot -from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks -from bot.converters import Subreddit -from bot.decorators import with_role -from bot.pagination import LinePaginator -from bot.utils.messages import sub_clyde - -log = logging.getLogger(__name__) - -AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) - - -class Reddit(Cog): - """Track subreddit posts and show detailed statistics about them.""" - - HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} - URL = "https://www.reddit.com" - OAUTH_URL = "https://oauth.reddit.com" - MAX_RETRIES = 3 - - def __init__(self, bot: Bot): - self.bot = bot - - self.webhook = None - self.access_token = None - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - - bot.loop.create_task(self.init_reddit_ready()) - self.auto_poster_loop.start() - - def cog_unload(self) -> None: - """Stop the loop task and revoke the access token when the cog is unloaded.""" - self.auto_poster_loop.cancel() - if self.access_token and self.access_token.expires_at > datetime.utcnow(): - asyncio.create_task(self.revoke_access_token()) - - async def init_reddit_ready(self) -> None: - """Sets the reddit webhook when the cog is loaded.""" - await self.bot.wait_until_guild_available() - if not self.webhook: - self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) - - @property - def channel(self) -> TextChannel: - """Get the #reddit channel object from the bot's cache.""" - return self.bot.get_channel(Channels.reddit) - - async def get_access_token(self) -> None: - """ - Get a Reddit API OAuth2 access token and assign it to self.access_token. - - A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded and a ClientError raised if retrieval was still unsuccessful. - """ - for i in range(1, self.MAX_RETRIES + 1): - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "grant_type": "client_credentials", - "duration": "temporary" - } - ) - - if response.status == 200 and response.content_type == "application/json": - content = await response.json() - expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - self.access_token = AccessToken( - token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(seconds=expiration) - ) - - log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") - return - else: - log.debug( - f"Failed to get an access token: " - f"status {response.status} & content type {response.content_type}; " - f"retrying ({i}/{self.MAX_RETRIES})" - ) - - await asyncio.sleep(3) - - self.bot.remove_cog(self.qualified_name) - raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") - - async def revoke_access_token(self) -> None: - """ - Revoke the OAuth2 access token for the Reddit API. - - For security reasons, it's good practice to revoke the token when it's no longer being used. - """ - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/revoke_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "token": self.access_token.token, - "token_type_hint": "access_token" - } - ) - - if response.status == 204 and response.content_type == "application/json": - self.access_token = None - else: - log.warning(f"Unable to revoke access token: status {response.status}.") - - async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: - """A helper method to fetch a certain amount of Reddit posts at a given route.""" - # Reddit's JSON responses only provide 25 posts at most. - if not 25 >= amount > 0: - raise ValueError("Invalid amount of subreddit posts requested.") - - # Renew the token if necessary. - if not self.access_token or self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() - - url = f"{self.OAUTH_URL}/{route}" - for _ in range(self.MAX_RETRIES): - response = await self.bot.http_session.get( - url=url, - headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, - params=params - ) - if response.status == 200 and response.content_type == 'application/json': - # Got appropriate response - process and return. - content = await response.json() - posts = content["data"]["children"] - return posts[:amount] - - await asyncio.sleep(3) - - log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") - return list() # Failed to get appropriate response within allowed number of retries. - - async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: - """ - Get the top amount of posts for a given subreddit within a specified timeframe. - - A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top - weekly posts. - - The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. - """ - embed = Embed(description="") - - posts = await self.fetch_posts( - route=f"{subreddit}/top", - amount=amount, - params={"t": time} - ) - - if not posts: - embed.title = random.choice(ERROR_REPLIES) - embed.colour = Colour.red() - embed.description = ( - "Sorry! We couldn't find any posts from that subreddit. " - "If this problem persists, please let us know." - ) - - return embed - - for post in posts: - data = post["data"] - - text = data["selftext"] - if text: - text = textwrap.shorten(text, width=128, placeholder="...") - text += "\n" # Add newline to separate embed info - - ups = data["ups"] - comments = data["num_comments"] - author = data["author"] - - title = textwrap.shorten(data["title"], width=64, placeholder="...") - link = self.URL + data["permalink"] - - embed.description += ( - f"**[{title}]({link})**\n" - f"{text}" - f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" - ) - - embed.colour = Colour.blurple() - return embed - - @loop() - async def auto_poster_loop(self) -> None: - """Post the top 5 posts daily, and the top 5 posts weekly.""" - # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter - now = datetime.utcnow() - tomorrow = now + timedelta(days=1) - midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) - seconds_until = (midnight_tomorrow - now).total_seconds() - - await asyncio.sleep(seconds_until) - - await self.bot.wait_until_guild_available() - if not self.webhook: - await self.bot.fetch_webhook(Webhooks.reddit) - - if datetime.utcnow().weekday() == 0: - await self.top_weekly_posts() - # if it's a monday send the top weekly posts - - for subreddit in RedditConfig.subreddits: - top_posts = await self.get_top_posts(subreddit=subreddit, time="day") - username = sub_clyde(f"{subreddit} Top Daily Posts") - message = await self.webhook.send(username=username, embed=top_posts, wait=True) - - if message.channel.is_news(): - await message.publish() - - async def top_weekly_posts(self) -> None: - """Post a summary of the top posts.""" - for subreddit in RedditConfig.subreddits: - # Send and pin the new weekly posts. - top_posts = await self.get_top_posts(subreddit=subreddit, time="week") - username = sub_clyde(f"{subreddit} Top Weekly Posts") - message = await self.webhook.send(wait=True, username=username, embed=top_posts) - - if subreddit.lower() == "r/python": - if not self.channel: - log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") - return - - # Remove the oldest pins so that only 12 remain at most. - pins = await self.channel.pins() - - while len(pins) >= 12: - await pins[-1].unpin() - del pins[-1] - - await message.pin() - - if message.channel.is_news(): - await message.publish() - - @group(name="reddit", invoke_without_command=True) - async def reddit_group(self, ctx: Context) -> None: - """View the top posts from various subreddits.""" - await ctx.send_help(ctx.command) - - @reddit_group.command(name="top") - async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of all time from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="all") - - await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - - @reddit_group.command(name="daily") - async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of today from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="day") - - await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - - @reddit_group.command(name="weekly") - async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of this week from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="week") - - await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - - @with_role(*STAFF_ROLES) - @reddit_group.command(name="subreddits", aliases=("subs",)) - async def subreddits_command(self, ctx: Context) -> None: - """Send a paginated embed of all the subreddits we're relaying.""" - embed = Embed() - embed.title = "Relayed subreddits." - embed.colour = Colour.blurple() - - await LinePaginator.paginate( - RedditConfig.subreddits, - ctx, embed, - footer_text="Use the reddit commands along with these to view their posts.", - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Reddit cog.""" - if not RedditConfig.secret or not RedditConfig.client_id: - log.error("Credentials not provided, cog not loaded.") - return - bot.add_cog(Reddit(bot)) diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py deleted file mode 100644 index 670493bcf..000000000 --- a/bot/cogs/reminders.py +++ /dev/null @@ -1,427 +0,0 @@ -import asyncio -import logging -import random -import textwrap -import typing as t -from datetime import datetime, timedelta -from operator import itemgetter - -import discord -from dateutil.parser import isoparse -from dateutil.relativedelta import relativedelta -from discord.ext.commands import Cog, Context, Greedy, group - -from bot.bot import Bot -from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES -from bot.converters import Duration -from bot.pagination import LinePaginator -from bot.utils.checks import without_role_check -from bot.utils.messages import send_denial -from bot.utils.scheduling import Scheduler -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -WHITELISTED_CHANNELS = Guild.reminder_whitelist -MAXIMUM_REMINDERS = 5 - -Mentionable = t.Union[discord.Member, discord.Role] - - -class Reminders(Cog): - """Provide in-channel reminder functionality.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_reminders()) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - async def reschedule_reminders(self) -> None: - """Get all current reminders from the API and reschedule them.""" - await self.bot.wait_until_guild_available() - response = await self.bot.api_client.get( - 'bot/reminders', - params={'active': 'true'} - ) - - now = datetime.utcnow() - - for reminder in response: - is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) - if not is_valid: - continue - - remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) - - # If the reminder is already overdue ... - if remind_at < now: - late = relativedelta(now, remind_at) - await self.send_reminder(reminder, late) - else: - self.schedule_reminder(reminder) - - def ensure_valid_reminder( - self, - reminder: dict, - cancel_task: bool = True - ) -> t.Tuple[bool, discord.User, discord.TextChannel]: - """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" - user = self.bot.get_user(reminder['author']) - channel = self.bot.get_channel(reminder['channel_id']) - is_valid = True - if not user or not channel: - is_valid = False - log.info( - f"Reminder {reminder['id']} invalid: " - f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." - ) - asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) - - return is_valid, user, channel - - @staticmethod - async def _send_confirmation( - ctx: Context, - on_success: str, - reminder_id: str, - delivery_dt: t.Optional[datetime], - ) -> None: - """Send an embed confirming the reminder change was made successfully.""" - embed = discord.Embed() - embed.colour = discord.Colour.green() - embed.title = random.choice(POSITIVE_REPLIES) - embed.description = on_success - - footer_str = f"ID: {reminder_id}" - if delivery_dt: - # Reminder deletion will have a `None` `delivery_dt` - footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" - - embed.set_footer(text=footer_str) - - await ctx.send(embed=embed) - - @staticmethod - async def _check_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> t.Tuple[bool, str]: - """ - Returns whether or not the list of mentions is allowed. - - Conditions: - - Role reminders are Mods+ - - Reminders for other users are Helpers+ - - If mentions aren't allowed, also return the type of mention(s) disallowed. - """ - if without_role_check(ctx, *STAFF_ROLES): - return False, "members/roles" - elif without_role_check(ctx, *MODERATION_ROLES): - return all(isinstance(mention, discord.Member) for mention in mentions), "roles" - else: - return True, "" - - @staticmethod - async def validate_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> bool: - """ - Filter mentions to see if the user can mention, and sends a denial if not allowed. - - Returns whether or not the validation is successful. - """ - mentions_allowed, disallowed_mentions = await Reminders._check_mentions(ctx, mentions) - - if not mentions or mentions_allowed: - return True - else: - await send_denial(ctx, f"You can't mention other {disallowed_mentions} in your reminder!") - return False - - def get_mentionables(self, mention_ids: t.List[int]) -> t.Iterator[Mentionable]: - """Converts Role and Member ids to their corresponding objects if possible.""" - guild = self.bot.get_guild(Guild.id) - for mention_id in mention_ids: - if (mentionable := (guild.get_member(mention_id) or guild.get_role(mention_id))): - yield mentionable - - def schedule_reminder(self, reminder: dict) -> None: - """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" - reminder_id = reminder["id"] - reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) - - async def _remind() -> None: - await self.send_reminder(reminder) - - log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") - await self._delete_reminder(reminder_id) - - self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) - - async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: - """Delete a reminder from the database, given its ID, and cancel the running task.""" - await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) - - if cancel_task: - # Now we can remove it from the schedule list - self.scheduler.cancel(reminder_id) - - async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: - """ - Edits a reminder in the database given the ID and payload. - - Returns the edited reminder. - """ - # Send the request to update the reminder in the database - reminder = await self.bot.api_client.patch( - 'bot/reminders/' + str(reminder_id), - json=payload - ) - return reminder - - async def _reschedule_reminder(self, reminder: dict) -> None: - """Reschedule a reminder object.""" - log.trace(f"Cancelling old task #{reminder['id']}") - self.scheduler.cancel(reminder["id"]) - - log.trace(f"Scheduling new task #{reminder['id']}") - self.schedule_reminder(reminder) - - async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: - """Send the reminder.""" - is_valid, user, channel = self.ensure_valid_reminder(reminder) - if not is_valid: - return - - embed = discord.Embed() - embed.colour = discord.Colour.blurple() - embed.set_author( - icon_url=Icons.remind_blurple, - name="It has arrived!" - ) - - embed.description = f"Here's your reminder: `{reminder['content']}`." - - if reminder.get("jump_url"): # keep backward compatibility - embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" - - if late: - embed.colour = discord.Colour.red() - embed.set_author( - icon_url=Icons.remind_red, - name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" - ) - - additional_mentions = ' '.join( - mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) - ) - - await channel.send( - content=f"{user.mention} {additional_mentions}", - embed=embed - ) - await self._delete_reminder(reminder["id"]) - - @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) - async def remind_group( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str - ) -> None: - """Commands for managing your reminders.""" - await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content) - - @remind_group.command(name="new", aliases=("add", "create")) - async def new_reminder( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str - ) -> None: - """ - Set yourself a simple reminder. - - Expiration is parsed per: http://strftime.org/ - """ - # If the user is not staff, we need to verify whether or not to make a reminder at all. - if without_role_check(ctx, *STAFF_ROLES): - - # If they don't have permission to set a reminder in this channel - if ctx.channel.id not in WHITELISTED_CHANNELS: - await send_denial(ctx, "Sorry, you can't do that here!") - return - - # Get their current active reminders - active_reminders = await self.bot.api_client.get( - 'bot/reminders', - params={ - 'author__id': str(ctx.author.id) - } - ) - - # Let's limit this, so we don't get 10 000 - # reminders from kip or something like that :P - if len(active_reminders) > MAXIMUM_REMINDERS: - await send_denial(ctx, "You have too many active reminders!") - return - - # Remove duplicate mentions - mentions = set(mentions) - mentions.discard(ctx.author) - - # Filter mentions to see if the user can mention members/roles - if not await self.validate_mentions(ctx, mentions): - return - - mention_ids = [mention.id for mention in mentions] - - # Now we can attempt to actually set the reminder. - reminder = await self.bot.api_client.post( - 'bot/reminders', - json={ - 'author': ctx.author.id, - 'channel_id': ctx.message.channel.id, - 'jump_url': ctx.message.jump_url, - 'content': content, - 'expiration': expiration.isoformat(), - 'mentions': mention_ids, - } - ) - - now = datetime.utcnow() - timedelta(seconds=1) - humanized_delta = humanize_delta(relativedelta(expiration, now)) - mention_string = ( - f"Your reminder will arrive in {humanized_delta} " - f"and will mention {len(mentions)} other(s)!" - ) - - # Confirm to the user that it worked. - await self._send_confirmation( - ctx, - on_success=mention_string, - reminder_id=reminder["id"], - delivery_dt=expiration, - ) - - self.schedule_reminder(reminder) - - @remind_group.command(name="list") - async def list_reminders(self, ctx: Context) -> None: - """View a paginated embed of all reminders for your user.""" - # Get all the user's reminders from the database. - data = await self.bot.api_client.get( - 'bot/reminders', - params={'author__id': str(ctx.author.id)} - ) - - now = datetime.utcnow() - - # Make a list of tuples so it can be sorted by time. - reminders = sorted( - ( - (rem['content'], rem['expiration'], rem['id'], rem['mentions']) - for rem in data - ), - key=itemgetter(1) - ) - - lines = [] - - for content, remind_at, id_, mentions in reminders: - # Parse and humanize the time, make it pretty :D - remind_datetime = isoparse(remind_at).replace(tzinfo=None) - time = humanize_delta(relativedelta(remind_datetime, now)) - - mentions = ", ".join( - # Both Role and User objects have the `name` attribute - mention.name for mention in self.get_mentionables(mentions) - ) - mention_string = f"\n**Mentions:** {mentions}" if mentions else "" - - text = textwrap.dedent(f""" - **Reminder #{id_}:** *expires in {time}* (ID: {id_}){mention_string} - {content} - """).strip() - - lines.append(text) - - embed = discord.Embed() - embed.colour = discord.Colour.blurple() - embed.title = f"Reminders for {ctx.author}" - - # Remind the user that they have no reminders :^) - if not lines: - embed.description = "No active reminders could be found." - await ctx.send(embed=embed) - return - - # Construct the embed and paginate it. - embed.colour = discord.Colour.blurple() - - await LinePaginator.paginate( - lines, - ctx, embed, - max_lines=3, - empty=True - ) - - @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) - async def edit_reminder_group(self, ctx: Context) -> None: - """Commands for modifying your current reminders.""" - await ctx.send_help(ctx.command) - - @edit_reminder_group.command(name="duration", aliases=("time",)) - async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: - """ - Edit one of your reminder's expiration. - - Expiration is parsed per: http://strftime.org/ - """ - await self.edit_reminder(ctx, id_, {'expiration': expiration.isoformat()}) - - @edit_reminder_group.command(name="content", aliases=("reason",)) - async def edit_reminder_content(self, ctx: Context, id_: int, *, content: str) -> None: - """Edit one of your reminder's content.""" - await self.edit_reminder(ctx, id_, {"content": content}) - - @edit_reminder_group.command(name="mentions", aliases=("pings",)) - async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: - """Edit one of your reminder's mentions.""" - # Remove duplicate mentions - mentions = set(mentions) - mentions.discard(ctx.author) - - # Filter mentions to see if the user can mention members/roles - if not await self.validate_mentions(ctx, mentions): - return - - mention_ids = [mention.id for mention in mentions] - await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) - - async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: - """Edits a reminder with the given payload, then sends a confirmation message.""" - reminder = await self._edit_reminder(id_, payload) - - # Parse the reminder expiration back into a datetime - expiration = isoparse(reminder["expiration"]).replace(tzinfo=None) - - # Send a confirmation message to the channel - await self._send_confirmation( - ctx, - on_success="That reminder has been edited successfully!", - reminder_id=id_, - delivery_dt=expiration, - ) - await self._reschedule_reminder(reminder) - - @remind_group.command("delete", aliases=("remove", "cancel")) - async def delete_reminder(self, ctx: Context, id_: int) -> None: - """Delete one of your active reminders.""" - await self._delete_reminder(id_) - await self._send_confirmation( - ctx, - on_success="That reminder has been deleted successfully!", - reminder_id=id_, - delivery_dt=None, - ) - - -def setup(bot: Bot) -> None: - """Load the Reminders cog.""" - bot.add_cog(Reminders(bot)) diff --git a/bot/cogs/security.py b/bot/cogs/security.py deleted file mode 100644 index c680c5e27..000000000 --- a/bot/cogs/security.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging - -from discord.ext.commands import Cog, Context, NoPrivateMessage - -from bot.bot import Bot - -log = logging.getLogger(__name__) - - -class Security(Cog): - """Security-related helpers.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.bot.check(self.check_not_bot) # Global commands check - no bots can run any commands at all - self.bot.check(self.check_on_guild) # Global commands check - commands can't be run in a DM - - def check_not_bot(self, ctx: Context) -> bool: - """Check if the context is a bot user.""" - return not ctx.author.bot - - def check_on_guild(self, ctx: Context) -> bool: - """Check if the context is in a guild.""" - if ctx.guild is None: - raise NoPrivateMessage("This command cannot be used in private messages.") - return True - - -def setup(bot: Bot) -> None: - """Load the Security cog.""" - bot.add_cog(Security(bot)) diff --git a/bot/cogs/site.py b/bot/cogs/site.py deleted file mode 100644 index ac29daa1d..000000000 --- a/bot/cogs/site.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging - -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import URLs -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" - - -class Site(Cog): - """Commands for linking to different parts of the site.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="site", aliases=("s",), invoke_without_command=True) - async def site_group(self, ctx: Context) -> None: - """Commands for getting info about our website.""" - await ctx.send_help(ctx.command) - - @site_group.command(name="home", aliases=("about",)) - async def site_main(self, ctx: Context) -> None: - """Info about the website itself.""" - url = f"{URLs.site_schema}{URLs.site}/" - - embed = Embed(title="Python Discord website") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - f"[Our official website]({url}) is an open-source community project " - "created with Python and Django. It contains information about the server " - "itself, lets you sign up for upcoming events, has its own wiki, contains " - "a list of valuable learning resources, and much more." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="resources") - async def site_resources(self, ctx: Context) -> None: - """Info about the site's Resources page.""" - learning_url = f"{PAGES_URL}/resources" - - embed = Embed(title="Resources") - embed.set_footer(text=f"{learning_url}") - embed.colour = Colour.blurple() - embed.description = ( - f"The [Resources page]({learning_url}) on our website contains a " - "list of hand-selected learning resources that we regularly recommend " - f"to both beginners and experts." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="tools") - async def site_tools(self, ctx: Context) -> None: - """Info about the site's Tools page.""" - tools_url = f"{PAGES_URL}/resources/tools" - - embed = Embed(title="Tools") - embed.set_footer(text=f"{tools_url}") - embed.colour = Colour.blurple() - embed.description = ( - f"The [Tools page]({tools_url}) on our website contains a " - f"couple of the most popular tools for programming in Python." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="help") - async def site_help(self, ctx: Context) -> None: - """Info about the site's Getting Help page.""" - url = f"{PAGES_URL}/resources/guides/asking-good-questions" - - embed = Embed(title="Asking Good Questions") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - "Asking the right question about something that's new to you can sometimes be tricky. " - f"To help with this, we've created a [guide to asking good questions]({url}) on our website. " - "It contains everything you need to get the very best help from our community." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="faq") - async def site_faq(self, ctx: Context) -> None: - """Info about the site's FAQ page.""" - url = f"{PAGES_URL}/frequently-asked-questions" - - embed = Embed(title="FAQ") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - "As the largest Python community on Discord, we get hundreds of questions every day. " - "Many of these questions have been asked before. We've compiled a list of the most " - "frequently asked questions along with their answers, which can be found on " - f"our [FAQ page]({url})." - ) - - await ctx.send(embed=embed) - - @site_group.command(aliases=['r', 'rule'], name='rules') - async def site_rules(self, ctx: Context, *rules: int) -> None: - """Provides a link to all rules or, if specified, displays specific rule(s).""" - rules_embed = Embed(title='Rules', color=Colour.blurple()) - rules_embed.url = f"{PAGES_URL}/rules" - - if not rules: - # Rules were not submitted. Return the default description. - rules_embed.description = ( - "The rules and guidelines that apply to this community can be found on" - f" our [rules page]({PAGES_URL}/rules). We expect" - " all members of the community to have read and understood these." - ) - - await ctx.send(embed=rules_embed) - return - - full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) - invalid_indices = tuple( - pick - for pick in rules - if pick < 1 or pick > len(full_rules) - ) - - if invalid_indices: - indices = ', '.join(map(str, invalid_indices)) - await ctx.send(f":x: Invalid rule indices: {indices}") - return - - for rule in rules: - self.bot.stats.incr(f"rule_uses.{rule}") - - final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) - - await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) - - -def setup(bot: Bot) -> None: - """Load the Site cog.""" - bot.add_cog(Site(bot)) diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py deleted file mode 100644 index 52c8b6f88..000000000 --- a/bot/cogs/snekbox.py +++ /dev/null @@ -1,349 +0,0 @@ -import asyncio -import contextlib -import datetime -import logging -import re -import textwrap -from functools import partial -from signal import Signals -from typing import Optional, Tuple - -from discord import HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only - -from bot.bot import Bot -from bot.constants import Categories, Channels, Roles, URLs -from bot.decorators import in_whitelist -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") -FORMATTED_CODE_REGEX = re.compile( - r"^\s*" # any leading whitespace from the beginning of the string - r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block - r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) - r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all code inside the markup - r"\s*" # any more whitespace before the end of the code markup - r"(?P=delim)" # match the exact same delimiter from the start again - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive -) -RAW_CODE_REGEX = re.compile( - r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all the rest as code - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL # "." also matches newlines -) - -MAX_PASTE_LEN = 1000 - -# `!eval` command whitelists -EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) -EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) -EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) - -SIGKILL = 9 - -REEVAL_EMOJI = '\U0001f501' # :repeat: -REEVAL_TIMEOUT = 30 - - -class Snekbox(Cog): - """Safe evaluation of Python code using Snekbox.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.jobs = {} - - async def post_eval(self, code: str) -> dict: - """Send a POST request to the Snekbox API to evaluate code and return the results.""" - url = URLs.snekbox_eval_api - data = {"input": code} - async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return await resp.json() - - async def upload_output(self, output: str) -> Optional[str]: - """Upload the eval output to a paste service and return a URL to it if successful.""" - log.trace("Uploading full output to paste service...") - - if len(output) > MAX_PASTE_LEN: - log.info("Full output is too long to upload") - return "too long to upload" - - url = URLs.paste_service.format(key="documents") - try: - async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: - data = await resp.json() - - if "key" in data: - return URLs.paste_service.format(key=data["key"]) - except Exception: - # 400 (Bad Request) means there are too many characters - log.exception("Failed to upload full output to paste service!") - - @staticmethod - def prepare_input(code: str) -> str: - """Extract code from the Markdown, format it, and insert it into the code template.""" - match = FORMATTED_CODE_REGEX.fullmatch(code) - if match: - code, block, lang, delim = match.group("code", "block", "lang", "delim") - code = textwrap.dedent(code) - if block: - info = (f"'{lang}' highlighted" if lang else "plain") + " code block" - else: - info = f"{delim}-enclosed inline code" - log.trace(f"Extracted {info} for evaluation:\n{code}") - else: - code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) - log.trace( - f"Eval message contains unformatted or badly formatted code, " - f"stripping whitespace only:\n{code}" - ) - - return code - - @staticmethod - def get_results_message(results: dict) -> Tuple[str, str]: - """Return a user-friendly message and error corresponding to the process's return code.""" - stdout, returncode = results["stdout"], results["returncode"] - msg = f"Your eval job has completed with return code {returncode}" - error = "" - - if returncode is None: - msg = "Your eval job has failed" - error = stdout.strip() - elif returncode == 128 + SIGKILL: - msg = "Your eval job timed out or ran out of memory" - elif returncode == 255: - msg = "Your eval job has failed" - error = "A fatal NsJail error occurred" - else: - # Try to append signal's name if one exists - try: - name = Signals(returncode - 128).name - msg = f"{msg} ({name})" - except ValueError: - pass - - return msg, error - - @staticmethod - def get_status_emoji(results: dict) -> str: - """Return an emoji corresponding to the status code or lack of output in result.""" - if not results["stdout"].strip(): # No output - return ":warning:" - elif results["returncode"] == 0: # No error - return ":white_check_mark:" - else: # Exception - return ":x:" - - async def format_output(self, output: str) -> Tuple[str, Optional[str]]: - """ - Format the output and return a tuple of the formatted output and a URL to the full output. - - Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters - and upload the full output to a paste service. - """ - log.trace("Formatting output...") - - output = output.rstrip("\n") - original_output = output # To be uploaded to a pasting service if needed - paste_link = None - - if "<@" in output: - output = output.replace("<@", "<@\u200B") # Zero-width space - - if " 0: - output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] - output = output[:11] # Limiting to only 11 lines - output = "\n".join(output) - - if lines > 10: - truncated = True - if len(output) >= 1000: - output = f"{output[:1000]}\n... (truncated - too long, too many lines)" - else: - output = f"{output}\n... (truncated - too many lines)" - elif len(output) >= 1000: - truncated = True - output = f"{output[:1000]}\n... (truncated - too long)" - - if truncated: - paste_link = await self.upload_output(original_output) - - output = output or "[No output]" - - return output, paste_link - - async def send_eval(self, ctx: Context, code: str) -> Message: - """ - Evaluate code, format it, and send the output to the corresponding channel. - - Return the bot response. - """ - async with ctx.typing(): - results = await self.post_eval(code) - msg, error = self.get_results_message(results) - - if error: - output, paste_link = error, None - else: - output, paste_link = await self.format_output(results["stdout"]) - - icon = self.get_status_emoji(results) - msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" - if paste_link: - msg = f"{msg}\nFull output: {paste_link}" - - # Collect stats of eval fails + successes - if icon == ":x:": - self.bot.stats.incr("snekbox.python.fail") - else: - self.bot.stats.incr("snekbox.python.success") - - filter_cog = self.bot.get_cog("Filtering") - filter_triggered = False - if filter_cog: - filter_triggered = await filter_cog.filter_eval(msg, ctx.message) - if filter_triggered: - response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") - else: - response = await ctx.send(msg) - self.bot.loop.create_task( - wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) - ) - - log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") - return response - - async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: - """ - Check if the eval session should continue. - - Return the new code to evaluate or None if the eval session should be terminated. - """ - _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) - _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) - - with contextlib.suppress(NotFound): - try: - _, new_message = await self.bot.wait_for( - 'message_edit', - check=_predicate_eval_message_edit, - timeout=REEVAL_TIMEOUT - ) - await ctx.message.add_reaction(REEVAL_EMOJI) - await self.bot.wait_for( - 'reaction_add', - check=_predicate_emoji_reaction, - timeout=10 - ) - - code = await self.get_code(new_message) - await ctx.message.clear_reactions() - with contextlib.suppress(HTTPException): - await response.delete() - - except asyncio.TimeoutError: - await ctx.message.clear_reactions() - return None - - return code - - async def get_code(self, message: Message) -> Optional[str]: - """ - Return the code from `message` to be evaluated. - - If the message is an invocation of the eval command, return the first argument or None if it - doesn't exist. Otherwise, return the full content of the message. - """ - log.trace(f"Getting context for message {message.id}.") - new_ctx = await self.bot.get_context(message) - - if new_ctx.command is self.eval_command: - log.trace(f"Message {message.id} invokes eval command.") - split = message.content.split(maxsplit=1) - code = split[1] if len(split) > 1 else None - else: - log.trace(f"Message {message.id} does not invoke eval command.") - code = message.content - - return code - - @command(name="eval", aliases=("e",)) - @guild_only() - @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) - async def eval_command(self, ctx: Context, *, code: str = None) -> None: - """ - Run Python code and get the results. - - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! - """ - if ctx.author.id in self.jobs: - await ctx.send( - f"{ctx.author.mention} You've already got a job running - " - "please wait for it to finish!" - ) - return - - if not code: # None or empty string - await ctx.send_help(ctx.command) - return - - if Roles.helpers in (role.id for role in ctx.author.roles): - self.bot.stats.incr("snekbox_usages.roles.helpers") - else: - self.bot.stats.incr("snekbox_usages.roles.developers") - - if ctx.channel.category_id == Categories.help_in_use: - self.bot.stats.incr("snekbox_usages.channels.help") - elif ctx.channel.id == Channels.bot_commands: - self.bot.stats.incr("snekbox_usages.channels.bot_commands") - else: - self.bot.stats.incr("snekbox_usages.channels.topical") - - log.info(f"Received code from {ctx.author} for evaluation:\n{code}") - - while True: - self.jobs[ctx.author.id] = datetime.datetime.now() - code = self.prepare_input(code) - try: - response = await self.send_eval(ctx, code) - finally: - del self.jobs[ctx.author.id] - - code = await self.continue_eval(ctx, response) - if not code: - break - log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") - - -def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: - """Return True if the edited message is the context message and the content was indeed modified.""" - return new_msg.id == ctx.message.id and old_msg.content != new_msg.content - - -def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: - """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI - - -def setup(bot: Bot) -> None: - """Load the Snekbox cog.""" - bot.add_cog(Snekbox(bot)) diff --git a/bot/cogs/source.py b/bot/cogs/source.py deleted file mode 100644 index 205e0ba81..000000000 --- a/bot/cogs/source.py +++ /dev/null @@ -1,141 +0,0 @@ -import inspect -from pathlib import Path -from typing import Optional, Tuple, Union - -from discord import Embed -from discord.ext import commands - -from bot.bot import Bot -from bot.constants import URLs - -SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] - - -class SourceConverter(commands.Converter): - """Convert an argument into a help command, tag, command, or cog.""" - - async def convert(self, ctx: commands.Context, argument: str) -> SourceType: - """Convert argument into source object.""" - if argument.lower().startswith("help"): - return ctx.bot.help_command - - cog = ctx.bot.get_cog(argument) - if cog: - return cog - - cmd = ctx.bot.get_command(argument) - if cmd: - return cmd - - tags_cog = ctx.bot.get_cog("Tags") - show_tag = True - - if not tags_cog: - show_tag = False - elif argument.lower() in tags_cog._cache: - return argument.lower() - - raise commands.BadArgument( - f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog." - ) - - -class BotSource(commands.Cog): - """Displays information about the bot's source code.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @commands.command(name="source", aliases=("src",)) - async def source_command(self, ctx: commands.Context, *, source_item: SourceConverter = None) -> None: - """Display information and a GitHub link to the source code of a command, tag, or cog.""" - if not source_item: - embed = Embed(title="Bot's GitHub Repository") - embed.add_field(name="Repository", value=f"[Go to GitHub]({URLs.github_bot_repo})") - embed.set_thumbnail(url="https://avatars1.githubusercontent.com/u/9919") - await ctx.send(embed=embed) - return - - embed = await self.build_embed(source_item) - await ctx.send(embed=embed) - - def get_source_link(self, source_item: SourceType) -> Tuple[str, str, Optional[int]]: - """ - Build GitHub link of source item, return this link, file location and first line number. - - Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). - """ - if isinstance(source_item, commands.Command): - if source_item.cog_name == "Alias": - cmd_name = source_item.callback.__name__.replace("_alias", "") - cmd = self.bot.get_command(cmd_name.replace("_", " ")) - src = cmd.callback.__code__ - filename = src.co_filename - else: - src = source_item.callback.__code__ - filename = src.co_filename - elif isinstance(source_item, str): - tags_cog = self.bot.get_cog("Tags") - filename = tags_cog._cache[source_item]["location"] - else: - src = type(source_item) - try: - filename = inspect.getsourcefile(src) - except TypeError: - raise commands.BadArgument("Cannot get source for a dynamically-created object.") - - if not isinstance(source_item, str): - try: - lines, first_line_no = inspect.getsourcelines(src) - except OSError: - raise commands.BadArgument("Cannot get source for a dynamically-created object.") - - lines_extension = f"#L{first_line_no}-L{first_line_no+len(lines)-1}" - else: - first_line_no = None - lines_extension = "" - - # Handle tag file location differently than others to avoid errors in some cases - if not first_line_no: - file_location = Path(filename).relative_to("/bot/") - else: - file_location = Path(filename).relative_to(Path.cwd()).as_posix() - - url = f"{URLs.github_bot_repo}/blob/master/{file_location}{lines_extension}" - - return url, file_location, first_line_no or None - - async def build_embed(self, source_object: SourceType) -> Optional[Embed]: - """Build embed based on source object.""" - url, location, first_line = self.get_source_link(source_object) - - if isinstance(source_object, commands.HelpCommand): - title = "Help Command" - description = source_object.__doc__.splitlines()[1] - elif isinstance(source_object, commands.Command): - if source_object.cog_name == "Alias": - cmd_name = source_object.callback.__name__.replace("_alias", "") - cmd = self.bot.get_command(cmd_name.replace("_", " ")) - description = cmd.short_doc - else: - description = source_object.short_doc - - title = f"Command: {source_object.qualified_name}" - elif isinstance(source_object, str): - title = f"Tag: {source_object}" - description = "" - else: - title = f"Cog: {source_object.qualified_name}" - description = source_object.description.splitlines()[0] - - embed = Embed(title=title, description=description) - embed.add_field(name="Source Code", value=f"[Go to GitHub]({url})") - line_text = f":{first_line}" if first_line else "" - embed.set_footer(text=f"{location}{line_text}") - - return embed - - -def setup(bot: Bot) -> None: - """Load the BotSource cog.""" - bot.add_cog(BotSource(bot)) diff --git a/bot/cogs/stats.py b/bot/cogs/stats.py deleted file mode 100644 index d42f55466..000000000 --- a/bot/cogs/stats.py +++ /dev/null @@ -1,129 +0,0 @@ -import string -from datetime import datetime - -from discord import Member, Message, Status -from discord.ext.commands import Cog, Context -from discord.ext.tasks import loop - -from bot.bot import Bot -from bot.constants import Categories, Channels, Guild, Stats as StatConf - - -CHANNEL_NAME_OVERRIDES = { - Channels.off_topic_0: "off_topic_0", - Channels.off_topic_1: "off_topic_1", - Channels.off_topic_2: "off_topic_2", - Channels.staff_lounge: "staff_lounge" -} - -ALLOWED_CHARS = string.ascii_letters + string.digits + "_" - - -class Stats(Cog): - """A cog which provides a way to hook onto Discord events and forward to stats.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.last_presence_update = None - self.update_guild_boost.start() - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Report message events in the server to statsd.""" - if message.guild is None: - return - - if message.guild.id != Guild.id: - return - - cat = getattr(message.channel, "category", None) - if cat is not None and cat.id == Categories.modmail: - if message.channel.id != Channels.incidents: - # Do not report modmail channels to stats, there are too many - # of them for interesting statistics to be drawn out of this. - return - - reformatted_name = message.channel.name.replace('-', '_') - - if CHANNEL_NAME_OVERRIDES.get(message.channel.id): - reformatted_name = CHANNEL_NAME_OVERRIDES.get(message.channel.id) - - reformatted_name = "".join(char for char in reformatted_name if char in ALLOWED_CHARS) - - stat_name = f"channels.{reformatted_name}" - self.bot.stats.incr(stat_name) - - # Increment the total message count - self.bot.stats.incr("messages") - - @Cog.listener() - async def on_command_completion(self, ctx: Context) -> None: - """Report completed commands to statsd.""" - command_name = ctx.command.qualified_name.replace(" ", "_") - - self.bot.stats.incr(f"commands.{command_name}") - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Update member count stat on member join.""" - if member.guild.id != Guild.id: - return - - self.bot.stats.gauge("guild.total_members", len(member.guild.members)) - - @Cog.listener() - async def on_member_leave(self, member: Member) -> None: - """Update member count stat on member leave.""" - if member.guild.id != Guild.id: - return - - self.bot.stats.gauge("guild.total_members", len(member.guild.members)) - - @Cog.listener() - async def on_member_update(self, _before: Member, after: Member) -> None: - """Update presence estimates on member update.""" - if after.guild.id != Guild.id: - return - - if self.last_presence_update: - if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: - return - - self.last_presence_update = datetime.now() - - online = 0 - idle = 0 - dnd = 0 - offline = 0 - - for member in after.guild.members: - if member.status is Status.online: - online += 1 - elif member.status is Status.dnd: - dnd += 1 - elif member.status is Status.idle: - idle += 1 - elif member.status is Status.offline: - offline += 1 - - self.bot.stats.gauge("guild.status.online", online) - self.bot.stats.gauge("guild.status.idle", idle) - self.bot.stats.gauge("guild.status.do_not_disturb", dnd) - self.bot.stats.gauge("guild.status.offline", offline) - - @loop(hours=1) - async def update_guild_boost(self) -> None: - """Post the server boost level and tier every hour.""" - await self.bot.wait_until_guild_available() - g = self.bot.get_guild(Guild.id) - self.bot.stats.gauge("boost.amount", g.premium_subscription_count) - self.bot.stats.gauge("boost.tier", g.premium_tier) - - def cog_unload(self) -> None: - """Stop the boost statistic task on unload of the Cog.""" - self.update_guild_boost.stop() - - -def setup(bot: Bot) -> None: - """Load the stats cog.""" - bot.add_cog(Stats(bot)) diff --git a/bot/cogs/sync/__init__.py b/bot/cogs/sync/__init__.py deleted file mode 100644 index fe7df4e9b..000000000 --- a/bot/cogs/sync/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from bot.bot import Bot -from .cog import Sync - - -def setup(bot: Bot) -> None: - """Load the Sync cog.""" - bot.add_cog(Sync(bot)) diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py deleted file mode 100644 index 5ace957e7..000000000 --- a/bot/cogs/sync/cog.py +++ /dev/null @@ -1,180 +0,0 @@ -import logging -from typing import Any, Dict - -from discord import Member, Role, User -from discord.ext import commands -from discord.ext.commands import Cog, Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.sync import syncers - -log = logging.getLogger(__name__) - - -class Sync(Cog): - """Captures relevant events and sends them to the site.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.role_syncer = syncers.RoleSyncer(self.bot) - self.user_syncer = syncers.UserSyncer(self.bot) - - self.bot.loop.create_task(self.sync_guild()) - - async def sync_guild(self) -> None: - """Syncs the roles/users of the guild with the database.""" - await self.bot.wait_until_guild_available() - - guild = self.bot.get_guild(constants.Guild.id) - if guild is None: - return - - for syncer in (self.role_syncer, self.user_syncer): - await syncer.sync(guild) - - async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: - """Send a PATCH request to partially update a user in the database.""" - try: - await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) - except ResponseCodeError as e: - if e.response.status != 404: - raise - if not ignore_404: - log.warning("Unable to update user, got 404. Assuming race condition from join event.") - - @Cog.listener() - async def on_guild_role_create(self, role: Role) -> None: - """Adds newly create role to the database table over the API.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.post( - 'bot/roles', - json={ - 'colour': role.colour.value, - 'id': role.id, - 'name': role.name, - 'permissions': role.permissions.value, - 'position': role.position, - } - ) - - @Cog.listener() - async def on_guild_role_delete(self, role: Role) -> None: - """Deletes role from the database when it's deleted from the guild.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - @Cog.listener() - async def on_guild_role_update(self, before: Role, after: Role) -> None: - """Syncs role with the database if any of the stored attributes were updated.""" - if after.guild.id != constants.Guild.id: - return - - was_updated = ( - before.name != after.name - or before.colour != after.colour - or before.permissions != after.permissions - or before.position != after.position - ) - - if was_updated: - await self.bot.api_client.put( - f'bot/roles/{after.id}', - json={ - 'colour': after.colour.value, - 'id': after.id, - 'name': after.name, - 'permissions': after.permissions.value, - 'position': after.position, - } - ) - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """ - Adds a new user or updates existing user to the database when a member joins the guild. - - If the joining member is a user that is already known to the database (i.e., a user that - previously left), it will update the user's information. If the user is not yet known by - the database, the user is added. - """ - if member.guild.id != constants.Guild.id: - return - - packed = { - 'discriminator': int(member.discriminator), - 'id': member.id, - 'in_guild': True, - 'name': member.name, - 'roles': sorted(role.id for role in member.roles) - } - - got_error = False - - try: - # First try an update of the user to set the `in_guild` field and other - # fields that may have changed since the last time we've seen them. - await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) - - except ResponseCodeError as e: - # If we didn't get 404, something else broke - propagate it up. - if e.response.status != 404: - raise - - got_error = True # yikes - - if got_error: - # If we got `404`, the user is new. Create them. - await self.bot.api_client.post('bot/users', json=packed) - - @Cog.listener() - async def on_member_remove(self, member: Member) -> None: - """Set the in_guild field to False when a member leaves the guild.""" - if member.guild.id != constants.Guild.id: - return - - await self.patch_user(member.id, json={"in_guild": False}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Update the roles of the member in the database if a change is detected.""" - if after.guild.id != constants.Guild.id: - return - - if before.roles != after.roles: - updated_information = {"roles": sorted(role.id for role in after.roles)} - await self.patch_user(after.id, json=updated_information) - - @Cog.listener() - async def on_user_update(self, before: User, after: User) -> None: - """Update the user information in the database if a relevant change is detected.""" - attrs = ("name", "discriminator") - if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): - updated_information = { - "name": after.name, - "discriminator": int(after.discriminator), - } - # A 404 likely means the user is in another guild. - await self.patch_user(after.id, json=updated_information, ignore_404=True) - - @commands.group(name='sync') - @commands.has_permissions(administrator=True) - async def sync_group(self, ctx: Context) -> None: - """Run synchronizations between the bot and site manually.""" - - @sync_group.command(name='roles') - @commands.has_permissions(administrator=True) - async def sync_roles_command(self, ctx: Context) -> None: - """Manually synchronise the guild's roles with the roles on the site.""" - await self.role_syncer.sync(ctx.guild, ctx) - - @sync_group.command(name='users') - @commands.has_permissions(administrator=True) - async def sync_users_command(self, ctx: Context) -> None: - """Manually synchronise the guild's users with the users on the site.""" - await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py deleted file mode 100644 index f7ba811bc..000000000 --- a/bot/cogs/sync/syncers.py +++ /dev/null @@ -1,347 +0,0 @@ -import abc -import asyncio -import logging -import typing as t -from collections import namedtuple -from functools import partial - -import discord -from discord import Guild, HTTPException, Member, Message, Reaction, User -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot - -log = logging.getLogger(__name__) - -# These objects are declared as namedtuples because tuples are hashable, -# something that we make use of when diffing site roles against guild roles. -_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) -_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) - - -class Syncer(abc.ABC): - """Base class for synchronising the database with objects in the Discord cache.""" - - _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " - _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - @property - @abc.abstractmethod - def name(self) -> str: - """The name of the syncer; used in output messages and logging.""" - raise NotImplementedError # pragma: no cover - - async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: - """ - Send a prompt to confirm or abort a sync using reactions and return the sent message. - - If a message is given, it is edited to display the prompt and reactions. Otherwise, a new - message is sent to the dev-core channel and mentions the core developers role. If the - channel cannot be retrieved, return None. - """ - log.trace(f"Sending {self.name} sync confirmation prompt.") - - msg_content = ( - f'Possible cache issue while syncing {self.name}s. ' - f'More than {constants.Sync.max_diff} {self.name}s were changed. ' - f'React to confirm or abort the sync.' - ) - - # Send to core developers if it's an automatic sync. - if not message: - log.trace("Message not provided for confirmation; creating a new one in dev-core.") - channel = self.bot.get_channel(constants.Channels.dev_core) - - if not channel: - log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") - try: - channel = await self.bot.fetch_channel(constants.Channels.dev_core) - except HTTPException: - log.exception( - f"Failed to fetch channel for sending sync confirmation prompt; " - f"aborting {self.name} sync." - ) - return None - - allowed_roles = [discord.Object(constants.Roles.core_developers)] - message = await channel.send( - f"{self._CORE_DEV_MENTION}{msg_content}", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) - ) - else: - await message.edit(content=msg_content) - - # Add the initial reactions. - log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") - for emoji in self._REACTION_EMOJIS: - await message.add_reaction(emoji) - - return message - - def _reaction_check( - self, - author: Member, - message: Message, - reaction: Reaction, - user: t.Union[Member, User] - ) -> bool: - """ - Return True if the `reaction` is a valid confirmation or abort reaction on `message`. - - If the `author` of the prompt is a bot, then a reaction by any core developer will be - considered valid. Otherwise, the author of the reaction (`user`) will have to be the - `author` of the prompt. - """ - # For automatic syncs, check for the core dev role instead of an exact author - has_role = any(constants.Roles.core_developers == role.id for role in user.roles) - return ( - reaction.message.id == message.id - and not user.bot - and (has_role if author.bot else user == author) - and str(reaction.emoji) in self._REACTION_EMOJIS - ) - - async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: - """ - Wait for a confirmation reaction by `author` on `message` and return True if confirmed. - - Uses the `_reaction_check` function to determine if a reaction is valid. - - If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. - To acknowledge the reaction (or lack thereof), `message` will be edited. - """ - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - reaction = None - try: - log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") - reaction, _ = await self.bot.wait_for( - 'reaction_add', - check=partial(self._reaction_check, author, message), - timeout=constants.Sync.confirm_timeout - ) - except asyncio.TimeoutError: - # reaction will remain none thus sync will be aborted in the finally block below. - log.debug(f"The {self.name} syncer confirmation prompt timed out.") - - if str(reaction) == constants.Emojis.check_mark: - log.trace(f"The {self.name} syncer was confirmed.") - await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') - return True - else: - log.info(f"The {self.name} syncer was aborted or timed out!") - await message.edit( - content=f':warning: {mention}{self.name} sync aborted or timed out!' - ) - return False - - @abc.abstractmethod - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference between the cache of `guild` and the database.""" - raise NotImplementedError # pragma: no cover - - @abc.abstractmethod - async def _sync(self, diff: _Diff) -> None: - """Perform the API calls for synchronisation.""" - raise NotImplementedError # pragma: no cover - - async def _get_confirmation_result( - self, - diff_size: int, - author: Member, - message: t.Optional[Message] = None - ) -> t.Tuple[bool, t.Optional[Message]]: - """ - Prompt for confirmation and return a tuple of the result and the prompt message. - - `diff_size` is the size of the diff of the sync. If it is greater than - `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the - sync and the `message` is an extant message to edit to display the prompt. - - If confirmed or no confirmation was needed, the result is True. The returned message will - either be the given `message` or a new one which was created when sending the prompt. - """ - log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") - if diff_size > constants.Sync.max_diff: - message = await self._send_prompt(message) - if not message: - return False, None # Couldn't get channel. - - confirmed = await self._wait_for_confirmation(author, message) - if not confirmed: - return False, message # Sync aborted. - - return True, message - - async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: - """ - Synchronise the database with the cache of `guild`. - - If the differences between the cache and the database are greater than - `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core - channel. The confirmation can be optionally redirect to `ctx` instead. - """ - log.info(f"Starting {self.name} syncer.") - - message = None - author = self.bot.user - if ctx: - message = await ctx.send(f"📊 Synchronising {self.name}s.") - author = ctx.author - - diff = await self._get_diff(guild) - diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict - totals = {k: len(v) for k, v in diff_dict.items() if v is not None} - diff_size = sum(totals.values()) - - confirmed, message = await self._get_confirmation_result(diff_size, author, message) - if not confirmed: - return - - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - try: - await self._sync(diff) - except ResponseCodeError as e: - log.exception(f"{self.name} syncer failed!") - - # Don't show response text because it's probably some really long HTML. - results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" - content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" - else: - results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) - log.info(f"{self.name} syncer finished: {results}.") - content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" - - if message: - await message.edit(content=content) - - -class RoleSyncer(Syncer): - """Synchronise the database with roles in the cache.""" - - name = "role" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of roles between the cache of `guild` and the database.""" - log.trace("Getting the diff for roles.") - roles = await self.bot.api_client.get('bot/roles') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_roles = {_Role(**role_dict) for role_dict in roles} - guild_roles = { - _Role( - id=role.id, - name=role.name, - colour=role.colour.value, - permissions=role.permissions.value, - position=role.position, - ) - for role in guild.roles - } - - guild_role_ids = {role.id for role in guild_roles} - api_role_ids = {role.id for role in db_roles} - new_role_ids = guild_role_ids - api_role_ids - deleted_role_ids = api_role_ids - guild_role_ids - - # New roles are those which are on the cached guild but not on the - # DB guild, going by the role ID. We need to send them in for creation. - roles_to_create = {role for role in guild_roles if role.id in new_role_ids} - roles_to_update = guild_roles - db_roles - roles_to_create - roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} - - return _Diff(roles_to_create, roles_to_update, roles_to_delete) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the role cache of `guild`.""" - log.trace("Syncing created roles...") - for role in diff.created: - await self.bot.api_client.post('bot/roles', json=role._asdict()) - - log.trace("Syncing updated roles...") - for role in diff.updated: - await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) - - log.trace("Syncing deleted roles...") - for role in diff.deleted: - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - -class UserSyncer(Syncer): - """Synchronise the database with users in the cache.""" - - name = "user" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of users between the cache of `guild` and the database.""" - log.trace("Getting the diff for users.") - users = await self.bot.api_client.get('bot/users') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_users = { - user_dict['id']: _User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in users - } - guild_users = { - member.id: _User( - id=member.id, - name=member.name, - discriminator=int(member.discriminator), - roles=tuple(sorted(role.id for role in member.roles)), - in_guild=True - ) - for member in guild.members - } - - users_to_create = set() - users_to_update = set() - - for db_user in db_users.values(): - guild_user = guild_users.get(db_user.id) - if guild_user is not None: - if db_user != guild_user: - users_to_update.add(guild_user) - - elif db_user.in_guild: - # The user is known in the DB but not the guild, and the - # DB currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - new_api_user = db_user._replace(in_guild=False) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(db_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) - - return _Diff(users_to_create, users_to_update, None) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the user cache of `guild`.""" - log.trace("Syncing created users...") - for user in diff.created: - await self.bot.api_client.post('bot/users', json=user._asdict()) - - log.trace("Syncing updated users...") - for user in diff.updated: - await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py deleted file mode 100644 index 3d76c5c08..000000000 --- a/bot/cogs/tags.py +++ /dev/null @@ -1,277 +0,0 @@ -import logging -import re -import time -from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot import constants -from bot.bot import Bot -from bot.converters import TagNameConverter -from bot.pagination import LinePaginator -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -TEST_CHANNELS = ( - constants.Channels.bot_commands, - constants.Channels.helpers -) - -REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) -FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." - - -class Tags(Cog): - """Save new tags and fetch existing tags.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.tag_cooldowns = {} - self._cache = self.get_tags() - - @staticmethod - def get_tags() -> dict: - """Get all tags.""" - cache = {} - - base_path = Path("bot", "resources", "tags") - for file in base_path.glob("**/*"): - if file.is_file(): - tag_title = file.stem - tag = { - "title": tag_title, - "embed": { - "description": file.read_text(encoding="utf8"), - }, - "restricted_to": "developers", - "location": f"/bot/{file}" - } - - # Convert to a list to allow negative indexing. - parents = list(file.relative_to(base_path).parents) - if len(parents) > 1: - # -1 would be '.' hence -2 is used as the index. - tag["restricted_to"] = parents[-2].name - - cache[tag_title] = tag - - return cache - - @staticmethod - def check_accessibility(user: Member, tag: dict) -> bool: - """Check if user can access a tag.""" - return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] - - @staticmethod - def _fuzzy_search(search: str, target: str) -> float: - """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" - current, index = 0, 0 - _search = REGEX_NON_ALPHABET.sub('', search.lower()) - _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) - _target = next(_targets) - try: - while True: - while index < len(_target) and _search[current] == _target[index]: - current += 1 - index += 1 - index, _target = 0, next(_targets) - except (StopIteration, IndexError): - pass - return current / len(_search) * 100 - - def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: - """Return a list of suggested tags.""" - scores: Dict[str, int] = { - tag_title: Tags._fuzzy_search(tag_name, tag['title']) - for tag_title, tag in self._cache.items() - } - - thresholds = thresholds or [100, 90, 80, 70, 60] - - for threshold in thresholds: - suggestions = [ - self._cache[tag_title] - for tag_title, matching_score in scores.items() - if matching_score >= threshold - ] - if suggestions: - return suggestions - - return [] - - def _get_tag(self, tag_name: str) -> list: - """Get a specific tag.""" - found = [self._cache.get(tag_name.lower(), None)] - if not found[0]: - return self._get_suggestions(tag_name) - return found - - def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: - """ - Search for tags via contents. - - `predicate` will be the built-in any, all, or a custom callable. Must return a bool. - """ - keywords_processed: List[str] = [] - for keyword in keywords.split(','): - keyword_sanitized = keyword.strip().casefold() - if not keyword_sanitized: - # this happens when there are leading / trailing / consecutive comma. - continue - keywords_processed.append(keyword_sanitized) - - if not keywords_processed: - # after sanitizing, we can end up with an empty list, for example when keywords is ',' - # in that case, we simply want to search for such keywords directly instead. - keywords_processed = [keywords] - - matching_tags = [] - for tag in self._cache.values(): - matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) - if self.check_accessibility(user, tag) and check(matches): - matching_tags.append(tag) - - return matching_tags - - async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: - """Send the result of matching tags to user.""" - if not matching_tags: - pass - elif len(matching_tags) == 1: - await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) - else: - is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 - embed = Embed( - title=f"Here are the tags containing the given keyword{'s' * is_plural}:", - description='\n'.join(tag['title'] for tag in matching_tags[:10]) - ) - await LinePaginator.paginate( - sorted(f"**»** {tag['title']}" for tag in matching_tags), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - - @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) - async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: - """Show all known tags, a single tag, or run a subcommand.""" - await ctx.invoke(self.get_command, tag_name=tag_name) - - @tags_group.group(name='search', invoke_without_command=True) - async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Only search for tags that has ALL the keywords. - """ - matching_tags = self._get_tags_via_content(all, keywords, ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @search_tag_content.command(name='any') - async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Search for tags that has ANY of the keywords. - """ - matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @tags_group.command(name='get', aliases=('show', 'g')) - async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: - """Get a specified tag, or a list of all tags if no tag is specified.""" - - def _command_on_cooldown(tag_name: str) -> bool: - """ - Check if the command is currently on cooldown, on a per-tag, per-channel basis. - - The cooldown duration is set in constants.py. - """ - now = time.time() - - cooldown_conditions = ( - tag_name - and tag_name in self.tag_cooldowns - and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags - and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id - ) - - if cooldown_conditions: - return True - return False - - if _command_on_cooldown(tag_name): - time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] - time_left = constants.Cooldowns.tags - time_elapsed - log.info( - f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " - f"Cooldown ends in {time_left:.1f} seconds." - ) - return - - if tag_name is not None: - temp_founds = self._get_tag(tag_name) - - founds = [] - - for found_tag in temp_founds: - if self.check_accessibility(ctx.author, found_tag): - founds.append(found_tag) - - if len(founds) == 1: - tag = founds[0] - if ctx.channel.id not in TEST_CHANNELS: - self.tag_cooldowns[tag_name] = { - "time": time.time(), - "channel": ctx.channel.id - } - - self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") - - await wait_for_deletion( - await ctx.send(embed=Embed.from_dict(tag['embed'])), - [ctx.author.id], - client=self.bot - ) - elif founds and len(tag_name) >= 3: - await wait_for_deletion( - await ctx.send( - embed=Embed( - title='Did you mean ...', - description='\n'.join(tag['title'] for tag in founds[:10]) - ) - ), - [ctx.author.id], - client=self.bot - ) - - else: - tags = self._cache.values() - if not tags: - await ctx.send(embed=Embed( - description="**There are no tags in the database!**", - colour=Colour.red() - )) - else: - embed: Embed = Embed(title="**Current tags**") - await LinePaginator.paginate( - sorted( - f"**»** {tag['title']}" for tag in tags - if self.check_accessibility(ctx.author, tag) - ), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Tags cog.""" - bot.add_cog(Tags(bot)) diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py deleted file mode 100644 index ef979f222..000000000 --- a/bot/cogs/token_remover.py +++ /dev/null @@ -1,182 +0,0 @@ -import base64 -import binascii -import logging -import re -import typing as t - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot import utils -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import Channels, Colours, Event, Icons - -log = logging.getLogger(__name__) - -LOG_MESSAGE = ( - "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, " - "token was `{user_id}.{timestamp}.{hmac}`" -) -DELETION_MESSAGE_TEMPLATE = ( - "Hey {mention}! I noticed you posted a seemingly valid Discord API " - "token in your message and have removed your message. " - "This means that your token has been **compromised**. " - "Please change your token **immediately** at: " - "\n\n" - "Feel free to re-post it with the token removed. " - "If you believe this was a mistake, please let us know!" -) -DISCORD_EPOCH = 1_420_070_400 -TOKEN_EPOCH = 1_293_840_000 - -# Three parts delimited by dots: user ID, creation timestamp, HMAC. -# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string. -# Each part only matches base64 URL-safe characters. -# Padding has never been observed, but the padding character '=' is matched just in case. -TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII) - - -class Token(t.NamedTuple): - """A Discord Bot token.""" - - user_id: str - timestamp: str - hmac: str - - -class TokenRemover(Cog): - """Scans messages for potential discord.py bot tokens and removes them.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """ - Check each message for a string that matches Discord's token pattern. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - # Ignore DMs; can't delete messages in there anyway. - if not msg.guild or msg.author.bot: - return - - found_token = self.find_token_in_message(msg) - if found_token: - await self.take_action(msg, found_token) - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """ - Check each edit for a string that matches Discord's token pattern. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - await self.on_message(after) - - async def take_action(self, msg: Message, found_token: Token) -> None: - """Remove the `msg` containing the `found_token` and send a mod log message.""" - self.mod_log.ignore(Event.message_delete, msg.id) - - try: - await msg.delete() - except NotFound: - log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") - return - - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - log_message = self.format_log_message(msg, found_token) - log.debug(log_message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=log_message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) - - self.bot.stats.incr("tokens.removed_tokens") - - @staticmethod - def format_log_message(msg: Message, token: Token) -> str: - """Return the log message to send for `token` being censored in `msg`.""" - return LOG_MESSAGE.format( - author=msg.author, - author_id=msg.author.id, - channel=msg.channel.mention, - user_id=token.user_id, - timestamp=token.timestamp, - hmac='x' * len(token.hmac), - ) - - @classmethod - def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: - """Return a seemingly valid token found in `msg` or `None` if no token is found.""" - # Use finditer rather than search to guard against method calls prematurely returning the - # token check (e.g. `message.channel.send` also matches our token pattern) - for match in TOKEN_RE.finditer(msg.content): - token = Token(*match.groups()) - if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp): - # Short-circuit on first match - return token - - # No matching substring - return - - @staticmethod - def is_valid_user_id(b64_content: str) -> bool: - """ - Check potential token to see if it contains a valid Discord user ID. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - b64_content = utils.pad_base64(b64_content) - - try: - decoded_bytes = base64.urlsafe_b64decode(b64_content) - string = decoded_bytes.decode('utf-8') - - # isdigit on its own would match a lot of other Unicode characters, hence the isascii. - return string.isascii() and string.isdigit() - except (binascii.Error, ValueError): - return False - - @staticmethod - def is_valid_timestamp(b64_content: str) -> bool: - """ - Return True if `b64_content` decodes to a valid timestamp. - - If the timestamp is greater than the Discord epoch, it's probably valid. - See: https://i.imgur.com/7WdehGn.png - """ - b64_content = utils.pad_base64(b64_content) - - try: - decoded_bytes = base64.urlsafe_b64decode(b64_content) - timestamp = int.from_bytes(decoded_bytes, byteorder="big") - except (binascii.Error, ValueError) as e: - log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") - return False - - # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound - # is not checked. - if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH: - return True - else: - log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch") - return False - - -def setup(bot: Bot) -> None: - """Load the TokenRemover cog.""" - bot.add_cog(TokenRemover(bot)) diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py deleted file mode 100644 index d96abbd5a..000000000 --- a/bot/cogs/utils.py +++ /dev/null @@ -1,265 +0,0 @@ -import difflib -import logging -import re -import unicodedata -from email.parser import HeaderParser -from io import StringIO -from typing import Tuple, Union - -from discord import Colour, Embed, utils -from discord.ext.commands import BadArgument, Cog, Context, clean_content, command - -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES -from bot.decorators import in_whitelist, with_role -from bot.pagination import LinePaginator -from bot.utils import messages - -log = logging.getLogger(__name__) - -ZEN_OF_PYTHON = """\ -Beautiful is better than ugly. -Explicit is better than implicit. -Simple is better than complex. -Complex is better than complicated. -Flat is better than nested. -Sparse is better than dense. -Readability counts. -Special cases aren't special enough to break the rules. -Although practicality beats purity. -Errors should never pass silently. -Unless explicitly silenced. -In the face of ambiguity, refuse the temptation to guess. -There should be one-- and preferably only one --obvious way to do it. -Although that way may not be obvious at first unless you're Dutch. -Now is better than never. -Although never is often better than *right* now. -If the implementation is hard to explain, it's a bad idea. -If the implementation is easy to explain, it may be a good idea. -Namespaces are one honking great idea -- let's do more of those! -""" - -ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" - - -class Utils(Cog): - """A selection of utilities which don't have a clear category.""" - - def __init__(self, bot: Bot): - self.bot = bot - - self.base_pep_url = "http://www.python.org/dev/peps/pep-" - self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" - - @command(name='pep', aliases=('get_pep', 'p')) - async def pep_command(self, ctx: Context, pep_number: str) -> None: - """Fetches information about a PEP and sends it to the channel.""" - if pep_number.isdigit(): - pep_number = int(pep_number) - else: - await ctx.send_help(ctx.command) - return - - # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. - if pep_number == 0: - return await self.send_pep_zero(ctx) - - possible_extensions = ['.txt', '.rst'] - found_pep = False - for extension in possible_extensions: - # Attempt to fetch the PEP - pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" - log.trace(f"Requesting PEP {pep_number} with {pep_url}") - response = await self.bot.http_session.get(pep_url) - - if response.status == 200: - log.trace("PEP found") - found_pep = True - - pep_content = await response.text() - - # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 - pep_header = HeaderParser().parse(StringIO(pep_content)) - - # Assemble the embed - pep_embed = Embed( - title=f"**PEP {pep_number} - {pep_header['Title']}**", - description=f"[Link]({self.base_pep_url}{pep_number:04})", - ) - - pep_embed.set_thumbnail(url=ICON_URL) - - # Add the interesting information - fields_to_check = ("Status", "Python-Version", "Created", "Type") - for field in fields_to_check: - # Check for a PEP metadata field that is present but has an empty value - # embed field values can't contain an empty string - if pep_header.get(field, ""): - pep_embed.add_field(name=field, value=pep_header[field]) - - elif response.status != 404: - # any response except 200 and 404 is expected - found_pep = True # actually not, but it's easier to display this way - log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " - f"{response.status}.\n{response.text}") - - error_message = "Unexpected HTTP error during PEP search. Please let us know." - pep_embed = Embed(title="Unexpected error", description=error_message) - pep_embed.colour = Colour.red() - break - - if not found_pep: - log.trace("PEP was not found") - not_found = f"PEP {pep_number} does not exist." - pep_embed = Embed(title="PEP not found", description=not_found) - pep_embed.colour = Colour.red() - - await ctx.message.channel.send(embed=pep_embed) - - @command() - @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) - async def charinfo(self, ctx: Context, *, characters: str) -> None: - """Shows you information on up to 50 unicode characters.""" - match = re.match(r"<(a?):(\w+):(\d+)>", characters) - if match: - return await messages.send_denial( - ctx, - "**Non-Character Detected**\n" - "Only unicode characters can be processed, but a custom Discord emoji " - "was found. Please remove it and try again." - ) - - if len(characters) > 50: - return await messages.send_denial(ctx, f"Too many characters ({len(characters)}/50)") - - def get_info(char: str) -> Tuple[str, str]: - digit = f"{ord(char):x}" - if len(digit) <= 4: - u_code = f"\\u{digit:>04}" - else: - u_code = f"\\U{digit:>08}" - url = f"https://www.compart.com/en/unicode/U+{digit:>04}" - name = f"[{unicodedata.name(char, '')}]({url})" - info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}" - return info, u_code - - char_list, raw_list = zip(*(get_info(c) for c in characters)) - embed = Embed().set_author(name="Character Info") - - if len(characters) > 1: - # Maximum length possible is 502 out of 1024, so there's no need to truncate. - embed.add_field(name='Full Raw Text', value=f"`{''.join(raw_list)}`", inline=False) - - await LinePaginator.paginate(char_list, ctx, embed, max_lines=10, max_size=2000, empty=False) - - @command() - async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: - """ - Show the Zen of Python. - - Without any arguments, the full Zen will be produced. - If an integer is provided, the line with that index will be produced. - If a string is provided, the line which matches best will be produced. - """ - embed = Embed( - colour=Colour.blurple(), - title="The Zen of Python", - description=ZEN_OF_PYTHON - ) - - if search_value is None: - embed.title += ", by Tim Peters" - await ctx.send(embed=embed) - return - - zen_lines = ZEN_OF_PYTHON.splitlines() - - # handle if it's an index int - if isinstance(search_value, int): - upper_bound = len(zen_lines) - 1 - lower_bound = -1 * upper_bound - if not (lower_bound <= search_value <= upper_bound): - raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") - - embed.title += f" (line {search_value % len(zen_lines)}):" - embed.description = zen_lines[search_value] - await ctx.send(embed=embed) - return - - # Try to handle first exact word due difflib.SequenceMatched may use some other similar word instead - # exact word. - for i, line in enumerate(zen_lines): - for word in line.split(): - if word.lower() == search_value.lower(): - embed.title += f" (line {i}):" - embed.description = line - await ctx.send(embed=embed) - return - - # handle if it's a search string and not exact word - matcher = difflib.SequenceMatcher(None, search_value.lower()) - - best_match = "" - match_index = 0 - best_ratio = 0 - - for index, line in enumerate(zen_lines): - matcher.set_seq2(line.lower()) - - # the match ratio needs to be adjusted because, naturally, - # longer lines will have worse ratios than shorter lines when - # fuzzy searching for keywords. this seems to work okay. - adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() - - if adjusted_ratio > best_ratio: - best_ratio = adjusted_ratio - best_match = line - match_index = index - - if not best_match: - raise BadArgument("I didn't get a match! Please try again with a different search term.") - - embed.title += f" (line {match_index}):" - embed.description = best_match - await ctx.send(embed=embed) - - @command(aliases=("poll",)) - @with_role(*MODERATION_ROLES) - async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None: - """ - Build a quick voting poll with matching reactions with the provided options. - - A maximum of 20 options can be provided, as Discord supports a max of 20 - reactions on a single message. - """ - if len(title) > 256: - raise BadArgument("The title cannot be longer than 256 characters.") - if len(options) < 2: - raise BadArgument("Please provide at least 2 options.") - if len(options) > 20: - raise BadArgument("I can only handle 20 options!") - - codepoint_start = 127462 # represents "regional_indicator_a" unicode value - options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} - embed = Embed(title=title, description="\n".join(options.values())) - message = await ctx.send(embed=embed) - for reaction in options: - await message.add_reaction(reaction) - - async def send_pep_zero(self, ctx: Context) -> None: - """Send information about PEP 0.""" - pep_embed = Embed( - title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**", - description="[Link](https://www.python.org/dev/peps/)" - ) - pep_embed.set_thumbnail(url=ICON_URL) - pep_embed.add_field(name="Status", value="Active") - pep_embed.add_field(name="Created", value="13-Jul-2000") - pep_embed.add_field(name="Type", value="Informational") - - await ctx.send(embed=pep_embed) - - -def setup(bot: Bot) -> None: - """Load the Utils cog.""" - bot.add_cog(Utils(bot)) diff --git a/bot/cogs/utils/__init__.py b/bot/cogs/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/utils/bot.py b/bot/cogs/utils/bot.py new file mode 100644 index 000000000..71ed54f60 --- /dev/null +++ b/bot/cogs/utils/bot.py @@ -0,0 +1,385 @@ +import ast +import logging +import re +import time +from typing import Optional, Tuple + +from discord import Embed, Message, RawMessageUpdateEvent, TextChannel +from discord.ext.commands import Cog, Context, command, group + +from bot.bot import Bot +from bot.cogs.filters.token_remover import TokenRemover +from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs +from bot.decorators import with_role +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +RE_MARKDOWN = re.compile(r'([*_~`|>])') + + +class BotCog(Cog, name="Bot"): + """Bot information commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + # Stores allowed channels plus epoch time since last call. + self.channel_cooldowns = { + Channels.python_discussion: 0, + } + + # These channels will also work, but will not be subject to cooldown + self.channel_whitelist = ( + Channels.bot_commands, + ) + + # Stores improperly formatted Python codeblock message ids and the corresponding bot message + self.codeblock_message_ids = {} + + @group(invoke_without_command=True, name="bot", hidden=True) + @with_role(Roles.verified) + async def botinfo_group(self, ctx: Context) -> None: + """Bot informational commands.""" + await ctx.send_help(ctx.command) + + @botinfo_group.command(name='about', aliases=('info',), hidden=True) + @with_role(Roles.verified) + async def about_command(self, ctx: Context) -> None: + """Get information about the bot.""" + embed = Embed( + description="A utility bot designed just for the Python server! Try `!help` for more info.", + url="https://github.com/python-discord/bot" + ) + + embed.add_field(name="Total Users", value=str(len(self.bot.get_guild(Guild.id).members))) + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=URLs.bot_avatar + ) + + await ctx.send(embed=embed) + + @command(name='echo', aliases=('print',)) + @with_role(*MODERATION_ROLES) + async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Repeat the given message in either a specified channel or the current channel.""" + if channel is None: + await ctx.send(text) + else: + await channel.send(text) + + @command(name='embed') + @with_role(*MODERATION_ROLES) + async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Send the input within an embed to either a specified channel or the current channel.""" + embed = Embed(description=text) + + if channel is None: + await ctx.send(embed=embed) + else: + await channel.send(embed=embed) + + def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: + """ + Strip msg in order to find Python code. + + Tries to strip out Python code out of msg and returns the stripped block or + None if the block is a valid Python codeblock. + """ + if msg.count("\n") >= 3: + # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. + if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: + log.trace( + "Someone wrote a message that was already a " + "valid Python syntax highlighted code block. No action taken." + ) + return None + + else: + # Stripping backticks from every line of the message. + log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") + content = "" + for line in msg.splitlines(keepends=True): + content += line.strip("`") + + content = content.strip() + + # Remove "Python" or "Py" from start of the message if it exists. + log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") + pycode = False + if content.lower().startswith("python"): + content = content[6:] + pycode = True + elif content.lower().startswith("py"): + content = content[2:] + pycode = True + + if pycode: + content = content.splitlines(keepends=True) + + # Check if there might be code in the first line, and preserve it. + first_line = content[0] + if " " in content[0]: + first_space = first_line.index(" ") + content[0] = first_line[first_space:] + content = "".join(content) + + # If there's no code we can just get rid of the first line. + else: + content = "".join(content[1:]) + + # Strip it again to remove any leading whitespace. This is neccessary + # if the first line of the message looked like ```python + old = content.strip() + + # Strips REPL code out of the message if there is any. + content, repl_code = self.repl_stripping(old) + if old != content: + return (content, old), repl_code + + # Try to apply indentation fixes to the code. + content = self.fix_indentation(content) + + # Check if the code contains backticks, if it does ignore the message. + if "`" in content: + log.trace("Detected ` inside the code, won't reply") + return None + else: + log.trace(f"Returning message.\n\n{content}\n\n") + return (content,), repl_code + + def fix_indentation(self, msg: str) -> str: + """Attempts to fix badly indented code.""" + def unindent(code: str, skip_spaces: int = 0) -> str: + """Unindents all code down to the number of spaces given in skip_spaces.""" + final = "" + current = code[0] + leading_spaces = 0 + + # Get numbers of spaces before code in the first line. + while current == " ": + current = code[leading_spaces + 1] + leading_spaces += 1 + leading_spaces -= skip_spaces + + # If there are any, remove that number of spaces from every line. + if leading_spaces > 0: + for line in code.splitlines(keepends=True): + line = line[leading_spaces:] + final += line + return final + else: + return code + + # Apply fix for "all lines are overindented" case. + msg = unindent(msg) + + # If the first line does not end with a colon, we can be + # certain the next line will be on the same indentation level. + # + # If it does end with a colon, we will need to indent all successive + # lines one additional level. + first_line = msg.splitlines()[0] + code = "".join(msg.splitlines(keepends=True)[1:]) + if not first_line.endswith(":"): + msg = f"{first_line}\n{unindent(code)}" + else: + msg = f"{first_line}\n{unindent(code, 4)}" + return msg + + def repl_stripping(self, msg: str) -> Tuple[str, bool]: + """ + Strip msg in order to extract Python code out of REPL output. + + Tries to strip out REPL Python code out of msg and returns the stripped msg. + + Returns True for the boolean if REPL code was found in the input msg. + """ + final = "" + for line in msg.splitlines(keepends=True): + if line.startswith(">>>") or line.startswith("..."): + final += line[4:] + log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") + if not final: + log.trace(f"Found no REPL code in \n\n{msg}\n\n") + return msg, False + else: + log.trace(f"Found REPL code in \n\n{msg}\n\n") + return final.rstrip(), True + + def has_bad_ticks(self, msg: Message) -> bool: + """Check to see if msg contains ticks that aren't '`'.""" + not_backticks = [ + "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", + "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", + "\u3003\u3003\u3003" + ] + + return msg.content[:3] in not_backticks + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """ + Detect poorly formatted Python code in new messages. + + If poorly formatted code is detected, send the user a helpful message explaining how to do + properly formatted Python syntax highlighting codeblocks. + """ + is_help_channel = ( + getattr(msg.channel, "category", None) + and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) + ) + parse_codeblock = ( + ( + is_help_channel + or msg.channel.id in self.channel_cooldowns + or msg.channel.id in self.channel_whitelist + ) + and not msg.author.bot + and len(msg.content.splitlines()) > 3 + and not TokenRemover.find_token_in_message(msg) + ) + + if parse_codeblock: # no token in the msg + on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 + if not on_cooldown or DEBUG_MODE: + try: + if self.has_bad_ticks(msg): + ticks = msg.content[:3] + content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) + if content is None: + return + + content, repl_code = content + + if len(content) == 2: + content = content[1] + else: + content = content[0] + + space_left = 204 + if len(content) >= space_left: + current_length = 0 + lines_walked = 0 + for line in content.splitlines(keepends=True): + if current_length + len(line) > space_left or lines_walked == 10: + break + current_length += len(line) + lines_walked += 1 + content = content[:current_length] + "#..." + content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) + howto = ( + "It looks like you are trying to paste code into this channel.\n\n" + "You seem to be using the wrong symbols to indicate where the codeblock should start. " + f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" + "**Here is an example of how it should look:**\n" + f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" + "**This will result in the following:**\n" + f"```python\n{content}\n```" + ) + + else: + howto = "" + content = self.codeblock_stripping(msg.content, False) + if content is None: + return + + content, repl_code = content + # Attempts to parse the message into an AST node. + # Invalid Python code will raise a SyntaxError. + tree = ast.parse(content[0]) + + # Multiple lines of single words could be interpreted as expressions. + # This check is to avoid all nodes being parsed as expressions. + # (e.g. words over multiple lines) + if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: + # Shorten the code to 10 lines and/or 204 characters. + space_left = 204 + if content and repl_code: + content = content[1] + else: + content = content[0] + + if len(content) >= space_left: + current_length = 0 + lines_walked = 0 + for line in content.splitlines(keepends=True): + if current_length + len(line) > space_left or lines_walked == 10: + break + current_length += len(line) + lines_walked += 1 + content = content[:current_length] + "#..." + + content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) + howto += ( + "It looks like you're trying to paste code into this channel.\n\n" + "Discord has support for Markdown, which allows you to post code with full " + "syntax highlighting. Please use these whenever you paste code, as this " + "helps improve the legibility and makes it easier for us to help you.\n\n" + f"**To do this, use the following method:**\n" + f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" + "**This will result in the following:**\n" + f"```python\n{content}\n```" + ) + + log.debug(f"{msg.author} posted something that needed to be put inside python code " + "blocks. Sending the user some instructions.") + else: + log.trace("The code consists only of expressions, not sending instructions") + + if howto != "": + # Increase amount of codeblock correction in stats + self.bot.stats.incr("codeblock_corrections") + howto_embed = Embed(description=howto) + bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) + self.codeblock_message_ids[msg.id] = bot_message.id + + self.bot.loop.create_task( + wait_for_deletion(bot_message, user_ids=(msg.author.id,), client=self.bot) + ) + else: + return + + if msg.channel.id not in self.channel_whitelist: + self.channel_cooldowns[msg.channel.id] = time.time() + + except SyntaxError: + log.trace( + f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " + "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " + f"The message that was posted was:\n\n{msg.content}\n\n" + ) + + @Cog.listener() + async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: + """Check to see if an edited message (previously called out) still contains poorly formatted code.""" + if ( + # Checks to see if the message was called out by the bot + payload.message_id not in self.codeblock_message_ids + # Makes sure that there is content in the message + or payload.data.get("content") is None + # Makes sure there's a channel id in the message payload + or payload.data.get("channel_id") is None + ): + return + + # Retrieve channel and message objects for use later + channel = self.bot.get_channel(int(payload.data.get("channel_id"))) + user_message = await channel.fetch_message(payload.message_id) + + # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None + has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) + + # If the message is fixed, delete the bot message and the entry from the id dictionary + if has_fixed_codeblock is None: + bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) + await bot_message.delete() + del self.codeblock_message_ids[payload.message_id] + log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") + + +def setup(bot: Bot) -> None: + """Load the Bot cog.""" + bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/utils/clean.py b/bot/cogs/utils/clean.py new file mode 100644 index 000000000..f436e531a --- /dev/null +++ b/bot/cogs/utils/clean.py @@ -0,0 +1,272 @@ +import logging +import random +import re +from typing import Iterable, Optional + +from discord import Colour, Embed, Message, TextChannel, User +from discord.ext import commands +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import ( + Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES +) +from bot.decorators import with_role + +log = logging.getLogger(__name__) + + +class Clean(Cog): + """ + A cog that allows messages to be deleted in bulk, while applying various filters. + + You can delete messages sent by a specific user, messages sent by bots, all messages, or messages that match a + specific regular expression. + + The deleted messages are saved and uploaded to the database via an API endpoint, and a URL is returned which can be + used to view the messages in the Discord dark theme style. + """ + + def __init__(self, bot: Bot): + self.bot = bot + self.cleaning = False + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def _clean_messages( + self, + amount: int, + ctx: Context, + channels: Iterable[TextChannel], + bots_only: bool = False, + user: User = None, + regex: Optional[str] = None, + until_message: Optional[Message] = None, + ) -> None: + """A helper function that does the actual message cleaning.""" + def predicate_bots_only(message: Message) -> bool: + """Return True if the message was sent by a bot.""" + return message.author.bot + + def predicate_specific_user(message: Message) -> bool: + """Return True if the message was sent by the user provided in the _clean_messages call.""" + return message.author == user + + def predicate_regex(message: Message) -> bool: + """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" + content = [message.content] + + # Add the content for all embed attributes + for embed in message.embeds: + content.append(embed.title) + content.append(embed.description) + content.append(embed.footer.text) + content.append(embed.author.name) + for field in embed.fields: + content.append(field.name) + content.append(field.value) + + # Get rid of empty attributes and turn it into a string + content = [attr for attr in content if attr] + content = "\n".join(content) + + # Now let's see if there's a regex match + if not content: + return False + else: + return bool(re.search(regex.lower(), content.lower())) + + # Is this an acceptable amount of messages to clean? + if amount > CleanMessages.message_limit: + embed = Embed( + color=Colour(Colours.soft_red), + title=random.choice(NEGATIVE_REPLIES), + description=f"You cannot clean more than {CleanMessages.message_limit} messages." + ) + await ctx.send(embed=embed) + return + + # Are we already performing a clean? + if self.cleaning: + embed = Embed( + color=Colour(Colours.soft_red), + title=random.choice(NEGATIVE_REPLIES), + description="Please wait for the currently ongoing clean operation to complete." + ) + await ctx.send(embed=embed) + return + + # Set up the correct predicate + if bots_only: + predicate = predicate_bots_only # Delete messages from bots + elif user: + predicate = predicate_specific_user # Delete messages from specific user + elif regex: + predicate = predicate_regex # Delete messages that match regex + else: + predicate = None # Delete all messages + + # Default to using the invoking context's channel + if not channels: + channels = [ctx.channel] + + # Delete the invocation first + self.mod_log.ignore(Event.message_delete, ctx.message.id) + await ctx.message.delete() + + messages = [] + message_ids = [] + self.cleaning = True + + # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. + for channel in channels: + async for message in channel.history(limit=amount): + + # If at any point the cancel command is invoked, we should stop. + if not self.cleaning: + return + + # If we are looking for specific message. + if until_message: + + # we could use ID's here however in case if the message we are looking for gets deleted, + # we won't have a way to figure that out thus checking for datetime should be more reliable + if message.created_at < until_message.created_at: + # means we have found the message until which we were supposed to be deleting. + break + + # Since we will be using `delete_messages` method of a TextChannel and we need message objects to + # use it as well as to send logs we will start appending messages here instead adding them from + # purge. + messages.append(message) + + # If the message passes predicate, let's save it. + if predicate is None or predicate(message): + message_ids.append(message.id) + + self.cleaning = False + + # Now let's delete the actual messages with purge. + self.mod_log.ignore(Event.message_delete, *message_ids) + for channel in channels: + if until_message: + for i in range(0, len(messages), 100): + # while purge automatically handles the amount of messages + # delete_messages only allows for up to 100 messages at once + # thus we need to paginate the amount to always be <= 100 + await channel.delete_messages(messages[i:i + 100]) + else: + messages += await channel.purge(limit=amount, check=predicate) + + # Reverse the list to restore chronological order + if messages: + messages = reversed(messages) + log_url = await self.mod_log.upload_log(messages, ctx.author.id) + else: + # Can't build an embed, nothing to clean! + embed = Embed( + color=Colour(Colours.soft_red), + description="No matching messages could be found." + ) + await ctx.send(embed=embed, delete_after=10) + return + + # Build the embed and send it + target_channels = ", ".join(channel.mention for channel in channels) + + message = ( + f"**{len(message_ids)}** messages deleted in {target_channels} by **{ctx.author.name}**\n\n" + f"A log of the deleted messages can be found [here]({log_url})." + ) + + await self.mod_log.send_log_message( + icon_url=Icons.message_bulk_delete, + colour=Colour(Colours.soft_red), + title="Bulk message delete", + text=message, + channel_id=Channels.mod_log, + ) + + @group(invoke_without_command=True, name="clean", aliases=["purge"]) + @with_role(*MODERATION_ROLES) + async def clean_group(self, ctx: Context) -> None: + """Commands for cleaning messages in channels.""" + await ctx.send_help(ctx.command) + + @clean_group.command(name="user", aliases=["users"]) + @with_role(*MODERATION_ROLES) + async def clean_user( + self, + ctx: Context, + user: User, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, user=user, channels=channels) + + @clean_group.command(name="all", aliases=["everything"]) + @with_role(*MODERATION_ROLES) + async def clean_all( + self, + ctx: Context, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, channels=channels) + + @clean_group.command(name="bots", aliases=["bot"]) + @with_role(*MODERATION_ROLES) + async def clean_bots( + self, + ctx: Context, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, bots_only=True, channels=channels) + + @clean_group.command(name="regex", aliases=["word", "expression"]) + @with_role(*MODERATION_ROLES) + async def clean_regex( + self, + ctx: Context, + regex: str, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, regex=regex, channels=channels) + + @clean_group.command(name="message", aliases=["messages"]) + @with_role(*MODERATION_ROLES) + async def clean_message(self, ctx: Context, message: Message) -> None: + """Delete all messages until certain message, stop cleaning after hitting the `message`.""" + await self._clean_messages( + CleanMessages.message_limit, + ctx, + channels=[message.channel], + until_message=message + ) + + @clean_group.command(name="stop", aliases=["cancel", "abort"]) + @with_role(*MODERATION_ROLES) + async def clean_cancel(self, ctx: Context) -> None: + """If there is an ongoing cleaning process, attempt to immediately cancel it.""" + self.cleaning = False + + embed = Embed( + color=Colour.blurple(), + description="Clean interrupted." + ) + await ctx.send(embed=embed, delete_after=10) + + +def setup(bot: Bot) -> None: + """Load the Clean cog.""" + bot.add_cog(Clean(bot)) diff --git a/bot/cogs/utils/eval.py b/bot/cogs/utils/eval.py new file mode 100644 index 000000000..eb8bfb1cf --- /dev/null +++ b/bot/cogs/utils/eval.py @@ -0,0 +1,202 @@ +import contextlib +import inspect +import logging +import pprint +import re +import textwrap +import traceback +from io import StringIO +from typing import Any, Optional, Tuple + +import discord +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Roles +from bot.decorators import with_role +from bot.interpreter import Interpreter + +log = logging.getLogger(__name__) + + +class CodeEval(Cog): + """Owner and admin feature that evaluates code and returns the result to the channel.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.env = {} + self.ln = 0 + self.stdout = StringIO() + + self.interpreter = Interpreter(bot) + + def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: + """Format the eval output into a string & attempt to format it into an Embed.""" + self._ = out + + res = "" + + # Erase temp input we made + if inp.startswith("_ = "): + inp = inp[4:] + + # Get all non-empty lines + lines = [line for line in inp.split("\n") if line.strip()] + if len(lines) != 1: + lines += [""] + + # Create the input dialog + for i, line in enumerate(lines): + if i == 0: + # Start dialog + start = f"In [{self.ln}]: " + + else: + # Indent the 3 dots correctly; + # Normally, it's something like + # In [X]: + # ...: + # + # But if it's + # In [XX]: + # ...: + # + # You can see it doesn't look right. + # This code simply indents the dots + # far enough to align them. + # we first `str()` the line number + # then we get the length + # and use `str.rjust()` + # to indent it. + start = "...: ".rjust(len(str(self.ln)) + 7) + + if i == len(lines) - 2: + if line.startswith("return"): + line = line[6:].strip() + + # Combine everything + res += (start + line + "\n") + + self.stdout.seek(0) + text = self.stdout.read() + self.stdout.close() + self.stdout = StringIO() + + if text: + res += (text + "\n") + + if out is None: + # No output, return the input statement + return (res, None) + + res += f"Out[{self.ln}]: " + + if isinstance(out, discord.Embed): + # We made an embed? Send that as embed + res += "" + res = (res, out) + + else: + if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): + # Leave out the traceback message + out = "\n" + "\n".join(out.split("\n")[1:]) + + if isinstance(out, str): + pretty = out + else: + pretty = pprint.pformat(out, compact=True, width=60) + + if pretty != str(out): + # We're using the pretty version, start on the next line + res += "\n" + + if pretty.count("\n") > 20: + # Text too long, shorten + li = pretty.split("\n") + + pretty = ("\n".join(li[:3]) # First 3 lines + + "\n ...\n" # Ellipsis to indicate removed lines + + "\n".join(li[-3:])) # last 3 lines + + # Add the output + res += pretty + res = (res, None) + + return res # Return (text, embed) + + async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: + """Eval the input code string & send an embed to the invoking context.""" + self.ln += 1 + + if code.startswith("exit"): + self.ln = 0 + self.env = {} + return await ctx.send("```Reset history!```") + + env = { + "message": ctx.message, + "author": ctx.message.author, + "channel": ctx.channel, + "guild": ctx.guild, + "ctx": ctx, + "self": self, + "bot": self.bot, + "inspect": inspect, + "discord": discord, + "contextlib": contextlib + } + + self.env.update(env) + + # Ignore this code, it works + code_ = """ +async def func(): # (None,) -> Any + try: + with contextlib.redirect_stdout(self.stdout): +{0} + if '_' in locals(): + if inspect.isawaitable(_): + _ = await _ + return _ + finally: + self.env.update(locals()) +""".format(textwrap.indent(code, ' ')) + + try: + exec(code_, self.env) # noqa: B102,S102 + func = self.env['func'] + res = await func() + + except Exception: + res = traceback.format_exc() + + out, embed = self._format(code, res) + await ctx.send(f"```py\n{out}```", embed=embed) + + @group(name='internal', aliases=('int',)) + @with_role(Roles.owners, Roles.admins) + async def internal_group(self, ctx: Context) -> None: + """Internal commands. Top secret!""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @internal_group.command(name='eval', aliases=('e',)) + @with_role(Roles.admins, Roles.owners) + async def eval(self, ctx: Context, *, code: str) -> None: + """Run eval in a REPL-like format.""" + code = code.strip("`") + if re.match('py(thon)?\n', code): + code = "\n".join(code.split("\n")[1:]) + + if not re.search( # Check if it's an expression + r"^(return|import|for|while|def|class|" + r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( + code.split("\n")) == 1: + code = "_ = " + code + + await self._eval(ctx, code) + + +def setup(bot: Bot) -> None: + """Load the CodeEval cog.""" + bot.add_cog(CodeEval(bot)) diff --git a/bot/cogs/utils/extensions.py b/bot/cogs/utils/extensions.py new file mode 100644 index 000000000..365f198ff --- /dev/null +++ b/bot/cogs/utils/extensions.py @@ -0,0 +1,236 @@ +import functools +import logging +import typing as t +from enum import Enum +from pkgutil import iter_modules + +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Context, group + +from bot.bot import Bot +from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + +UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"} +EXTENSIONS = frozenset( + ext.name + for ext in iter_modules(("bot/cogs",), "bot.cogs.") + if ext.name[-1] != "_" +) + + +class Action(Enum): + """Represents an action to perform on an extension.""" + + # Need to be partial otherwise they are considered to be function definitions. + LOAD = functools.partial(Bot.load_extension) + UNLOAD = functools.partial(Bot.unload_extension) + RELOAD = functools.partial(Bot.reload_extension) + + +class Extension(commands.Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if "." not in argument: + argument = f"bot.cogs.{argument}" + + if argument in EXTENSIONS: + return argument + else: + raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") + + +class Extensions(commands.Cog): + """Extension management commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) + async def extensions_group(self, ctx: Context) -> None: + """Load, unload, reload, and list loaded extensions.""" + await ctx.send_help(ctx.command) + + @extensions_group.command(name="load", aliases=("l",)) + async def load_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Load extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "*" in extensions or "**" in extensions: + extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + + msg = self.batch_manage(Action.LOAD, *extensions) + await ctx.send(msg) + + @extensions_group.command(name="unload", aliases=("ul",)) + async def unload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Unload currently loaded extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + + if blacklisted: + msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" + else: + if "*" in extensions or "**" in extensions: + extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + + msg = self.batch_manage(Action.UNLOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="reload", aliases=("r",)) + async def reload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Reload extensions given their fully qualified or unqualified names. + + If an extension fails to be reloaded, it will be rolled-back to the prior working state. + + If '\*' is given as the name, all currently loaded extensions will be reloaded. + If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "**" in extensions: + extensions = EXTENSIONS + elif "*" in extensions: + extensions = set(self.bot.extensions.keys()) | set(extensions) + extensions.remove("*") + + msg = self.batch_manage(Action.RELOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="list", aliases=("all",)) + async def list_command(self, ctx: Context) -> None: + """ + Get a list of all extensions, including their loaded status. + + Grey indicates that the extension is unloaded. + Green indicates that the extension is currently loaded. + """ + embed = Embed() + lines = [] + + embed.colour = Colour.blurple() + embed.set_author( + name="Extensions List", + url=URLs.github_bot_repo, + icon_url=URLs.bot_avatar + ) + + for ext in sorted(list(EXTENSIONS)): + if ext in self.bot.extensions: + status = Emojis.status_online + else: + status = Emojis.status_offline + + ext = ext.rsplit(".", 1)[1] + lines.append(f"{status} {ext}") + + log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") + await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) + + def batch_manage(self, action: Action, *extensions: str) -> str: + """ + Apply an action to multiple extensions and return a message with the results. + + If only one extension is given, it is deferred to `manage()`. + """ + if len(extensions) == 1: + msg, _ = self.manage(action, extensions[0]) + return msg + + verb = action.name.lower() + failures = {} + + for extension in extensions: + _, error = self.manage(action, extension) + if error: + failures[extension] = error + + emoji = ":x:" if failures else ":ok_hand:" + msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." + + if failures: + failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) + msg += f"\nFailures:```{failures}```" + + log.debug(f"Batch {verb}ed extensions.") + + return msg + + def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: + """Apply an action to an extension and return the status message and any error message.""" + verb = action.name.lower() + error_msg = None + + try: + action.value(self.bot, ext) + except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): + if action is Action.RELOAD: + # When reloading, just load the extension if it was not loaded. + return self.manage(Action.LOAD, ext) + + msg = f":x: Extension `{ext}` is already {verb}ed." + log.debug(msg[4:]) + except Exception as e: + if hasattr(e, "original"): + e = e.original + + log.exception(f"Extension '{ext}' failed to {verb}.") + + error_msg = f"{e.__class__.__name__}: {e}" + msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" + else: + msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." + log.debug(msg[10:]) + + return msg, error_msg + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators and core developers to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Handle BadArgument errors locally to prevent the help command from showing.""" + if isinstance(error, commands.BadArgument): + await ctx.send(str(error)) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Extensions cog.""" + bot.add_cog(Extensions(bot)) diff --git a/bot/cogs/utils/jams.py b/bot/cogs/utils/jams.py new file mode 100644 index 000000000..b3102db2f --- /dev/null +++ b/bot/cogs/utils/jams.py @@ -0,0 +1,150 @@ +import logging +import typing as t + +from discord import CategoryChannel, Guild, Member, PermissionOverwrite, Role +from discord.ext import commands +from more_itertools import unique_everseen + +from bot.bot import Bot +from bot.constants import Roles +from bot.decorators import with_role + +log = logging.getLogger(__name__) + +MAX_CHANNELS = 50 +CATEGORY_NAME = "Code Jam" + + +class CodeJams(commands.Cog): + """Manages the code-jam related parts of our server.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @commands.command() + @with_role(Roles.admins) + async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: + """ + Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. + + The first user passed will always be the team leader. + """ + # Ignore duplicate members + members = list(unique_everseen(members)) + + # We had a little issue during Code Jam 4 here, the greedy converter did it's job + # and ignored anything which wasn't a valid argument which left us with teams of + # two members or at some times even 1 member. This fixes that by checking that there + # are always 3 members in the members list. + if len(members) < 3: + await ctx.send( + ":no_entry_sign: One of your arguments was invalid\n" + f"There must be a minimum of 3 valid members in your team. Found: {len(members)}" + " members" + ) + return + + team_channel = await self.create_channels(ctx.guild, team_name, members) + await self.add_roles(ctx.guild, members) + + await ctx.send( + f":ok_hand: Team created: {team_channel}\n" + f"**Team Leader:** {members[0].mention}\n" + f"**Team Members:** {' '.join(member.mention for member in members[1:])}" + ) + + async def get_category(self, guild: Guild) -> CategoryChannel: + """ + Return a code jam category. + + If all categories are full or none exist, create a new category. + """ + for category in guild.categories: + # Need 2 available spaces: one for the text channel and one for voice. + if category.name == CATEGORY_NAME and MAX_CHANNELS - len(category.channels) >= 2: + return category + + return await self.create_category(guild) + + @staticmethod + async def create_category(guild: Guild) -> CategoryChannel: + """Create a new code jam category and return it.""" + log.info("Creating a new code jam category.") + + category_overwrites = { + guild.default_role: PermissionOverwrite(read_messages=False), + guild.me: PermissionOverwrite(read_messages=True) + } + + return await guild.create_category_channel( + CATEGORY_NAME, + overwrites=category_overwrites, + reason="It's code jam time!" + ) + + @staticmethod + def get_overwrites(members: t.List[Member], guild: Guild) -> t.Dict[t.Union[Member, Role], PermissionOverwrite]: + """Get code jam team channels permission overwrites.""" + # First member is always the team leader + team_channel_overwrites = { + members[0]: PermissionOverwrite( + manage_messages=True, + read_messages=True, + manage_webhooks=True, + connect=True + ), + guild.default_role: PermissionOverwrite(read_messages=False, connect=False), + guild.get_role(Roles.verified): PermissionOverwrite( + read_messages=False, + connect=False + ) + } + + # Rest of members should just have read_messages + for member in members[1:]: + team_channel_overwrites[member] = PermissionOverwrite( + read_messages=True, + connect=True + ) + + return team_channel_overwrites + + async def create_channels(self, guild: Guild, team_name: str, members: t.List[Member]) -> str: + """Create team text and voice channels. Return the mention for the text channel.""" + # Get permission overwrites and category + team_channel_overwrites = self.get_overwrites(members, guild) + code_jam_category = await self.get_category(guild) + + # Create a text channel for the team + team_channel = await guild.create_text_channel( + team_name, + overwrites=team_channel_overwrites, + category=code_jam_category + ) + + # Create a voice channel for the team + team_voice_name = " ".join(team_name.split("-")).title() + + await guild.create_voice_channel( + team_voice_name, + overwrites=team_channel_overwrites, + category=code_jam_category + ) + + return team_channel.mention + + @staticmethod + async def add_roles(guild: Guild, members: t.List[Member]) -> None: + """Assign team leader and jammer roles.""" + # Assign team leader role + await members[0].add_roles(guild.get_role(Roles.team_leaders)) + + # Assign rest of roles + jammer_role = guild.get_role(Roles.jammers) + for member in members: + await member.add_roles(jammer_role) + + +def setup(bot: Bot) -> None: + """Load the CodeJams cog.""" + bot.add_cog(CodeJams(bot)) diff --git a/bot/cogs/utils/reminders.py b/bot/cogs/utils/reminders.py new file mode 100644 index 000000000..670493bcf --- /dev/null +++ b/bot/cogs/utils/reminders.py @@ -0,0 +1,427 @@ +import asyncio +import logging +import random +import textwrap +import typing as t +from datetime import datetime, timedelta +from operator import itemgetter + +import discord +from dateutil.parser import isoparse +from dateutil.relativedelta import relativedelta +from discord.ext.commands import Cog, Context, Greedy, group + +from bot.bot import Bot +from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES +from bot.converters import Duration +from bot.pagination import LinePaginator +from bot.utils.checks import without_role_check +from bot.utils.messages import send_denial +from bot.utils.scheduling import Scheduler +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +WHITELISTED_CHANNELS = Guild.reminder_whitelist +MAXIMUM_REMINDERS = 5 + +Mentionable = t.Union[discord.Member, discord.Role] + + +class Reminders(Cog): + """Provide in-channel reminder functionality.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_reminders()) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + async def reschedule_reminders(self) -> None: + """Get all current reminders from the API and reschedule them.""" + await self.bot.wait_until_guild_available() + response = await self.bot.api_client.get( + 'bot/reminders', + params={'active': 'true'} + ) + + now = datetime.utcnow() + + for reminder in response: + is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) + if not is_valid: + continue + + remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) + + # If the reminder is already overdue ... + if remind_at < now: + late = relativedelta(now, remind_at) + await self.send_reminder(reminder, late) + else: + self.schedule_reminder(reminder) + + def ensure_valid_reminder( + self, + reminder: dict, + cancel_task: bool = True + ) -> t.Tuple[bool, discord.User, discord.TextChannel]: + """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" + user = self.bot.get_user(reminder['author']) + channel = self.bot.get_channel(reminder['channel_id']) + is_valid = True + if not user or not channel: + is_valid = False + log.info( + f"Reminder {reminder['id']} invalid: " + f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." + ) + asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) + + return is_valid, user, channel + + @staticmethod + async def _send_confirmation( + ctx: Context, + on_success: str, + reminder_id: str, + delivery_dt: t.Optional[datetime], + ) -> None: + """Send an embed confirming the reminder change was made successfully.""" + embed = discord.Embed() + embed.colour = discord.Colour.green() + embed.title = random.choice(POSITIVE_REPLIES) + embed.description = on_success + + footer_str = f"ID: {reminder_id}" + if delivery_dt: + # Reminder deletion will have a `None` `delivery_dt` + footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" + + embed.set_footer(text=footer_str) + + await ctx.send(embed=embed) + + @staticmethod + async def _check_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> t.Tuple[bool, str]: + """ + Returns whether or not the list of mentions is allowed. + + Conditions: + - Role reminders are Mods+ + - Reminders for other users are Helpers+ + + If mentions aren't allowed, also return the type of mention(s) disallowed. + """ + if without_role_check(ctx, *STAFF_ROLES): + return False, "members/roles" + elif without_role_check(ctx, *MODERATION_ROLES): + return all(isinstance(mention, discord.Member) for mention in mentions), "roles" + else: + return True, "" + + @staticmethod + async def validate_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> bool: + """ + Filter mentions to see if the user can mention, and sends a denial if not allowed. + + Returns whether or not the validation is successful. + """ + mentions_allowed, disallowed_mentions = await Reminders._check_mentions(ctx, mentions) + + if not mentions or mentions_allowed: + return True + else: + await send_denial(ctx, f"You can't mention other {disallowed_mentions} in your reminder!") + return False + + def get_mentionables(self, mention_ids: t.List[int]) -> t.Iterator[Mentionable]: + """Converts Role and Member ids to their corresponding objects if possible.""" + guild = self.bot.get_guild(Guild.id) + for mention_id in mention_ids: + if (mentionable := (guild.get_member(mention_id) or guild.get_role(mention_id))): + yield mentionable + + def schedule_reminder(self, reminder: dict) -> None: + """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" + reminder_id = reminder["id"] + reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) + + async def _remind() -> None: + await self.send_reminder(reminder) + + log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") + await self._delete_reminder(reminder_id) + + self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) + + async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: + """Delete a reminder from the database, given its ID, and cancel the running task.""" + await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) + + if cancel_task: + # Now we can remove it from the schedule list + self.scheduler.cancel(reminder_id) + + async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: + """ + Edits a reminder in the database given the ID and payload. + + Returns the edited reminder. + """ + # Send the request to update the reminder in the database + reminder = await self.bot.api_client.patch( + 'bot/reminders/' + str(reminder_id), + json=payload + ) + return reminder + + async def _reschedule_reminder(self, reminder: dict) -> None: + """Reschedule a reminder object.""" + log.trace(f"Cancelling old task #{reminder['id']}") + self.scheduler.cancel(reminder["id"]) + + log.trace(f"Scheduling new task #{reminder['id']}") + self.schedule_reminder(reminder) + + async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: + """Send the reminder.""" + is_valid, user, channel = self.ensure_valid_reminder(reminder) + if not is_valid: + return + + embed = discord.Embed() + embed.colour = discord.Colour.blurple() + embed.set_author( + icon_url=Icons.remind_blurple, + name="It has arrived!" + ) + + embed.description = f"Here's your reminder: `{reminder['content']}`." + + if reminder.get("jump_url"): # keep backward compatibility + embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" + + if late: + embed.colour = discord.Colour.red() + embed.set_author( + icon_url=Icons.remind_red, + name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" + ) + + additional_mentions = ' '.join( + mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) + ) + + await channel.send( + content=f"{user.mention} {additional_mentions}", + embed=embed + ) + await self._delete_reminder(reminder["id"]) + + @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) + async def remind_group( + self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + ) -> None: + """Commands for managing your reminders.""" + await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content) + + @remind_group.command(name="new", aliases=("add", "create")) + async def new_reminder( + self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + ) -> None: + """ + Set yourself a simple reminder. + + Expiration is parsed per: http://strftime.org/ + """ + # If the user is not staff, we need to verify whether or not to make a reminder at all. + if without_role_check(ctx, *STAFF_ROLES): + + # If they don't have permission to set a reminder in this channel + if ctx.channel.id not in WHITELISTED_CHANNELS: + await send_denial(ctx, "Sorry, you can't do that here!") + return + + # Get their current active reminders + active_reminders = await self.bot.api_client.get( + 'bot/reminders', + params={ + 'author__id': str(ctx.author.id) + } + ) + + # Let's limit this, so we don't get 10 000 + # reminders from kip or something like that :P + if len(active_reminders) > MAXIMUM_REMINDERS: + await send_denial(ctx, "You have too many active reminders!") + return + + # Remove duplicate mentions + mentions = set(mentions) + mentions.discard(ctx.author) + + # Filter mentions to see if the user can mention members/roles + if not await self.validate_mentions(ctx, mentions): + return + + mention_ids = [mention.id for mention in mentions] + + # Now we can attempt to actually set the reminder. + reminder = await self.bot.api_client.post( + 'bot/reminders', + json={ + 'author': ctx.author.id, + 'channel_id': ctx.message.channel.id, + 'jump_url': ctx.message.jump_url, + 'content': content, + 'expiration': expiration.isoformat(), + 'mentions': mention_ids, + } + ) + + now = datetime.utcnow() - timedelta(seconds=1) + humanized_delta = humanize_delta(relativedelta(expiration, now)) + mention_string = ( + f"Your reminder will arrive in {humanized_delta} " + f"and will mention {len(mentions)} other(s)!" + ) + + # Confirm to the user that it worked. + await self._send_confirmation( + ctx, + on_success=mention_string, + reminder_id=reminder["id"], + delivery_dt=expiration, + ) + + self.schedule_reminder(reminder) + + @remind_group.command(name="list") + async def list_reminders(self, ctx: Context) -> None: + """View a paginated embed of all reminders for your user.""" + # Get all the user's reminders from the database. + data = await self.bot.api_client.get( + 'bot/reminders', + params={'author__id': str(ctx.author.id)} + ) + + now = datetime.utcnow() + + # Make a list of tuples so it can be sorted by time. + reminders = sorted( + ( + (rem['content'], rem['expiration'], rem['id'], rem['mentions']) + for rem in data + ), + key=itemgetter(1) + ) + + lines = [] + + for content, remind_at, id_, mentions in reminders: + # Parse and humanize the time, make it pretty :D + remind_datetime = isoparse(remind_at).replace(tzinfo=None) + time = humanize_delta(relativedelta(remind_datetime, now)) + + mentions = ", ".join( + # Both Role and User objects have the `name` attribute + mention.name for mention in self.get_mentionables(mentions) + ) + mention_string = f"\n**Mentions:** {mentions}" if mentions else "" + + text = textwrap.dedent(f""" + **Reminder #{id_}:** *expires in {time}* (ID: {id_}){mention_string} + {content} + """).strip() + + lines.append(text) + + embed = discord.Embed() + embed.colour = discord.Colour.blurple() + embed.title = f"Reminders for {ctx.author}" + + # Remind the user that they have no reminders :^) + if not lines: + embed.description = "No active reminders could be found." + await ctx.send(embed=embed) + return + + # Construct the embed and paginate it. + embed.colour = discord.Colour.blurple() + + await LinePaginator.paginate( + lines, + ctx, embed, + max_lines=3, + empty=True + ) + + @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) + async def edit_reminder_group(self, ctx: Context) -> None: + """Commands for modifying your current reminders.""" + await ctx.send_help(ctx.command) + + @edit_reminder_group.command(name="duration", aliases=("time",)) + async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: + """ + Edit one of your reminder's expiration. + + Expiration is parsed per: http://strftime.org/ + """ + await self.edit_reminder(ctx, id_, {'expiration': expiration.isoformat()}) + + @edit_reminder_group.command(name="content", aliases=("reason",)) + async def edit_reminder_content(self, ctx: Context, id_: int, *, content: str) -> None: + """Edit one of your reminder's content.""" + await self.edit_reminder(ctx, id_, {"content": content}) + + @edit_reminder_group.command(name="mentions", aliases=("pings",)) + async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: + """Edit one of your reminder's mentions.""" + # Remove duplicate mentions + mentions = set(mentions) + mentions.discard(ctx.author) + + # Filter mentions to see if the user can mention members/roles + if not await self.validate_mentions(ctx, mentions): + return + + mention_ids = [mention.id for mention in mentions] + await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) + + async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: + """Edits a reminder with the given payload, then sends a confirmation message.""" + reminder = await self._edit_reminder(id_, payload) + + # Parse the reminder expiration back into a datetime + expiration = isoparse(reminder["expiration"]).replace(tzinfo=None) + + # Send a confirmation message to the channel + await self._send_confirmation( + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=id_, + delivery_dt=expiration, + ) + await self._reschedule_reminder(reminder) + + @remind_group.command("delete", aliases=("remove", "cancel")) + async def delete_reminder(self, ctx: Context, id_: int) -> None: + """Delete one of your active reminders.""" + await self._delete_reminder(id_) + await self._send_confirmation( + ctx, + on_success="That reminder has been deleted successfully!", + reminder_id=id_, + delivery_dt=None, + ) + + +def setup(bot: Bot) -> None: + """Load the Reminders cog.""" + bot.add_cog(Reminders(bot)) diff --git a/bot/cogs/utils/snekbox.py b/bot/cogs/utils/snekbox.py new file mode 100644 index 000000000..52c8b6f88 --- /dev/null +++ b/bot/cogs/utils/snekbox.py @@ -0,0 +1,349 @@ +import asyncio +import contextlib +import datetime +import logging +import re +import textwrap +from functools import partial +from signal import Signals +from typing import Optional, Tuple + +from discord import HTTPException, Message, NotFound, Reaction, User +from discord.ext.commands import Cog, Context, command, guild_only + +from bot.bot import Bot +from bot.constants import Categories, Channels, Roles, URLs +from bot.decorators import in_whitelist +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") +FORMATTED_CODE_REGEX = re.compile( + r"^\s*" # any leading whitespace from the beginning of the string + r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block + r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) + r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all code inside the markup + r"\s*" # any more whitespace before the end of the code markup + r"(?P=delim)" # match the exact same delimiter from the start again + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive +) +RAW_CODE_REGEX = re.compile( + r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all the rest as code + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL # "." also matches newlines +) + +MAX_PASTE_LEN = 1000 + +# `!eval` command whitelists +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) +EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) +EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) + +SIGKILL = 9 + +REEVAL_EMOJI = '\U0001f501' # :repeat: +REEVAL_TIMEOUT = 30 + + +class Snekbox(Cog): + """Safe evaluation of Python code using Snekbox.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.jobs = {} + + async def post_eval(self, code: str) -> dict: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + url = URLs.snekbox_eval_api + data = {"input": code} + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return await resp.json() + + async def upload_output(self, output: str) -> Optional[str]: + """Upload the eval output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") + + if len(output) > MAX_PASTE_LEN: + log.info("Full output is too long to upload") + return "too long to upload" + + url = URLs.paste_service.format(key="documents") + try: + async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: + data = await resp.json() + + if "key" in data: + return URLs.paste_service.format(key=data["key"]) + except Exception: + # 400 (Bad Request) means there are too many characters + log.exception("Failed to upload full output to paste service!") + + @staticmethod + def prepare_input(code: str) -> str: + """Extract code from the Markdown, format it, and insert it into the code template.""" + match = FORMATTED_CODE_REGEX.fullmatch(code) + if match: + code, block, lang, delim = match.group("code", "block", "lang", "delim") + code = textwrap.dedent(code) + if block: + info = (f"'{lang}' highlighted" if lang else "plain") + " code block" + else: + info = f"{delim}-enclosed inline code" + log.trace(f"Extracted {info} for evaluation:\n{code}") + else: + code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) + log.trace( + f"Eval message contains unformatted or badly formatted code, " + f"stripping whitespace only:\n{code}" + ) + + return code + + @staticmethod + def get_results_message(results: dict) -> Tuple[str, str]: + """Return a user-friendly message and error corresponding to the process's return code.""" + stdout, returncode = results["stdout"], results["returncode"] + msg = f"Your eval job has completed with return code {returncode}" + error = "" + + if returncode is None: + msg = "Your eval job has failed" + error = stdout.strip() + elif returncode == 128 + SIGKILL: + msg = "Your eval job timed out or ran out of memory" + elif returncode == 255: + msg = "Your eval job has failed" + error = "A fatal NsJail error occurred" + else: + # Try to append signal's name if one exists + try: + name = Signals(returncode - 128).name + msg = f"{msg} ({name})" + except ValueError: + pass + + return msg, error + + @staticmethod + def get_status_emoji(results: dict) -> str: + """Return an emoji corresponding to the status code or lack of output in result.""" + if not results["stdout"].strip(): # No output + return ":warning:" + elif results["returncode"] == 0: # No error + return ":white_check_mark:" + else: # Exception + return ":x:" + + async def format_output(self, output: str) -> Tuple[str, Optional[str]]: + """ + Format the output and return a tuple of the formatted output and a URL to the full output. + + Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters + and upload the full output to a paste service. + """ + log.trace("Formatting output...") + + output = output.rstrip("\n") + original_output = output # To be uploaded to a pasting service if needed + paste_link = None + + if "<@" in output: + output = output.replace("<@", "<@\u200B") # Zero-width space + + if " 0: + output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] + output = output[:11] # Limiting to only 11 lines + output = "\n".join(output) + + if lines > 10: + truncated = True + if len(output) >= 1000: + output = f"{output[:1000]}\n... (truncated - too long, too many lines)" + else: + output = f"{output}\n... (truncated - too many lines)" + elif len(output) >= 1000: + truncated = True + output = f"{output[:1000]}\n... (truncated - too long)" + + if truncated: + paste_link = await self.upload_output(original_output) + + output = output or "[No output]" + + return output, paste_link + + async def send_eval(self, ctx: Context, code: str) -> Message: + """ + Evaluate code, format it, and send the output to the corresponding channel. + + Return the bot response. + """ + async with ctx.typing(): + results = await self.post_eval(code) + msg, error = self.get_results_message(results) + + if error: + output, paste_link = error, None + else: + output, paste_link = await self.format_output(results["stdout"]) + + icon = self.get_status_emoji(results) + msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" + if paste_link: + msg = f"{msg}\nFull output: {paste_link}" + + # Collect stats of eval fails + successes + if icon == ":x:": + self.bot.stats.incr("snekbox.python.fail") + else: + self.bot.stats.incr("snekbox.python.success") + + filter_cog = self.bot.get_cog("Filtering") + filter_triggered = False + if filter_cog: + filter_triggered = await filter_cog.filter_eval(msg, ctx.message) + if filter_triggered: + response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + else: + response = await ctx.send(msg) + self.bot.loop.create_task( + wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) + ) + + log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") + return response + + async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + """ + Check if the eval session should continue. + + Return the new code to evaluate or None if the eval session should be terminated. + """ + _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + + with contextlib.suppress(NotFound): + try: + _, new_message = await self.bot.wait_for( + 'message_edit', + check=_predicate_eval_message_edit, + timeout=REEVAL_TIMEOUT + ) + await ctx.message.add_reaction(REEVAL_EMOJI) + await self.bot.wait_for( + 'reaction_add', + check=_predicate_emoji_reaction, + timeout=10 + ) + + code = await self.get_code(new_message) + await ctx.message.clear_reactions() + with contextlib.suppress(HTTPException): + await response.delete() + + except asyncio.TimeoutError: + await ctx.message.clear_reactions() + return None + + return code + + async def get_code(self, message: Message) -> Optional[str]: + """ + Return the code from `message` to be evaluated. + + If the message is an invocation of the eval command, return the first argument or None if it + doesn't exist. Otherwise, return the full content of the message. + """ + log.trace(f"Getting context for message {message.id}.") + new_ctx = await self.bot.get_context(message) + + if new_ctx.command is self.eval_command: + log.trace(f"Message {message.id} invokes eval command.") + split = message.content.split(maxsplit=1) + code = split[1] if len(split) > 1 else None + else: + log.trace(f"Message {message.id} does not invoke eval command.") + code = message.content + + return code + + @command(name="eval", aliases=("e",)) + @guild_only() + @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) + async def eval_command(self, ctx: Context, *, code: str = None) -> None: + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + if ctx.author.id in self.jobs: + await ctx.send( + f"{ctx.author.mention} You've already got a job running - " + "please wait for it to finish!" + ) + return + + if not code: # None or empty string + await ctx.send_help(ctx.command) + return + + if Roles.helpers in (role.id for role in ctx.author.roles): + self.bot.stats.incr("snekbox_usages.roles.helpers") + else: + self.bot.stats.incr("snekbox_usages.roles.developers") + + if ctx.channel.category_id == Categories.help_in_use: + self.bot.stats.incr("snekbox_usages.channels.help") + elif ctx.channel.id == Channels.bot_commands: + self.bot.stats.incr("snekbox_usages.channels.bot_commands") + else: + self.bot.stats.incr("snekbox_usages.channels.topical") + + log.info(f"Received code from {ctx.author} for evaluation:\n{code}") + + while True: + self.jobs[ctx.author.id] = datetime.datetime.now() + code = self.prepare_input(code) + try: + response = await self.send_eval(ctx, code) + finally: + del self.jobs[ctx.author.id] + + code = await self.continue_eval(ctx, response) + if not code: + break + log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") + + +def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: + """Return True if the edited message is the context message and the content was indeed modified.""" + return new_msg.id == ctx.message.id and old_msg.content != new_msg.content + + +def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI + + +def setup(bot: Bot) -> None: + """Load the Snekbox cog.""" + bot.add_cog(Snekbox(bot)) diff --git a/bot/cogs/utils/utils.py b/bot/cogs/utils/utils.py new file mode 100644 index 000000000..d96abbd5a --- /dev/null +++ b/bot/cogs/utils/utils.py @@ -0,0 +1,265 @@ +import difflib +import logging +import re +import unicodedata +from email.parser import HeaderParser +from io import StringIO +from typing import Tuple, Union + +from discord import Colour, Embed, utils +from discord.ext.commands import BadArgument, Cog, Context, clean_content, command + +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES +from bot.decorators import in_whitelist, with_role +from bot.pagination import LinePaginator +from bot.utils import messages + +log = logging.getLogger(__name__) + +ZEN_OF_PYTHON = """\ +Beautiful is better than ugly. +Explicit is better than implicit. +Simple is better than complex. +Complex is better than complicated. +Flat is better than nested. +Sparse is better than dense. +Readability counts. +Special cases aren't special enough to break the rules. +Although practicality beats purity. +Errors should never pass silently. +Unless explicitly silenced. +In the face of ambiguity, refuse the temptation to guess. +There should be one-- and preferably only one --obvious way to do it. +Although that way may not be obvious at first unless you're Dutch. +Now is better than never. +Although never is often better than *right* now. +If the implementation is hard to explain, it's a bad idea. +If the implementation is easy to explain, it may be a good idea. +Namespaces are one honking great idea -- let's do more of those! +""" + +ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" + + +class Utils(Cog): + """A selection of utilities which don't have a clear category.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.base_pep_url = "http://www.python.org/dev/peps/pep-" + self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" + + @command(name='pep', aliases=('get_pep', 'p')) + async def pep_command(self, ctx: Context, pep_number: str) -> None: + """Fetches information about a PEP and sends it to the channel.""" + if pep_number.isdigit(): + pep_number = int(pep_number) + else: + await ctx.send_help(ctx.command) + return + + # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. + if pep_number == 0: + return await self.send_pep_zero(ctx) + + possible_extensions = ['.txt', '.rst'] + found_pep = False + for extension in possible_extensions: + # Attempt to fetch the PEP + pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" + log.trace(f"Requesting PEP {pep_number} with {pep_url}") + response = await self.bot.http_session.get(pep_url) + + if response.status == 200: + log.trace("PEP found") + found_pep = True + + pep_content = await response.text() + + # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 + pep_header = HeaderParser().parse(StringIO(pep_content)) + + # Assemble the embed + pep_embed = Embed( + title=f"**PEP {pep_number} - {pep_header['Title']}**", + description=f"[Link]({self.base_pep_url}{pep_number:04})", + ) + + pep_embed.set_thumbnail(url=ICON_URL) + + # Add the interesting information + fields_to_check = ("Status", "Python-Version", "Created", "Type") + for field in fields_to_check: + # Check for a PEP metadata field that is present but has an empty value + # embed field values can't contain an empty string + if pep_header.get(field, ""): + pep_embed.add_field(name=field, value=pep_header[field]) + + elif response.status != 404: + # any response except 200 and 404 is expected + found_pep = True # actually not, but it's easier to display this way + log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " + f"{response.status}.\n{response.text}") + + error_message = "Unexpected HTTP error during PEP search. Please let us know." + pep_embed = Embed(title="Unexpected error", description=error_message) + pep_embed.colour = Colour.red() + break + + if not found_pep: + log.trace("PEP was not found") + not_found = f"PEP {pep_number} does not exist." + pep_embed = Embed(title="PEP not found", description=not_found) + pep_embed.colour = Colour.red() + + await ctx.message.channel.send(embed=pep_embed) + + @command() + @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) + async def charinfo(self, ctx: Context, *, characters: str) -> None: + """Shows you information on up to 50 unicode characters.""" + match = re.match(r"<(a?):(\w+):(\d+)>", characters) + if match: + return await messages.send_denial( + ctx, + "**Non-Character Detected**\n" + "Only unicode characters can be processed, but a custom Discord emoji " + "was found. Please remove it and try again." + ) + + if len(characters) > 50: + return await messages.send_denial(ctx, f"Too many characters ({len(characters)}/50)") + + def get_info(char: str) -> Tuple[str, str]: + digit = f"{ord(char):x}" + if len(digit) <= 4: + u_code = f"\\u{digit:>04}" + else: + u_code = f"\\U{digit:>08}" + url = f"https://www.compart.com/en/unicode/U+{digit:>04}" + name = f"[{unicodedata.name(char, '')}]({url})" + info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}" + return info, u_code + + char_list, raw_list = zip(*(get_info(c) for c in characters)) + embed = Embed().set_author(name="Character Info") + + if len(characters) > 1: + # Maximum length possible is 502 out of 1024, so there's no need to truncate. + embed.add_field(name='Full Raw Text', value=f"`{''.join(raw_list)}`", inline=False) + + await LinePaginator.paginate(char_list, ctx, embed, max_lines=10, max_size=2000, empty=False) + + @command() + async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: + """ + Show the Zen of Python. + + Without any arguments, the full Zen will be produced. + If an integer is provided, the line with that index will be produced. + If a string is provided, the line which matches best will be produced. + """ + embed = Embed( + colour=Colour.blurple(), + title="The Zen of Python", + description=ZEN_OF_PYTHON + ) + + if search_value is None: + embed.title += ", by Tim Peters" + await ctx.send(embed=embed) + return + + zen_lines = ZEN_OF_PYTHON.splitlines() + + # handle if it's an index int + if isinstance(search_value, int): + upper_bound = len(zen_lines) - 1 + lower_bound = -1 * upper_bound + if not (lower_bound <= search_value <= upper_bound): + raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") + + embed.title += f" (line {search_value % len(zen_lines)}):" + embed.description = zen_lines[search_value] + await ctx.send(embed=embed) + return + + # Try to handle first exact word due difflib.SequenceMatched may use some other similar word instead + # exact word. + for i, line in enumerate(zen_lines): + for word in line.split(): + if word.lower() == search_value.lower(): + embed.title += f" (line {i}):" + embed.description = line + await ctx.send(embed=embed) + return + + # handle if it's a search string and not exact word + matcher = difflib.SequenceMatcher(None, search_value.lower()) + + best_match = "" + match_index = 0 + best_ratio = 0 + + for index, line in enumerate(zen_lines): + matcher.set_seq2(line.lower()) + + # the match ratio needs to be adjusted because, naturally, + # longer lines will have worse ratios than shorter lines when + # fuzzy searching for keywords. this seems to work okay. + adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() + + if adjusted_ratio > best_ratio: + best_ratio = adjusted_ratio + best_match = line + match_index = index + + if not best_match: + raise BadArgument("I didn't get a match! Please try again with a different search term.") + + embed.title += f" (line {match_index}):" + embed.description = best_match + await ctx.send(embed=embed) + + @command(aliases=("poll",)) + @with_role(*MODERATION_ROLES) + async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None: + """ + Build a quick voting poll with matching reactions with the provided options. + + A maximum of 20 options can be provided, as Discord supports a max of 20 + reactions on a single message. + """ + if len(title) > 256: + raise BadArgument("The title cannot be longer than 256 characters.") + if len(options) < 2: + raise BadArgument("Please provide at least 2 options.") + if len(options) > 20: + raise BadArgument("I can only handle 20 options!") + + codepoint_start = 127462 # represents "regional_indicator_a" unicode value + options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} + embed = Embed(title=title, description="\n".join(options.values())) + message = await ctx.send(embed=embed) + for reaction in options: + await message.add_reaction(reaction) + + async def send_pep_zero(self, ctx: Context) -> None: + """Send information about PEP 0.""" + pep_embed = Embed( + title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**", + description="[Link](https://www.python.org/dev/peps/)" + ) + pep_embed.set_thumbnail(url=ICON_URL) + pep_embed.add_field(name="Status", value="Active") + pep_embed.add_field(name="Created", value="13-Jul-2000") + pep_embed.add_field(name="Type", value="Informational") + + await ctx.send(embed=pep_embed) + + +def setup(bot: Bot) -> None: + """Load the Utils cog.""" + bot.add_cog(Utils(bot)) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py deleted file mode 100644 index ae156cf70..000000000 --- a/bot/cogs/verification.py +++ /dev/null @@ -1,191 +0,0 @@ -import logging -from contextlib import suppress - -from discord import Colour, Forbidden, Message, NotFound, Object -from discord.ext.commands import Cog, Context, command - -from bot import constants -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.decorators import in_whitelist, without_role -from bot.utils.checks import InWhitelistCheckFailure, without_role_check - -log = logging.getLogger(__name__) - -WELCOME_MESSAGE = f""" -Hello! Welcome to the server, and thanks for verifying yourself! - -For your records, these are the documents you accepted: - -`1)` Our rules, here: -`2)` Our privacy policy, here: - you can find information on how to have \ -your information removed here as well. - -Feel free to review them at any point! - -Additionally, if you'd like to receive notifications for the announcements \ -we post in <#{constants.Channels.announcements}> -from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ -to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. - -If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ -<#{constants.Channels.bot_commands}>. -""" - -BOT_MESSAGE_DELETE_DELAY = 10 - - -class Verification(Cog): - """User verification and role self-management.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Check new message event for messages to the checkpoint channel & process.""" - if message.channel.id != constants.Channels.verification: - return # Only listen for #checkpoint messages - - if message.author.bot: - # They're a bot, delete their message after the delay. - await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) - return - - # if a user mentions a role or guild member - # alert the mods in mod-alerts channel - if message.mentions or message.role_mentions: - log.debug( - f"{message.author} mentioned one or more users " - f"and/or roles in {message.channel.name}" - ) - - embed_text = ( - f"{message.author.mention} sent a message in " - f"{message.channel.mention} that contained user and/or role mentions." - f"\n\n**Original message:**\n>>> {message.content}" - ) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=constants.Icons.filtering, - colour=Colour(constants.Colours.soft_red), - title=f"User/Role mentioned in {message.channel.name}", - text=embed_text, - thumbnail=message.author.avatar_url_as(static_format="png"), - channel_id=constants.Channels.mod_alerts, - ) - - ctx: Context = await self.bot.get_context(message) - if ctx.command is not None and ctx.command.name == "accept": - return - - if any(r.id == constants.Roles.verified for r in ctx.author.roles): - log.info( - f"{ctx.author} posted '{ctx.message.content}' " - "in the verification channel, but is already verified." - ) - return - - log.debug( - f"{ctx.author} posted '{ctx.message.content}' in the verification " - "channel. We are providing instructions how to verify." - ) - await ctx.send( - f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " - f"and gain access to the rest of the server.", - delete_after=20 - ) - - log.trace(f"Deleting the message posted by {ctx.author}") - with suppress(NotFound): - await ctx.message.delete() - - @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) - @without_role(constants.Roles.verified) - @in_whitelist(channels=(constants.Channels.verification,)) - async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Accept our rules and gain access to the rest of the server.""" - log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") - await ctx.author.add_roles(Object(constants.Roles.verified), reason="Accepted the rules") - try: - await ctx.author.send(WELCOME_MESSAGE) - except Forbidden: - log.info(f"Sending welcome message failed for {ctx.author}.") - finally: - log.trace(f"Deleting accept message by {ctx.author}.") - with suppress(NotFound): - self.mod_log.ignore(constants.Event.message_delete, ctx.message.id) - await ctx.message.delete() - - @command(name='subscribe') - @in_whitelist(channels=(constants.Channels.bot_commands,)) - async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Subscribe to announcement notifications by assigning yourself the role.""" - has_role = False - - for role in ctx.author.roles: - if role.id == constants.Roles.announcements: - has_role = True - break - - if has_role: - await ctx.send(f"{ctx.author.mention} You're already subscribed!") - return - - log.debug(f"{ctx.author} called !subscribe. Assigning the 'Announcements' role.") - await ctx.author.add_roles(Object(constants.Roles.announcements), reason="Subscribed to announcements") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - await ctx.send( - f"{ctx.author.mention} Subscribed to <#{constants.Channels.announcements}> notifications.", - ) - - @command(name='unsubscribe') - @in_whitelist(channels=(constants.Channels.bot_commands,)) - async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Unsubscribe from announcement notifications by removing the role from yourself.""" - has_role = False - - for role in ctx.author.roles: - if role.id == constants.Roles.announcements: - has_role = True - break - - if not has_role: - await ctx.send(f"{ctx.author.mention} You're already unsubscribed!") - return - - log.debug(f"{ctx.author} called !unsubscribe. Removing the 'Announcements' role.") - await ctx.author.remove_roles(Object(constants.Roles.announcements), reason="Unsubscribed from announcements") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - await ctx.send( - f"{ctx.author.mention} Unsubscribed from <#{constants.Channels.announcements}> notifications." - ) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Check for & ignore any InWhitelistCheckFailure.""" - if isinstance(error, InWhitelistCheckFailure): - error.handled = True - - @staticmethod - def bot_check(ctx: Context) -> bool: - """Block any command within the verification channel that is not !accept.""" - if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): - return ctx.command.name == "accept" - else: - return True - - -def setup(bot: Bot) -> None: - """Load the Verification cog.""" - bot.add_cog(Verification(bot)) diff --git a/bot/cogs/watchchannels/__init__.py b/bot/cogs/watchchannels/__init__.py deleted file mode 100644 index 69d118df6..000000000 --- a/bot/cogs/watchchannels/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from bot.bot import Bot -from .bigbrother import BigBrother -from .talentpool import TalentPool - - -def setup(bot: Bot) -> None: - """Load the BigBrother and TalentPool cogs.""" - bot.add_cog(BigBrother(bot)) - bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py deleted file mode 100644 index 4d27a6333..000000000 --- a/bot/cogs/watchchannels/bigbrother.py +++ /dev/null @@ -1,165 +0,0 @@ -import logging -import textwrap -from collections import ChainMap - -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation.utils import post_infraction -from bot.constants import Channels, MODERATION_ROLES, Webhooks -from bot.converters import FetchedMember -from bot.decorators import with_role -from .watchchannel import WatchChannel - -log = logging.getLogger(__name__) - - -class BigBrother(WatchChannel, Cog, name="Big Brother"): - """Monitors users by relaying their messages to a watch channel to assist with moderation.""" - - def __init__(self, bot: Bot) -> None: - super().__init__( - bot, - destination=Channels.big_brother_logs, - webhook_id=Webhooks.big_brother, - api_endpoint='bot/infractions', - api_default_params={'active': 'true', 'type': 'watch', 'ordering': '-inserted_at'}, - logger=log - ) - - @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def bigbrother_group(self, ctx: Context) -> None: - """Monitors users by relaying their messages to the Big Brother watch channel.""" - await ctx.send_help(ctx.command) - - @bigbrother_group.command(name='watched', aliases=('all', 'list')) - @with_role(*MODERATION_ROLES) - async def watched_command( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Shows the users that are currently being monitored by Big Brother. - - The optional kwarg `oldest_first` can be used to order the list by oldest watched. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) - - @bigbrother_group.command(name='oldest') - @with_role(*MODERATION_ROLES) - async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: - """ - Shows Big Brother monitored users ordered by oldest watched. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) - - @bigbrother_group.command(name='watch', aliases=('w',)) - @with_role(*MODERATION_ROLES) - async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Relay messages sent by the given `user` to the `#big-brother` channel. - - A `reason` for adding the user to Big Brother is required and will be displayed - in the header when relaying messages of this user to the watchchannel. - """ - await self.apply_watch(ctx, user, reason) - - @bigbrother_group.command(name='unwatch', aliases=('uw',)) - @with_role(*MODERATION_ROLES) - async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Stop relaying messages by the given `user`.""" - await self.apply_unwatch(ctx, user, reason) - - async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: - """ - Add `user` to watched users and apply a watch infraction with `reason`. - - A message indicating the result of the operation is sent to `ctx`. - The message will include `user`'s previous watch infraction history, if it exists. - """ - if user.bot: - await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") - return - - if not await self.fetch_user_cache(): - await ctx.send(f":x: Updating the user cache failed, can't watch user {user}") - return - - if user.id in self.watched_users: - await ctx.send(f":x: {user} is already being watched.") - return - - response = await post_infraction(ctx, user, 'watch', reason, hidden=True, active=True) - - if response is not None: - self.watched_users[user.id] = response - msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother." - - history = await self.bot.api_client.get( - self.api_endpoint, - params={ - "user__id": str(user.id), - "active": "false", - 'type': 'watch', - 'ordering': '-inserted_at' - } - ) - - if len(history) > 1: - total = f"({len(history) // 2} previous infractions in total)" - end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") - start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" - msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" - else: - msg = ":x: Failed to post the infraction: response was empty." - - await ctx.send(msg) - - async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: - """ - Remove `user` from watched users and mark their infraction as inactive with `reason`. - - If `send_message` is True, a message indicating the result of the operation is sent to - `ctx`. - """ - active_watches = await self.bot.api_client.get( - self.api_endpoint, - params=ChainMap( - self.api_default_params, - {"user__id": str(user.id)} - ) - ) - if active_watches: - log.trace("Active watches for user found. Attempting to remove.") - [infraction] = active_watches - - await self.bot.api_client.patch( - f"{self.api_endpoint}/{infraction['id']}", - json={'active': False} - ) - - await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False) - - self._remove_user(user.id) - - if not send_message: # Prevents a message being sent to the channel if part of a permanent ban - log.debug(f"Perma-banned user {user} was unwatched.") - return - log.trace("User is not banned. Sending message to channel") - message = f":white_check_mark: Messages sent by {user} will no longer be relayed." - - else: - log.trace("No active watches found for user.") - if not send_message: # Prevents a message being sent to the channel if part of a permanent ban - log.debug(f"{user} was not on the watch list; no removal necessary.") - return - log.trace("User is not perma banned. Send the error message.") - message = ":x: The specified user is currently not being watched." - - await ctx.send(message) diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py deleted file mode 100644 index 89256e92e..000000000 --- a/bot/cogs/watchchannels/talentpool.py +++ /dev/null @@ -1,264 +0,0 @@ -import logging -import textwrap -from collections import ChainMap - -from discord import Color, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks -from bot.converters import FetchedMember -from bot.decorators import with_role -from bot.pagination import LinePaginator -from bot.utils import time -from .watchchannel import WatchChannel - -log = logging.getLogger(__name__) - - -class TalentPool(WatchChannel, Cog, name="Talentpool"): - """Relays messages of helper candidates to a watch channel to observe them.""" - - def __init__(self, bot: Bot) -> None: - super().__init__( - bot, - destination=Channels.talent_pool, - webhook_id=Webhooks.talent_pool, - api_endpoint='bot/nominations', - api_default_params={'active': 'true', 'ordering': '-inserted_at'}, - logger=log, - ) - - @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def nomination_group(self, ctx: Context) -> None: - """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" - await ctx.send_help(ctx.command) - - @nomination_group.command(name='watched', aliases=('all', 'list')) - @with_role(*MODERATION_ROLES) - async def watched_command( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Shows the users that are currently being monitored in the talent pool. - - The optional kwarg `oldest_first` can be used to order the list by oldest nomination. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) - - @nomination_group.command(name='oldest') - @with_role(*MODERATION_ROLES) - async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: - """ - Shows talent pool monitored users ordered by oldest nomination. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) - - @nomination_group.command(name='watch', aliases=('w', 'add', 'a')) - @with_role(*STAFF_ROLES) - async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Relay messages sent by the given `user` to the `#talent-pool` channel. - - A `reason` for adding the user to the talent pool is required and will be displayed - in the header when relaying messages of this user to the channel. - """ - if user.bot: - await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") - return - - if isinstance(user, Member) and any(role.id in STAFF_ROLES for role in user.roles): - await ctx.send(":x: Nominating staff members, eh? Here's a cookie :cookie:") - return - - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update the user cache; can't add {user}") - return - - if user.id in self.watched_users: - await ctx.send(f":x: {user} is already being watched in the talent pool") - return - - # Manual request with `raise_for_status` as False because we want the actual response - session = self.bot.api_client.session - url = self.bot.api_client._url_for(self.api_endpoint) - kwargs = { - 'json': { - 'actor': ctx.author.id, - 'reason': reason, - 'user': user.id - }, - 'raise_for_status': False, - } - async with session.post(url, **kwargs) as resp: - response_data = await resp.json() - - if resp.status == 400 and response_data.get('user', False): - await ctx.send(":x: The specified user can't be found in the database tables") - return - else: - resp.raise_for_status() - - self.watched_users[user.id] = response_data - msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel" - - history = await self.bot.api_client.get( - self.api_endpoint, - params={ - "user__id": str(user.id), - "active": "false", - "ordering": "-inserted_at" - } - ) - - if history: - total = f"({len(history)} previous nominations in total)" - start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" - end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" - msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" - - await ctx.send(msg) - - @nomination_group.command(name='history', aliases=('info', 'search')) - @with_role(*MODERATION_ROLES) - async def history_command(self, ctx: Context, user: FetchedMember) -> None: - """Shows the specified user's nomination history.""" - result = await self.bot.api_client.get( - self.api_endpoint, - params={ - 'user__id': str(user.id), - 'ordering': "-active,-inserted_at" - } - ) - if not result: - await ctx.send(":warning: This user has never been nominated") - return - - embed = Embed( - title=f"Nominations for {user.display_name} `({user.id})`", - color=Color.blue() - ) - lines = [self._nomination_to_string(nomination) for nomination in result] - await LinePaginator.paginate( - lines, - ctx=ctx, - embed=embed, - empty=True, - max_lines=3, - max_size=1000 - ) - - @nomination_group.command(name='unwatch', aliases=('end', )) - @with_role(*MODERATION_ROLES) - async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Ends the active nomination of the specified user with the given reason. - - Providing a `reason` is required. - """ - active_nomination = await self.bot.api_client.get( - self.api_endpoint, - params=ChainMap( - self.api_default_params, - {"user__id": str(user.id)} - ) - ) - - if not active_nomination: - await ctx.send(":x: The specified user does not have an active nomination") - return - - [nomination] = active_nomination - await self.bot.api_client.patch( - f"{self.api_endpoint}/{nomination['id']}", - json={'end_reason': reason, 'active': False} - ) - await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") - self._remove_user(user.id) - - @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def nomination_edit_group(self, ctx: Context) -> None: - """Commands to edit nominations.""" - await ctx.send_help(ctx.command) - - @nomination_edit_group.command(name='reason') - @with_role(*MODERATION_ROLES) - async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None: - """ - Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. - - If the nomination is active, the reason for nominating the user will be edited; - If the nomination is no longer active, the reason for ending the nomination will be edited instead. - """ - try: - nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") - except ResponseCodeError as e: - if e.response.status == 404: - self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") - await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") - return - else: - raise - - field = "reason" if nomination["active"] else "end_reason" - - self.log.trace(f"Changing {field} for nomination with id {nomination_id} to {reason}") - - await self.bot.api_client.patch( - f"{self.api_endpoint}/{nomination_id}", - json={field: reason} - ) - - await ctx.send(f":white_check_mark: Updated the {field} of the nomination!") - - def _nomination_to_string(self, nomination_object: dict) -> str: - """Creates a string representation of a nomination.""" - guild = self.bot.get_guild(Guild.id) - - actor_id = nomination_object["actor"] - actor = guild.get_member(actor_id) - - active = nomination_object["active"] - log.debug(active) - log.debug(type(nomination_object["inserted_at"])) - - start_date = time.format_infraction(nomination_object["inserted_at"]) - if active: - lines = textwrap.dedent( - f""" - =============== - Status: **Active** - Date: {start_date} - Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} - Nomination ID: `{nomination_object["id"]}` - =============== - """ - ) - else: - end_date = time.format_infraction(nomination_object["ended_at"]) - lines = textwrap.dedent( - f""" - =============== - Status: Inactive - Date: {start_date} - Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} - - End date: {end_date} - Unwatch reason: {nomination_object["end_reason"]} - Nomination ID: `{nomination_object["id"]}` - =============== - """ - ) - - return lines.strip() diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py deleted file mode 100644 index 044077350..000000000 --- a/bot/cogs/watchchannels/watchchannel.py +++ /dev/null @@ -1,348 +0,0 @@ -import asyncio -import logging -import re -import textwrap -from abc import abstractmethod -from collections import defaultdict, deque -from dataclasses import dataclass -from typing import Optional - -import dateutil.parser -import discord -from discord import Color, DMChannel, Embed, HTTPException, Message, errors -from discord.ext.commands import Cog, Context - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons -from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - -URL_RE = re.compile(r"(https?://[^\s]+)") - - -@dataclass -class MessageHistory: - """Represents a watch channel's message history.""" - - last_author: Optional[int] = None - last_channel: Optional[int] = None - message_count: int = 0 - - -class WatchChannel(metaclass=CogABCMeta): - """ABC with functionality for relaying users' messages to a certain channel.""" - - @abstractmethod - def __init__( - self, - bot: Bot, - destination: int, - webhook_id: int, - api_endpoint: str, - api_default_params: dict, - logger: logging.Logger - ) -> None: - self.bot = bot - - self.destination = destination # E.g., Channels.big_brother_logs - self.webhook_id = webhook_id # E.g., Webhooks.big_brother - self.api_endpoint = api_endpoint # E.g., 'bot/infractions' - self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} - self.log = logger # Logger of the child cog for a correct name in the logs - - self._consume_task = None - self.watched_users = defaultdict(dict) - self.message_queue = defaultdict(lambda: defaultdict(deque)) - self.consumption_queue = {} - self.retries = 5 - self.retry_delay = 10 - self.channel = None - self.webhook = None - self.message_history = MessageHistory() - - self._start = self.bot.loop.create_task(self.start_watchchannel()) - - @property - def modlog(self) -> ModLog: - """Provides access to the ModLog cog for alert purposes.""" - return self.bot.get_cog("ModLog") - - @property - def consuming_messages(self) -> bool: - """Checks if a consumption task is currently running.""" - if self._consume_task is None: - return False - - if self._consume_task.done(): - exc = self._consume_task.exception() - if exc: - self.log.exception( - "The message queue consume task has failed with:", - exc_info=exc - ) - return False - - return True - - async def start_watchchannel(self) -> None: - """Starts the watch channel by getting the channel, webhook, and user cache ready.""" - await self.bot.wait_until_guild_available() - - try: - self.channel = await self.bot.fetch_channel(self.destination) - except HTTPException: - self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - if self.channel is None or self.webhook is None: - self.log.error("Failed to start the watch channel; unloading the cog.") - - message = textwrap.dedent( - f""" - An error occurred while loading the text channel or webhook. - - TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} - Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} - - The Cog has been unloaded. - """ - ) - - await self.modlog.send_log_message( - title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", - text=message, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - if not await self.fetch_user_cache(): - await self.modlog.send_log_message( - title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", - text="Could not retrieve the list of watched users from the API and messages will not be relayed.", - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - async def fetch_user_cache(self) -> bool: - """ - Fetches watched users from the API and updates the watched user cache accordingly. - - This function returns `True` if the update succeeded. - """ - try: - data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) - except ResponseCodeError as err: - self.log.exception("Failed to fetch the watched users from the API", exc_info=err) - return False - - self.watched_users = defaultdict(dict) - - for entry in data: - user_id = entry.pop('user') - self.watched_users[user_id] = entry - - return True - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Queues up messages sent by watched users.""" - if msg.author.id in self.watched_users: - if not self.consuming_messages: - self._consume_task = self.bot.loop.create_task(self.consume_messages()) - - self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") - self.message_queue[msg.author.id][msg.channel.id].append(msg) - - async def consume_messages(self, delay_consumption: bool = True) -> None: - """Consumes the message queues to log watched users' messages.""" - if delay_consumption: - self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") - await asyncio.sleep(BigBrotherConfig.log_delay) - - self.log.trace("Started consuming the message queue") - - # If the previous consumption Task failed, first consume the existing comsumption_queue - if not self.consumption_queue: - self.consumption_queue = self.message_queue.copy() - self.message_queue.clear() - - for user_channel_queues in self.consumption_queue.values(): - for channel_queue in user_channel_queues.values(): - while channel_queue: - msg = channel_queue.popleft() - - self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") - await self.relay_message(msg) - - self.consumption_queue.clear() - - if self.message_queue: - self.log.trace("Channel queue not empty: Continuing consuming queues") - self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) - else: - self.log.trace("Done consuming messages.") - - async def webhook_send( - self, - content: Optional[str] = None, - username: Optional[str] = None, - avatar_url: Optional[str] = None, - embed: Optional[Embed] = None, - ) -> None: - """Sends a message to the webhook with the specified kwargs.""" - username = messages.sub_clyde(username) - try: - await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send a message to the webhook", - exc_info=exc - ) - - async def relay_message(self, msg: Message) -> None: - """Relays the message to the relevant watch channel.""" - limit = BigBrotherConfig.header_message_limit - - if ( - msg.author.id != self.message_history.last_author - or msg.channel.id != self.message_history.last_channel - or self.message_history.message_count >= limit - ): - self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) - - await self.send_header(msg) - - cleaned_content = msg.clean_content - - if cleaned_content: - # Put all non-media URLs in a code block to prevent embeds - media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} - for url in URL_RE.findall(cleaned_content): - if url not in media_urls: - cleaned_content = cleaned_content.replace(url, f"`{url}`") - await self.webhook_send( - cleaned_content, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - - if msg.attachments: - try: - await messages.send_attachments(msg, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await self.webhook_send( - embed=e, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send an attachment to the webhook", - exc_info=exc - ) - - self.message_history.message_count += 1 - - async def send_header(self, msg: Message) -> None: - """Sends a header embed with information about the relayed messages to the watch channel.""" - user_id = msg.author.id - - guild = self.bot.get_guild(GuildConfig.id) - actor = guild.get_member(self.watched_users[user_id]['actor']) - actor = actor.display_name if actor else self.watched_users[user_id]['actor'] - - inserted_at = self.watched_users[user_id]['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - - reason = self.watched_users[user_id]['reason'] - - if isinstance(msg.channel, DMChannel): - # If a watched user DMs the bot there won't be a channel name or jump URL - # This could technically include a GroupChannel but bot's can't be in those - message_jump = "via DM" - else: - message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" - - footer = f"Added {time_delta} by {actor} | Reason: {reason}" - embed = Embed(description=f"{msg.author.mention} {message_jump}") - embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) - - await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) - - async def list_watched_users( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Gives an overview of the watched user list for this channel. - - The optional kwarg `oldest_first` orders the list by oldest entry. - - The optional kwarg `update_cache` specifies whether the cache should - be refreshed by polling the API. - """ - if update_cache: - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") - update_cache = False - - lines = [] - for user_id, user_data in self.watched_users.items(): - inserted_at = user_data['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - lines.append(f"• <@{user_id}> (added {time_delta})") - - if oldest_first: - lines.reverse() - - lines = lines or ("There's nothing here yet.",) - - embed = Embed( - title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", - color=Color.blue() - ) - await LinePaginator.paginate(lines, ctx, embed, empty=False) - - @staticmethod - def _get_time_delta(time_string: str) -> str: - """Returns the time in human-readable time delta format.""" - date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) - time_delta = time_since(date_time, precision="minutes", max_units=1) - - return time_delta - - def _remove_user(self, user_id: int) -> None: - """Removes a user from a watch channel.""" - self.watched_users.pop(user_id, None) - self.message_queue.pop(user_id, None) - self.consumption_queue.pop(user_id, None) - - def cog_unload(self) -> None: - """Takes care of unloading the cog and canceling the consumption task.""" - self.log.trace("Unloading the cog") - if self._consume_task and not self._consume_task.done(): - self._consume_task.cancel() - try: - self._consume_task.result() - except asyncio.CancelledError as e: - self.log.exception( - "The consume task was canceled. Messages may be lost.", - exc_info=e - ) diff --git a/bot/cogs/webhook_remover.py b/bot/cogs/webhook_remover.py deleted file mode 100644 index 5812da87c..000000000 --- a/bot/cogs/webhook_remover.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging -import re - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Channels, Colours, Event, Icons - -WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) - -ALERT_MESSAGE_TEMPLATE = ( - "{user}, looks like you posted a Discord webhook URL. Therefore, your " - "message has been removed. Your webhook may have been **compromised** so " - "please re-create the webhook **immediately**. If you believe this was " - "mistake, please let us know." -) - -log = logging.getLogger(__name__) - - -class WebhookRemover(Cog): - """Scan messages to detect Discord webhooks links.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get current instance of `ModLog`.""" - return self.bot.get_cog("ModLog") - - async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: - """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" - # Don't log this, due internal delete, not by user. Will make different entry. - self.mod_log.ignore(Event.message_delete, msg.id) - - try: - await msg.delete() - except NotFound: - log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") - return - - await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) - - message = ( - f"{msg.author} (`{msg.author.id}`) posted a Discord webhook URL " - f"to #{msg.channel}. Webhook URL was `{redacted_url}`" - ) - log.debug(message) - - # Send entry to moderation alerts. - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Discord webhook URL removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts - ) - - self.bot.stats.incr("tokens.removed_webhooks") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Check if a Discord webhook URL is in `message`.""" - # Ignore DMs; can't delete messages in there anyway. - if not msg.guild or msg.author.bot: - return - - matches = WEBHOOK_URL_RE.search(msg.content) - if matches: - await self.delete_and_respond(msg, matches[1] + "xxx") - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """Check if a Discord webhook URL is in the edited message `after`.""" - await self.on_message(after) - - -def setup(bot: Bot) -> None: - """Load `WebhookRemover` cog.""" - bot.add_cog(WebhookRemover(bot)) diff --git a/bot/cogs/wolfram.py b/bot/cogs/wolfram.py deleted file mode 100644 index e6cae3bb8..000000000 --- a/bot/cogs/wolfram.py +++ /dev/null @@ -1,280 +0,0 @@ -import logging -from io import BytesIO -from typing import Callable, List, Optional, Tuple -from urllib import parse - -import discord -from dateutil.relativedelta import relativedelta -from discord import Embed -from discord.ext import commands -from discord.ext.commands import BucketType, Cog, Context, check, group - -from bot.bot import Bot -from bot.constants import Colours, STAFF_ROLES, Wolfram -from bot.pagination import ImagePaginator -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -APPID = Wolfram.key -DEFAULT_OUTPUT_FORMAT = "JSON" -QUERY = "http://api.wolframalpha.com/v2/{request}?{data}" -WOLF_IMAGE = "https://www.symbols.com/gi.php?type=1&id=2886&i=1" - -MAX_PODS = 20 - -# Allows for 10 wolfram calls pr user pr day -usercd = commands.CooldownMapping.from_cooldown(Wolfram.user_limit_day, 60*60*24, BucketType.user) - -# Allows for max api requests / days in month per day for the entire guild (Temporary) -guildcd = commands.CooldownMapping.from_cooldown(Wolfram.guild_limit_day, 60*60*24, BucketType.guild) - - -async def send_embed( - ctx: Context, - message_txt: str, - colour: int = Colours.soft_red, - footer: str = None, - img_url: str = None, - f: discord.File = None -) -> None: - """Generate & send a response embed with Wolfram as the author.""" - embed = Embed(colour=colour) - embed.description = message_txt - embed.set_author(name="Wolfram Alpha", - icon_url=WOLF_IMAGE, - url="https://www.wolframalpha.com/") - if footer: - embed.set_footer(text=footer) - - if img_url: - embed.set_image(url=img_url) - - await ctx.send(embed=embed, file=f) - - -def custom_cooldown(*ignore: List[int]) -> Callable: - """ - Implement per-user and per-guild cooldowns for requests to the Wolfram API. - - A list of roles may be provided to ignore the per-user cooldown - """ - async def predicate(ctx: Context) -> bool: - if ctx.invoked_with == 'help': - # if the invoked command is help we don't want to increase the ratelimits since it's not actually - # invoking the command/making a request, so instead just check if the user/guild are on cooldown. - guild_cooldown = not guildcd.get_bucket(ctx.message).get_tokens() == 0 # if guild is on cooldown - if not any(r.id in ignore for r in ctx.author.roles): # check user bucket if user is not ignored - return guild_cooldown and not usercd.get_bucket(ctx.message).get_tokens() == 0 - return guild_cooldown - - user_bucket = usercd.get_bucket(ctx.message) - - if all(role.id not in ignore for role in ctx.author.roles): - user_rate = user_bucket.update_rate_limit() - - if user_rate: - # Can't use api; cause: member limit - delta = relativedelta(seconds=int(user_rate)) - cooldown = humanize_delta(delta) - message = ( - "You've used up your limit for Wolfram|Alpha requests.\n" - f"Cooldown: {cooldown}" - ) - await send_embed(ctx, message) - return False - - guild_bucket = guildcd.get_bucket(ctx.message) - guild_rate = guild_bucket.update_rate_limit() - - # Repr has a token attribute to read requests left - log.debug(guild_bucket) - - if guild_rate: - # Can't use api; cause: guild limit - message = ( - "The max limit of requests for the server has been reached for today.\n" - f"Cooldown: {int(guild_rate)}" - ) - await send_embed(ctx, message) - return False - - return True - return check(predicate) - - -async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tuple]]: - """Get the Wolfram API pod pages for the provided query.""" - async with ctx.channel.typing(): - url_str = parse.urlencode({ - "input": query, - "appid": APPID, - "output": DEFAULT_OUTPUT_FORMAT, - "format": "image,plaintext" - }) - request_url = QUERY.format(request="query", data=url_str) - - async with bot.http_session.get(request_url) as response: - json = await response.json(content_type='text/plain') - - result = json["queryresult"] - - if result["error"]: - # API key not set up correctly - if result["error"]["msg"] == "Invalid appid": - message = "Wolfram API key is invalid or missing." - log.warning( - "API key seems to be missing, or invalid when " - f"processing a wolfram request: {url_str}, Response: {json}" - ) - await send_embed(ctx, message) - return - - message = "Something went wrong internally with your request, please notify staff!" - log.warning(f"Something went wrong getting a response from wolfram: {url_str}, Response: {json}") - await send_embed(ctx, message) - return - - if not result["success"]: - message = f"I couldn't find anything for {query}." - await send_embed(ctx, message) - return - - if not result["numpods"]: - message = "Could not find any results." - await send_embed(ctx, message) - return - - pods = result["pods"] - pages = [] - for pod in pods[:MAX_PODS]: - subs = pod.get("subpods") - - for sub in subs: - title = sub.get("title") or sub.get("plaintext") or sub.get("id", "") - img = sub["img"]["src"] - pages.append((title, img)) - return pages - - -class Wolfram(Cog): - """Commands for interacting with the Wolfram|Alpha API.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_command(self, ctx: Context, *, query: str) -> None: - """Requests all answers on a single image, sends an image of all related pods.""" - url_str = parse.urlencode({ - "i": query, - "appid": APPID, - }) - query = QUERY.format(request="simple", data=url_str) - - # Give feedback that the bot is working. - async with ctx.channel.typing(): - async with self.bot.http_session.get(query) as response: - status = response.status - image_bytes = await response.read() - - f = discord.File(BytesIO(image_bytes), filename="image.png") - image_url = "attachment://image.png" - - if status == 501: - message = "Failed to get response" - footer = "" - color = Colours.soft_red - elif status == 400: - message = "No input found" - footer = "" - color = Colours.soft_red - elif status == 403: - message = "Wolfram API key is invalid or missing." - footer = "" - color = Colours.soft_red - else: - message = "" - footer = "View original for a bigger picture." - color = Colours.soft_orange - - # Sends a "blank" embed if no request is received, unsure how to fix - await send_embed(ctx, message, color, footer=footer, img_url=image_url, f=f) - - @wolfram_command.command(name="page", aliases=("pa", "p")) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_page_command(self, ctx: Context, *, query: str) -> None: - """ - Requests a drawn image of given query. - - Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. - """ - pages = await get_pod_pages(ctx, self.bot, query) - - if not pages: - return - - embed = Embed() - embed.set_author(name="Wolfram Alpha", - icon_url=WOLF_IMAGE, - url="https://www.wolframalpha.com/") - embed.colour = Colours.soft_orange - - await ImagePaginator.paginate(pages, ctx, embed) - - @wolfram_command.command(name="cut", aliases=("c",)) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_cut_command(self, ctx: Context, *, query: str) -> None: - """ - Requests a drawn image of given query. - - Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. - """ - pages = await get_pod_pages(ctx, self.bot, query) - - if not pages: - return - - if len(pages) >= 2: - page = pages[1] - else: - page = pages[0] - - await send_embed(ctx, page[0], colour=Colours.soft_orange, img_url=page[1]) - - @wolfram_command.command(name="short", aliases=("sh", "s")) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_short_command(self, ctx: Context, *, query: str) -> None: - """Requests an answer to a simple question.""" - url_str = parse.urlencode({ - "i": query, - "appid": APPID, - }) - query = QUERY.format(request="result", data=url_str) - - # Give feedback that the bot is working. - async with ctx.channel.typing(): - async with self.bot.http_session.get(query) as response: - status = response.status - response_text = await response.text() - - if status == 501: - message = "Failed to get response" - color = Colours.soft_red - elif status == 400: - message = "No input found" - color = Colours.soft_red - elif response_text == "Error 1: Invalid appid": - message = "Wolfram API key is invalid or missing." - color = Colours.soft_red - else: - message = response_text - color = Colours.soft_orange - - await send_embed(ctx, message, color) - - -def setup(bot: Bot) -> None: - """Load the Wolfram cog.""" - bot.add_cog(Wolfram(bot)) diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index da4e92ccc..df38090fb 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -2,7 +2,7 @@ import textwrap import unittest from unittest.mock import AsyncMock, Mock, patch -from bot.cogs.moderation.infractions import Infractions +from bot.cogs.moderation.infraction.infractions import Infractions from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index 70aea2bab..84d036405 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -6,7 +6,7 @@ import discord from bot import constants from bot.api import ResponseCodeError -from bot.cogs.sync.syncers import Syncer, _Diff +from bot.cogs.backend.sync import Syncer, _Diff from tests import helpers diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 120bc991d..ea7d090ba 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -5,8 +5,8 @@ import discord from bot import constants from bot.api import ResponseCodeError -from bot.cogs import sync -from bot.cogs.sync.syncers import Syncer +from bot.cogs.backend import sync +from bot.cogs.backend.sync import Syncer from tests import helpers from tests.base import CommandTestCase diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 79eee98f4..888c49ca8 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -3,7 +3,7 @@ from unittest import mock import discord -from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from bot.cogs.backend.sync import RoleSyncer, _Diff, _Role from tests import helpers diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 002a947ad..71f4b134c 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from bot.cogs.backend.sync import UserSyncer, _Diff, _User from tests import helpers diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index ecb7abf00..b00211f47 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, Mock from discord import NotFound -from bot.cogs import antimalware +from bot.cogs.filters import antimalware from bot.constants import Channels, STAFF_ROLES from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/cogs/test_antispam.py index ce5472c71..8a3d8d02e 100644 --- a/tests/bot/cogs/test_antispam.py +++ b/tests/bot/cogs/test_antispam.py @@ -1,6 +1,6 @@ import unittest -from bot.cogs import antispam +from bot.cogs.filters import antispam class AntispamConfigurationValidationTests(unittest.TestCase): diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 79c0e0ad3..305a2bad9 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -6,7 +6,7 @@ import unittest.mock import discord from bot import constants -from bot.cogs import information +from bot.cogs.info import information from bot.utils.checks import InWhitelistCheckFailure from tests import helpers diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py index 9d1a62f7e..82679f69c 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/cogs/test_security.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock from discord.ext.commands import NoPrivateMessage -from bot.cogs import security +from bot.cogs.filters import security from tests.helpers import MockBot, MockContext diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 343e37db9..c7bac3ab3 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -6,8 +6,8 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, pat from discord.ext import commands from bot import constants -from bot.cogs import snekbox -from bot.cogs.snekbox import Snekbox +from bot.cogs.utils import snekbox +from bot.cogs.utils.snekbox import Snekbox from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3349caa73..e33f3af38 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -6,9 +6,9 @@ from unittest.mock import MagicMock from discord import Colour, NotFound from bot import constants -from bot.cogs import token_remover +from bot.cogs.filters import token_remover +from bot.cogs.filters.token_remover import Token, TokenRemover from bot.cogs.moderation import ModLog -from bot.cogs.token_remover import Token, TokenRemover from tests.helpers import MockBot, MockMessage, autospec -- cgit v1.2.3 From b224d46d68699ece3382cd333df7ede9e9a62e02 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 12 Aug 2020 14:31:56 -0700 Subject: Restructure tests and fix broken tests The cog tests structure should mirror the structure of the cogs folder. Fix some import/patch paths which broke due to the restructure. --- tests/bot/cogs/backend/__init__.py | 0 tests/bot/cogs/backend/sync/__init__.py | 0 tests/bot/cogs/backend/sync/test_base.py | 404 ++++++++++++++ tests/bot/cogs/backend/sync/test_cog.py | 415 +++++++++++++++ tests/bot/cogs/backend/sync/test_roles.py | 157 ++++++ tests/bot/cogs/backend/sync/test_users.py | 158 ++++++ tests/bot/cogs/backend/test_logging.py | 32 ++ tests/bot/cogs/filters/__init__.py | 0 tests/bot/cogs/filters/test_antimalware.py | 165 ++++++ tests/bot/cogs/filters/test_antispam.py | 35 ++ tests/bot/cogs/filters/test_security.py | 54 ++ tests/bot/cogs/filters/test_token_remover.py | 310 +++++++++++ tests/bot/cogs/info/__init__.py | 0 tests/bot/cogs/info/test_information.py | 584 +++++++++++++++++++++ tests/bot/cogs/moderation/infraction/__init__.py | 0 .../cogs/moderation/infraction/test_infractions.py | 55 ++ tests/bot/cogs/moderation/test_incidents.py | 4 +- tests/bot/cogs/moderation/test_infractions.py | 55 -- tests/bot/cogs/moderation/test_slowmode.py | 111 ++++ tests/bot/cogs/sync/__init__.py | 0 tests/bot/cogs/sync/test_base.py | 404 -------------- tests/bot/cogs/sync/test_cog.py | 415 --------------- tests/bot/cogs/sync/test_roles.py | 157 ------ tests/bot/cogs/sync/test_users.py | 158 ------ tests/bot/cogs/test_antimalware.py | 165 ------ tests/bot/cogs/test_antispam.py | 35 -- tests/bot/cogs/test_information.py | 584 --------------------- tests/bot/cogs/test_jams.py | 173 ------ tests/bot/cogs/test_logging.py | 32 -- tests/bot/cogs/test_security.py | 54 -- tests/bot/cogs/test_slowmode.py | 111 ---- tests/bot/cogs/test_snekbox.py | 409 --------------- tests/bot/cogs/test_token_remover.py | 310 ----------- tests/bot/cogs/utils/__init__.py | 0 tests/bot/cogs/utils/test_jams.py | 173 ++++++ tests/bot/cogs/utils/test_snekbox.py | 409 +++++++++++++++ 36 files changed, 3064 insertions(+), 3064 deletions(-) create mode 100644 tests/bot/cogs/backend/__init__.py create mode 100644 tests/bot/cogs/backend/sync/__init__.py create mode 100644 tests/bot/cogs/backend/sync/test_base.py create mode 100644 tests/bot/cogs/backend/sync/test_cog.py create mode 100644 tests/bot/cogs/backend/sync/test_roles.py create mode 100644 tests/bot/cogs/backend/sync/test_users.py create mode 100644 tests/bot/cogs/backend/test_logging.py create mode 100644 tests/bot/cogs/filters/__init__.py create mode 100644 tests/bot/cogs/filters/test_antimalware.py create mode 100644 tests/bot/cogs/filters/test_antispam.py create mode 100644 tests/bot/cogs/filters/test_security.py create mode 100644 tests/bot/cogs/filters/test_token_remover.py create mode 100644 tests/bot/cogs/info/__init__.py create mode 100644 tests/bot/cogs/info/test_information.py create mode 100644 tests/bot/cogs/moderation/infraction/__init__.py create mode 100644 tests/bot/cogs/moderation/infraction/test_infractions.py delete mode 100644 tests/bot/cogs/moderation/test_infractions.py create mode 100644 tests/bot/cogs/moderation/test_slowmode.py delete mode 100644 tests/bot/cogs/sync/__init__.py delete mode 100644 tests/bot/cogs/sync/test_base.py delete mode 100644 tests/bot/cogs/sync/test_cog.py delete mode 100644 tests/bot/cogs/sync/test_roles.py delete mode 100644 tests/bot/cogs/sync/test_users.py delete mode 100644 tests/bot/cogs/test_antimalware.py delete mode 100644 tests/bot/cogs/test_antispam.py delete mode 100644 tests/bot/cogs/test_information.py delete mode 100644 tests/bot/cogs/test_jams.py delete mode 100644 tests/bot/cogs/test_logging.py delete mode 100644 tests/bot/cogs/test_security.py delete mode 100644 tests/bot/cogs/test_slowmode.py delete mode 100644 tests/bot/cogs/test_snekbox.py delete mode 100644 tests/bot/cogs/test_token_remover.py create mode 100644 tests/bot/cogs/utils/__init__.py create mode 100644 tests/bot/cogs/utils/test_jams.py create mode 100644 tests/bot/cogs/utils/test_snekbox.py (limited to 'tests') diff --git a/tests/bot/cogs/backend/__init__.py b/tests/bot/cogs/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/backend/sync/__init__.py b/tests/bot/cogs/backend/sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/backend/sync/test_base.py b/tests/bot/cogs/backend/sync/test_base.py new file mode 100644 index 000000000..0d0a8299d --- /dev/null +++ b/tests/bot/cogs/backend/sync/test_base.py @@ -0,0 +1,404 @@ +import asyncio +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.backend.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): + """Syncer subclass with mocks for abstract methods for testing purposes.""" + + name = "test" + _get_diff = mock.AsyncMock() + _sync = mock.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): + """Tests for the syncer base class.""" + + def setUp(self): + self.bot = helpers.MockBot() + + def test_instantiation_fails_without_abstract_methods(self): + """The class must have abstract methods implemented.""" + with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): + Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): + """Tests for sending the sync confirmation prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + + def mock_get_channel(self): + """Fixture to return a mock channel and message for when `get_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + mock_channel.send.return_value = mock_message + self.bot.get_channel.return_value = mock_channel + + return mock_channel, mock_message + + def mock_fetch_channel(self): + """Fixture to return a mock channel and message for when `fetch_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + self.bot.get_channel.return_value = None + mock_channel.send.return_value = mock_message + self.bot.fetch_channel.return_value = mock_channel + + return mock_channel, mock_message + + async def test_send_prompt_edits_and_returns_message(self): + """The given message should be edited to display the prompt and then should be returned.""" + msg = helpers.MockMessage() + ret_val = await self.syncer._send_prompt(msg) + + msg.edit.assert_called_once() + self.assertIn("content", msg.edit.call_args[1]) + self.assertEqual(ret_val, msg) + + async def test_send_prompt_gets_dev_core_channel(self): + """The dev-core channel should be retrieved if an extant message isn't given.""" + subtests = ( + (self.bot.get_channel, self.mock_get_channel), + (self.bot.fetch_channel, self.mock_fetch_channel), + ) + + for method, mock_ in subtests: + with self.subTest(method=method, msg=mock_.__name__): + mock_() + await self.syncer._send_prompt() + + method.assert_called_once_with(constants.Channels.dev_core) + + async def test_send_prompt_returns_none_if_channel_fetch_fails(self): + """None should be returned if there's an HTTPException when fetching the channel.""" + self.bot.get_channel.return_value = None + self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + + ret_val = await self.syncer._send_prompt() + + self.assertIsNone(ret_val) + + async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): + """A new message mentioning core devs should be sent and returned if message isn't given.""" + for mock_ in (self.mock_get_channel, self.mock_fetch_channel): + with self.subTest(msg=mock_.__name__): + mock_channel, mock_message = mock_() + ret_val = await self.syncer._send_prompt() + + mock_channel.send.assert_called_once() + self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) + self.assertEqual(ret_val, mock_message) + + async def test_send_prompt_adds_reactions(self): + """The message should have reactions for confirmation added.""" + extant_message = helpers.MockMessage() + subtests = ( + (extant_message, lambda: (None, extant_message)), + (None, self.mock_get_channel), + (None, self.mock_fetch_channel), + ) + + for message_arg, mock_ in subtests: + subtest_msg = "Extant message" if mock_.__name__ == "" else mock_.__name__ + + with self.subTest(msg=subtest_msg): + _, mock_message = mock_() + await self.syncer._send_prompt(message_arg) + + calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] + mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): + """Tests for waiting for a sync confirmation reaction on the prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) + + @staticmethod + def get_message_reaction(emoji): + """Fixture to return a mock message an reaction from the given `emoji`.""" + message = helpers.MockMessage() + reaction = helpers.MockReaction(emoji=emoji, message=message) + + return message, reaction + + def test_reaction_check_for_valid_emoji_and_authors(self): + """Should return True if authors are identical or are a bot and a core dev, respectively.""" + user_subtests = ( + ( + helpers.MockMember(id=77), + helpers.MockMember(id=77), + "identical users", + ), + ( + helpers.MockMember(id=77, bot=True), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "bot author and core-dev reactor", + ), + ) + + for emoji in self.syncer._REACTION_EMOJIS: + for author, user, msg in user_subtests: + with self.subTest(author=author, user=user, emoji=emoji, msg=msg): + message, reaction = self.get_message_reaction(emoji) + ret_val = self.syncer._reaction_check(author, message, reaction, user) + + self.assertTrue(ret_val) + + def test_reaction_check_for_invalid_reactions(self): + """Should return False for invalid reaction events.""" + valid_emoji = self.syncer._REACTION_EMOJIS[0] + subtests = ( + ( + helpers.MockMember(id=77), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "users are not identical", + ), + ( + helpers.MockMember(id=77, bot=True), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43), + "reactor lacks the core-dev role", + ), + ( + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + "reactor is a bot", + ), + ( + helpers.MockMember(id=77), + helpers.MockMessage(id=95), + helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), + helpers.MockMember(id=77), + "messages are not identical", + ), + ( + helpers.MockMember(id=77), + *self.get_message_reaction("InVaLiD"), + helpers.MockMember(id=77), + "emoji is invalid", + ), + ) + + for *args, msg in subtests: + kwargs = dict(zip(("author", "message", "reaction", "user"), args)) + with self.subTest(**kwargs, msg=msg): + ret_val = self.syncer._reaction_check(*args) + self.assertFalse(ret_val) + + async def test_wait_for_confirmation(self): + """The message should always be edited and only return True if the emoji is a check mark.""" + subtests = ( + (constants.Emojis.check_mark, True, None), + ("InVaLiD", False, None), + (None, False, asyncio.TimeoutError), + ) + + for emoji, ret_val, side_effect in subtests: + for bot in (True, False): + with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): + # Set up mocks + message = helpers.MockMessage() + member = helpers.MockMember(bot=bot) + + self.bot.wait_for.reset_mock() + self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) + self.bot.wait_for.side_effect = side_effect + + # Call the function + actual_return = await self.syncer._wait_for_confirmation(member, message) + + # Perform assertions + self.bot.wait_for.assert_called_once() + self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + + message.edit.assert_called_once() + kwargs = message.edit.call_args[1] + self.assertIn("content", kwargs) + + # Core devs should only be mentioned if the author is a bot. + if bot: + self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + else: + self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + + self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for main function orchestrating the sync.""" + + def setUp(self): + self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) + self.syncer = TestSyncer(self.bot) + + async def test_sync_respects_confirmation_result(self): + """The sync should abort if confirmation fails and continue if confirmed.""" + mock_message = helpers.MockMessage() + subtests = ( + (True, mock_message), + (False, None), + ) + + for confirmed, message in subtests: + with self.subTest(confirmed=confirmed): + self.syncer._sync.reset_mock() + self.syncer._get_diff.reset_mock() + + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(confirmed, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + + if confirmed: + self.syncer._sync.assert_called_once_with(diff) + else: + self.syncer._sync.assert_not_called() + + async def test_sync_diff_size(self): + """The diff size should be correctly calculated.""" + subtests = ( + (6, _Diff({1, 2}, {3, 4}, {5, 6})), + (5, _Diff({1, 2, 3}, None, {4, 5})), + (0, _Diff(None, None, None)), + (0, _Diff(set(), set(), set())), + ) + + for size, diff in subtests: + with self.subTest(size=size, diff=diff): + self.syncer._get_diff.reset_mock() + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + + async def test_sync_message_edited(self): + """The message should be edited if one was sent, even if the sync has an API error.""" + subtests = ( + (None, None, False), + (helpers.MockMessage(), None, True), + (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), + ) + + for message, side_effect, should_edit in subtests: + with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): + self.syncer._sync.side_effect = side_effect + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(True, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + if should_edit: + message.edit.assert_called_once() + self.assertIn("content", message.edit.call_args[1]) + + async def test_sync_confirmation_context_redirect(self): + """If ctx is given, a new message should be sent and author should be ctx's author.""" + mock_member = helpers.MockMember() + subtests = ( + (None, self.bot.user, None), + (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + ) + + for ctx, author, message in subtests: + with self.subTest(ctx=ctx, author=author, message=message): + if ctx is not None: + ctx.send.return_value = message + + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock + self.syncer._get_diff.return_value = mock.MagicMock() + + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild, ctx) + + if ctx is not None: + ctx.send.assert_called_once() + + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_small_diff(self): + """Should always return True and the given message if the diff size is too small.""" + author = helpers.MockMember() + expected_message = helpers.MockMessage() + + for size in (3, 2): # pragma: no cover + with self.subTest(size=size): + self.syncer._send_prompt = mock.AsyncMock() + self.syncer._wait_for_confirmation = mock.AsyncMock() + + coro = self.syncer._get_confirmation_result(size, author, expected_message) + result, actual_message = await coro + + self.assertTrue(result) + self.assertEqual(actual_message, expected_message) + self.syncer._send_prompt.assert_not_called() + self.syncer._wait_for_confirmation.assert_not_called() + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_large_diff(self): + """Should return True if confirmed and False if _send_prompt fails or aborted.""" + author = helpers.MockMember() + mock_message = helpers.MockMessage() + + subtests = ( + (True, mock_message, True, "confirmed"), + (False, None, False, "_send_prompt failed"), + (False, mock_message, False, "aborted"), + ) + + for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover + with self.subTest(msg=msg): + self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) + + coro = self.syncer._get_confirmation_result(4, author) + actual_result, actual_message = await coro + + self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None + self.assertIs(actual_result, expected_result) + self.assertEqual(actual_message, expected_message) + + if expected_message: + self.syncer._wait_for_confirmation.assert_called_once_with( + author, expected_message + ) diff --git a/tests/bot/cogs/backend/sync/test_cog.py b/tests/bot/cogs/backend/sync/test_cog.py new file mode 100644 index 000000000..199747051 --- /dev/null +++ b/tests/bot/cogs/backend/sync/test_cog.py @@ -0,0 +1,415 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.backend import sync +from bot.cogs.backend.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): + """Tests for the sync extension.""" + + @staticmethod + def test_extension_setup(): + """The Sync cog should be added.""" + bot = helpers.MockBot() + sync.setup(bot) + bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): + """Base class for Sync cog tests. Sets up patches for syncers.""" + + def setUp(self): + self.bot = helpers.MockBot() + + self.role_syncer_patcher = mock.patch( + "bot.cogs.backend.sync.syncers.RoleSyncer", + autospec=Syncer, + spec_set=True + ) + self.user_syncer_patcher = mock.patch( + "bot.cogs.backend.sync.syncers.UserSyncer", + autospec=Syncer, + spec_set=True + ) + self.RoleSyncer = self.role_syncer_patcher.start() + self.UserSyncer = self.user_syncer_patcher.start() + + self.cog = sync.Sync(self.bot) + + def tearDown(self): + self.role_syncer_patcher.stop() + self.user_syncer_patcher.stop() + + @staticmethod + def response_error(status: int) -> ResponseCodeError: + """Fixture to return a ResponseCodeError with the given status code.""" + response = mock.MagicMock() + response.status = status + + return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): + """Tests for the Sync cog.""" + + @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) + def test_sync_cog_init(self, sync_guild): + """Should instantiate syncers and run a sync for the guild.""" + # Reset because a Sync cog was already instantiated in setUp. + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + self.bot.loop.create_task = mock.MagicMock() + + mock_sync_guild_coro = mock.MagicMock() + sync_guild.return_value = mock_sync_guild_coro + + sync.Sync(self.bot) + + self.RoleSyncer.assert_called_once_with(self.bot) + self.UserSyncer.assert_called_once_with(self.bot) + sync_guild.assert_called_once_with() + self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + + async def test_sync_cog_sync_guild(self): + """Roles and users should be synced only if a guild is successfully retrieved.""" + for guild in (helpers.MockGuild(), None): + with self.subTest(guild=guild): + self.bot.reset_mock() + self.cog.role_syncer.reset_mock() + self.cog.user_syncer.reset_mock() + + self.bot.get_guild = mock.MagicMock(return_value=guild) + + await self.cog.sync_guild() + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(constants.Guild.id) + + if guild is None: + self.cog.role_syncer.sync.assert_not_called() + self.cog.user_syncer.sync.assert_not_called() + else: + self.cog.role_syncer.sync.assert_called_once_with(guild) + self.cog.user_syncer.sync.assert_called_once_with(guild) + + async def patch_user_helper(self, side_effect: BaseException) -> None: + """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" + self.bot.api_client.patch.reset_mock(side_effect=True) + self.bot.api_client.patch.side_effect = side_effect + + user_id, updated_information = 5, {"key": 123} + await self.cog.patch_user(user_id, updated_information) + + self.bot.api_client.patch.assert_called_once_with( + f"bot/users/{user_id}", + json=updated_information, + ) + + async def test_sync_cog_patch_user(self): + """A PATCH request should be sent and 404 errors ignored.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + await self.patch_user_helper(side_effect) + + async def test_sync_cog_patch_user_non_404(self): + """A PATCH request should be sent and the error raised if it's not a 404.""" + with self.assertRaises(ResponseCodeError): + await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): + """Tests for the listeners of the Sync cog.""" + + def setUp(self): + super().setUp() + self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) + + self.guild_id_patcher = mock.patch("bot.cogs.backend.sync.cog.constants.Guild.id", 5) + self.guild_id = self.guild_id_patcher.start() + + self.guild = helpers.MockGuild(id=self.guild_id) + self.other_guild = helpers.MockGuild(id=0) + + def tearDown(self): + self.guild_id_patcher.stop() + + async def test_sync_cog_on_guild_role_create(self): + """A POST request should be sent with the new role's data.""" + self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + role = helpers.MockRole(**role_data, guild=self.guild) + await self.cog.on_guild_role_create(role) + + self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + + async def test_sync_cog_on_guild_role_create_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_create(role) + self.bot.api_client.post.assert_not_awaited() + + async def test_sync_cog_on_guild_role_delete(self): + """A DELETE request should be sent.""" + self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + + role = helpers.MockRole(id=99, guild=self.guild) + await self.cog.on_guild_role_delete(role) + + self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + + async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_delete(role) + self.bot.api_client.delete.assert_not_awaited() + + async def test_sync_cog_on_guild_role_update(self): + """A PUT request should be sent if the colour, name, permissions, or position changes.""" + self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + subtests = ( + (True, ("colour", "name", "permissions", "position")), + (False, ("hoist", "mentionable")), + ) + + for should_put, attributes in subtests: + for attribute in attributes: + with self.subTest(should_put=should_put, changed_attribute=attribute): + self.bot.api_client.put.reset_mock() + + after_role_data = role_data.copy() + after_role_data[attribute] = 876 + + before_role = helpers.MockRole(**role_data, guild=self.guild) + after_role = helpers.MockRole(**after_role_data, guild=self.guild) + + await self.cog.on_guild_role_update(before_role, after_role) + + if should_put: + self.bot.api_client.put.assert_called_once_with( + f"bot/roles/{after_role.id}", + json=after_role_data + ) + else: + self.bot.api_client.put.assert_not_called() + + async def test_sync_cog_on_guild_role_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_update(role, role) + self.bot.api_client.put.assert_not_awaited() + + async def test_sync_cog_on_member_remove(self): + """Member should be patched to set in_guild as False.""" + self.assertTrue(self.cog.on_member_remove.__cog_listener__) + + member = helpers.MockMember(guild=self.guild) + await self.cog.on_member_remove(member) + + self.cog.patch_user.assert_called_once_with( + member.id, + json={"in_guild": False} + ) + + async def test_sync_cog_on_member_remove_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_remove(member) + self.cog.patch_user.assert_not_awaited() + + async def test_sync_cog_on_member_update_roles(self): + """Members should be patched if their roles have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + # Roles are intentionally unsorted. + before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] + before_member = helpers.MockMember(roles=before_roles, guild=self.guild) + after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) + + await self.cog.on_member_update(before_member, after_member) + + data = {"roles": sorted(role.id for role in after_member.roles)} + self.cog.patch_user.assert_called_once_with(after_member.id, json=data) + + async def test_sync_cog_on_member_update_other(self): + """Members should not be patched if other attributes have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + subtests = ( + ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("nick", "old nick", "new nick"), + ("status", discord.Status.online, discord.Status.offline), + ) + + for attribute, old_value, new_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) + after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) + + await self.cog.on_member_update(before_member, after_member) + + self.cog.patch_user.assert_not_called() + + async def test_sync_cog_on_member_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_update(member, member) + self.cog.patch_user.assert_not_awaited() + + async def test_sync_cog_on_user_update(self): + """A user should be patched only if the name, discriminator, or avatar changes.""" + self.assertTrue(self.cog.on_user_update.__cog_listener__) + + before_data = { + "name": "old name", + "discriminator": "1234", + "bot": False, + } + + subtests = ( + (True, "name", "name", "new name", "new name"), + (True, "discriminator", "discriminator", "8765", 8765), + (False, "bot", "bot", True, True), + ) + + for should_patch, attribute, api_field, value, api_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + after_data = before_data.copy() + after_data[attribute] = value + before_user = helpers.MockUser(**before_data) + after_user = helpers.MockUser(**after_data) + + await self.cog.on_user_update(before_user, after_user) + + if should_patch: + self.cog.patch_user.assert_called_once() + + # Don't care if *all* keys are present; only the changed one is required + call_args = self.cog.patch_user.call_args + self.assertEqual(call_args.args[0], after_user.id) + self.assertIn("json", call_args.kwargs) + + self.assertIn("ignore_404", call_args.kwargs) + self.assertTrue(call_args.kwargs["ignore_404"]) + + json = call_args.kwargs["json"] + self.assertIn(api_field, json) + self.assertEqual(json[api_field], api_value) + else: + self.cog.patch_user.assert_not_called() + + async def on_member_join_helper(self, side_effect: Exception) -> dict: + """ + Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + + The request data for the mock member is returned. All exceptions will be re-raised. + """ + member = helpers.MockMember( + discriminator="1234", + roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + guild=self.guild, + ) + + data = { + "discriminator": int(member.discriminator), + "id": member.id, + "in_guild": True, + "name": member.name, + "roles": sorted(role.id for role in member.roles) + } + + self.bot.api_client.put.reset_mock(side_effect=True) + self.bot.api_client.put.side_effect = side_effect + + try: + await self.cog.on_member_join(member) + except Exception: + raise + finally: + self.bot.api_client.put.assert_called_once_with( + f"bot/users/{member.id}", + json=data + ) + + return data + + async def test_sync_cog_on_member_join(self): + """Should PUT user's data or POST it if the user doesn't exist.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + self.bot.api_client.post.reset_mock() + data = await self.on_member_join_helper(side_effect) + + if side_effect: + self.bot.api_client.post.assert_called_once_with("bot/users", json=data) + else: + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_non_404(self): + """ResponseCodeError should be re-raised if status code isn't a 404.""" + with self.assertRaises(ResponseCodeError): + await self.on_member_join_helper(self.response_error(500)) + + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_join(member) + self.bot.api_client.post.assert_not_awaited() + self.bot.api_client.put.assert_not_awaited() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): + """Tests for the commands in the Sync cog.""" + + async def test_sync_roles_command(self): + """sync() should be called on the RoleSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_roles_command.callback(self.cog, ctx) + + self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_sync_users_command(self): + """sync() should be called on the UserSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_users_command.callback(self.cog, ctx) + + self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_commands_require_admin(self): + """The sync commands should only run if the author has the administrator permission.""" + cmds = ( + self.cog.sync_group, + self.cog.sync_roles_command, + self.cog.sync_users_command, + ) + + for cmd in cmds: + with self.subTest(cmd=cmd): + await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/backend/sync/test_roles.py b/tests/bot/cogs/backend/sync/test_roles.py new file mode 100644 index 000000000..cc2e51c7f --- /dev/null +++ b/tests/bot/cogs/backend/sync/test_roles.py @@ -0,0 +1,157 @@ +import unittest +from unittest import mock + +import discord + +from bot.cogs.backend.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): + """Fixture to return a dictionary representing a role with default values set.""" + kwargs.setdefault("id", 9) + kwargs.setdefault("name", "fake role") + kwargs.setdefault("colour", 7) + kwargs.setdefault("permissions", 0) + kwargs.setdefault("position", 55) + + return kwargs + + +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @staticmethod + def get_guild(*roles): + """Fixture to return a guild object with the given roles.""" + guild = helpers.MockGuild() + guild.roles = [] + + for role in roles: + mock_role = helpers.MockRole(**role) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) + guild.roles.append(mock_role) + + return guild + + async def test_empty_diff_for_identical_roles(self): + """No differences should be found if the roles in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_roles(self): + """Only updated roles should be added to the 'updated' set of the diff.""" + updated_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] + guild = self.get_guild(updated_role, fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_Role(**updated_role)}, set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_roles(self): + """Only new roles should be added to the 'created' set of the diff.""" + new_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role(), new_role) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new_role)}, set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_deleted_roles(self): + """Only deleted roles should be added to the 'deleted' set of the diff.""" + deleted_role = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [fake_role(), deleted_role] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), {_Role(**deleted_role)}) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_deleted_roles(self): + """When roles are added, updated, and removed, all of them are returned properly.""" + new = fake_role(id=41, name="new") + updated = fake_role(id=71, name="updated") + deleted = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [ + fake_role(), + fake_role(id=71, name="updated name"), + deleted, + ] + guild = self.get_guild(fake_role(), new, updated) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + + self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync roles.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + async def test_sync_created_roles(self): + """Only POST requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(role_tuples, set(), set()) + await self.syncer._sync(diff) + + calls = [mock.call("bot/roles", json=role) for role in roles] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_roles(self): + """Only PUT requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), role_tuples, set()) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_deleted_roles(self): + """Only DELETE requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), set(), role_tuples) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] + self.bot.api_client.delete.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/backend/sync/test_users.py b/tests/bot/cogs/backend/sync/test_users.py new file mode 100644 index 000000000..490ea9e06 --- /dev/null +++ b/tests/bot/cogs/backend/sync/test_users.py @@ -0,0 +1,158 @@ +import unittest +from unittest import mock + +from bot.cogs.backend.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers + + +def fake_user(**kwargs): + """Fixture to return a dictionary representing a user with default values set.""" + kwargs.setdefault("id", 43) + kwargs.setdefault("name", "bob the test man") + kwargs.setdefault("discriminator", 1337) + kwargs.setdefault("roles", (666,)) + kwargs.setdefault("in_guild", True) + + return kwargs + + +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between users in the DB and users in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @staticmethod + def get_guild(*members): + """Fixture to return a guild object with the given members.""" + guild = helpers.MockGuild() + guild.members = [] + + for member in members: + member = member.copy() + del member["in_guild"] + + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + + guild.members.append(mock_member) + + return guild + + async def test_empty_diff_for_no_users(self): + """When no users are given, an empty diff should be returned.""" + guild = self.get_guild() + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_identical_users(self): + """No differences should be found if the users in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_users(self): + """Only updated users should be added to the 'updated' set of the diff.""" + updated_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + guild = self.get_guild(updated_user, fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**updated_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_users(self): + """Only new users should be added to the 'created' set of the diff.""" + new_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user(), new_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_sets_in_guild_false_for_leaving_users(self): + """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_leaving_users(self): + """When users are added, updated, and removed, all of them are returned properly.""" + new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + guild = self.get_guild(fake_user(), new_user, updated_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_db_users_not_in_guild(self): + """When the DB knows a user the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync users.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + async def test_sync_created_users(self): + """Only POST requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(user_tuples, set(), None) + await self.syncer._sync(diff) + + calls = [mock.call("bot/users", json=user) for user in users] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(users)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_users(self): + """Only PUT requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(set(), user_tuples, None) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(users)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/backend/test_logging.py b/tests/bot/cogs/backend/test_logging.py new file mode 100644 index 000000000..c867773e2 --- /dev/null +++ b/tests/bot/cogs/backend/test_logging.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import patch + +from bot import constants +from bot.cogs.backend.logging import Logging +from tests.helpers import MockBot, MockTextChannel + + +class LoggingTests(unittest.IsolatedAsyncioTestCase): + """Test cases for connected login.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Logging(self.bot) + self.dev_log = MockTextChannel(id=1234, name="dev-log") + + @patch("bot.cogs.backend.logging.DEBUG_MODE", False) + async def test_debug_mode_false(self): + """Should send connected message to dev-log.""" + self.bot.get_channel.return_value = self.dev_log + + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) + self.dev_log.send.assert_awaited_once() + + @patch("bot.cogs.backend.logging.DEBUG_MODE", True) + async def test_debug_mode_true(self): + """Should not send anything to dev-log.""" + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_not_called() diff --git a/tests/bot/cogs/filters/__init__.py b/tests/bot/cogs/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/filters/test_antimalware.py b/tests/bot/cogs/filters/test_antimalware.py new file mode 100644 index 000000000..b00211f47 --- /dev/null +++ b/tests/bot/cogs/filters/test_antimalware.py @@ -0,0 +1,165 @@ +import unittest +from unittest.mock import AsyncMock, Mock + +from discord import NotFound + +from bot.cogs.filters import antimalware +from bot.constants import Channels, STAFF_ROLES +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + + +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): + """Test the AntiMalware cog.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = MockBot() + self.bot.filter_list_cache = { + "FILE_FORMAT.True": { + ".first": {}, + ".second": {}, + ".third": {}, + } + } + self.cog = antimalware.AntiMalware(self.bot) + self.message = MockMessage() + self.whitelist = [".first", ".second", ".third"] + + async def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should not be deleted""" + attachment = MockAttachment(filename="python.first") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_message_without_attachment(self): + """Messages without attachments should result in no action.""" + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_direct_message_with_attachment(self): + """Direct messages should have no action taken.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.guild = None + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_message_with_illegal_extension_gets_deleted(self): + """A message containing an illegal extension should send an embed.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_called_once() + + async def test_message_send_by_staff(self): + """A message send by a member of staff should be ignored.""" + staff_role = MockRole(id=STAFF_ROLES[0]) + self.message.author.roles.append(staff_role) + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_python_file_redirect_embed_description(self): + """A message containing a .py file should result in an embed redirecting the user to our paste site""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + + self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt file should result in the correct embed.""" + attachment = MockAttachment(filename="python.txt") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.TXT_EMBED_DESCRIPTION = Mock() + antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + cmd_channel = self.bot.get_channel(Channels.bot_commands) + + self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) + antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + + async def test_other_disallowed_extension_embed_description(self): + """Test the description for a non .py/.txt disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + meta_channel = self.bot.get_channel(Channels.meta) + + self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + joined_whitelist=", ".join(self.whitelist), + blocked_extensions_str=".disallowed", + meta_channel_mention=meta_channel.mention + ) + + async def test_removing_deleted_message_logs(self): + """Removing an already deleted message logs the correct message""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + self.message.delete.assert_called_once() + + async def test_message_with_illegal_attachment_logs(self): + """Deleting a message with an illegal attachment should result in a log.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + ([], []), + (self.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], [".disallowed"]), + ([".disallowed"], [".disallowed"]), + ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] + disallowed_extensions = self.cog._get_disallowed_extensions(self.message) + self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): + """Tests setup of the `AntiMalware` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + antimalware.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/filters/test_antispam.py b/tests/bot/cogs/filters/test_antispam.py new file mode 100644 index 000000000..8a3d8d02e --- /dev/null +++ b/tests/bot/cogs/filters/test_antispam.py @@ -0,0 +1,35 @@ +import unittest + +from bot.cogs.filters import antispam + + +class AntispamConfigurationValidationTests(unittest.TestCase): + """Tests validation of the antispam cog configuration.""" + + def test_default_antispam_config_is_valid(self): + """The default antispam configuration is valid.""" + validation_errors = antispam.validate_config() + self.assertEqual(validation_errors, {}) + + def test_unknown_rule_returns_error(self): + """Configuring an unknown rule returns an error.""" + self.assertEqual( + antispam.validate_config({'invalid-rule': {}}), + {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} + ) + + def test_missing_keys_returns_error(self): + """Not configuring required keys returns an error.""" + keys = (('interval', 'max'), ('max', 'interval')) + for configured_key, unconfigured_key in keys: + with self.subTest( + configured_key=configured_key, + unconfigured_key=unconfigured_key + ): + config = {'burst': {configured_key: 10}} + error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" + + self.assertEqual( + antispam.validate_config(config), + {'burst': error} + ) diff --git a/tests/bot/cogs/filters/test_security.py b/tests/bot/cogs/filters/test_security.py new file mode 100644 index 000000000..82679f69c --- /dev/null +++ b/tests/bot/cogs/filters/test_security.py @@ -0,0 +1,54 @@ +import unittest +from unittest.mock import MagicMock + +from discord.ext.commands import NoPrivateMessage + +from bot.cogs.filters import security +from tests.helpers import MockBot, MockContext + + +class SecurityCogTests(unittest.TestCase): + """Tests the `Security` cog.""" + + def setUp(self): + """Attach an instance of the cog to the class for tests.""" + self.bot = MockBot() + self.cog = security.Security(self.bot) + self.ctx = MockContext() + + def test_check_additions(self): + """The cog should add its checks after initialization.""" + self.bot.check.assert_any_call(self.cog.check_on_guild) + self.bot.check.assert_any_call(self.cog.check_not_bot) + + def test_check_not_bot_returns_false_for_humans(self): + """The bot check should return `True` when invoked with human authors.""" + self.ctx.author.bot = False + self.assertTrue(self.cog.check_not_bot(self.ctx)) + + def test_check_not_bot_returns_true_for_robots(self): + """The bot check should return `False` when invoked with robotic authors.""" + self.ctx.author.bot = True + self.assertFalse(self.cog.check_not_bot(self.ctx)) + + def test_check_on_guild_raises_when_outside_of_guild(self): + """When invoked outside of a guild, `check_on_guild` should cause an error.""" + self.ctx.guild = None + + with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): + self.cog.check_on_guild(self.ctx) + + def test_check_on_guild_returns_true_inside_of_guild(self): + """When invoked inside of a guild, `check_on_guild` should return `True`.""" + self.ctx.guild = "lemon's lemonade stand" + self.assertTrue(self.cog.check_on_guild(self.ctx)) + + +class SecurityCogLoadTests(unittest.TestCase): + """Tests loading the `Security` cog.""" + + def test_security_cog_load(self): + """Setup of the extension should call add_cog.""" + bot = MagicMock() + security.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/filters/test_token_remover.py b/tests/bot/cogs/filters/test_token_remover.py new file mode 100644 index 000000000..5c527ed94 --- /dev/null +++ b/tests/bot/cogs/filters/test_token_remover.py @@ -0,0 +1,310 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock + +from discord import Colour, NotFound + +from bot import constants +from bot.cogs.filters import token_remover +from bot.cogs.filters.token_remover import Token, TokenRemover +from bot.cogs.moderation import ModLog +from tests.helpers import MockBot, MockMessage, autospec + + +class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): + """Tests the `TokenRemover` cog.""" + + def setUp(self): + """Adds the cog, a bot, and a message to the instance for usage in tests.""" + self.bot = MockBot() + self.cog = TokenRemover(bot=self.bot) + + self.msg = MockMessage(id=555, content="hello world") + self.msg.channel.mention = "#lemonade-stand" + self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + self.msg.author.avatar_url_as.return_value = "picture-lemon.png" + + def test_is_valid_user_id_valid(self): + """Should consider user IDs valid if they decode entirely to ASCII digits.""" + ids = ( + "NDcyMjY1OTQzMDYyNDEzMzMy", + "NDc1MDczNjI5Mzk5NTQ3OTA0", + "NDY3MjIzMjMwNjUwNzc3NjQx", + ) + + for user_id in ids: + with self.subTest(user_id=user_id): + result = TokenRemover.is_valid_user_id(user_id) + self.assertTrue(result) + + def test_is_valid_user_id_invalid(self): + """Should consider non-digit and non-ASCII IDs invalid.""" + ids = ( + ("SGVsbG8gd29ybGQ", "non-digit ASCII"), + ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), + ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), + ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), + ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for user_id, msg in ids: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_user_id(user_id) + self.assertFalse(result) + + def test_is_valid_timestamp_valid(self): + """Should consider timestamps valid if they're greater than the Discord epoch.""" + timestamps = ( + "XsyRkw", + "Xrim9Q", + "XsyR-w", + "XsySD_", + "Dn9r_A", + ) + + for timestamp in timestamps: + with self.subTest(timestamp=timestamp): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertTrue(result) + + def test_is_valid_timestamp_invalid(self): + """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" + timestamps = ( + ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), + ("ew", "123"), + ("AoIKgA", "42076800"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for timestamp, msg in timestamps: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertFalse(result) + + def test_mod_log_property(self): + """The `mod_log` property should ask the bot to return the `ModLog` cog.""" + self.bot.get_cog.return_value = 'lemon' + self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) + self.bot.get_cog.assert_called_once_with('ModLog') + + async def test_on_message_edit_uses_on_message(self): + """The edit listener should delegate handling of the message to the normal listener.""" + self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) + + await self.cog.on_message_edit(MockMessage(), self.msg) + self.cog.on_message.assert_awaited_once_with(self.msg) + + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_takes_action(self, find_token_in_message, take_action): + """Should take action if a valid token is found when a message is sent.""" + cog = TokenRemover(self.bot) + found_token = "foobar" + find_token_in_message.return_value = found_token + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_awaited_once_with(cog, self.msg, found_token) + + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): + """Shouldn't take action if a valid token isn't found when a message is sent.""" + cog = TokenRemover(self.bot) + find_token_in_message.return_value = False + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_not_awaited() + + @autospec(TokenRemover, "find_token_in_message") + async def test_on_message_ignores_dms_bots(self, find_token_in_message): + """Shouldn't parse a message if it is a DM or authored by a bot.""" + cog = TokenRemover(self.bot) + dm_msg = MockMessage(guild=None) + bot_msg = MockMessage(author=MagicMock(bot=True)) + + for msg in (dm_msg, bot_msg): + await cog.on_message(msg) + find_token_in_message.assert_not_called() + + @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") + def test_find_token_no_matches(self, token_re): + """None should be returned if the regex matches no tokens in a message.""" + token_re.finditer.return_value = () + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.cogs.filters.token_remover", "Token") + @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") + def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """The first match with a valid user ID and timestamp should be returned as a `Token`.""" + matches = [ + mock.create_autospec(Match, spec_set=True, instance=True), + mock.create_autospec(Match, spec_set=True, instance=True), + ] + tokens = [ + mock.create_autospec(Token, spec_set=True, instance=True), + mock.create_autospec(Token, spec_set=True, instance=True), + ] + + token_re.finditer.return_value = matches + token_cls.side_effect = tokens + is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid. + is_valid_timestamp.return_value = True + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertEqual(tokens[1], return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.cogs.filters.token_remover", "Token") + @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") + def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """None should be returned if no matches have valid user IDs or timestamps.""" + token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] + token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) + is_valid_id.return_value = False + is_valid_timestamp.return_value = False + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + def test_regex_invalid_tokens(self): + """Messages without anything looking like a token are not matched.""" + tokens = ( + "", + "lemon wins", + "..", + "x.y", + "x.y.", + ".y.z", + ".y.", + "..z", + "x..z", + " . . ", + "\n.\n.\n", + "hellö.world.bye", + "base64.nötbåse64.morebase64", + "19jd3J.dfkm3d.€víł§tüff", + ) + + for token in tokens: + with self.subTest(token=token): + results = token_remover.TOKEN_RE.findall(token) + self.assertEqual(len(results), 0) + + def test_regex_valid_tokens(self): + """Messages that look like tokens should be matched.""" + # Don't worry, these tokens have been invalidated. + tokens = ( + "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", + "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", + "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", + ) + + for token in tokens: + with self.subTest(token=token): + results = token_remover.TOKEN_RE.fullmatch(token) + self.assertIsNotNone(results, f"{token} was not matched by the regex") + + def test_regex_matches_multiple_valid(self): + """Should support multiple matches in the middle of a string.""" + token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" + token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" + message = f"garbage {token_1} hello {token_2} world" + + results = token_remover.TOKEN_RE.finditer(message) + results = [match[0] for match in results] + self.assertCountEqual((token_1, token_2), results) + + @autospec("bot.cogs.filters.token_remover", "LOG_MESSAGE") + def test_format_log_message(self, log_message): + """Should correctly format the log message with info from the message and token.""" + token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + log_message.format.return_value = "Howdy" + + return_value = TokenRemover.format_log_message(self.msg, token) + + self.assertEqual(return_value, log_message.format.return_value) + log_message.format.assert_called_once_with( + author=self.msg.author, + author_id=self.msg.author.id, + channel=self.msg.channel.mention, + user_id=token.user_id, + timestamp=token.timestamp, + hmac="x" * len(token.hmac), + ) + + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + @autospec("bot.cogs.filters.token_remover", "log") + @autospec(TokenRemover, "format_log_message") + async def test_take_action(self, format_log_message, logger, mod_log_property): + """Should delete the message and send a mod log.""" + cog = TokenRemover(self.bot) + mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) + token = mock.create_autospec(Token, spec_set=True, instance=True) + log_msg = "testing123" + + mod_log_property.return_value = mod_log + format_log_message.return_value = log_msg + + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_called_once_with( + token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + ) + + format_log_message.assert_called_once_with(self.msg, token) + logger.debug.assert_called_with(log_msg) + self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") + + mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) + mod_log.send_log_message.assert_called_once_with( + icon_url=constants.Icons.token_removed, + colour=Colour(constants.Colours.soft_red), + title="Token removed!", + text=log_msg, + thumbnail=self.msg.author.avatar_url_as.return_value, + channel_id=constants.Channels.mod_alerts + ) + + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + async def test_take_action_delete_failure(self, mod_log_property): + """Shouldn't send any messages if the token message can't be deleted.""" + cog = TokenRemover(self.bot) + mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) + self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) + + token = mock.create_autospec(Token, spec_set=True, instance=True) + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_not_awaited() + + +class TokenRemoverExtensionTests(unittest.TestCase): + """Tests for the token_remover extension.""" + + @autospec("bot.cogs.filters.token_remover", "TokenRemover") + def test_extension_setup(self, cog): + """The TokenRemover cog should be added.""" + bot = MockBot() + token_remover.setup(bot) + + cog.assert_called_once_with(bot) + bot.add_cog.assert_called_once() + self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/cogs/info/__init__.py b/tests/bot/cogs/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/info/test_information.py b/tests/bot/cogs/info/test_information.py new file mode 100644 index 000000000..895a8328e --- /dev/null +++ b/tests/bot/cogs/info/test_information.py @@ -0,0 +1,584 @@ +import asyncio +import textwrap +import unittest +import unittest.mock + +import discord + +from bot import constants +from bot.cogs.info import information +from bot.utils.checks import InWhitelistCheckFailure +from tests import helpers + +COG_PATH = "bot.cogs.info.information.Information" + + +class InformationCogTests(unittest.TestCase): + """Tests the Information cog.""" + + @classmethod + def setUpClass(cls): + cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = helpers.MockBot() + + self.cog = information.Information(self.bot) + + self.ctx = helpers.MockContext() + self.ctx.author.roles.append(self.moderator_role) + + def test_roles_command_command(self): + """Test if the `role_info` command correctly returns the `moderator_role`.""" + self.ctx.guild.roles.append(self.moderator_role) + + self.cog.roles_info.can_run = unittest.mock.AsyncMock() + self.cog.roles_info.can_run.return_value = True + + coroutine = self.cog.roles_info.callback(self.cog, self.ctx) + + self.assertIsNone(asyncio.run(coroutine)) + self.ctx.send.assert_called_once() + + _, kwargs = self.ctx.send.call_args + embed = kwargs.pop('embed') + + self.assertEqual(embed.title, "Role information (Total 1 role)") + self.assertEqual(embed.colour, discord.Colour.blurple()) + self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") + + def test_role_info_command(self): + """Tests the `role info` command.""" + dummy_role = helpers.MockRole( + name="Dummy", + id=112233445566778899, + colour=discord.Colour.blurple(), + position=10, + members=[self.ctx.author], + permissions=discord.Permissions(0) + ) + + admin_role = helpers.MockRole( + name="Admins", + id=998877665544332211, + colour=discord.Colour.red(), + position=3, + members=[self.ctx.author], + permissions=discord.Permissions(0), + ) + + self.ctx.guild.roles.append([dummy_role, admin_role]) + + self.cog.role_info.can_run = unittest.mock.AsyncMock() + self.cog.role_info.can_run.return_value = True + + coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) + + self.assertIsNone(asyncio.run(coroutine)) + + self.assertEqual(self.ctx.send.call_count, 2) + + (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list + + dummy_embed = dummy_kwargs["embed"] + admin_embed = admin_kwargs["embed"] + + self.assertEqual(dummy_embed.title, "Dummy info") + self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) + + self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) + self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") + self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") + self.assertEqual(dummy_embed.fields[3].value, "1") + self.assertEqual(dummy_embed.fields[4].value, "10") + self.assertEqual(dummy_embed.fields[5].value, "0") + + self.assertEqual(admin_embed.title, "Admins info") + self.assertEqual(admin_embed.colour, discord.Colour.red()) + + @unittest.mock.patch('bot.cogs.info.information.time_since') + def test_server_info_command(self, time_since_patch): + time_since_patch.return_value = '2 days ago' + + self.ctx.guild = helpers.MockGuild( + features=('lemons', 'apples'), + region="The Moon", + roles=[self.moderator_role], + channels=[ + discord.TextChannel( + state={}, + guild=self.ctx.guild, + data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} + ), + discord.CategoryChannel( + state={}, + guild=self.ctx.guild, + data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} + ), + discord.VoiceChannel( + state={}, + guild=self.ctx.guild, + data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} + ) + ], + members=[ + *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), + *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), + *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), + *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), + ], + member_count=1_234, + icon_url='a-lemon.jpg', + ) + + coroutine = self.cog.server_info.callback(self.cog, self.ctx) + self.assertIsNone(asyncio.run(coroutine)) + + time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') + _, kwargs = self.ctx.send.call_args + embed = kwargs.pop('embed') + self.assertEqual(embed.colour, discord.Colour.blurple()) + self.assertEqual( + embed.description, + textwrap.dedent( + f""" + **Server information** + Created: {time_since_patch.return_value} + Voice region: {self.ctx.guild.region} + Features: {', '.join(self.ctx.guild.features)} + + **Channel counts** + Category channels: 1 + Text channels: 1 + Voice channels: 1 + Staff channels: 0 + + **Member counts** + Members: {self.ctx.guild.member_count:,} + Staff members: 0 + Roles: {len(self.ctx.guild.roles)} + + **Member statuses** + {constants.Emojis.status_online} 2 + {constants.Emojis.status_idle} 1 + {constants.Emojis.status_dnd} 4 + {constants.Emojis.status_offline} 3 + """ + ) + ) + self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') + + +class UserInfractionHelperMethodTests(unittest.TestCase): + """Tests for the helper methods of the `!user` command.""" + + def setUp(self): + """Common set-up steps done before for each test.""" + self.bot = helpers.MockBot() + self.bot.api_client.get = unittest.mock.AsyncMock() + self.cog = information.Information(self.bot) + self.member = helpers.MockMember(id=1234) + + def test_user_command_helper_method_get_requests(self): + """The helper methods should form the correct get requests.""" + test_values = ( + { + "helper_method": self.cog.basic_user_infraction_counts, + "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}), + }, + { + "helper_method": self.cog.expanded_user_infraction_counts, + "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}), + }, + { + "helper_method": self.cog.user_nomination_counts, + "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}), + }, + ) + + for test_value in test_values: + helper_method = test_value["helper_method"] + endpoint, params = test_value["expected_args"] + + with self.subTest(method=helper_method, endpoint=endpoint, params=params): + asyncio.run(helper_method(self.member)) + self.bot.api_client.get.assert_called_once_with(endpoint, params=params) + self.bot.api_client.get.reset_mock() + + def _method_subtests(self, method, test_values, default_header): + """Helper method that runs the subtests for the different helper methods.""" + for test_value in test_values: + api_response = test_value["api response"] + expected_lines = test_value["expected_lines"] + + with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): + self.bot.api_client.get.return_value = api_response + + expected_output = "\n".join(default_header + expected_lines) + actual_output = asyncio.run(method(self.member)) + + self.assertEqual(expected_output, actual_output) + + def test_basic_user_infraction_counts_returns_correct_strings(self): + """The method should correctly list both the total and active number of non-hidden infractions.""" + test_values = ( + # No infractions means zero counts + { + "api response": [], + "expected_lines": ["Total: 0", "Active: 0"], + }, + # Simple, single-infraction dictionaries + { + "api response": [{"type": "ban", "active": True}], + "expected_lines": ["Total: 1", "Active: 1"], + }, + { + "api response": [{"type": "ban", "active": False}], + "expected_lines": ["Total: 1", "Active: 0"], + }, + # Multiple infractions with various `active` status + { + "api response": [ + {"type": "ban", "active": True}, + {"type": "kick", "active": False}, + {"type": "ban", "active": True}, + {"type": "ban", "active": False}, + ], + "expected_lines": ["Total: 4", "Active: 2"], + }, + ) + + header = ["**Infractions**"] + + self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) + + def test_expanded_user_infraction_counts_returns_correct_strings(self): + """The method should correctly list the total and active number of all infractions split by infraction type.""" + test_values = ( + { + "api response": [], + "expected_lines": ["This user has never received an infraction."], + }, + # Shows non-hidden inactive infraction as expected + { + "api response": [{"type": "kick", "active": False, "hidden": False}], + "expected_lines": ["Kicks: 1"], + }, + # Shows non-hidden active infraction as expected + { + "api response": [{"type": "mute", "active": True, "hidden": False}], + "expected_lines": ["Mutes: 1 (1 active)"], + }, + # Shows hidden inactive infraction as expected + { + "api response": [{"type": "superstar", "active": False, "hidden": True}], + "expected_lines": ["Superstars: 1"], + }, + # Shows hidden active infraction as expected + { + "api response": [{"type": "ban", "active": True, "hidden": True}], + "expected_lines": ["Bans: 1 (1 active)"], + }, + # Correctly displays tally of multiple infractions of mixed properties in alphabetical order + { + "api response": [ + {"type": "kick", "active": False, "hidden": True}, + {"type": "ban", "active": True, "hidden": True}, + {"type": "superstar", "active": True, "hidden": True}, + {"type": "mute", "active": True, "hidden": True}, + {"type": "ban", "active": False, "hidden": False}, + {"type": "note", "active": False, "hidden": True}, + {"type": "note", "active": False, "hidden": True}, + {"type": "warn", "active": False, "hidden": False}, + {"type": "note", "active": False, "hidden": True}, + ], + "expected_lines": [ + "Bans: 2 (1 active)", + "Kicks: 1", + "Mutes: 1 (1 active)", + "Notes: 3", + "Superstars: 1 (1 active)", + "Warns: 1", + ], + }, + ) + + header = ["**Infractions**"] + + self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) + + def test_user_nomination_counts_returns_correct_strings(self): + """The method should list the number of active and historical nominations for the user.""" + test_values = ( + { + "api response": [], + "expected_lines": ["This user has never been nominated."], + }, + { + "api response": [{'active': True}], + "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], + }, + { + "api response": [{'active': True}, {'active': False}], + "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], + }, + { + "api response": [{'active': False}], + "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."], + }, + { + "api response": [{'active': False}, {'active': False}], + "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."], + }, + + ) + + header = ["**Nominations**"] + + self._method_subtests(self.cog.user_nomination_counts, test_values, header) + + +@unittest.mock.patch("bot.cogs.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) +@unittest.mock.patch("bot.cogs.info.information.constants.MODERATION_CHANNELS", new=[50]) +class UserEmbedTests(unittest.TestCase): + """Tests for the creation of the `!user` embed.""" + + def setUp(self): + """Common set-up steps done before for each test.""" + self.bot = helpers.MockBot() + self.bot.api_client.get = unittest.mock.AsyncMock() + self.cog = information.Information(self.bot) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): + """The embed should use the string representation of the user if they don't have a nick.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + user = helpers.MockMember() + user.nick = None + user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.title, "Mr. Hemlock") + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_nick_in_title_if_available(self): + """The embed should use the nick if it's available.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + user = helpers.MockMember() + user.nick = "Cat lover" + user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_ignores_everyone_role(self): + """Created `!user` embeds should not contain mention of the @everyone-role.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + admins_role = helpers.MockRole(name='Admins') + admins_role.colour = 100 + + # A `MockMember` has the @Everyone role by default; we add the Admins to that. + user = helpers.MockMember(roles=[admins_role], top_role=admins_role) + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertIn("&Admins", embed.description) + self.assertNotIn("&Everyone", embed.description) + + @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) + def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): + """The embed should contain expanded infractions and nomination info in mod channels.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + infraction_counts.return_value = "expanded infractions info" + nomination_counts.return_value = "nomination info" + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + infraction_counts.assert_called_once_with(user) + nomination_counts.assert_called_once_with(user) + + self.assertEqual( + textwrap.dedent(f""" + **User Information** + Created: {"1 year ago"} + Profile: {user.mention} + ID: {user.id} + + **Member Information** + Joined: {"1 year ago"} + Roles: &Moderators + + expanded infractions info + + nomination info + """).strip(), + embed.description + ) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): + """The embed should contain only basic infraction data outside of mod channels.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + infraction_counts.return_value = "basic infractions info" + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + infraction_counts.assert_called_once_with(user) + + self.assertEqual( + textwrap.dedent(f""" + **User Information** + Created: {"1 year ago"} + Profile: {user.mention} + ID: {user.id} + + **Member Information** + Joined: {"1 year ago"} + Roles: &Moderators + + basic infractions info + """).strip(), + embed.description + ) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): + """The embed should be created with the colour of the top role, if a top role is available.""" + ctx = helpers.MockContext() + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): + """The embed should be created with a blurple colour if the user has no assigned roles.""" + ctx = helpers.MockContext() + + user = helpers.MockMember(id=217) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.colour, discord.Colour.blurple()) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): + """The embed thumbnail should be set to the user's avatar in `png` format.""" + ctx = helpers.MockContext() + + user = helpers.MockMember(id=217) + user.avatar_url_as.return_value = "avatar url" + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + user.avatar_url_as.assert_called_once_with(static_format="png") + self.assertEqual(embed.thumbnail.url, "avatar url") + + +@unittest.mock.patch("bot.cogs.info.information.constants") +class UserCommandTests(unittest.TestCase): + """Tests for the `!user` command.""" + + def setUp(self): + """Set up steps executed before each test is run.""" + self.bot = helpers.MockBot() + self.cog = information.Information(self.bot) + + self.moderator_role = helpers.MockRole(name="Moderators", id=2, position=10) + self.flautist_role = helpers.MockRole(name="Flautists", id=3, position=2) + self.bassist_role = helpers.MockRole(name="Bassists", id=4, position=3) + + self.author = helpers.MockMember(id=1, name="syntaxaire") + self.moderator = helpers.MockMember(id=2, name="riffautae", roles=[self.moderator_role]) + self.target = helpers.MockMember(id=3, name="__fluzz__") + + def test_regular_member_cannot_target_another_member(self, constants): + """A regular user should not be able to use `!user` targeting another user.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + + ctx = helpers.MockContext(author=self.author) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + + ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") + + def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): + """A regular user should not be able to use this command outside of bot-commands.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) + + msg = "Sorry, but you may only use this command within <#50>." + with self.assertRaises(InWhitelistCheckFailure, msg=msg): + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") + def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): + """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + create_embed.assert_called_once_with(ctx, self.author) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") + def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): + """A user should target itself with `!user` when a `user` argument was not provided.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) + + create_embed.assert_called_once_with(ctx, self.author) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") + def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): + """Staff members should be able to bypass the bot-commands channel restriction.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + create_embed.assert_called_once_with(ctx, self.moderator) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") + def test_moderators_can_target_another_member(self, create_embed, constants): + """A moderator should be able to use `!user` targeting another user.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + constants.STAFF_ROLES = [self.moderator_role.id] + + ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + + create_embed.assert_called_once_with(ctx, self.target) + ctx.send.assert_called_once() diff --git a/tests/bot/cogs/moderation/infraction/__init__.py b/tests/bot/cogs/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/moderation/infraction/test_infractions.py b/tests/bot/cogs/moderation/infraction/test_infractions.py new file mode 100644 index 000000000..a79042557 --- /dev/null +++ b/tests/bot/cogs/moderation/infraction/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.cogs.moderation.infraction.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): + """Tests for ban and kick command reason truncation.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Infractions(self.bot) + self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) + self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) + self.guild = MockGuild(id=4567) + self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + + @patch("bot.cogs.moderation.infraction.utils.get_active_infraction") + @patch("bot.cogs.moderation.infraction.utils.post_infraction") + async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): + """Should truncate reason for `ctx.guild.ban`.""" + get_active_mock.return_value = None + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.bot.get_cog.return_value = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.ctx.guild.ban = Mock() + + await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) + self.ctx.guild.ban.assert_called_once_with( + self.target, + reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), + delete_message_days=0 + ) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value + ) + + @patch("bot.cogs.moderation.infraction.utils.post_infraction") + async def test_apply_kick_reason_truncation(self, post_infraction_mock): + """Should truncate reason for `Member.kick`.""" + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.target.kick = Mock() + + await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) + self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value + ) diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 435a1cd51..5e4d90251 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, call, patch import aiohttp import discord -from bot.cogs.moderation import Incidents, incidents +from bot.cogs.moderation import incidents from bot.constants import Colours from tests.helpers import ( MockAsyncWebhook, @@ -290,7 +290,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): Note that this will not schedule `crawl_incidents` in the background, as everything is being mocked. The `crawl_task` attribute will end up being None. """ - self.cog_instance = Incidents(MockBot()) + self.cog_instance = incidents.Incidents(MockBot()) @patch("asyncio.sleep", AsyncMock()) # Prevent the coro from sleeping to speed up the test diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py deleted file mode 100644 index df38090fb..000000000 --- a/tests/bot/cogs/moderation/test_infractions.py +++ /dev/null @@ -1,55 +0,0 @@ -import textwrap -import unittest -from unittest.mock import AsyncMock, Mock, patch - -from bot.cogs.moderation.infraction.infractions import Infractions -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole - - -class TruncationTests(unittest.IsolatedAsyncioTestCase): - """Tests for ban and kick command reason truncation.""" - - def setUp(self): - self.bot = MockBot() - self.cog = Infractions(self.bot) - self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) - self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) - self.guild = MockGuild(id=4567) - self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) - - @patch("bot.cogs.moderation.utils.get_active_infraction") - @patch("bot.cogs.moderation.utils.post_infraction") - async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): - """Should truncate reason for `ctx.guild.ban`.""" - get_active_mock.return_value = None - post_infraction_mock.return_value = {"foo": "bar"} - - self.cog.apply_infraction = AsyncMock() - self.bot.get_cog.return_value = AsyncMock() - self.cog.mod_log.ignore = Mock() - self.ctx.guild.ban = Mock() - - await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) - self.ctx.guild.ban.assert_called_once_with( - self.target, - reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), - delete_message_days=0 - ) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value - ) - - @patch("bot.cogs.moderation.utils.post_infraction") - async def test_apply_kick_reason_truncation(self, post_infraction_mock): - """Should truncate reason for `Member.kick`.""" - post_infraction_mock.return_value = {"foo": "bar"} - - self.cog.apply_infraction = AsyncMock() - self.cog.mod_log.ignore = Mock() - self.target.kick = Mock() - - await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) - self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value - ) diff --git a/tests/bot/cogs/moderation/test_slowmode.py b/tests/bot/cogs/moderation/test_slowmode.py new file mode 100644 index 000000000..f442814c8 --- /dev/null +++ b/tests/bot/cogs/moderation/test_slowmode.py @@ -0,0 +1,111 @@ +import unittest +from unittest import mock + +from dateutil.relativedelta import relativedelta + +from bot.cogs.moderation.slowmode import Slowmode +from bot.constants import Emojis +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SlowmodeTests(unittest.IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.bot = MockBot() + self.cog = Slowmode(self.bot) + self.ctx = MockContext() + + async def test_get_slowmode_no_channel(self) -> None: + """Get slowmode without a given channel.""" + self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) + + await self.cog.get_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") + + async def test_get_slowmode_with_channel(self) -> None: + """Get slowmode with a given channel.""" + text_channel = MockTextChannel(name='python-language', slowmode_delay=2) + + await self.cog.get_slowmode(self.cog, self.ctx, text_channel) + self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') + + async def test_set_slowmode_no_channel(self) -> None: + """Set slowmode without a given channel.""" + test_cases = ( + ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), + ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), + ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') + ) + + for channel_name, seconds, edited, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + edited=edited, + result_msg=result_msg + ): + self.ctx.channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) + + if edited: + self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + self.ctx.channel.edit.assert_not_called() + + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + + async def test_set_slowmode_with_channel(self) -> None: + """Set slowmode with a given channel.""" + test_cases = ( + ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), + ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), + ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') + ) + + for channel_name, seconds, edited, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + edited=edited, + result_msg=result_msg + ): + text_channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) + + if edited: + text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + text_channel.edit.assert_not_called() + + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + + async def test_reset_slowmode_no_channel(self) -> None: + """Reset slowmode without a given channel.""" + self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) + + await self.cog.reset_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' + ) + + async def test_reset_slowmode_with_channel(self) -> None: + """Reset slowmode with a given channel.""" + text_channel = MockTextChannel(name='meta', slowmode_delay=1) + + await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' + ) + + @mock.patch("bot.cogs.moderation.slowmode.with_role_check") + @mock.patch("bot.cogs.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) + def test_cog_check(self, role_check): + """Role check is called with `MODERATION_ROLES`""" + self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/sync/__init__.py b/tests/bot/cogs/sync/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py deleted file mode 100644 index 84d036405..000000000 --- a/tests/bot/cogs/sync/test_base.py +++ /dev/null @@ -1,404 +0,0 @@ -import asyncio -import unittest -from unittest import mock - -import discord - -from bot import constants -from bot.api import ResponseCodeError -from bot.cogs.backend.sync import Syncer, _Diff -from tests import helpers - - -class TestSyncer(Syncer): - """Syncer subclass with mocks for abstract methods for testing purposes.""" - - name = "test" - _get_diff = mock.AsyncMock() - _sync = mock.AsyncMock() - - -class SyncerBaseTests(unittest.TestCase): - """Tests for the syncer base class.""" - - def setUp(self): - self.bot = helpers.MockBot() - - def test_instantiation_fails_without_abstract_methods(self): - """The class must have abstract methods implemented.""" - with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): - Syncer(self.bot) - - -class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): - """Tests for sending the sync confirmation prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - - def mock_get_channel(self): - """Fixture to return a mock channel and message for when `get_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - mock_channel.send.return_value = mock_message - self.bot.get_channel.return_value = mock_channel - - return mock_channel, mock_message - - def mock_fetch_channel(self): - """Fixture to return a mock channel and message for when `fetch_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - self.bot.get_channel.return_value = None - mock_channel.send.return_value = mock_message - self.bot.fetch_channel.return_value = mock_channel - - return mock_channel, mock_message - - async def test_send_prompt_edits_and_returns_message(self): - """The given message should be edited to display the prompt and then should be returned.""" - msg = helpers.MockMessage() - ret_val = await self.syncer._send_prompt(msg) - - msg.edit.assert_called_once() - self.assertIn("content", msg.edit.call_args[1]) - self.assertEqual(ret_val, msg) - - async def test_send_prompt_gets_dev_core_channel(self): - """The dev-core channel should be retrieved if an extant message isn't given.""" - subtests = ( - (self.bot.get_channel, self.mock_get_channel), - (self.bot.fetch_channel, self.mock_fetch_channel), - ) - - for method, mock_ in subtests: - with self.subTest(method=method, msg=mock_.__name__): - mock_() - await self.syncer._send_prompt() - - method.assert_called_once_with(constants.Channels.dev_core) - - async def test_send_prompt_returns_none_if_channel_fetch_fails(self): - """None should be returned if there's an HTTPException when fetching the channel.""" - self.bot.get_channel.return_value = None - self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") - - ret_val = await self.syncer._send_prompt() - - self.assertIsNone(ret_val) - - async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): - """A new message mentioning core devs should be sent and returned if message isn't given.""" - for mock_ in (self.mock_get_channel, self.mock_fetch_channel): - with self.subTest(msg=mock_.__name__): - mock_channel, mock_message = mock_() - ret_val = await self.syncer._send_prompt() - - mock_channel.send.assert_called_once() - self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) - self.assertEqual(ret_val, mock_message) - - async def test_send_prompt_adds_reactions(self): - """The message should have reactions for confirmation added.""" - extant_message = helpers.MockMessage() - subtests = ( - (extant_message, lambda: (None, extant_message)), - (None, self.mock_get_channel), - (None, self.mock_fetch_channel), - ) - - for message_arg, mock_ in subtests: - subtest_msg = "Extant message" if mock_.__name__ == "" else mock_.__name__ - - with self.subTest(msg=subtest_msg): - _, mock_message = mock_() - await self.syncer._send_prompt(message_arg) - - calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] - mock_message.add_reaction.assert_has_calls(calls) - - -class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): - """Tests for waiting for a sync confirmation reaction on the prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) - - @staticmethod - def get_message_reaction(emoji): - """Fixture to return a mock message an reaction from the given `emoji`.""" - message = helpers.MockMessage() - reaction = helpers.MockReaction(emoji=emoji, message=message) - - return message, reaction - - def test_reaction_check_for_valid_emoji_and_authors(self): - """Should return True if authors are identical or are a bot and a core dev, respectively.""" - user_subtests = ( - ( - helpers.MockMember(id=77), - helpers.MockMember(id=77), - "identical users", - ), - ( - helpers.MockMember(id=77, bot=True), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "bot author and core-dev reactor", - ), - ) - - for emoji in self.syncer._REACTION_EMOJIS: - for author, user, msg in user_subtests: - with self.subTest(author=author, user=user, emoji=emoji, msg=msg): - message, reaction = self.get_message_reaction(emoji) - ret_val = self.syncer._reaction_check(author, message, reaction, user) - - self.assertTrue(ret_val) - - def test_reaction_check_for_invalid_reactions(self): - """Should return False for invalid reaction events.""" - valid_emoji = self.syncer._REACTION_EMOJIS[0] - subtests = ( - ( - helpers.MockMember(id=77), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "users are not identical", - ), - ( - helpers.MockMember(id=77, bot=True), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43), - "reactor lacks the core-dev role", - ), - ( - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - "reactor is a bot", - ), - ( - helpers.MockMember(id=77), - helpers.MockMessage(id=95), - helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), - helpers.MockMember(id=77), - "messages are not identical", - ), - ( - helpers.MockMember(id=77), - *self.get_message_reaction("InVaLiD"), - helpers.MockMember(id=77), - "emoji is invalid", - ), - ) - - for *args, msg in subtests: - kwargs = dict(zip(("author", "message", "reaction", "user"), args)) - with self.subTest(**kwargs, msg=msg): - ret_val = self.syncer._reaction_check(*args) - self.assertFalse(ret_val) - - async def test_wait_for_confirmation(self): - """The message should always be edited and only return True if the emoji is a check mark.""" - subtests = ( - (constants.Emojis.check_mark, True, None), - ("InVaLiD", False, None), - (None, False, asyncio.TimeoutError), - ) - - for emoji, ret_val, side_effect in subtests: - for bot in (True, False): - with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): - # Set up mocks - message = helpers.MockMessage() - member = helpers.MockMember(bot=bot) - - self.bot.wait_for.reset_mock() - self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) - self.bot.wait_for.side_effect = side_effect - - # Call the function - actual_return = await self.syncer._wait_for_confirmation(member, message) - - # Perform assertions - self.bot.wait_for.assert_called_once() - self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) - - message.edit.assert_called_once() - kwargs = message.edit.call_args[1] - self.assertIn("content", kwargs) - - # Core devs should only be mentioned if the author is a bot. - if bot: - self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - else: - self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - - self.assertIs(actual_return, ret_val) - - -class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for main function orchestrating the sync.""" - - def setUp(self): - self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) - self.syncer = TestSyncer(self.bot) - - async def test_sync_respects_confirmation_result(self): - """The sync should abort if confirmation fails and continue if confirmed.""" - mock_message = helpers.MockMessage() - subtests = ( - (True, mock_message), - (False, None), - ) - - for confirmed, message in subtests: - with self.subTest(confirmed=confirmed): - self.syncer._sync.reset_mock() - self.syncer._get_diff.reset_mock() - - diff = _Diff({1, 2, 3}, {4, 5}, None) - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(confirmed, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - - if confirmed: - self.syncer._sync.assert_called_once_with(diff) - else: - self.syncer._sync.assert_not_called() - - async def test_sync_diff_size(self): - """The diff size should be correctly calculated.""" - subtests = ( - (6, _Diff({1, 2}, {3, 4}, {5, 6})), - (5, _Diff({1, 2, 3}, None, {4, 5})), - (0, _Diff(None, None, None)), - (0, _Diff(set(), set(), set())), - ) - - for size, diff in subtests: - with self.subTest(size=size, diff=diff): - self.syncer._get_diff.reset_mock() - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) - - async def test_sync_message_edited(self): - """The message should be edited if one was sent, even if the sync has an API error.""" - subtests = ( - (None, None, False), - (helpers.MockMessage(), None, True), - (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), - ) - - for message, side_effect, should_edit in subtests: - with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): - self.syncer._sync.side_effect = side_effect - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(True, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - if should_edit: - message.edit.assert_called_once() - self.assertIn("content", message.edit.call_args[1]) - - async def test_sync_confirmation_context_redirect(self): - """If ctx is given, a new message should be sent and author should be ctx's author.""" - mock_member = helpers.MockMember() - subtests = ( - (None, self.bot.user, None), - (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), - ) - - for ctx, author, message in subtests: - with self.subTest(ctx=ctx, author=author, message=message): - if ctx is not None: - ctx.send.return_value = message - - # Make sure `_get_diff` returns a MagicMock, not an AsyncMock - self.syncer._get_diff.return_value = mock.MagicMock() - - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild, ctx) - - if ctx is not None: - ctx.send.assert_called_once() - - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_small_diff(self): - """Should always return True and the given message if the diff size is too small.""" - author = helpers.MockMember() - expected_message = helpers.MockMessage() - - for size in (3, 2): # pragma: no cover - with self.subTest(size=size): - self.syncer._send_prompt = mock.AsyncMock() - self.syncer._wait_for_confirmation = mock.AsyncMock() - - coro = self.syncer._get_confirmation_result(size, author, expected_message) - result, actual_message = await coro - - self.assertTrue(result) - self.assertEqual(actual_message, expected_message) - self.syncer._send_prompt.assert_not_called() - self.syncer._wait_for_confirmation.assert_not_called() - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_large_diff(self): - """Should return True if confirmed and False if _send_prompt fails or aborted.""" - author = helpers.MockMember() - mock_message = helpers.MockMessage() - - subtests = ( - (True, mock_message, True, "confirmed"), - (False, None, False, "_send_prompt failed"), - (False, mock_message, False, "aborted"), - ) - - for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover - with self.subTest(msg=msg): - self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) - self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) - - coro = self.syncer._get_confirmation_result(4, author) - actual_result, actual_message = await coro - - self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None - self.assertIs(actual_result, expected_result) - self.assertEqual(actual_message, expected_message) - - if expected_message: - self.syncer._wait_for_confirmation.assert_called_once_with( - author, expected_message - ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py deleted file mode 100644 index ea7d090ba..000000000 --- a/tests/bot/cogs/sync/test_cog.py +++ /dev/null @@ -1,415 +0,0 @@ -import unittest -from unittest import mock - -import discord - -from bot import constants -from bot.api import ResponseCodeError -from bot.cogs.backend import sync -from bot.cogs.backend.sync import Syncer -from tests import helpers -from tests.base import CommandTestCase - - -class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): - """Tests for the sync extension.""" - - @staticmethod - def test_extension_setup(): - """The Sync cog should be added.""" - bot = helpers.MockBot() - sync.setup(bot) - bot.add_cog.assert_called_once() - - -class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): - """Base class for Sync cog tests. Sets up patches for syncers.""" - - def setUp(self): - self.bot = helpers.MockBot() - - self.role_syncer_patcher = mock.patch( - "bot.cogs.sync.syncers.RoleSyncer", - autospec=Syncer, - spec_set=True - ) - self.user_syncer_patcher = mock.patch( - "bot.cogs.sync.syncers.UserSyncer", - autospec=Syncer, - spec_set=True - ) - self.RoleSyncer = self.role_syncer_patcher.start() - self.UserSyncer = self.user_syncer_patcher.start() - - self.cog = sync.Sync(self.bot) - - def tearDown(self): - self.role_syncer_patcher.stop() - self.user_syncer_patcher.stop() - - @staticmethod - def response_error(status: int) -> ResponseCodeError: - """Fixture to return a ResponseCodeError with the given status code.""" - response = mock.MagicMock() - response.status = status - - return ResponseCodeError(response) - - -class SyncCogTests(SyncCogTestCase): - """Tests for the Sync cog.""" - - @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) - def test_sync_cog_init(self, sync_guild): - """Should instantiate syncers and run a sync for the guild.""" - # Reset because a Sync cog was already instantiated in setUp. - self.RoleSyncer.reset_mock() - self.UserSyncer.reset_mock() - self.bot.loop.create_task = mock.MagicMock() - - mock_sync_guild_coro = mock.MagicMock() - sync_guild.return_value = mock_sync_guild_coro - - sync.Sync(self.bot) - - self.RoleSyncer.assert_called_once_with(self.bot) - self.UserSyncer.assert_called_once_with(self.bot) - sync_guild.assert_called_once_with() - self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) - - async def test_sync_cog_sync_guild(self): - """Roles and users should be synced only if a guild is successfully retrieved.""" - for guild in (helpers.MockGuild(), None): - with self.subTest(guild=guild): - self.bot.reset_mock() - self.cog.role_syncer.reset_mock() - self.cog.user_syncer.reset_mock() - - self.bot.get_guild = mock.MagicMock(return_value=guild) - - await self.cog.sync_guild() - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.get_guild.assert_called_once_with(constants.Guild.id) - - if guild is None: - self.cog.role_syncer.sync.assert_not_called() - self.cog.user_syncer.sync.assert_not_called() - else: - self.cog.role_syncer.sync.assert_called_once_with(guild) - self.cog.user_syncer.sync.assert_called_once_with(guild) - - async def patch_user_helper(self, side_effect: BaseException) -> None: - """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" - self.bot.api_client.patch.reset_mock(side_effect=True) - self.bot.api_client.patch.side_effect = side_effect - - user_id, updated_information = 5, {"key": 123} - await self.cog.patch_user(user_id, updated_information) - - self.bot.api_client.patch.assert_called_once_with( - f"bot/users/{user_id}", - json=updated_information, - ) - - async def test_sync_cog_patch_user(self): - """A PATCH request should be sent and 404 errors ignored.""" - for side_effect in (None, self.response_error(404)): - with self.subTest(side_effect=side_effect): - await self.patch_user_helper(side_effect) - - async def test_sync_cog_patch_user_non_404(self): - """A PATCH request should be sent and the error raised if it's not a 404.""" - with self.assertRaises(ResponseCodeError): - await self.patch_user_helper(self.response_error(500)) - - -class SyncCogListenerTests(SyncCogTestCase): - """Tests for the listeners of the Sync cog.""" - - def setUp(self): - super().setUp() - self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - - self.guild_id_patcher = mock.patch("bot.cogs.sync.cog.constants.Guild.id", 5) - self.guild_id = self.guild_id_patcher.start() - - self.guild = helpers.MockGuild(id=self.guild_id) - self.other_guild = helpers.MockGuild(id=0) - - def tearDown(self): - self.guild_id_patcher.stop() - - async def test_sync_cog_on_guild_role_create(self): - """A POST request should be sent with the new role's data.""" - self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) - - role_data = { - "colour": 49, - "id": 777, - "name": "rolename", - "permissions": 8, - "position": 23, - } - role = helpers.MockRole(**role_data, guild=self.guild) - await self.cog.on_guild_role_create(role) - - self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) - - async def test_sync_cog_on_guild_role_create_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_create(role) - self.bot.api_client.post.assert_not_awaited() - - async def test_sync_cog_on_guild_role_delete(self): - """A DELETE request should be sent.""" - self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) - - role = helpers.MockRole(id=99, guild=self.guild) - await self.cog.on_guild_role_delete(role) - - self.bot.api_client.delete.assert_called_once_with("bot/roles/99") - - async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_delete(role) - self.bot.api_client.delete.assert_not_awaited() - - async def test_sync_cog_on_guild_role_update(self): - """A PUT request should be sent if the colour, name, permissions, or position changes.""" - self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) - - role_data = { - "colour": 49, - "id": 777, - "name": "rolename", - "permissions": 8, - "position": 23, - } - subtests = ( - (True, ("colour", "name", "permissions", "position")), - (False, ("hoist", "mentionable")), - ) - - for should_put, attributes in subtests: - for attribute in attributes: - with self.subTest(should_put=should_put, changed_attribute=attribute): - self.bot.api_client.put.reset_mock() - - after_role_data = role_data.copy() - after_role_data[attribute] = 876 - - before_role = helpers.MockRole(**role_data, guild=self.guild) - after_role = helpers.MockRole(**after_role_data, guild=self.guild) - - await self.cog.on_guild_role_update(before_role, after_role) - - if should_put: - self.bot.api_client.put.assert_called_once_with( - f"bot/roles/{after_role.id}", - json=after_role_data - ) - else: - self.bot.api_client.put.assert_not_called() - - async def test_sync_cog_on_guild_role_update_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_update(role, role) - self.bot.api_client.put.assert_not_awaited() - - async def test_sync_cog_on_member_remove(self): - """Member should be patched to set in_guild as False.""" - self.assertTrue(self.cog.on_member_remove.__cog_listener__) - - member = helpers.MockMember(guild=self.guild) - await self.cog.on_member_remove(member) - - self.cog.patch_user.assert_called_once_with( - member.id, - json={"in_guild": False} - ) - - async def test_sync_cog_on_member_remove_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_remove(member) - self.cog.patch_user.assert_not_awaited() - - async def test_sync_cog_on_member_update_roles(self): - """Members should be patched if their roles have changed.""" - self.assertTrue(self.cog.on_member_update.__cog_listener__) - - # Roles are intentionally unsorted. - before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] - before_member = helpers.MockMember(roles=before_roles, guild=self.guild) - after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) - - await self.cog.on_member_update(before_member, after_member) - - data = {"roles": sorted(role.id for role in after_member.roles)} - self.cog.patch_user.assert_called_once_with(after_member.id, json=data) - - async def test_sync_cog_on_member_update_other(self): - """Members should not be patched if other attributes have changed.""" - self.assertTrue(self.cog.on_member_update.__cog_listener__) - - subtests = ( - ("activities", discord.Game("Pong"), discord.Game("Frogger")), - ("nick", "old nick", "new nick"), - ("status", discord.Status.online, discord.Status.offline), - ) - - for attribute, old_value, new_value in subtests: - with self.subTest(attribute=attribute): - self.cog.patch_user.reset_mock() - - before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) - after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) - - await self.cog.on_member_update(before_member, after_member) - - self.cog.patch_user.assert_not_called() - - async def test_sync_cog_on_member_update_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_update(member, member) - self.cog.patch_user.assert_not_awaited() - - async def test_sync_cog_on_user_update(self): - """A user should be patched only if the name, discriminator, or avatar changes.""" - self.assertTrue(self.cog.on_user_update.__cog_listener__) - - before_data = { - "name": "old name", - "discriminator": "1234", - "bot": False, - } - - subtests = ( - (True, "name", "name", "new name", "new name"), - (True, "discriminator", "discriminator", "8765", 8765), - (False, "bot", "bot", True, True), - ) - - for should_patch, attribute, api_field, value, api_value in subtests: - with self.subTest(attribute=attribute): - self.cog.patch_user.reset_mock() - - after_data = before_data.copy() - after_data[attribute] = value - before_user = helpers.MockUser(**before_data) - after_user = helpers.MockUser(**after_data) - - await self.cog.on_user_update(before_user, after_user) - - if should_patch: - self.cog.patch_user.assert_called_once() - - # Don't care if *all* keys are present; only the changed one is required - call_args = self.cog.patch_user.call_args - self.assertEqual(call_args.args[0], after_user.id) - self.assertIn("json", call_args.kwargs) - - self.assertIn("ignore_404", call_args.kwargs) - self.assertTrue(call_args.kwargs["ignore_404"]) - - json = call_args.kwargs["json"] - self.assertIn(api_field, json) - self.assertEqual(json[api_field], api_value) - else: - self.cog.patch_user.assert_not_called() - - async def on_member_join_helper(self, side_effect: Exception) -> dict: - """ - Helper to set `side_effect` for on_member_join and assert a PUT request was sent. - - The request data for the mock member is returned. All exceptions will be re-raised. - """ - member = helpers.MockMember( - discriminator="1234", - roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], - guild=self.guild, - ) - - data = { - "discriminator": int(member.discriminator), - "id": member.id, - "in_guild": True, - "name": member.name, - "roles": sorted(role.id for role in member.roles) - } - - self.bot.api_client.put.reset_mock(side_effect=True) - self.bot.api_client.put.side_effect = side_effect - - try: - await self.cog.on_member_join(member) - except Exception: - raise - finally: - self.bot.api_client.put.assert_called_once_with( - f"bot/users/{member.id}", - json=data - ) - - return data - - async def test_sync_cog_on_member_join(self): - """Should PUT user's data or POST it if the user doesn't exist.""" - for side_effect in (None, self.response_error(404)): - with self.subTest(side_effect=side_effect): - self.bot.api_client.post.reset_mock() - data = await self.on_member_join_helper(side_effect) - - if side_effect: - self.bot.api_client.post.assert_called_once_with("bot/users", json=data) - else: - self.bot.api_client.post.assert_not_called() - - async def test_sync_cog_on_member_join_non_404(self): - """ResponseCodeError should be re-raised if status code isn't a 404.""" - with self.assertRaises(ResponseCodeError): - await self.on_member_join_helper(self.response_error(500)) - - self.bot.api_client.post.assert_not_called() - - async def test_sync_cog_on_member_join_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_join(member) - self.bot.api_client.post.assert_not_awaited() - self.bot.api_client.put.assert_not_awaited() - - -class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): - """Tests for the commands in the Sync cog.""" - - async def test_sync_roles_command(self): - """sync() should be called on the RoleSyncer.""" - ctx = helpers.MockContext() - await self.cog.sync_roles_command.callback(self.cog, ctx) - - self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) - - async def test_sync_users_command(self): - """sync() should be called on the UserSyncer.""" - ctx = helpers.MockContext() - await self.cog.sync_users_command.callback(self.cog, ctx) - - self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) - - async def test_commands_require_admin(self): - """The sync commands should only run if the author has the administrator permission.""" - cmds = ( - self.cog.sync_group, - self.cog.sync_roles_command, - self.cog.sync_users_command, - ) - - for cmd in cmds: - with self.subTest(cmd=cmd): - await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py deleted file mode 100644 index 888c49ca8..000000000 --- a/tests/bot/cogs/sync/test_roles.py +++ /dev/null @@ -1,157 +0,0 @@ -import unittest -from unittest import mock - -import discord - -from bot.cogs.backend.sync import RoleSyncer, _Diff, _Role -from tests import helpers - - -def fake_role(**kwargs): - """Fixture to return a dictionary representing a role with default values set.""" - kwargs.setdefault("id", 9) - kwargs.setdefault("name", "fake role") - kwargs.setdefault("colour", 7) - kwargs.setdefault("permissions", 0) - kwargs.setdefault("position", 55) - - return kwargs - - -class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): - """Tests for determining differences between roles in the DB and roles in the Guild cache.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) - - @staticmethod - def get_guild(*roles): - """Fixture to return a guild object with the given roles.""" - guild = helpers.MockGuild() - guild.roles = [] - - for role in roles: - mock_role = helpers.MockRole(**role) - mock_role.colour = discord.Colour(role["colour"]) - mock_role.permissions = discord.Permissions(role["permissions"]) - guild.roles.append(mock_role) - - return guild - - async def test_empty_diff_for_identical_roles(self): - """No differences should be found if the roles in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_role()] - guild = self.get_guild(fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_updated_roles(self): - """Only updated roles should be added to the 'updated' set of the diff.""" - updated_role = fake_role(id=41, name="new") - - self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] - guild = self.get_guild(updated_role, fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_Role(**updated_role)}, set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_roles(self): - """Only new roles should be added to the 'created' set of the diff.""" - new_role = fake_role(id=41, name="new") - - self.bot.api_client.get.return_value = [fake_role()] - guild = self.get_guild(fake_role(), new_role) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_Role(**new_role)}, set(), set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_deleted_roles(self): - """Only deleted roles should be added to the 'deleted' set of the diff.""" - deleted_role = fake_role(id=61, name="deleted") - - self.bot.api_client.get.return_value = [fake_role(), deleted_role] - guild = self.get_guild(fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), {_Role(**deleted_role)}) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_updated_and_deleted_roles(self): - """When roles are added, updated, and removed, all of them are returned properly.""" - new = fake_role(id=41, name="new") - updated = fake_role(id=71, name="updated") - deleted = fake_role(id=61, name="deleted") - - self.bot.api_client.get.return_value = [ - fake_role(), - fake_role(id=71, name="updated name"), - deleted, - ] - guild = self.get_guild(fake_role(), new, updated) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) - - self.assertEqual(actual_diff, expected_diff) - - -class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for the API requests that sync roles.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) - - async def test_sync_created_roles(self): - """Only POST requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(role_tuples, set(), set()) - await self.syncer._sync(diff) - - calls = [mock.call("bot/roles", json=role) for role in roles] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(roles)) - - self.bot.api_client.put.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_updated_roles(self): - """Only PUT requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(set(), role_tuples, set()) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(roles)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_deleted_roles(self): - """Only DELETE requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(set(), set(), role_tuples) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] - self.bot.api_client.delete.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py deleted file mode 100644 index 71f4b134c..000000000 --- a/tests/bot/cogs/sync/test_users.py +++ /dev/null @@ -1,158 +0,0 @@ -import unittest -from unittest import mock - -from bot.cogs.backend.sync import UserSyncer, _Diff, _User -from tests import helpers - - -def fake_user(**kwargs): - """Fixture to return a dictionary representing a user with default values set.""" - kwargs.setdefault("id", 43) - kwargs.setdefault("name", "bob the test man") - kwargs.setdefault("discriminator", 1337) - kwargs.setdefault("roles", (666,)) - kwargs.setdefault("in_guild", True) - - return kwargs - - -class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): - """Tests for determining differences between users in the DB and users in the Guild cache.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) - - @staticmethod - def get_guild(*members): - """Fixture to return a guild object with the given members.""" - guild = helpers.MockGuild() - guild.members = [] - - for member in members: - member = member.copy() - del member["in_guild"] - - mock_member = helpers.MockMember(**member) - mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] - - guild.members.append(mock_member) - - return guild - - async def test_empty_diff_for_no_users(self): - """When no users are given, an empty diff should be returned.""" - guild = self.get_guild() - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_empty_diff_for_identical_users(self): - """No differences should be found if the users in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_user()] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_updated_users(self): - """Only updated users should be added to the 'updated' set of the diff.""" - updated_user = fake_user(id=99, name="new") - - self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] - guild = self.get_guild(updated_user, fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**updated_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_users(self): - """Only new users should be added to the 'created' set of the diff.""" - new_user = fake_user(id=99, name="new") - - self.bot.api_client.get.return_value = [fake_user()] - guild = self.get_guild(fake_user(), new_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_sets_in_guild_false_for_leaving_users(self): - """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**leaving_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_updated_and_leaving_users(self): - """When users are added, updated, and removed, all of them are returned properly.""" - new_user = fake_user(id=99, name="new") - updated_user = fake_user(id=55, name="updated") - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] - guild = self.get_guild(fake_user(), new_user, updated_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_empty_diff_for_db_users_not_in_guild(self): - """When the DB knows a user the guild doesn't, no difference is found.""" - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - -class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for the API requests that sync users.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) - - async def test_sync_created_users(self): - """Only POST requests should be made with the correct payload.""" - users = [fake_user(id=111), fake_user(id=222)] - - user_tuples = {_User(**user) for user in users} - diff = _Diff(user_tuples, set(), None) - await self.syncer._sync(diff) - - calls = [mock.call("bot/users", json=user) for user in users] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(users)) - - self.bot.api_client.put.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_updated_users(self): - """Only PUT requests should be made with the correct payload.""" - users = [fake_user(id=111), fake_user(id=222)] - - user_tuples = {_User(**user) for user in users} - diff = _Diff(set(), user_tuples, None) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(users)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py deleted file mode 100644 index b00211f47..000000000 --- a/tests/bot/cogs/test_antimalware.py +++ /dev/null @@ -1,165 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, Mock - -from discord import NotFound - -from bot.cogs.filters import antimalware -from bot.constants import Channels, STAFF_ROLES -from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole - - -class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): - """Test the AntiMalware cog.""" - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = MockBot() - self.bot.filter_list_cache = { - "FILE_FORMAT.True": { - ".first": {}, - ".second": {}, - ".third": {}, - } - } - self.cog = antimalware.AntiMalware(self.bot) - self.message = MockMessage() - self.whitelist = [".first", ".second", ".third"] - - async def test_message_with_allowed_attachment(self): - """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_message_without_attachment(self): - """Messages without attachments should result in no action.""" - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_direct_message_with_attachment(self): - """Direct messages should have no action taken.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.guild = None - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_message_with_illegal_extension_gets_deleted(self): - """A message containing an illegal extension should send an embed.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_called_once() - - async def test_message_send_by_staff(self): - """A message send by a member of staff should be ignored.""" - staff_role = MockRole(id=STAFF_ROLES[0]) - self.message.author.roles.append(staff_role) - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_python_file_redirect_embed_description(self): - """A message containing a .py file should result in an embed redirecting the user to our paste site""" - attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - - self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - - async def test_txt_file_redirect_embed_description(self): - """A message containing a .txt file should result in the correct embed.""" - attachment = MockAttachment(filename="python.txt") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.TXT_EMBED_DESCRIPTION = Mock() - antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - cmd_channel = self.bot.get_channel(Channels.bot_commands) - - self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) - antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) - - async def test_other_disallowed_extension_embed_description(self): - """Test the description for a non .py/.txt disallowed extension.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - meta_channel = self.bot.get_channel(Channels.meta) - - self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( - joined_whitelist=", ".join(self.whitelist), - blocked_extensions_str=".disallowed", - meta_channel_mention=meta_channel.mention - ) - - async def test_removing_deleted_message_logs(self): - """Removing an already deleted message logs the correct message""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - self.message.delete.assert_called_once() - - async def test_message_with_illegal_attachment_logs(self): - """Deleting a message with an illegal attachment should result in a log.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - - async def test_get_disallowed_extensions(self): - """The return value should include all non-whitelisted extensions.""" - test_values = ( - ([], []), - (self.whitelist, []), - ([".first"], []), - ([".first", ".disallowed"], [".disallowed"]), - ([".disallowed"], [".disallowed"]), - ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), - ) - - for extensions, expected_disallowed_extensions in test_values: - with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog._get_disallowed_extensions(self.message) - self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) - - -class AntiMalwareSetupTests(unittest.TestCase): - """Tests setup of the `AntiMalware` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = MockBot() - antimalware.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/cogs/test_antispam.py deleted file mode 100644 index 8a3d8d02e..000000000 --- a/tests/bot/cogs/test_antispam.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from bot.cogs.filters import antispam - - -class AntispamConfigurationValidationTests(unittest.TestCase): - """Tests validation of the antispam cog configuration.""" - - def test_default_antispam_config_is_valid(self): - """The default antispam configuration is valid.""" - validation_errors = antispam.validate_config() - self.assertEqual(validation_errors, {}) - - def test_unknown_rule_returns_error(self): - """Configuring an unknown rule returns an error.""" - self.assertEqual( - antispam.validate_config({'invalid-rule': {}}), - {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} - ) - - def test_missing_keys_returns_error(self): - """Not configuring required keys returns an error.""" - keys = (('interval', 'max'), ('max', 'interval')) - for configured_key, unconfigured_key in keys: - with self.subTest( - configured_key=configured_key, - unconfigured_key=unconfigured_key - ): - config = {'burst': {configured_key: 10}} - error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" - - self.assertEqual( - antispam.validate_config(config), - {'burst': error} - ) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py deleted file mode 100644 index 305a2bad9..000000000 --- a/tests/bot/cogs/test_information.py +++ /dev/null @@ -1,584 +0,0 @@ -import asyncio -import textwrap -import unittest -import unittest.mock - -import discord - -from bot import constants -from bot.cogs.info import information -from bot.utils.checks import InWhitelistCheckFailure -from tests import helpers - -COG_PATH = "bot.cogs.information.Information" - - -class InformationCogTests(unittest.TestCase): - """Tests the Information cog.""" - - @classmethod - def setUpClass(cls): - cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = helpers.MockBot() - - self.cog = information.Information(self.bot) - - self.ctx = helpers.MockContext() - self.ctx.author.roles.append(self.moderator_role) - - def test_roles_command_command(self): - """Test if the `role_info` command correctly returns the `moderator_role`.""" - self.ctx.guild.roles.append(self.moderator_role) - - self.cog.roles_info.can_run = unittest.mock.AsyncMock() - self.cog.roles_info.can_run.return_value = True - - coroutine = self.cog.roles_info.callback(self.cog, self.ctx) - - self.assertIsNone(asyncio.run(coroutine)) - self.ctx.send.assert_called_once() - - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - - self.assertEqual(embed.title, "Role information (Total 1 role)") - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") - - def test_role_info_command(self): - """Tests the `role info` command.""" - dummy_role = helpers.MockRole( - name="Dummy", - id=112233445566778899, - colour=discord.Colour.blurple(), - position=10, - members=[self.ctx.author], - permissions=discord.Permissions(0) - ) - - admin_role = helpers.MockRole( - name="Admins", - id=998877665544332211, - colour=discord.Colour.red(), - position=3, - members=[self.ctx.author], - permissions=discord.Permissions(0), - ) - - self.ctx.guild.roles.append([dummy_role, admin_role]) - - self.cog.role_info.can_run = unittest.mock.AsyncMock() - self.cog.role_info.can_run.return_value = True - - coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) - - self.assertIsNone(asyncio.run(coroutine)) - - self.assertEqual(self.ctx.send.call_count, 2) - - (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list - - dummy_embed = dummy_kwargs["embed"] - admin_embed = admin_kwargs["embed"] - - self.assertEqual(dummy_embed.title, "Dummy info") - self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) - - self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) - self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") - self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") - self.assertEqual(dummy_embed.fields[3].value, "1") - self.assertEqual(dummy_embed.fields[4].value, "10") - self.assertEqual(dummy_embed.fields[5].value, "0") - - self.assertEqual(admin_embed.title, "Admins info") - self.assertEqual(admin_embed.colour, discord.Colour.red()) - - @unittest.mock.patch('bot.cogs.information.time_since') - def test_server_info_command(self, time_since_patch): - time_since_patch.return_value = '2 days ago' - - self.ctx.guild = helpers.MockGuild( - features=('lemons', 'apples'), - region="The Moon", - roles=[self.moderator_role], - channels=[ - discord.TextChannel( - state={}, - guild=self.ctx.guild, - data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} - ), - discord.CategoryChannel( - state={}, - guild=self.ctx.guild, - data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} - ), - discord.VoiceChannel( - state={}, - guild=self.ctx.guild, - data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} - ) - ], - members=[ - *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), - *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), - *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), - *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), - ], - member_count=1_234, - icon_url='a-lemon.jpg', - ) - - coroutine = self.cog.server_info.callback(self.cog, self.ctx) - self.assertIsNone(asyncio.run(coroutine)) - - time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual( - embed.description, - textwrap.dedent( - f""" - **Server information** - Created: {time_since_patch.return_value} - Voice region: {self.ctx.guild.region} - Features: {', '.join(self.ctx.guild.features)} - - **Channel counts** - Category channels: 1 - Text channels: 1 - Voice channels: 1 - Staff channels: 0 - - **Member counts** - Members: {self.ctx.guild.member_count:,} - Staff members: 0 - Roles: {len(self.ctx.guild.roles)} - - **Member statuses** - {constants.Emojis.status_online} 2 - {constants.Emojis.status_idle} 1 - {constants.Emojis.status_dnd} 4 - {constants.Emojis.status_offline} 3 - """ - ) - ) - self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') - - -class UserInfractionHelperMethodTests(unittest.TestCase): - """Tests for the helper methods of the `!user` command.""" - - def setUp(self): - """Common set-up steps done before for each test.""" - self.bot = helpers.MockBot() - self.bot.api_client.get = unittest.mock.AsyncMock() - self.cog = information.Information(self.bot) - self.member = helpers.MockMember(id=1234) - - def test_user_command_helper_method_get_requests(self): - """The helper methods should form the correct get requests.""" - test_values = ( - { - "helper_method": self.cog.basic_user_infraction_counts, - "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}), - }, - { - "helper_method": self.cog.expanded_user_infraction_counts, - "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}), - }, - { - "helper_method": self.cog.user_nomination_counts, - "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}), - }, - ) - - for test_value in test_values: - helper_method = test_value["helper_method"] - endpoint, params = test_value["expected_args"] - - with self.subTest(method=helper_method, endpoint=endpoint, params=params): - asyncio.run(helper_method(self.member)) - self.bot.api_client.get.assert_called_once_with(endpoint, params=params) - self.bot.api_client.get.reset_mock() - - def _method_subtests(self, method, test_values, default_header): - """Helper method that runs the subtests for the different helper methods.""" - for test_value in test_values: - api_response = test_value["api response"] - expected_lines = test_value["expected_lines"] - - with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): - self.bot.api_client.get.return_value = api_response - - expected_output = "\n".join(default_header + expected_lines) - actual_output = asyncio.run(method(self.member)) - - self.assertEqual(expected_output, actual_output) - - def test_basic_user_infraction_counts_returns_correct_strings(self): - """The method should correctly list both the total and active number of non-hidden infractions.""" - test_values = ( - # No infractions means zero counts - { - "api response": [], - "expected_lines": ["Total: 0", "Active: 0"], - }, - # Simple, single-infraction dictionaries - { - "api response": [{"type": "ban", "active": True}], - "expected_lines": ["Total: 1", "Active: 1"], - }, - { - "api response": [{"type": "ban", "active": False}], - "expected_lines": ["Total: 1", "Active: 0"], - }, - # Multiple infractions with various `active` status - { - "api response": [ - {"type": "ban", "active": True}, - {"type": "kick", "active": False}, - {"type": "ban", "active": True}, - {"type": "ban", "active": False}, - ], - "expected_lines": ["Total: 4", "Active: 2"], - }, - ) - - header = ["**Infractions**"] - - self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) - - def test_expanded_user_infraction_counts_returns_correct_strings(self): - """The method should correctly list the total and active number of all infractions split by infraction type.""" - test_values = ( - { - "api response": [], - "expected_lines": ["This user has never received an infraction."], - }, - # Shows non-hidden inactive infraction as expected - { - "api response": [{"type": "kick", "active": False, "hidden": False}], - "expected_lines": ["Kicks: 1"], - }, - # Shows non-hidden active infraction as expected - { - "api response": [{"type": "mute", "active": True, "hidden": False}], - "expected_lines": ["Mutes: 1 (1 active)"], - }, - # Shows hidden inactive infraction as expected - { - "api response": [{"type": "superstar", "active": False, "hidden": True}], - "expected_lines": ["Superstars: 1"], - }, - # Shows hidden active infraction as expected - { - "api response": [{"type": "ban", "active": True, "hidden": True}], - "expected_lines": ["Bans: 1 (1 active)"], - }, - # Correctly displays tally of multiple infractions of mixed properties in alphabetical order - { - "api response": [ - {"type": "kick", "active": False, "hidden": True}, - {"type": "ban", "active": True, "hidden": True}, - {"type": "superstar", "active": True, "hidden": True}, - {"type": "mute", "active": True, "hidden": True}, - {"type": "ban", "active": False, "hidden": False}, - {"type": "note", "active": False, "hidden": True}, - {"type": "note", "active": False, "hidden": True}, - {"type": "warn", "active": False, "hidden": False}, - {"type": "note", "active": False, "hidden": True}, - ], - "expected_lines": [ - "Bans: 2 (1 active)", - "Kicks: 1", - "Mutes: 1 (1 active)", - "Notes: 3", - "Superstars: 1 (1 active)", - "Warns: 1", - ], - }, - ) - - header = ["**Infractions**"] - - self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) - - def test_user_nomination_counts_returns_correct_strings(self): - """The method should list the number of active and historical nominations for the user.""" - test_values = ( - { - "api response": [], - "expected_lines": ["This user has never been nominated."], - }, - { - "api response": [{'active': True}], - "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], - }, - { - "api response": [{'active': True}, {'active': False}], - "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], - }, - { - "api response": [{'active': False}], - "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."], - }, - { - "api response": [{'active': False}, {'active': False}], - "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."], - }, - - ) - - header = ["**Nominations**"] - - self._method_subtests(self.cog.user_nomination_counts, test_values, header) - - -@unittest.mock.patch("bot.cogs.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) -@unittest.mock.patch("bot.cogs.information.constants.MODERATION_CHANNELS", new=[50]) -class UserEmbedTests(unittest.TestCase): - """Tests for the creation of the `!user` embed.""" - - def setUp(self): - """Common set-up steps done before for each test.""" - self.bot = helpers.MockBot() - self.bot.api_client.get = unittest.mock.AsyncMock() - self.cog = information.Information(self.bot) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): - """The embed should use the string representation of the user if they don't have a nick.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - user = helpers.MockMember() - user.nick = None - user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.title, "Mr. Hemlock") - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_nick_in_title_if_available(self): - """The embed should use the nick if it's available.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - user = helpers.MockMember() - user.nick = "Cat lover" - user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_ignores_everyone_role(self): - """Created `!user` embeds should not contain mention of the @everyone-role.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - admins_role = helpers.MockRole(name='Admins') - admins_role.colour = 100 - - # A `MockMember` has the @Everyone role by default; we add the Admins to that. - user = helpers.MockMember(roles=[admins_role], top_role=admins_role) - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertIn("&Admins", embed.description) - self.assertNotIn("&Everyone", embed.description) - - @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): - """The embed should contain expanded infractions and nomination info in mod channels.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - infraction_counts.return_value = "expanded infractions info" - nomination_counts.return_value = "nomination info" - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - infraction_counts.assert_called_once_with(user) - nomination_counts.assert_called_once_with(user) - - self.assertEqual( - textwrap.dedent(f""" - **User Information** - Created: {"1 year ago"} - Profile: {user.mention} - ID: {user.id} - - **Member Information** - Joined: {"1 year ago"} - Roles: &Moderators - - expanded infractions info - - nomination info - """).strip(), - embed.description - ) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): - """The embed should contain only basic infraction data outside of mod channels.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - infraction_counts.return_value = "basic infractions info" - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - infraction_counts.assert_called_once_with(user) - - self.assertEqual( - textwrap.dedent(f""" - **User Information** - Created: {"1 year ago"} - Profile: {user.mention} - ID: {user.id} - - **Member Information** - Joined: {"1 year ago"} - Roles: &Moderators - - basic infractions info - """).strip(), - embed.description - ) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): - """The embed should be created with the colour of the top role, if a top role is available.""" - ctx = helpers.MockContext() - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): - """The embed should be created with a blurple colour if the user has no assigned roles.""" - ctx = helpers.MockContext() - - user = helpers.MockMember(id=217) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.colour, discord.Colour.blurple()) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): - """The embed thumbnail should be set to the user's avatar in `png` format.""" - ctx = helpers.MockContext() - - user = helpers.MockMember(id=217) - user.avatar_url_as.return_value = "avatar url" - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - user.avatar_url_as.assert_called_once_with(static_format="png") - self.assertEqual(embed.thumbnail.url, "avatar url") - - -@unittest.mock.patch("bot.cogs.information.constants") -class UserCommandTests(unittest.TestCase): - """Tests for the `!user` command.""" - - def setUp(self): - """Set up steps executed before each test is run.""" - self.bot = helpers.MockBot() - self.cog = information.Information(self.bot) - - self.moderator_role = helpers.MockRole(name="Moderators", id=2, position=10) - self.flautist_role = helpers.MockRole(name="Flautists", id=3, position=2) - self.bassist_role = helpers.MockRole(name="Bassists", id=4, position=3) - - self.author = helpers.MockMember(id=1, name="syntaxaire") - self.moderator = helpers.MockMember(id=2, name="riffautae", roles=[self.moderator_role]) - self.target = helpers.MockMember(id=3, name="__fluzz__") - - def test_regular_member_cannot_target_another_member(self, constants): - """A regular user should not be able to use `!user` targeting another user.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - - ctx = helpers.MockContext(author=self.author) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) - - ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") - - def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): - """A regular user should not be able to use this command outside of bot-commands.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) - - msg = "Sorry, but you may only use this command within <#50>." - with self.assertRaises(InWhitelistCheckFailure, msg=msg): - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) - def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): - """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - create_embed.assert_called_once_with(ctx, self.author) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) - def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): - """A user should target itself with `!user` when a `user` argument was not provided.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) - - create_embed.assert_called_once_with(ctx, self.author) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) - def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): - """Staff members should be able to bypass the bot-commands channel restriction.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - create_embed.assert_called_once_with(ctx, self.moderator) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) - def test_moderators_can_target_another_member(self, create_embed, constants): - """A moderator should be able to use `!user` targeting another user.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - constants.STAFF_ROLES = [self.moderator_role.id] - - ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) - - create_embed.assert_called_once_with(ctx, self.target) - ctx.send.assert_called_once() diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py deleted file mode 100644 index b4ad8535f..000000000 --- a/tests/bot/cogs/test_jams.py +++ /dev/null @@ -1,173 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, MagicMock, create_autospec - -from discord import CategoryChannel - -from bot.cogs import jams -from bot.constants import Roles -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel - - -def get_mock_category(channel_count: int, name: str) -> CategoryChannel: - """Return a mocked code jam category.""" - category = create_autospec(CategoryChannel, spec_set=True, instance=True) - category.name = name - category.channels = [MockTextChannel() for _ in range(channel_count)] - - return category - - -class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): - """Tests for `createteam` command.""" - - def setUp(self): - self.bot = MockBot() - self.admin_role = MockRole(name="Admins", id=Roles.admins) - self.command_user = MockMember([self.admin_role]) - self.guild = MockGuild([self.admin_role]) - self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) - self.cog = jams.CodeJams(self.bot) - - async def test_too_small_amount_of_team_members_passed(self): - """Should `ctx.send` and exit early when too small amount of members.""" - for case in (1, 2): - with self.subTest(amount_of_members=case): - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - self.ctx.reset_mock() - members = (MockMember() for _ in range(case)) - await self.cog.createteam(self.cog, self.ctx, "foo", members) - - self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_not_awaited() - self.cog.add_roles.assert_not_awaited() - - async def test_duplicate_members_provided(self): - """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - member = MockMember() - await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) - - self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_not_awaited() - self.cog.add_roles.assert_not_awaited() - - async def test_result_sending(self): - """Should call `ctx.send` when everything goes right.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - members = [MockMember() for _ in range(5)] - await self.cog.createteam(self.cog, self.ctx, "foo", members) - - self.cog.create_channels.assert_awaited_once() - self.cog.add_roles.assert_awaited_once() - self.ctx.send.assert_awaited_once() - - async def test_category_doesnt_exist(self): - """Should create a new code jam category.""" - subtests = ( - [], - [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], - [get_mock_category(jams.MAX_CHANNELS - 2, "other")], - ) - - for categories in subtests: - self.guild.reset_mock() - self.guild.categories = categories - - with self.subTest(categories=categories): - actual_category = await self.cog.get_category(self.guild) - - self.guild.create_category_channel.assert_awaited_once() - category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] - - self.assertFalse(category_overwrites[self.guild.default_role].read_messages) - self.assertTrue(category_overwrites[self.guild.me].read_messages) - self.assertEqual(self.guild.create_category_channel.return_value, actual_category) - - async def test_category_channel_exist(self): - """Should not try to create category channel.""" - expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) - self.guild.categories = [ - get_mock_category(jams.MAX_CHANNELS - 2, "other"), - expected_category, - get_mock_category(0, jams.CATEGORY_NAME), - ] - - actual_category = await self.cog.get_category(self.guild) - self.assertEqual(expected_category, actual_category) - - async def test_channel_overwrites(self): - """Should have correct permission overwrites for users and roles.""" - leader = MockMember() - members = [leader] + [MockMember() for _ in range(4)] - overwrites = self.cog.get_overwrites(members, self.guild) - - # Leader permission overwrites - self.assertTrue(overwrites[leader].manage_messages) - self.assertTrue(overwrites[leader].read_messages) - self.assertTrue(overwrites[leader].manage_webhooks) - self.assertTrue(overwrites[leader].connect) - - # Other members permission overwrites - for member in members[1:]: - self.assertTrue(overwrites[member].read_messages) - self.assertTrue(overwrites[member].connect) - - # Everyone and verified role overwrite - self.assertFalse(overwrites[self.guild.default_role].read_messages) - self.assertFalse(overwrites[self.guild.default_role].connect) - self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) - self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) - - async def test_team_channels_creation(self): - """Should create new voice and text channel for team.""" - members = [MockMember() for _ in range(5)] - - self.cog.get_overwrites = MagicMock() - self.cog.get_category = AsyncMock() - self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") - actual = await self.cog.create_channels(self.guild, "my-team", members) - - self.assertEqual("foobar-channel", actual) - self.cog.get_overwrites.assert_called_once_with(members, self.guild) - self.cog.get_category.assert_awaited_once_with(self.guild) - - self.guild.create_text_channel.assert_awaited_once_with( - "my-team", - overwrites=self.cog.get_overwrites.return_value, - category=self.cog.get_category.return_value - ) - self.guild.create_voice_channel.assert_awaited_once_with( - "My Team", - overwrites=self.cog.get_overwrites.return_value, - category=self.cog.get_category.return_value - ) - - async def test_jam_roles_adding(self): - """Should add team leader role to leader and jam role to every team member.""" - leader_role = MockRole(name="Team Leader") - jam_role = MockRole(name="Jammer") - self.guild.get_role.side_effect = [leader_role, jam_role] - - leader = MockMember() - members = [leader] + [MockMember() for _ in range(4)] - await self.cog.add_roles(self.guild, members) - - leader.add_roles.assert_any_await(leader_role) - for member in members: - member.add_roles.assert_any_await(jam_role) - - -class CodeJamSetup(unittest.TestCase): - """Test for `setup` function of `CodeJam` cog.""" - - def test_setup(self): - """Should call `bot.add_cog`.""" - bot = MockBot() - jams.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_logging.py b/tests/bot/cogs/test_logging.py deleted file mode 100644 index 8a18fdcd6..000000000 --- a/tests/bot/cogs/test_logging.py +++ /dev/null @@ -1,32 +0,0 @@ -import unittest -from unittest.mock import patch - -from bot import constants -from bot.cogs.logging import Logging -from tests.helpers import MockBot, MockTextChannel - - -class LoggingTests(unittest.IsolatedAsyncioTestCase): - """Test cases for connected login.""" - - def setUp(self): - self.bot = MockBot() - self.cog = Logging(self.bot) - self.dev_log = MockTextChannel(id=1234, name="dev-log") - - @patch("bot.cogs.logging.DEBUG_MODE", False) - async def test_debug_mode_false(self): - """Should send connected message to dev-log.""" - self.bot.get_channel.return_value = self.dev_log - - await self.cog.startup_greeting() - self.bot.wait_until_guild_available.assert_awaited_once_with() - self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) - self.dev_log.send.assert_awaited_once() - - @patch("bot.cogs.logging.DEBUG_MODE", True) - async def test_debug_mode_true(self): - """Should not send anything to dev-log.""" - await self.cog.startup_greeting() - self.bot.wait_until_guild_available.assert_awaited_once_with() - self.bot.get_channel.assert_not_called() diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py deleted file mode 100644 index 82679f69c..000000000 --- a/tests/bot/cogs/test_security.py +++ /dev/null @@ -1,54 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from discord.ext.commands import NoPrivateMessage - -from bot.cogs.filters import security -from tests.helpers import MockBot, MockContext - - -class SecurityCogTests(unittest.TestCase): - """Tests the `Security` cog.""" - - def setUp(self): - """Attach an instance of the cog to the class for tests.""" - self.bot = MockBot() - self.cog = security.Security(self.bot) - self.ctx = MockContext() - - def test_check_additions(self): - """The cog should add its checks after initialization.""" - self.bot.check.assert_any_call(self.cog.check_on_guild) - self.bot.check.assert_any_call(self.cog.check_not_bot) - - def test_check_not_bot_returns_false_for_humans(self): - """The bot check should return `True` when invoked with human authors.""" - self.ctx.author.bot = False - self.assertTrue(self.cog.check_not_bot(self.ctx)) - - def test_check_not_bot_returns_true_for_robots(self): - """The bot check should return `False` when invoked with robotic authors.""" - self.ctx.author.bot = True - self.assertFalse(self.cog.check_not_bot(self.ctx)) - - def test_check_on_guild_raises_when_outside_of_guild(self): - """When invoked outside of a guild, `check_on_guild` should cause an error.""" - self.ctx.guild = None - - with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): - self.cog.check_on_guild(self.ctx) - - def test_check_on_guild_returns_true_inside_of_guild(self): - """When invoked inside of a guild, `check_on_guild` should return `True`.""" - self.ctx.guild = "lemon's lemonade stand" - self.assertTrue(self.cog.check_on_guild(self.ctx)) - - -class SecurityCogLoadTests(unittest.TestCase): - """Tests loading the `Security` cog.""" - - def test_security_cog_load(self): - """Setup of the extension should call add_cog.""" - bot = MagicMock() - security.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py deleted file mode 100644 index f442814c8..000000000 --- a/tests/bot/cogs/test_slowmode.py +++ /dev/null @@ -1,111 +0,0 @@ -import unittest -from unittest import mock - -from dateutil.relativedelta import relativedelta - -from bot.cogs.moderation.slowmode import Slowmode -from bot.constants import Emojis -from tests.helpers import MockBot, MockContext, MockTextChannel - - -class SlowmodeTests(unittest.IsolatedAsyncioTestCase): - - def setUp(self) -> None: - self.bot = MockBot() - self.cog = Slowmode(self.bot) - self.ctx = MockContext() - - async def test_get_slowmode_no_channel(self) -> None: - """Get slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) - - await self.cog.get_slowmode(self.cog, self.ctx, None) - self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") - - async def test_get_slowmode_with_channel(self) -> None: - """Get slowmode with a given channel.""" - text_channel = MockTextChannel(name='python-language', slowmode_delay=2) - - await self.cog.get_slowmode(self.cog, self.ctx, text_channel) - self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') - - async def test_set_slowmode_no_channel(self) -> None: - """Set slowmode without a given channel.""" - test_cases = ( - ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), - ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), - ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') - ) - - for channel_name, seconds, edited, result_msg in test_cases: - with self.subTest( - channel_mention=channel_name, - seconds=seconds, - edited=edited, - result_msg=result_msg - ): - self.ctx.channel = MockTextChannel(name=channel_name) - - await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) - - if edited: - self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) - else: - self.ctx.channel.edit.assert_not_called() - - self.ctx.send.assert_called_once_with(result_msg) - - self.ctx.reset_mock() - - async def test_set_slowmode_with_channel(self) -> None: - """Set slowmode with a given channel.""" - test_cases = ( - ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), - ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), - ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') - ) - - for channel_name, seconds, edited, result_msg in test_cases: - with self.subTest( - channel_mention=channel_name, - seconds=seconds, - edited=edited, - result_msg=result_msg - ): - text_channel = MockTextChannel(name=channel_name) - - await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) - - if edited: - text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) - else: - text_channel.edit.assert_not_called() - - self.ctx.send.assert_called_once_with(result_msg) - - self.ctx.reset_mock() - - async def test_reset_slowmode_no_channel(self) -> None: - """Reset slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) - - await self.cog.reset_slowmode(self.cog, self.ctx, None) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' - ) - - async def test_reset_slowmode_with_channel(self) -> None: - """Reset slowmode with a given channel.""" - text_channel = MockTextChannel(name='meta', slowmode_delay=1) - - await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' - ) - - @mock.patch("bot.cogs.moderation.slowmode.with_role_check") - @mock.patch("bot.cogs.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) - def test_cog_check(self, role_check): - """Role check is called with `MODERATION_ROLES`""" - self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py deleted file mode 100644 index c7bac3ab3..000000000 --- a/tests/bot/cogs/test_snekbox.py +++ /dev/null @@ -1,409 +0,0 @@ -import asyncio -import logging -import unittest -from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch - -from discord.ext import commands - -from bot import constants -from bot.cogs.utils import snekbox -from bot.cogs.utils.snekbox import Snekbox -from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser - - -class SnekboxTests(unittest.IsolatedAsyncioTestCase): - def setUp(self): - """Add mocked bot and cog to the instance.""" - self.bot = MockBot() - self.cog = Snekbox(bot=self.bot) - - async def test_post_eval(self): - """Post the eval code to the URLs.snekbox_eval_api endpoint.""" - resp = MagicMock() - resp.json = AsyncMock(return_value="return") - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual(await self.cog.post_eval("import random"), "return") - self.bot.http_session.post.assert_called_with( - constants.URLs.snekbox_eval_api, - json={"input": "import random"}, - raise_for_status=True - ) - resp.json.assert_awaited_once() - - async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LEN.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) - self.assertEqual(result, "too long to upload") - - async def test_upload_output(self): - """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - key = "MarkDiamond" - resp = MagicMock() - resp.json = AsyncMock(return_value={"key": key}) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual( - await self.cog.upload_output("My awesome output"), - constants.URLs.paste_service.format(key=key) - ) - self.bot.http_session.post.assert_called_with( - constants.URLs.paste_service.format(key="documents"), - data="My awesome output", - raise_for_status=True - ) - - async def test_upload_output_gracefully_fallback_if_exception_during_request(self): - """Output upload gracefully fallback if the upload fail.""" - resp = MagicMock() - resp.json = AsyncMock(side_effect=Exception) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - log = logging.getLogger("bot.cogs.snekbox") - with self.assertLogs(logger=log, level='ERROR'): - await self.cog.upload_output('My awesome output!') - - async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): - """Output upload gracefully fallback if there is no key entry in the response body.""" - self.assertEqual((await self.cog.upload_output('My awesome output!')), None) - - def test_prepare_input(self): - cases = ( - ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), - ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), - ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), - ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), - ) - for case, expected, testname in cases: - with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) - - def test_get_results_message(self): - """Return error and message according to the eval result.""" - cases = ( - ('ERROR', None, ('Your eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), - ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) - - @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) - def test_get_results_message_invalid_signal(self, mock_signals: Mock): - self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127', '') - ) - - @patch('bot.cogs.snekbox.Signals') - def test_get_results_message_valid_signal(self, mock_signals: Mock): - mock_signals.return_value.name = 'SIGTEST' - self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127 (SIGTEST)', '') - ) - - def test_get_status_emoji(self): - """Return emoji according to the eval result.""" - cases = ( - (' ', -1, ':warning:'), - ('Hello world!', 0, ':white_check_mark:'), - ('Invalid beard size', -1, ':x:') - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) - - async def test_format_output(self): - """Test output formatting.""" - self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') - - too_many_lines = ( - '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' - '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' - ) - too_long_too_many_lines = ( - "\n".join( - f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) - )[:1000] + "\n... (truncated - too long, too many lines)" - ) - - cases = ( - ('', ('[No output]', None), 'No output'), - ('My awesome output', ('My awesome output', None), 'One line output'), - ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), - (' CategoryChannel: + """Return a mocked code jam category.""" + category = create_autospec(CategoryChannel, spec_set=True, instance=True) + category.name = name + category.channels = [MockTextChannel() for _ in range(channel_count)] + + return category + + +class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): + """Tests for `createteam` command.""" + + def setUp(self): + self.bot = MockBot() + self.admin_role = MockRole(name="Admins", id=Roles.admins) + self.command_user = MockMember([self.admin_role]) + self.guild = MockGuild([self.admin_role]) + self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) + self.cog = jams.CodeJams(self.bot) + + async def test_too_small_amount_of_team_members_passed(self): + """Should `ctx.send` and exit early when too small amount of members.""" + for case in (1, 2): + with self.subTest(amount_of_members=case): + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + self.ctx.reset_mock() + members = (MockMember() for _ in range(case)) + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + self.ctx.send.assert_awaited_once() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() + + async def test_duplicate_members_provided(self): + """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + member = MockMember() + await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) + + self.ctx.send.assert_awaited_once() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() + + async def test_result_sending(self): + """Should call `ctx.send` when everything goes right.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + members = [MockMember() for _ in range(5)] + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + self.cog.create_channels.assert_awaited_once() + self.cog.add_roles.assert_awaited_once() + self.ctx.send.assert_awaited_once() + + async def test_category_doesnt_exist(self): + """Should create a new code jam category.""" + subtests = ( + [], + [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], + [get_mock_category(jams.MAX_CHANNELS - 2, "other")], + ) + + for categories in subtests: + self.guild.reset_mock() + self.guild.categories = categories + + with self.subTest(categories=categories): + actual_category = await self.cog.get_category(self.guild) + + self.guild.create_category_channel.assert_awaited_once() + category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] + + self.assertFalse(category_overwrites[self.guild.default_role].read_messages) + self.assertTrue(category_overwrites[self.guild.me].read_messages) + self.assertEqual(self.guild.create_category_channel.return_value, actual_category) + + async def test_category_channel_exist(self): + """Should not try to create category channel.""" + expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) + self.guild.categories = [ + get_mock_category(jams.MAX_CHANNELS - 2, "other"), + expected_category, + get_mock_category(0, jams.CATEGORY_NAME), + ] + + actual_category = await self.cog.get_category(self.guild) + self.assertEqual(expected_category, actual_category) + + async def test_channel_overwrites(self): + """Should have correct permission overwrites for users and roles.""" + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + overwrites = self.cog.get_overwrites(members, self.guild) + + # Leader permission overwrites + self.assertTrue(overwrites[leader].manage_messages) + self.assertTrue(overwrites[leader].read_messages) + self.assertTrue(overwrites[leader].manage_webhooks) + self.assertTrue(overwrites[leader].connect) + + # Other members permission overwrites + for member in members[1:]: + self.assertTrue(overwrites[member].read_messages) + self.assertTrue(overwrites[member].connect) + + # Everyone and verified role overwrite + self.assertFalse(overwrites[self.guild.default_role].read_messages) + self.assertFalse(overwrites[self.guild.default_role].connect) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) + + async def test_team_channels_creation(self): + """Should create new voice and text channel for team.""" + members = [MockMember() for _ in range(5)] + + self.cog.get_overwrites = MagicMock() + self.cog.get_category = AsyncMock() + self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") + actual = await self.cog.create_channels(self.guild, "my-team", members) + + self.assertEqual("foobar-channel", actual) + self.cog.get_overwrites.assert_called_once_with(members, self.guild) + self.cog.get_category.assert_awaited_once_with(self.guild) + + self.guild.create_text_channel.assert_awaited_once_with( + "my-team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + self.guild.create_voice_channel.assert_awaited_once_with( + "My Team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + + async def test_jam_roles_adding(self): + """Should add team leader role to leader and jam role to every team member.""" + leader_role = MockRole(name="Team Leader") + jam_role = MockRole(name="Jammer") + self.guild.get_role.side_effect = [leader_role, jam_role] + + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + await self.cog.add_roles(self.guild, members) + + leader.add_roles.assert_any_await(leader_role) + for member in members: + member.add_roles.assert_any_await(jam_role) + + +class CodeJamSetup(unittest.TestCase): + """Test for `setup` function of `CodeJam` cog.""" + + def test_setup(self): + """Should call `bot.add_cog`.""" + bot = MockBot() + jams.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/utils/test_snekbox.py b/tests/bot/cogs/utils/test_snekbox.py new file mode 100644 index 000000000..3e447f319 --- /dev/null +++ b/tests/bot/cogs/utils/test_snekbox.py @@ -0,0 +1,409 @@ +import asyncio +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch + +from discord.ext import commands + +from bot import constants +from bot.cogs.utils import snekbox +from bot.cogs.utils.snekbox import Snekbox +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser + + +class SnekboxTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Add mocked bot and cog to the instance.""" + self.bot = MockBot() + self.cog = Snekbox(bot=self.bot) + + async def test_post_eval(self): + """Post the eval code to the URLs.snekbox_eval_api endpoint.""" + resp = MagicMock() + resp.json = AsyncMock(return_value="return") + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + self.assertEqual(await self.cog.post_eval("import random"), "return") + self.bot.http_session.post.assert_called_with( + constants.URLs.snekbox_eval_api, + json={"input": "import random"}, + raise_for_status=True + ) + resp.json.assert_awaited_once() + + async def test_upload_output_reject_too_long(self): + """Reject output longer than MAX_PASTE_LEN.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + self.assertEqual(result, "too long to upload") + + async def test_upload_output(self): + """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" + key = "MarkDiamond" + resp = MagicMock() + resp.json = AsyncMock(return_value={"key": key}) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + self.assertEqual( + await self.cog.upload_output("My awesome output"), + constants.URLs.paste_service.format(key=key) + ) + self.bot.http_session.post.assert_called_with( + constants.URLs.paste_service.format(key="documents"), + data="My awesome output", + raise_for_status=True + ) + + async def test_upload_output_gracefully_fallback_if_exception_during_request(self): + """Output upload gracefully fallback if the upload fail.""" + resp = MagicMock() + resp.json = AsyncMock(side_effect=Exception) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + log = logging.getLogger("bot.cogs.utils.snekbox") + with self.assertLogs(logger=log, level='ERROR'): + await self.cog.upload_output('My awesome output!') + + async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): + """Output upload gracefully fallback if there is no key entry in the response body.""" + self.assertEqual((await self.cog.upload_output('My awesome output!')), None) + + def test_prepare_input(self): + cases = ( + ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), + ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), + ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), + ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), + ) + for case, expected, testname in cases: + with self.subTest(msg=f'Extract code from {testname}.'): + self.assertEqual(self.cog.prepare_input(case), expected) + + def test_get_results_message(self): + """Return error and message according to the eval result.""" + cases = ( + ('ERROR', None, ('Your eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), + ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + @patch('bot.cogs.utils.snekbox.Signals', side_effect=ValueError) + def test_get_results_message_invalid_signal(self, mock_signals: Mock): + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127', '') + ) + + @patch('bot.cogs.utils.snekbox.Signals') + def test_get_results_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = 'SIGTEST' + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127 (SIGTEST)', '') + ) + + def test_get_status_emoji(self): + """Return emoji according to the eval result.""" + cases = ( + (' ', -1, ':warning:'), + ('Hello world!', 0, ':white_check_mark:'), + ('Invalid beard size', -1, ':x:') + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + async def test_format_output(self): + """Test output formatting.""" + self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + + too_many_lines = ( + '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' + '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' + ) + too_long_too_many_lines = ( + "\n".join( + f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) + )[:1000] + "\n... (truncated - too long, too many lines)" + ) + + cases = ( + ('', ('[No output]', None), 'No output'), + ('My awesome output', ('My awesome output', None), 'One line output'), + ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), + (' Date: Wed, 12 Aug 2020 22:31:08 -0700 Subject: Prefix names of non-extension modules with _ This naming scheme will make them easy to distinguish from extensions. --- bot/cogs/backend/sync/__init__.py | 2 +- bot/cogs/backend/sync/_cog.py | 180 ++++++++ bot/cogs/backend/sync/_syncers.py | 347 +++++++++++++++ bot/cogs/backend/sync/cog.py | 180 -------- bot/cogs/backend/sync/syncers.py | 347 --------------- bot/cogs/moderation/__init__.py | 19 - bot/cogs/moderation/incidents.py | 5 + bot/cogs/moderation/infraction/_scheduler.py | 463 +++++++++++++++++++++ bot/cogs/moderation/infraction/_utils.py | 201 +++++++++ bot/cogs/moderation/infraction/infractions.py | 31 +- bot/cogs/moderation/infraction/management.py | 11 +- bot/cogs/moderation/infraction/scheduler.py | 463 --------------------- bot/cogs/moderation/infraction/superstarify.py | 29 +- bot/cogs/moderation/infraction/utils.py | 201 --------- bot/cogs/moderation/modlog.py | 5 + bot/cogs/moderation/silence.py | 5 + bot/cogs/moderation/watchchannels/__init__.py | 9 - bot/cogs/moderation/watchchannels/_watchchannel.py | 348 ++++++++++++++++ bot/cogs/moderation/watchchannels/bigbrother.py | 9 +- bot/cogs/moderation/watchchannels/talentpool.py | 7 +- bot/cogs/moderation/watchchannels/watchchannel.py | 348 ---------------- tests/bot/cogs/backend/sync/test_base.py | 2 +- tests/bot/cogs/backend/sync/test_cog.py | 15 +- tests/bot/cogs/backend/sync/test_roles.py | 2 +- tests/bot/cogs/backend/sync/test_users.py | 2 +- .../cogs/moderation/infraction/test_infractions.py | 6 +- 26 files changed, 1625 insertions(+), 1612 deletions(-) create mode 100644 bot/cogs/backend/sync/_cog.py create mode 100644 bot/cogs/backend/sync/_syncers.py delete mode 100644 bot/cogs/backend/sync/cog.py delete mode 100644 bot/cogs/backend/sync/syncers.py create mode 100644 bot/cogs/moderation/infraction/_scheduler.py create mode 100644 bot/cogs/moderation/infraction/_utils.py delete mode 100644 bot/cogs/moderation/infraction/scheduler.py delete mode 100644 bot/cogs/moderation/infraction/utils.py create mode 100644 bot/cogs/moderation/watchchannels/_watchchannel.py delete mode 100644 bot/cogs/moderation/watchchannels/watchchannel.py (limited to 'tests') diff --git a/bot/cogs/backend/sync/__init__.py b/bot/cogs/backend/sync/__init__.py index fe7df4e9b..fb640a1cf 100644 --- a/bot/cogs/backend/sync/__init__.py +++ b/bot/cogs/backend/sync/__init__.py @@ -1,5 +1,5 @@ from bot.bot import Bot -from .cog import Sync +from ._cog import Sync def setup(bot: Bot) -> None: diff --git a/bot/cogs/backend/sync/_cog.py b/bot/cogs/backend/sync/_cog.py new file mode 100644 index 000000000..b6068f328 --- /dev/null +++ b/bot/cogs/backend/sync/_cog.py @@ -0,0 +1,180 @@ +import logging +from typing import Any, Dict + +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from . import _syncers + +log = logging.getLogger(__name__) + + +class Sync(Cog): + """Captures relevant events and sends them to the site.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.role_syncer = _syncers.RoleSyncer(self.bot) + self.user_syncer = _syncers.UserSyncer(self.bot) + + self.bot.loop.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Syncs the roles/users of the guild with the database.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: + """Send a PATCH request to partially update a user in the database.""" + try: + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) + except ResponseCodeError as e: + if e.response.status != 404: + raise + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") + + @Cog.listener() + async def on_guild_role_create(self, role: Role) -> None: + """Adds newly create role to the database table over the API.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.post( + 'bot/roles', + json={ + 'colour': role.colour.value, + 'id': role.id, + 'name': role.name, + 'permissions': role.permissions.value, + 'position': role.position, + } + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: Role) -> None: + """Deletes role from the database when it's deleted from the guild.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + @Cog.listener() + async def on_guild_role_update(self, before: Role, after: Role) -> None: + """Syncs role with the database if any of the stored attributes were updated.""" + if after.guild.id != constants.Guild.id: + return + + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: + await self.bot.api_client.put( + f'bot/roles/{after.id}', + json={ + 'colour': after.colour.value, + 'id': after.id, + 'name': after.name, + 'permissions': after.permissions.value, + 'position': after.position, + } + ) + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """ + Adds a new user or updates existing user to the database when a member joins the guild. + + If the joining member is a user that is already known to the database (i.e., a user that + previously left), it will update the user's information. If the user is not yet known by + the database, the user is added. + """ + if member.guild.id != constants.Guild.id: + return + + packed = { + 'discriminator': int(member.discriminator), + 'id': member.id, + 'in_guild': True, + 'name': member.name, + 'roles': sorted(role.id for role in member.roles) + } + + got_error = False + + try: + # First try an update of the user to set the `in_guild` field and other + # fields that may have changed since the last time we've seen them. + await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) + + except ResponseCodeError as e: + # If we didn't get 404, something else broke - propagate it up. + if e.response.status != 404: + raise + + got_error = True # yikes + + if got_error: + # If we got `404`, the user is new. Create them. + await self.bot.api_client.post('bot/users', json=packed) + + @Cog.listener() + async def on_member_remove(self, member: Member) -> None: + """Set the in_guild field to False when a member leaves the guild.""" + if member.guild.id != constants.Guild.id: + return + + await self.patch_user(member.id, json={"in_guild": False}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Update the roles of the member in the database if a change is detected.""" + if after.guild.id != constants.Guild.id: + return + + if before.roles != after.roles: + updated_information = {"roles": sorted(role.id for role in after.roles)} + await self.patch_user(after.id, json=updated_information) + + @Cog.listener() + async def on_user_update(self, before: User, after: User) -> None: + """Update the user information in the database if a relevant change is detected.""" + attrs = ("name", "discriminator") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): + updated_information = { + "name": after.name, + "discriminator": int(after.discriminator), + } + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) + + @commands.group(name='sync') + @commands.has_permissions(administrator=True) + async def sync_group(self, ctx: Context) -> None: + """Run synchronizations between the bot and site manually.""" + + @sync_group.command(name='roles') + @commands.has_permissions(administrator=True) + async def sync_roles_command(self, ctx: Context) -> None: + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) + + @sync_group.command(name='users') + @commands.has_permissions(administrator=True) + async def sync_users_command(self, ctx: Context) -> None: + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/backend/sync/_syncers.py b/bot/cogs/backend/sync/_syncers.py new file mode 100644 index 000000000..f7ba811bc --- /dev/null +++ b/bot/cogs/backend/sync/_syncers.py @@ -0,0 +1,347 @@ +import abc +import asyncio +import logging +import typing as t +from collections import namedtuple +from functools import partial + +import discord +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot + +log = logging.getLogger(__name__) + +# These objects are declared as namedtuples because tuples are hashable, +# something that we make use of when diffing site roles against guild roles. +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) + + +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' + ) + + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + allowed_roles = [discord.Object(constants.Roles.core_developers)] + message = await channel.send( + f"{self._CORE_DEV_MENTION}{msg_content}", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS + ) + + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except asyncio.TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.info(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/backend/sync/cog.py b/bot/cogs/backend/sync/cog.py deleted file mode 100644 index 274845a50..000000000 --- a/bot/cogs/backend/sync/cog.py +++ /dev/null @@ -1,180 +0,0 @@ -import logging -from typing import Any, Dict - -from discord import Member, Role, User -from discord.ext import commands -from discord.ext.commands import Cog, Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from . import syncers - -log = logging.getLogger(__name__) - - -class Sync(Cog): - """Captures relevant events and sends them to the site.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.role_syncer = syncers.RoleSyncer(self.bot) - self.user_syncer = syncers.UserSyncer(self.bot) - - self.bot.loop.create_task(self.sync_guild()) - - async def sync_guild(self) -> None: - """Syncs the roles/users of the guild with the database.""" - await self.bot.wait_until_guild_available() - - guild = self.bot.get_guild(constants.Guild.id) - if guild is None: - return - - for syncer in (self.role_syncer, self.user_syncer): - await syncer.sync(guild) - - async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: - """Send a PATCH request to partially update a user in the database.""" - try: - await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) - except ResponseCodeError as e: - if e.response.status != 404: - raise - if not ignore_404: - log.warning("Unable to update user, got 404. Assuming race condition from join event.") - - @Cog.listener() - async def on_guild_role_create(self, role: Role) -> None: - """Adds newly create role to the database table over the API.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.post( - 'bot/roles', - json={ - 'colour': role.colour.value, - 'id': role.id, - 'name': role.name, - 'permissions': role.permissions.value, - 'position': role.position, - } - ) - - @Cog.listener() - async def on_guild_role_delete(self, role: Role) -> None: - """Deletes role from the database when it's deleted from the guild.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - @Cog.listener() - async def on_guild_role_update(self, before: Role, after: Role) -> None: - """Syncs role with the database if any of the stored attributes were updated.""" - if after.guild.id != constants.Guild.id: - return - - was_updated = ( - before.name != after.name - or before.colour != after.colour - or before.permissions != after.permissions - or before.position != after.position - ) - - if was_updated: - await self.bot.api_client.put( - f'bot/roles/{after.id}', - json={ - 'colour': after.colour.value, - 'id': after.id, - 'name': after.name, - 'permissions': after.permissions.value, - 'position': after.position, - } - ) - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """ - Adds a new user or updates existing user to the database when a member joins the guild. - - If the joining member is a user that is already known to the database (i.e., a user that - previously left), it will update the user's information. If the user is not yet known by - the database, the user is added. - """ - if member.guild.id != constants.Guild.id: - return - - packed = { - 'discriminator': int(member.discriminator), - 'id': member.id, - 'in_guild': True, - 'name': member.name, - 'roles': sorted(role.id for role in member.roles) - } - - got_error = False - - try: - # First try an update of the user to set the `in_guild` field and other - # fields that may have changed since the last time we've seen them. - await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) - - except ResponseCodeError as e: - # If we didn't get 404, something else broke - propagate it up. - if e.response.status != 404: - raise - - got_error = True # yikes - - if got_error: - # If we got `404`, the user is new. Create them. - await self.bot.api_client.post('bot/users', json=packed) - - @Cog.listener() - async def on_member_remove(self, member: Member) -> None: - """Set the in_guild field to False when a member leaves the guild.""" - if member.guild.id != constants.Guild.id: - return - - await self.patch_user(member.id, json={"in_guild": False}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Update the roles of the member in the database if a change is detected.""" - if after.guild.id != constants.Guild.id: - return - - if before.roles != after.roles: - updated_information = {"roles": sorted(role.id for role in after.roles)} - await self.patch_user(after.id, json=updated_information) - - @Cog.listener() - async def on_user_update(self, before: User, after: User) -> None: - """Update the user information in the database if a relevant change is detected.""" - attrs = ("name", "discriminator") - if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): - updated_information = { - "name": after.name, - "discriminator": int(after.discriminator), - } - # A 404 likely means the user is in another guild. - await self.patch_user(after.id, json=updated_information, ignore_404=True) - - @commands.group(name='sync') - @commands.has_permissions(administrator=True) - async def sync_group(self, ctx: Context) -> None: - """Run synchronizations between the bot and site manually.""" - - @sync_group.command(name='roles') - @commands.has_permissions(administrator=True) - async def sync_roles_command(self, ctx: Context) -> None: - """Manually synchronise the guild's roles with the roles on the site.""" - await self.role_syncer.sync(ctx.guild, ctx) - - @sync_group.command(name='users') - @commands.has_permissions(administrator=True) - async def sync_users_command(self, ctx: Context) -> None: - """Manually synchronise the guild's users with the users on the site.""" - await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/backend/sync/syncers.py b/bot/cogs/backend/sync/syncers.py deleted file mode 100644 index f7ba811bc..000000000 --- a/bot/cogs/backend/sync/syncers.py +++ /dev/null @@ -1,347 +0,0 @@ -import abc -import asyncio -import logging -import typing as t -from collections import namedtuple -from functools import partial - -import discord -from discord import Guild, HTTPException, Member, Message, Reaction, User -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot - -log = logging.getLogger(__name__) - -# These objects are declared as namedtuples because tuples are hashable, -# something that we make use of when diffing site roles against guild roles. -_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) -_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) - - -class Syncer(abc.ABC): - """Base class for synchronising the database with objects in the Discord cache.""" - - _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " - _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - @property - @abc.abstractmethod - def name(self) -> str: - """The name of the syncer; used in output messages and logging.""" - raise NotImplementedError # pragma: no cover - - async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: - """ - Send a prompt to confirm or abort a sync using reactions and return the sent message. - - If a message is given, it is edited to display the prompt and reactions. Otherwise, a new - message is sent to the dev-core channel and mentions the core developers role. If the - channel cannot be retrieved, return None. - """ - log.trace(f"Sending {self.name} sync confirmation prompt.") - - msg_content = ( - f'Possible cache issue while syncing {self.name}s. ' - f'More than {constants.Sync.max_diff} {self.name}s were changed. ' - f'React to confirm or abort the sync.' - ) - - # Send to core developers if it's an automatic sync. - if not message: - log.trace("Message not provided for confirmation; creating a new one in dev-core.") - channel = self.bot.get_channel(constants.Channels.dev_core) - - if not channel: - log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") - try: - channel = await self.bot.fetch_channel(constants.Channels.dev_core) - except HTTPException: - log.exception( - f"Failed to fetch channel for sending sync confirmation prompt; " - f"aborting {self.name} sync." - ) - return None - - allowed_roles = [discord.Object(constants.Roles.core_developers)] - message = await channel.send( - f"{self._CORE_DEV_MENTION}{msg_content}", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) - ) - else: - await message.edit(content=msg_content) - - # Add the initial reactions. - log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") - for emoji in self._REACTION_EMOJIS: - await message.add_reaction(emoji) - - return message - - def _reaction_check( - self, - author: Member, - message: Message, - reaction: Reaction, - user: t.Union[Member, User] - ) -> bool: - """ - Return True if the `reaction` is a valid confirmation or abort reaction on `message`. - - If the `author` of the prompt is a bot, then a reaction by any core developer will be - considered valid. Otherwise, the author of the reaction (`user`) will have to be the - `author` of the prompt. - """ - # For automatic syncs, check for the core dev role instead of an exact author - has_role = any(constants.Roles.core_developers == role.id for role in user.roles) - return ( - reaction.message.id == message.id - and not user.bot - and (has_role if author.bot else user == author) - and str(reaction.emoji) in self._REACTION_EMOJIS - ) - - async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: - """ - Wait for a confirmation reaction by `author` on `message` and return True if confirmed. - - Uses the `_reaction_check` function to determine if a reaction is valid. - - If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. - To acknowledge the reaction (or lack thereof), `message` will be edited. - """ - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - reaction = None - try: - log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") - reaction, _ = await self.bot.wait_for( - 'reaction_add', - check=partial(self._reaction_check, author, message), - timeout=constants.Sync.confirm_timeout - ) - except asyncio.TimeoutError: - # reaction will remain none thus sync will be aborted in the finally block below. - log.debug(f"The {self.name} syncer confirmation prompt timed out.") - - if str(reaction) == constants.Emojis.check_mark: - log.trace(f"The {self.name} syncer was confirmed.") - await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') - return True - else: - log.info(f"The {self.name} syncer was aborted or timed out!") - await message.edit( - content=f':warning: {mention}{self.name} sync aborted or timed out!' - ) - return False - - @abc.abstractmethod - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference between the cache of `guild` and the database.""" - raise NotImplementedError # pragma: no cover - - @abc.abstractmethod - async def _sync(self, diff: _Diff) -> None: - """Perform the API calls for synchronisation.""" - raise NotImplementedError # pragma: no cover - - async def _get_confirmation_result( - self, - diff_size: int, - author: Member, - message: t.Optional[Message] = None - ) -> t.Tuple[bool, t.Optional[Message]]: - """ - Prompt for confirmation and return a tuple of the result and the prompt message. - - `diff_size` is the size of the diff of the sync. If it is greater than - `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the - sync and the `message` is an extant message to edit to display the prompt. - - If confirmed or no confirmation was needed, the result is True. The returned message will - either be the given `message` or a new one which was created when sending the prompt. - """ - log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") - if diff_size > constants.Sync.max_diff: - message = await self._send_prompt(message) - if not message: - return False, None # Couldn't get channel. - - confirmed = await self._wait_for_confirmation(author, message) - if not confirmed: - return False, message # Sync aborted. - - return True, message - - async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: - """ - Synchronise the database with the cache of `guild`. - - If the differences between the cache and the database are greater than - `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core - channel. The confirmation can be optionally redirect to `ctx` instead. - """ - log.info(f"Starting {self.name} syncer.") - - message = None - author = self.bot.user - if ctx: - message = await ctx.send(f"📊 Synchronising {self.name}s.") - author = ctx.author - - diff = await self._get_diff(guild) - diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict - totals = {k: len(v) for k, v in diff_dict.items() if v is not None} - diff_size = sum(totals.values()) - - confirmed, message = await self._get_confirmation_result(diff_size, author, message) - if not confirmed: - return - - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - try: - await self._sync(diff) - except ResponseCodeError as e: - log.exception(f"{self.name} syncer failed!") - - # Don't show response text because it's probably some really long HTML. - results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" - content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" - else: - results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) - log.info(f"{self.name} syncer finished: {results}.") - content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" - - if message: - await message.edit(content=content) - - -class RoleSyncer(Syncer): - """Synchronise the database with roles in the cache.""" - - name = "role" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of roles between the cache of `guild` and the database.""" - log.trace("Getting the diff for roles.") - roles = await self.bot.api_client.get('bot/roles') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_roles = {_Role(**role_dict) for role_dict in roles} - guild_roles = { - _Role( - id=role.id, - name=role.name, - colour=role.colour.value, - permissions=role.permissions.value, - position=role.position, - ) - for role in guild.roles - } - - guild_role_ids = {role.id for role in guild_roles} - api_role_ids = {role.id for role in db_roles} - new_role_ids = guild_role_ids - api_role_ids - deleted_role_ids = api_role_ids - guild_role_ids - - # New roles are those which are on the cached guild but not on the - # DB guild, going by the role ID. We need to send them in for creation. - roles_to_create = {role for role in guild_roles if role.id in new_role_ids} - roles_to_update = guild_roles - db_roles - roles_to_create - roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} - - return _Diff(roles_to_create, roles_to_update, roles_to_delete) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the role cache of `guild`.""" - log.trace("Syncing created roles...") - for role in diff.created: - await self.bot.api_client.post('bot/roles', json=role._asdict()) - - log.trace("Syncing updated roles...") - for role in diff.updated: - await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) - - log.trace("Syncing deleted roles...") - for role in diff.deleted: - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - -class UserSyncer(Syncer): - """Synchronise the database with users in the cache.""" - - name = "user" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of users between the cache of `guild` and the database.""" - log.trace("Getting the diff for users.") - users = await self.bot.api_client.get('bot/users') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_users = { - user_dict['id']: _User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in users - } - guild_users = { - member.id: _User( - id=member.id, - name=member.name, - discriminator=int(member.discriminator), - roles=tuple(sorted(role.id for role in member.roles)), - in_guild=True - ) - for member in guild.members - } - - users_to_create = set() - users_to_update = set() - - for db_user in db_users.values(): - guild_user = guild_users.get(db_user.id) - if guild_user is not None: - if db_user != guild_user: - users_to_update.add(guild_user) - - elif db_user.in_guild: - # The user is known in the DB but not the guild, and the - # DB currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - new_api_user = db_user._replace(in_guild=False) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(db_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) - - return _Diff(users_to_create, users_to_update, None) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the user cache of `guild`.""" - log.trace("Syncing created users...") - for user in diff.created: - await self.bot.api_client.post('bot/users', json=user._asdict()) - - log.trace("Syncing updated users...") - for user in diff.updated: - await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index aad1f3c26..e69de29bb 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,19 +0,0 @@ -from bot.bot import Bot -from .incidents import Incidents -from .infraction.infractions import Infractions -from .infraction.management import ModManagement -from .infraction.superstarify import Superstarify -from .modlog import ModLog -from .silence import Silence -from .slowmode import Slowmode - - -def setup(bot: Bot) -> None: - """Load the Incidents, Infractions, ModManagement, ModLog, Silence, Slowmode and Superstarify cogs.""" - bot.add_cog(Incidents(bot)) - bot.add_cog(Infractions(bot)) - bot.add_cog(ModLog(bot)) - bot.add_cog(ModManagement(bot)) - bot.add_cog(Silence(bot)) - bot.add_cog(Slowmode(bot)) - bot.add_cog(Superstarify(bot)) diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 3605ab1d2..e49913552 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -405,3 +405,8 @@ class Incidents(Cog): """Pass `message` to `add_signals` if and only if it satisfies `is_incident`.""" if is_incident(message): await add_signals(message) + + +def setup(bot: Bot) -> None: + """Load the Incidents cog.""" + bot.add_cog(Incidents(bot)) diff --git a/bot/cogs/moderation/infraction/_scheduler.py b/bot/cogs/moderation/infraction/_scheduler.py new file mode 100644 index 000000000..33944a8db --- /dev/null +++ b/bot/cogs/moderation/infraction/_scheduler.py @@ -0,0 +1,463 @@ +import logging +import textwrap +import typing as t +from abc import abstractmethod +from datetime import datetime +from gettext import ngettext + +import dateutil.parser +import discord +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.cogs.moderation.modlog import ModLog +from bot.constants import Colours, STAFF_CHANNELS +from bot.utils import time +from bot.utils.scheduling import Scheduler +from . import _utils +from ._utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class InfractionScheduler: + """Handles the application, pardoning, and expiration of infractions.""" + + def __init__(self, bot: Bot, supported_infractions: t.Container[str]): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + @property + def mod_log(self) -> ModLog: + """Get the currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: + """Schedule expiration for previous infractions.""" + await self.bot.wait_until_guild_available() + + log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") + + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={'active': 'true'} + ) + for infraction in infractions: + if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: + self.schedule_expiration(infraction) + + async def reapply_infraction( + self, + infraction: _utils.Infraction, + apply_coro: t.Optional[t.Awaitable] + ) -> None: + """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" + # Calculate the time remaining, in seconds, for the mute. + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + delta = (expiry - datetime.utcnow()).total_seconds() + + # Mark as inactive if less than a minute remains. + if delta < 60: + log.info( + "Infraction will be deactivated instead of re-applied " + "because less than 1 minute remains." + ) + await self.deactivate_infraction(infraction) + return + + # Allowing mod log since this is a passive action that should be logged. + await apply_coro + log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") + + async def apply_infraction( + self, + ctx: Context, + infraction: _utils.Infraction, + user: UserSnowflake, + action_coro: t.Optional[t.Awaitable] = None + ) -> None: + """Apply an infraction to the user, log the infraction, and optionally notify the user.""" + infr_type = infraction["type"] + icon = _utils.INFRACTION_ICONS[infr_type][0] + reason = infraction["reason"] + expiry = time.format_infraction_with_duration(infraction["expires_at"]) + id_ = infraction['id'] + + log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") + + # Default values for the confirmation message and mod log. + confirm_msg = ":ok_hand: applied" + + # Specifying an expiry for a note or warning makes no sense. + if infr_type in ("note", "warning"): + expiry_msg = "" + else: + expiry_msg = f" until {expiry}" if expiry else " permanently" + + dm_result = "" + dm_log_text = "" + expiry_log_text = f"\nExpires: {expiry}" if expiry else "" + log_title = "applied" + log_content = None + failed = False + + # DM the user about the infraction if it's not a shadow/hidden infraction. + # This needs to happen before we apply the infraction, as the bot cannot + # send DMs to user that it doesn't share a guild with. If we were to + # apply kick/ban infractions first, this would mean that we'd make it + # impossible for us to deliver a DM. See python-discord/bot#982. + if not infraction["hidden"]: + dm_result = f"{constants.Emojis.failmail} " + dm_log_text = "\nDM: **Failed**" + + # Sometimes user is a discord.Object; make it a proper user. + try: + if not isinstance(user, (discord.Member, discord.User)): + user = await self.bot.fetch_user(user.id) + except discord.HTTPException as e: + log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + else: + # Accordingly display whether the user was successfully notified via DM. + if await _utils.notify_infraction(user, infr_type, expiry, reason, icon): + dm_result = ":incoming_envelope: " + dm_log_text = "\nDM: Sent" + + end_msg = "" + if infraction["actor"] == self.bot.user.id: + log.trace( + f"Infraction #{id_} actor is bot; including the reason in the confirmation message." + ) + if reason: + end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" + elif ctx.channel.id not in STAFF_CHANNELS: + log.trace( + f"Infraction #{id_} context is not in a staff channel; omitting infraction count." + ) + else: + log.trace(f"Fetching total infraction count for {user}.") + + infractions = await self.bot.api_client.get( + "bot/infractions", + params={"user__id": str(user.id)} + ) + total = len(infractions) + end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" + + # Execute the necessary actions to apply the infraction on Discord. + if action_coro: + log.trace(f"Awaiting the infraction #{id_} application action coroutine.") + try: + await action_coro + if expiry: + # Schedule the expiration of the infraction. + self.schedule_expiration(infraction) + except discord.HTTPException as e: + # Accordingly display that applying the infraction failed. + confirm_msg = ":x: failed to apply" + expiry_msg = "" + log_content = ctx.author.mention + log_title = "failed to apply" + + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) + failed = True + + if failed: + log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") + try: + await self.bot.api_client.delete(f"bot/infractions/{id_}") + except ResponseCodeError as e: + confirm_msg += " and failed to delete" + log_title += " and failed to delete" + log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") + infr_message = "" + else: + infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" + + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{id_} confirmation message.") + await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") + + # Send a log message to the mod log. + log.trace(f"Sending apply mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=icon, + colour=Colours.soft_red, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {user.mention} (`{user.id}`) + Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} + Reason: {reason} + """), + content=log_content, + footer=f"ID {infraction['id']}" + ) + + log.info(f"Applied {infr_type} infraction #{id_} to {user}.") + + async def pardon_infraction( + self, + ctx: Context, + infr_type: str, + user: UserSnowflake, + send_msg: bool = True + ) -> None: + """ + Prematurely end an infraction for a user and log the action in the mod log. + + If `send_msg` is True, then a pardoning confirmation message will be sent to + the context channel. Otherwise, no such message will be sent. + """ + log.trace(f"Pardoning {infr_type} infraction for {user}.") + + # Check the current active infraction + log.trace(f"Fetching active {infr_type} infractions for {user}.") + response = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': user.id + } + ) + + if not response: + log.debug(f"No active {infr_type} infraction found for {user}.") + await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") + return + + # Deactivate the infraction and cancel its scheduled expiration task. + log_text = await self.deactivate_infraction(response[0], send_log=False) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["Actor"] = str(ctx.message.author) + log_content = None + id_ = response[0]['id'] + footer = f"ID: {id_}" + + # If multiple active infractions were found, mark them as inactive in the database + # and cancel their expiration tasks. + if len(response) > 1: + log.info( + f"Found more than one active {infr_type} infraction for user {user.id}; " + "deactivating the extra active infractions too." + ) + + footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" + + log_note = f"Found multiple **active** {infr_type} infractions in the database." + if "Note" in log_text: + log_text["Note"] = f" {log_note}" + else: + log_text["Note"] = log_note + + # deactivate_infraction() is not called again because: + # 1. Discord cannot store multiple active bans or assign multiples of the same role + # 2. It would send a pardon DM for each active infraction, which is redundant + for infraction in response[1:]: + id_ = infraction['id'] + try: + # Mark infraction as inactive in the database. + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError: + log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") + # This is simpler and cleaner than trying to concatenate all the errors. + log_text["Failure"] = "See bot's logs for details." + + # Cancel pending expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Accordingly display whether the user was successfully notified via DM. + dm_emoji = "" + if log_text.get("DM") == "Sent": + dm_emoji = ":incoming_envelope: " + elif "DM" in log_text: + dm_emoji = f"{constants.Emojis.failmail} " + + # Accordingly display whether the pardon failed. + if "Failure" in log_text: + confirm_msg = ":x: failed to pardon" + log_title = "pardon failed" + log_content = ctx.author.mention + + log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") + else: + confirm_msg = ":ok_hand: pardoned" + log_title = "pardoned" + + log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") + + # Send a confirmation message to the invoking context. + if send_msg: + log.trace(f"Sending infraction #{id_} pardon confirmation message.") + await ctx.send( + f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " + f"{log_text.get('Failure', '')}" + ) + + # Move reason to end of entry to avoid cutting out some keys + log_text["Reason"] = log_text.pop("Reason") + + # Send a log message to the mod log. + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS[infr_type][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=footer, + content=log_content, + ) + + async def deactivate_infraction( + self, + infraction: _utils.Infraction, + send_log: bool = True + ) -> t.Dict[str, str]: + """ + Deactivate an active infraction and return a dictionary of lines to send in a mod log. + + The infraction is removed from Discord, marked as inactive in the database, and has its + expiration task cancelled. If `send_log` is True, a mod log is sent for the + deactivation of the infraction. + + Infractions of unsupported types will raise a ValueError. + """ + guild = self.bot.get_guild(constants.Guild.id) + mod_role = guild.get_role(constants.Roles.moderators) + user_id = infraction["user"] + actor = infraction["actor"] + type_ = infraction["type"] + id_ = infraction["id"] + inserted_at = infraction["inserted_at"] + expiry = infraction["expires_at"] + + log.info(f"Marking infraction #{id_} as inactive (expired).") + + expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None + created = time.format_infraction_with_duration(inserted_at, expiry) + + log_content = None + log_text = { + "Member": f"<@{user_id}>", + "Actor": str(self.bot.get_user(actor) or actor), + "Reason": infraction["reason"], + "Created": created, + } + + try: + log.trace("Awaiting the pardon action coroutine.") + returned_log = await self._pardon_action(infraction) + + if returned_log is not None: + log_text = {**log_text, **returned_log} # Merge the logs together + else: + raise ValueError( + f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" + ) + except discord.Forbidden: + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") + log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" + log_content = mod_role.mention + except discord.HTTPException as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." + log_content = mod_role.mention + + # Check if the user is currently being watched by Big Brother. + try: + log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") + + active_watch = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "watch", + "user__id": user_id + } + ) + + log_text["Watching"] = "Yes" if active_watch else "No" + except ResponseCodeError: + log.exception(f"Failed to fetch watch status for user {user_id}") + log_text["Watching"] = "Unknown - failed to fetch watch status." + + try: + # Mark infraction as inactive in the database. + log.trace(f"Marking infraction #{id_} as inactive in the database.") + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_line = f"API request failed with code {e.status}." + log_content = mod_role.mention + + # Append to an existing failure message if possible + if "Failure" in log_text: + log_text["Failure"] += f" {log_line}" + else: + log_text["Failure"] = log_line + + # Cancel the expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Send a log message to the mod log. + if send_log: + log_title = "expiration failed" if "Failure" in log_text else "expired" + + user = self.bot.get_user(user_id) + avatar = user.avatar_url_as(static_format="png") if user else None + + # Move reason to end so when reason is too long, this is not gonna cut out required items. + log_text["Reason"] = log_text.pop("Reason") + + log.trace(f"Sending deactivation mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS[type_][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {type_}", + thumbnail=avatar, + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=f"ID: {id_}", + content=log_content, + ) + + return log_text + + @abstractmethod + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + raise NotImplementedError + + def schedule_expiration(self, infraction: _utils.Infraction) -> None: + """ + Marks an infraction expired after the delay from time of scheduling to time of expiration. + + At the time of expiration, the infraction is marked as inactive on the website and the + expiration task is cancelled. + """ + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/infraction/_utils.py b/bot/cogs/moderation/infraction/_utils.py new file mode 100644 index 000000000..fb55287b6 --- /dev/null +++ b/bot/cogs/moderation/infraction/_utils.py @@ -0,0 +1,201 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext.commands import Context + +from bot.api import ResponseCodeError +from bot.constants import Colours, Icons + +log = logging.getLogger(__name__) + +# apply icon, pardon icon +INFRACTION_ICONS = { + "ban": (Icons.user_ban, Icons.user_unban), + "kick": (Icons.sign_out, None), + "mute": (Icons.user_mute, Icons.user_unmute), + "note": (Icons.user_warn, None), + "superstar": (Icons.superstarify, Icons.unsuperstarify), + "warning": (Icons.user_warn, None), +} +RULES_URL = "https://pythondiscord.com/pages/rules" +APPEALABLE_INFRACTIONS = ("ban", "mute") + +# Type aliases +UserObject = t.Union[discord.Member, discord.User] +UserSnowflake = t.Union[UserObject, discord.Object] +Infraction = t.Dict[str, t.Union[str, int, bool]] + + +async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: + """ + Create a new user in the database. + + Used when an infraction needs to be applied on a user absent in the guild. + """ + log.trace(f"Attempting to add user {user.id} to the database.") + + if not isinstance(user, (discord.Member, discord.User)): + log.debug("The user being added to the DB is not a Member or User object.") + + payload = { + 'discriminator': int(getattr(user, 'discriminator', 0)), + 'id': user.id, + 'in_guild': False, + 'name': getattr(user, 'name', 'Name unknown'), + 'roles': [] + } + + try: + response = await ctx.bot.api_client.post('bot/users', json=payload) + log.info(f"User {user.id} added to the DB.") + return response + except ResponseCodeError as e: + log.error(f"Failed to add user {user.id} to the DB. {e}") + await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") + + +async def post_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + reason: str, + expires_at: datetime = None, + hidden: bool = False, + active: bool = True +) -> t.Optional[dict]: + """Posts an infraction to the API.""" + log.trace(f"Posting {infr_type} infraction for {user} to the API.") + + payload = { + "actor": ctx.message.author.id, + "hidden": hidden, + "reason": reason, + "type": infr_type, + "user": user.id, + "active": active + } + if expires_at: + payload['expires_at'] = expires_at.isoformat() + + # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. + for should_post_user in (True, False): + try: + response = await ctx.bot.api_client.post('bot/infractions', json=payload) + return response + except ResponseCodeError as e: + if e.status == 400 and 'user' in e.response_json: + # Only one attempt to add the user to the database, not two: + if not should_post_user or await post_user(ctx, user) is None: + return + else: + log.exception(f"Unexpected error while adding an infraction for {user}:") + await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") + return + + +async def get_active_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + send_msg: bool = True +) -> t.Optional[dict]: + """ + Retrieves an active infraction of the given type for the user. + + If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, + then a message for the moderator will be sent to the context channel letting them know. + Otherwise, no message will be sent. + """ + log.trace(f"Checking if {user} has active infractions of type {infr_type}.") + + active_infractions = await ctx.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': str(user.id) + } + ) + if active_infractions: + # Checks to see if the moderator should be told there is an active infraction + if send_msg: + log.trace(f"{user} has active infractions of type {infr_type}.") + await ctx.send( + f":x: According to my records, this user already has a {infr_type} infraction. " + f"See infraction **#{active_infractions[0]['id']}**." + ) + return active_infractions[0] + else: + log.trace(f"{user} does not have active infractions of type {infr_type}.") + + +async def notify_infraction( + user: UserObject, + infr_type: str, + expires_at: t.Optional[str] = None, + reason: t.Optional[str] = None, + icon_url: str = Icons.token_removed +) -> bool: + """DM a user about their new infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + + text = textwrap.dedent(f""" + **Type:** {infr_type.capitalize()} + **Expires:** {expires_at or "N/A"} + **Reason:** {reason or "No reason provided."} + """) + + embed = discord.Embed( + description=textwrap.shorten(text, width=2048, placeholder="..."), + colour=Colours.soft_red + ) + + embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) + embed.title = f"Please review our rules over at {RULES_URL}" + embed.url = RULES_URL + + if infr_type in APPEALABLE_INFRACTIONS: + embed.set_footer( + text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" + ) + + return await send_private_embed(user, embed) + + +async def notify_pardon( + user: UserObject, + title: str, + content: str, + icon_url: str = Icons.user_verified +) -> bool: + """DM a user about their pardoned infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their pardoned infraction.") + + embed = discord.Embed( + description=content, + colour=Colours.soft_green + ) + + embed.set_author(name=title, icon_url=icon_url) + + return await send_private_embed(user, embed) + + +async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: + """ + A helper method for sending an embed to a user's DMs. + + Returns a boolean indicator of DM success. + """ + try: + await user.send(embed=embed) + return True + except (discord.HTTPException, discord.Forbidden, discord.NotFound): + log.debug( + f"Infraction-related information could not be sent to user {user} ({user.id}). " + "The user either could not be retrieved or probably disabled their DMs." + ) + return False diff --git a/bot/cogs/moderation/infraction/infractions.py b/bot/cogs/moderation/infraction/infractions.py index 8df642428..cb459b447 100644 --- a/bot/cogs/moderation/infraction/infractions.py +++ b/bot/cogs/moderation/infraction/infractions.py @@ -13,9 +13,9 @@ from bot.constants import Event from bot.converters import Expiry, FetchedMember from bot.decorators import respect_role_hierarchy from bot.utils.checks import with_role_check -from . import utils -from .scheduler import InfractionScheduler -from .utils import UserSnowflake +from . import _utils +from ._scheduler import InfractionScheduler +from ._utils import UserSnowflake log = logging.getLogger(__name__) @@ -55,7 +55,7 @@ class Infractions(InfractionScheduler, commands.Cog): @command() async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: """Warn a user for the given reason.""" - infraction = await utils.post_infraction(ctx, user, "warning", reason, active=False) + infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False) if infraction is None: return @@ -125,7 +125,7 @@ class Infractions(InfractionScheduler, commands.Cog): @command(hidden=True) async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: """Create a private note for a user with the given reason without notifying the user.""" - infraction = await utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) + infraction = await _utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) if infraction is None: return @@ -213,10 +213,10 @@ class Infractions(InfractionScheduler, commands.Cog): async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: """Apply a mute infraction with kwargs passed to `post_infraction`.""" - if await utils.get_active_infraction(ctx, user, "mute"): + if await _utils.get_active_infraction(ctx, user, "mute"): return - infraction = await utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) + infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) if infraction is None: return @@ -233,7 +233,7 @@ class Infractions(InfractionScheduler, commands.Cog): @respect_role_hierarchy() async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: """Apply a kick infraction with kwargs passed to `post_infraction`.""" - infraction = await utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) + infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) if infraction is None: return @@ -254,7 +254,7 @@ class Infractions(InfractionScheduler, commands.Cog): """ # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active is_temporary = kwargs.get("expires_at") is not None - active_infraction = await utils.get_active_infraction(ctx, user, "ban", is_temporary) + active_infraction = await _utils.get_active_infraction(ctx, user, "ban", is_temporary) if active_infraction: if is_temporary: @@ -269,7 +269,7 @@ class Infractions(InfractionScheduler, commands.Cog): log.trace("Old tempban is being replaced by new permaban.") await self.pardon_infraction(ctx, "ban", user, is_temporary) - infraction = await utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) + infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) if infraction is None: return @@ -309,11 +309,11 @@ class Infractions(InfractionScheduler, commands.Cog): await user.remove_roles(self._muted_role, reason=reason) # DM the user about the expiration. - notified = await utils.notify_pardon( + notified = await _utils.notify_pardon( user=user, title="You have been unmuted", content="You may now send messages in the server.", - icon_url=utils.INFRACTION_ICONS["mute"][1] + icon_url=_utils.INFRACTION_ICONS["mute"][1] ) log_text["Member"] = f"{user.mention}(`{user.id}`)" @@ -339,7 +339,7 @@ class Infractions(InfractionScheduler, commands.Cog): return log_text - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: """ Execute deactivation steps specific to the infraction's type and return a log dict. @@ -368,3 +368,8 @@ class Infractions(InfractionScheduler, commands.Cog): if discord.User in error.converters or discord.Member in error.converters: await ctx.send(str(error.errors[0])) error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Infractions cog.""" + bot.add_cog(Infractions(bot)) diff --git a/bot/cogs/moderation/infraction/management.py b/bot/cogs/moderation/infraction/management.py index 791585b6e..9e7ae8113 100644 --- a/bot/cogs/moderation/infraction/management.py +++ b/bot/cogs/moderation/infraction/management.py @@ -14,7 +14,7 @@ from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_whitelist_check, with_role_check -from . import utils +from . import _utils from .infractions import Infractions log = logging.getLogger(__name__) @@ -220,7 +220,7 @@ class ModManagement(commands.Cog): self, ctx: Context, embed: discord.Embed, - infractions: t.Iterable[utils.Infraction] + infractions: t.Iterable[_utils.Infraction] ) -> None: """Send a paginated embed of infractions for the specified user.""" if not infractions: @@ -241,7 +241,7 @@ class ModManagement(commands.Cog): max_size=1000 ) - def infraction_to_string(self, infraction: utils.Infraction) -> str: + def infraction_to_string(self, infraction: _utils.Infraction) -> str: """Convert the infraction object to a string representation.""" actor_id = infraction["actor"] guild = self.bot.get_guild(constants.Guild.id) @@ -303,3 +303,8 @@ class ModManagement(commands.Cog): if discord.User in error.converters: await ctx.send(str(error.errors[0])) error.handled = True + + +def setup(bot: Bot) -> None: + """Load the ModManagement cog.""" + bot.add_cog(ModManagement(bot)) diff --git a/bot/cogs/moderation/infraction/scheduler.py b/bot/cogs/moderation/infraction/scheduler.py deleted file mode 100644 index b3d27fe76..000000000 --- a/bot/cogs/moderation/infraction/scheduler.py +++ /dev/null @@ -1,463 +0,0 @@ -import logging -import textwrap -import typing as t -from abc import abstractmethod -from datetime import datetime -from gettext import ngettext - -import dateutil.parser -import discord -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Colours, STAFF_CHANNELS -from bot.utils import time -from bot.utils.scheduling import Scheduler -from . import utils -from .utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class InfractionScheduler: - """Handles the application, pardoning, and expiration of infractions.""" - - def __init__(self, bot: Bot, supported_infractions: t.Container[str]): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - @property - def mod_log(self) -> ModLog: - """Get the currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: - """Schedule expiration for previous infractions.""" - await self.bot.wait_until_guild_available() - - log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") - - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={'active': 'true'} - ) - for infraction in infractions: - if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_expiration(infraction) - - async def reapply_infraction( - self, - infraction: utils.Infraction, - apply_coro: t.Optional[t.Awaitable] - ) -> None: - """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" - # Calculate the time remaining, in seconds, for the mute. - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - delta = (expiry - datetime.utcnow()).total_seconds() - - # Mark as inactive if less than a minute remains. - if delta < 60: - log.info( - "Infraction will be deactivated instead of re-applied " - "because less than 1 minute remains." - ) - await self.deactivate_infraction(infraction) - return - - # Allowing mod log since this is a passive action that should be logged. - await apply_coro - log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") - - async def apply_infraction( - self, - ctx: Context, - infraction: utils.Infraction, - user: UserSnowflake, - action_coro: t.Optional[t.Awaitable] = None - ) -> None: - """Apply an infraction to the user, log the infraction, and optionally notify the user.""" - infr_type = infraction["type"] - icon = utils.INFRACTION_ICONS[infr_type][0] - reason = infraction["reason"] - expiry = time.format_infraction_with_duration(infraction["expires_at"]) - id_ = infraction['id'] - - log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") - - # Default values for the confirmation message and mod log. - confirm_msg = ":ok_hand: applied" - - # Specifying an expiry for a note or warning makes no sense. - if infr_type in ("note", "warning"): - expiry_msg = "" - else: - expiry_msg = f" until {expiry}" if expiry else " permanently" - - dm_result = "" - dm_log_text = "" - expiry_log_text = f"\nExpires: {expiry}" if expiry else "" - log_title = "applied" - log_content = None - failed = False - - # DM the user about the infraction if it's not a shadow/hidden infraction. - # This needs to happen before we apply the infraction, as the bot cannot - # send DMs to user that it doesn't share a guild with. If we were to - # apply kick/ban infractions first, this would mean that we'd make it - # impossible for us to deliver a DM. See python-discord/bot#982. - if not infraction["hidden"]: - dm_result = f"{constants.Emojis.failmail} " - dm_log_text = "\nDM: **Failed**" - - # Sometimes user is a discord.Object; make it a proper user. - try: - if not isinstance(user, (discord.Member, discord.User)): - user = await self.bot.fetch_user(user.id) - except discord.HTTPException as e: - log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") - else: - # Accordingly display whether the user was successfully notified via DM. - if await utils.notify_infraction(user, infr_type, expiry, reason, icon): - dm_result = ":incoming_envelope: " - dm_log_text = "\nDM: Sent" - - end_msg = "" - if infraction["actor"] == self.bot.user.id: - log.trace( - f"Infraction #{id_} actor is bot; including the reason in the confirmation message." - ) - if reason: - end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" - elif ctx.channel.id not in STAFF_CHANNELS: - log.trace( - f"Infraction #{id_} context is not in a staff channel; omitting infraction count." - ) - else: - log.trace(f"Fetching total infraction count for {user}.") - - infractions = await self.bot.api_client.get( - "bot/infractions", - params={"user__id": str(user.id)} - ) - total = len(infractions) - end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" - - # Execute the necessary actions to apply the infraction on Discord. - if action_coro: - log.trace(f"Awaiting the infraction #{id_} application action coroutine.") - try: - await action_coro - if expiry: - # Schedule the expiration of the infraction. - self.schedule_expiration(infraction) - except discord.HTTPException as e: - # Accordingly display that applying the infraction failed. - confirm_msg = ":x: failed to apply" - expiry_msg = "" - log_content = ctx.author.mention - log_title = "failed to apply" - - log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" - if isinstance(e, discord.Forbidden): - log.warning(f"{log_msg}: bot lacks permissions.") - else: - log.exception(log_msg) - failed = True - - if failed: - log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") - try: - await self.bot.api_client.delete(f"bot/infractions/{id_}") - except ResponseCodeError as e: - confirm_msg += " and failed to delete" - log_title += " and failed to delete" - log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") - infr_message = "" - else: - infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" - - # Send a confirmation message to the invoking context. - log.trace(f"Sending infraction #{id_} confirmation message.") - await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") - - # Send a log message to the mod log. - log.trace(f"Sending apply mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=icon, - colour=Colours.soft_red, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {user.mention} (`{user.id}`) - Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} - Reason: {reason} - """), - content=log_content, - footer=f"ID {infraction['id']}" - ) - - log.info(f"Applied {infr_type} infraction #{id_} to {user}.") - - async def pardon_infraction( - self, - ctx: Context, - infr_type: str, - user: UserSnowflake, - send_msg: bool = True - ) -> None: - """ - Prematurely end an infraction for a user and log the action in the mod log. - - If `send_msg` is True, then a pardoning confirmation message will be sent to - the context channel. Otherwise, no such message will be sent. - """ - log.trace(f"Pardoning {infr_type} infraction for {user}.") - - # Check the current active infraction - log.trace(f"Fetching active {infr_type} infractions for {user}.") - response = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': user.id - } - ) - - if not response: - log.debug(f"No active {infr_type} infraction found for {user}.") - await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") - return - - # Deactivate the infraction and cancel its scheduled expiration task. - log_text = await self.deactivate_infraction(response[0], send_log=False) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["Actor"] = str(ctx.message.author) - log_content = None - id_ = response[0]['id'] - footer = f"ID: {id_}" - - # If multiple active infractions were found, mark them as inactive in the database - # and cancel their expiration tasks. - if len(response) > 1: - log.info( - f"Found more than one active {infr_type} infraction for user {user.id}; " - "deactivating the extra active infractions too." - ) - - footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" - - log_note = f"Found multiple **active** {infr_type} infractions in the database." - if "Note" in log_text: - log_text["Note"] = f" {log_note}" - else: - log_text["Note"] = log_note - - # deactivate_infraction() is not called again because: - # 1. Discord cannot store multiple active bans or assign multiples of the same role - # 2. It would send a pardon DM for each active infraction, which is redundant - for infraction in response[1:]: - id_ = infraction['id'] - try: - # Mark infraction as inactive in the database. - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError: - log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") - # This is simpler and cleaner than trying to concatenate all the errors. - log_text["Failure"] = "See bot's logs for details." - - # Cancel pending expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Accordingly display whether the user was successfully notified via DM. - dm_emoji = "" - if log_text.get("DM") == "Sent": - dm_emoji = ":incoming_envelope: " - elif "DM" in log_text: - dm_emoji = f"{constants.Emojis.failmail} " - - # Accordingly display whether the pardon failed. - if "Failure" in log_text: - confirm_msg = ":x: failed to pardon" - log_title = "pardon failed" - log_content = ctx.author.mention - - log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") - else: - confirm_msg = ":ok_hand: pardoned" - log_title = "pardoned" - - log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") - - # Send a confirmation message to the invoking context. - if send_msg: - log.trace(f"Sending infraction #{id_} pardon confirmation message.") - await ctx.send( - f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " - f"{log_text.get('Failure', '')}" - ) - - # Move reason to end of entry to avoid cutting out some keys - log_text["Reason"] = log_text.pop("Reason") - - # Send a log message to the mod log. - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[infr_type][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=footer, - content=log_content, - ) - - async def deactivate_infraction( - self, - infraction: utils.Infraction, - send_log: bool = True - ) -> t.Dict[str, str]: - """ - Deactivate an active infraction and return a dictionary of lines to send in a mod log. - - The infraction is removed from Discord, marked as inactive in the database, and has its - expiration task cancelled. If `send_log` is True, a mod log is sent for the - deactivation of the infraction. - - Infractions of unsupported types will raise a ValueError. - """ - guild = self.bot.get_guild(constants.Guild.id) - mod_role = guild.get_role(constants.Roles.moderators) - user_id = infraction["user"] - actor = infraction["actor"] - type_ = infraction["type"] - id_ = infraction["id"] - inserted_at = infraction["inserted_at"] - expiry = infraction["expires_at"] - - log.info(f"Marking infraction #{id_} as inactive (expired).") - - expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None - created = time.format_infraction_with_duration(inserted_at, expiry) - - log_content = None - log_text = { - "Member": f"<@{user_id}>", - "Actor": str(self.bot.get_user(actor) or actor), - "Reason": infraction["reason"], - "Created": created, - } - - try: - log.trace("Awaiting the pardon action coroutine.") - returned_log = await self._pardon_action(infraction) - - if returned_log is not None: - log_text = {**log_text, **returned_log} # Merge the logs together - else: - raise ValueError( - f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" - ) - except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") - log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" - log_content = mod_role.mention - except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." - log_content = mod_role.mention - - # Check if the user is currently being watched by Big Brother. - try: - log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") - - active_watch = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "watch", - "user__id": user_id - } - ) - - log_text["Watching"] = "Yes" if active_watch else "No" - except ResponseCodeError: - log.exception(f"Failed to fetch watch status for user {user_id}") - log_text["Watching"] = "Unknown - failed to fetch watch status." - - try: - # Mark infraction as inactive in the database. - log.trace(f"Marking infraction #{id_} as inactive in the database.") - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_line = f"API request failed with code {e.status}." - log_content = mod_role.mention - - # Append to an existing failure message if possible - if "Failure" in log_text: - log_text["Failure"] += f" {log_line}" - else: - log_text["Failure"] = log_line - - # Cancel the expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Send a log message to the mod log. - if send_log: - log_title = "expiration failed" if "Failure" in log_text else "expired" - - user = self.bot.get_user(user_id) - avatar = user.avatar_url_as(static_format="png") if user else None - - # Move reason to end so when reason is too long, this is not gonna cut out required items. - log_text["Reason"] = log_text.pop("Reason") - - log.trace(f"Sending deactivation mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[type_][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {type_}", - thumbnail=avatar, - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=f"ID: {id_}", - content=log_content, - ) - - return log_text - - @abstractmethod - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - raise NotImplementedError - - def schedule_expiration(self, infraction: utils.Infraction) -> None: - """ - Marks an infraction expired after the delay from time of scheduling to time of expiration. - - At the time of expiration, the infraction is marked as inactive on the website and the - expiration task is cancelled. - """ - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/infraction/superstarify.py b/bot/cogs/moderation/infraction/superstarify.py index 867de815a..7dc5b4691 100644 --- a/bot/cogs/moderation/infraction/superstarify.py +++ b/bot/cogs/moderation/infraction/superstarify.py @@ -13,8 +13,8 @@ from bot.bot import Bot from bot.converters import Expiry from bot.utils.checks import with_role_check from bot.utils.time import format_infraction -from . import utils -from .scheduler import InfractionScheduler +from . import _utils +from ._scheduler import InfractionScheduler log = logging.getLogger(__name__) NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" @@ -67,7 +67,7 @@ class Superstarify(InfractionScheduler, Cog): reason=f"Superstarified member tried to escape the prison: {infraction['id']}" ) - notified = await utils.notify_infraction( + notified = await _utils.notify_infraction( user=after, infr_type="Superstarify", expires_at=format_infraction(infraction["expires_at"]), @@ -76,7 +76,7 @@ class Superstarify(InfractionScheduler, Cog): f"from **{before.display_name}** to **{after.display_name}**, but as you " "are currently in superstar-prison, you do not have permission to do so." ), - icon_url=utils.INFRACTION_ICONS["superstar"][0] + icon_url=_utils.INFRACTION_ICONS["superstar"][0] ) if not notified: @@ -130,12 +130,12 @@ class Superstarify(InfractionScheduler, Cog): An optional reason can be provided. If no reason is given, the original name will be shown in a generated reason. """ - if await utils.get_active_infraction(ctx, member, "superstar"): + if await _utils.get_active_infraction(ctx, member, "superstar"): return # Post the infraction to the API reason = reason or f"old nick: {member.display_name}" - infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) + infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) id_ = infraction["id"] old_nick = member.display_name @@ -149,11 +149,11 @@ class Superstarify(InfractionScheduler, Cog): self.schedule_expiration(infraction) # Send a DM to the user to notify them of their new infraction. - await utils.notify_infraction( + await _utils.notify_infraction( user=member, infr_type="Superstarify", expires_at=expiry_str, - icon_url=utils.INFRACTION_ICONS["superstar"][0], + icon_url=_utils.INFRACTION_ICONS["superstar"][0], reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." ) @@ -176,7 +176,7 @@ class Superstarify(InfractionScheduler, Cog): # Log to the mod log channel. log.trace(f"Sending apply mod log for superstar #{id_}.") await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS["superstar"][0], + icon_url=_utils.INFRACTION_ICONS["superstar"][0], colour=Colour.gold(), title="Member achieved superstardom", thumbnail=member.avatar_url_as(static_format="png"), @@ -196,7 +196,7 @@ class Superstarify(InfractionScheduler, Cog): """Remove the superstarify infraction and allow the user to change their nickname.""" await self.pardon_infraction(ctx, "superstar", member) - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: """Pardon a superstar infraction and return a log dict.""" if infraction["type"] != "superstar": return @@ -213,11 +213,11 @@ class Superstarify(InfractionScheduler, Cog): return {} # DM the user about the expiration. - notified = await utils.notify_pardon( + notified = await _utils.notify_pardon( user=user, title="You are no longer superstarified", content="You may now change your nickname on the server.", - icon_url=utils.INFRACTION_ICONS["superstar"][1] + icon_url=_utils.INFRACTION_ICONS["superstar"][1] ) return { @@ -237,3 +237,8 @@ class Superstarify(InfractionScheduler, Cog): def cog_check(self, ctx: Context) -> bool: """Only allow moderators to invoke the commands in this cog.""" return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Superstarify cog.""" + bot.add_cog(Superstarify(bot)) diff --git a/bot/cogs/moderation/infraction/utils.py b/bot/cogs/moderation/infraction/utils.py deleted file mode 100644 index fb55287b6..000000000 --- a/bot/cogs/moderation/infraction/utils.py +++ /dev/null @@ -1,201 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext.commands import Context - -from bot.api import ResponseCodeError -from bot.constants import Colours, Icons - -log = logging.getLogger(__name__) - -# apply icon, pardon icon -INFRACTION_ICONS = { - "ban": (Icons.user_ban, Icons.user_unban), - "kick": (Icons.sign_out, None), - "mute": (Icons.user_mute, Icons.user_unmute), - "note": (Icons.user_warn, None), - "superstar": (Icons.superstarify, Icons.unsuperstarify), - "warning": (Icons.user_warn, None), -} -RULES_URL = "https://pythondiscord.com/pages/rules" -APPEALABLE_INFRACTIONS = ("ban", "mute") - -# Type aliases -UserObject = t.Union[discord.Member, discord.User] -UserSnowflake = t.Union[UserObject, discord.Object] -Infraction = t.Dict[str, t.Union[str, int, bool]] - - -async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: - """ - Create a new user in the database. - - Used when an infraction needs to be applied on a user absent in the guild. - """ - log.trace(f"Attempting to add user {user.id} to the database.") - - if not isinstance(user, (discord.Member, discord.User)): - log.debug("The user being added to the DB is not a Member or User object.") - - payload = { - 'discriminator': int(getattr(user, 'discriminator', 0)), - 'id': user.id, - 'in_guild': False, - 'name': getattr(user, 'name', 'Name unknown'), - 'roles': [] - } - - try: - response = await ctx.bot.api_client.post('bot/users', json=payload) - log.info(f"User {user.id} added to the DB.") - return response - except ResponseCodeError as e: - log.error(f"Failed to add user {user.id} to the DB. {e}") - await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") - - -async def post_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - reason: str, - expires_at: datetime = None, - hidden: bool = False, - active: bool = True -) -> t.Optional[dict]: - """Posts an infraction to the API.""" - log.trace(f"Posting {infr_type} infraction for {user} to the API.") - - payload = { - "actor": ctx.message.author.id, - "hidden": hidden, - "reason": reason, - "type": infr_type, - "user": user.id, - "active": active - } - if expires_at: - payload['expires_at'] = expires_at.isoformat() - - # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. - for should_post_user in (True, False): - try: - response = await ctx.bot.api_client.post('bot/infractions', json=payload) - return response - except ResponseCodeError as e: - if e.status == 400 and 'user' in e.response_json: - # Only one attempt to add the user to the database, not two: - if not should_post_user or await post_user(ctx, user) is None: - return - else: - log.exception(f"Unexpected error while adding an infraction for {user}:") - await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") - return - - -async def get_active_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - send_msg: bool = True -) -> t.Optional[dict]: - """ - Retrieves an active infraction of the given type for the user. - - If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, - then a message for the moderator will be sent to the context channel letting them know. - Otherwise, no message will be sent. - """ - log.trace(f"Checking if {user} has active infractions of type {infr_type}.") - - active_infractions = await ctx.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': str(user.id) - } - ) - if active_infractions: - # Checks to see if the moderator should be told there is an active infraction - if send_msg: - log.trace(f"{user} has active infractions of type {infr_type}.") - await ctx.send( - f":x: According to my records, this user already has a {infr_type} infraction. " - f"See infraction **#{active_infractions[0]['id']}**." - ) - return active_infractions[0] - else: - log.trace(f"{user} does not have active infractions of type {infr_type}.") - - -async def notify_infraction( - user: UserObject, - infr_type: str, - expires_at: t.Optional[str] = None, - reason: t.Optional[str] = None, - icon_url: str = Icons.token_removed -) -> bool: - """DM a user about their new infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their {infr_type} infraction.") - - text = textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """) - - embed = discord.Embed( - description=textwrap.shorten(text, width=2048, placeholder="..."), - colour=Colours.soft_red - ) - - embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) - embed.title = f"Please review our rules over at {RULES_URL}" - embed.url = RULES_URL - - if infr_type in APPEALABLE_INFRACTIONS: - embed.set_footer( - text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" - ) - - return await send_private_embed(user, embed) - - -async def notify_pardon( - user: UserObject, - title: str, - content: str, - icon_url: str = Icons.user_verified -) -> bool: - """DM a user about their pardoned infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their pardoned infraction.") - - embed = discord.Embed( - description=content, - colour=Colours.soft_green - ) - - embed.set_author(name=title, icon_url=icon_url) - - return await send_private_embed(user, embed) - - -async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: - """ - A helper method for sending an embed to a user's DMs. - - Returns a boolean indicator of DM success. - """ - try: - await user.send(embed=embed) - return True - except (discord.HTTPException, discord.Forbidden, discord.NotFound): - log.debug( - f"Infraction-related information could not be sent to user {user} ({user.id}). " - "The user either could not be retrieved or probably disabled their DMs." - ) - return False diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 0a63f57b8..c86f04b9d 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -830,3 +830,8 @@ class ModLog(Cog, name="ModLog"): thumbnail=member.avatar_url_as(static_format="png"), channel_id=Channels.voice_log ) + + +def setup(bot: Bot) -> None: + """Load the ModLog cog.""" + bot.add_cog(ModLog(bot)) diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index f8a6592bc..4af87c724 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -163,3 +163,8 @@ class Silence(commands.Cog): def cog_check(self, ctx: Context) -> bool: """Only allow moderators to invoke the commands in this cog.""" return with_role_check(ctx, *MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Silence cog.""" + bot.add_cog(Silence(bot)) diff --git a/bot/cogs/moderation/watchchannels/__init__.py b/bot/cogs/moderation/watchchannels/__init__.py index 69d118df6..e69de29bb 100644 --- a/bot/cogs/moderation/watchchannels/__init__.py +++ b/bot/cogs/moderation/watchchannels/__init__.py @@ -1,9 +0,0 @@ -from bot.bot import Bot -from .bigbrother import BigBrother -from .talentpool import TalentPool - - -def setup(bot: Bot) -> None: - """Load the BigBrother and TalentPool cogs.""" - bot.add_cog(BigBrother(bot)) - bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/moderation/watchchannels/_watchchannel.py b/bot/cogs/moderation/watchchannels/_watchchannel.py new file mode 100644 index 000000000..044077350 --- /dev/null +++ b/bot/cogs/moderation/watchchannels/_watchchannel.py @@ -0,0 +1,348 @@ +import asyncio +import logging +import re +import textwrap +from abc import abstractmethod +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Optional + +import dateutil.parser +import discord +from discord import Color, DMChannel, Embed, HTTPException, Message, errors +from discord.ext.commands import Cog, Context + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons +from bot.pagination import LinePaginator +from bot.utils import CogABCMeta, messages +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + +URL_RE = re.compile(r"(https?://[^\s]+)") + + +@dataclass +class MessageHistory: + """Represents a watch channel's message history.""" + + last_author: Optional[int] = None + last_channel: Optional[int] = None + message_count: int = 0 + + +class WatchChannel(metaclass=CogABCMeta): + """ABC with functionality for relaying users' messages to a certain channel.""" + + @abstractmethod + def __init__( + self, + bot: Bot, + destination: int, + webhook_id: int, + api_endpoint: str, + api_default_params: dict, + logger: logging.Logger + ) -> None: + self.bot = bot + + self.destination = destination # E.g., Channels.big_brother_logs + self.webhook_id = webhook_id # E.g., Webhooks.big_brother + self.api_endpoint = api_endpoint # E.g., 'bot/infractions' + self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} + self.log = logger # Logger of the child cog for a correct name in the logs + + self._consume_task = None + self.watched_users = defaultdict(dict) + self.message_queue = defaultdict(lambda: defaultdict(deque)) + self.consumption_queue = {} + self.retries = 5 + self.retry_delay = 10 + self.channel = None + self.webhook = None + self.message_history = MessageHistory() + + self._start = self.bot.loop.create_task(self.start_watchchannel()) + + @property + def modlog(self) -> ModLog: + """Provides access to the ModLog cog for alert purposes.""" + return self.bot.get_cog("ModLog") + + @property + def consuming_messages(self) -> bool: + """Checks if a consumption task is currently running.""" + if self._consume_task is None: + return False + + if self._consume_task.done(): + exc = self._consume_task.exception() + if exc: + self.log.exception( + "The message queue consume task has failed with:", + exc_info=exc + ) + return False + + return True + + async def start_watchchannel(self) -> None: + """Starts the watch channel by getting the channel, webhook, and user cache ready.""" + await self.bot.wait_until_guild_available() + + try: + self.channel = await self.bot.fetch_channel(self.destination) + except HTTPException: + self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + if self.channel is None or self.webhook is None: + self.log.error("Failed to start the watch channel; unloading the cog.") + + message = textwrap.dedent( + f""" + An error occurred while loading the text channel or webhook. + + TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} + Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} + + The Cog has been unloaded. + """ + ) + + await self.modlog.send_log_message( + title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", + text=message, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + if not await self.fetch_user_cache(): + await self.modlog.send_log_message( + title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", + text="Could not retrieve the list of watched users from the API and messages will not be relayed.", + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + async def fetch_user_cache(self) -> bool: + """ + Fetches watched users from the API and updates the watched user cache accordingly. + + This function returns `True` if the update succeeded. + """ + try: + data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) + except ResponseCodeError as err: + self.log.exception("Failed to fetch the watched users from the API", exc_info=err) + return False + + self.watched_users = defaultdict(dict) + + for entry in data: + user_id = entry.pop('user') + self.watched_users[user_id] = entry + + return True + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Queues up messages sent by watched users.""" + if msg.author.id in self.watched_users: + if not self.consuming_messages: + self._consume_task = self.bot.loop.create_task(self.consume_messages()) + + self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") + self.message_queue[msg.author.id][msg.channel.id].append(msg) + + async def consume_messages(self, delay_consumption: bool = True) -> None: + """Consumes the message queues to log watched users' messages.""" + if delay_consumption: + self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") + await asyncio.sleep(BigBrotherConfig.log_delay) + + self.log.trace("Started consuming the message queue") + + # If the previous consumption Task failed, first consume the existing comsumption_queue + if not self.consumption_queue: + self.consumption_queue = self.message_queue.copy() + self.message_queue.clear() + + for user_channel_queues in self.consumption_queue.values(): + for channel_queue in user_channel_queues.values(): + while channel_queue: + msg = channel_queue.popleft() + + self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") + await self.relay_message(msg) + + self.consumption_queue.clear() + + if self.message_queue: + self.log.trace("Channel queue not empty: Continuing consuming queues") + self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) + else: + self.log.trace("Done consuming messages.") + + async def webhook_send( + self, + content: Optional[str] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + embed: Optional[Embed] = None, + ) -> None: + """Sends a message to the webhook with the specified kwargs.""" + username = messages.sub_clyde(username) + try: + await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send a message to the webhook", + exc_info=exc + ) + + async def relay_message(self, msg: Message) -> None: + """Relays the message to the relevant watch channel.""" + limit = BigBrotherConfig.header_message_limit + + if ( + msg.author.id != self.message_history.last_author + or msg.channel.id != self.message_history.last_channel + or self.message_history.message_count >= limit + ): + self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) + + await self.send_header(msg) + + cleaned_content = msg.clean_content + + if cleaned_content: + # Put all non-media URLs in a code block to prevent embeds + media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} + for url in URL_RE.findall(cleaned_content): + if url not in media_urls: + cleaned_content = cleaned_content.replace(url, f"`{url}`") + await self.webhook_send( + cleaned_content, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + + if msg.attachments: + try: + await messages.send_attachments(msg, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.webhook_send( + embed=e, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send an attachment to the webhook", + exc_info=exc + ) + + self.message_history.message_count += 1 + + async def send_header(self, msg: Message) -> None: + """Sends a header embed with information about the relayed messages to the watch channel.""" + user_id = msg.author.id + + guild = self.bot.get_guild(GuildConfig.id) + actor = guild.get_member(self.watched_users[user_id]['actor']) + actor = actor.display_name if actor else self.watched_users[user_id]['actor'] + + inserted_at = self.watched_users[user_id]['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + + reason = self.watched_users[user_id]['reason'] + + if isinstance(msg.channel, DMChannel): + # If a watched user DMs the bot there won't be a channel name or jump URL + # This could technically include a GroupChannel but bot's can't be in those + message_jump = "via DM" + else: + message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + + footer = f"Added {time_delta} by {actor} | Reason: {reason}" + embed = Embed(description=f"{msg.author.mention} {message_jump}") + embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) + + await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) + + async def list_watched_users( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Gives an overview of the watched user list for this channel. + + The optional kwarg `oldest_first` orders the list by oldest entry. + + The optional kwarg `update_cache` specifies whether the cache should + be refreshed by polling the API. + """ + if update_cache: + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") + update_cache = False + + lines = [] + for user_id, user_data in self.watched_users.items(): + inserted_at = user_data['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + lines.append(f"• <@{user_id}> (added {time_delta})") + + if oldest_first: + lines.reverse() + + lines = lines or ("There's nothing here yet.",) + + embed = Embed( + title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", + color=Color.blue() + ) + await LinePaginator.paginate(lines, ctx, embed, empty=False) + + @staticmethod + def _get_time_delta(time_string: str) -> str: + """Returns the time in human-readable time delta format.""" + date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) + time_delta = time_since(date_time, precision="minutes", max_units=1) + + return time_delta + + def _remove_user(self, user_id: int) -> None: + """Removes a user from a watch channel.""" + self.watched_users.pop(user_id, None) + self.message_queue.pop(user_id, None) + self.consumption_queue.pop(user_id, None) + + def cog_unload(self) -> None: + """Takes care of unloading the cog and canceling the consumption task.""" + self.log.trace("Unloading the cog") + if self._consume_task and not self._consume_task.done(): + self._consume_task.cancel() + try: + self._consume_task.result() + except asyncio.CancelledError as e: + self.log.exception( + "The consume task was canceled. Messages may be lost.", + exc_info=e + ) diff --git a/bot/cogs/moderation/watchchannels/bigbrother.py b/bot/cogs/moderation/watchchannels/bigbrother.py index 0c72e88f7..7db34bcf2 100644 --- a/bot/cogs/moderation/watchchannels/bigbrother.py +++ b/bot/cogs/moderation/watchchannels/bigbrother.py @@ -5,11 +5,11 @@ from collections import ChainMap from discord.ext.commands import Cog, Context, group from bot.bot import Bot -from bot.cogs.moderation.infraction.utils import post_infraction +from bot.cogs.moderation.infraction._utils import post_infraction from bot.constants import Channels, MODERATION_ROLES, Webhooks from bot.converters import FetchedMember from bot.decorators import with_role -from .watchchannel import WatchChannel +from ._watchchannel import WatchChannel log = logging.getLogger(__name__) @@ -163,3 +163,8 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"): message = ":x: The specified user is currently not being watched." await ctx.send(message) + + +def setup(bot: Bot) -> None: + """Load the BigBrother cog.""" + bot.add_cog(BigBrother(bot)) diff --git a/bot/cogs/moderation/watchchannels/talentpool.py b/bot/cogs/moderation/watchchannels/talentpool.py index 89256e92e..2972f56e1 100644 --- a/bot/cogs/moderation/watchchannels/talentpool.py +++ b/bot/cogs/moderation/watchchannels/talentpool.py @@ -12,7 +12,7 @@ from bot.converters import FetchedMember from bot.decorators import with_role from bot.pagination import LinePaginator from bot.utils import time -from .watchchannel import WatchChannel +from ._watchchannel import WatchChannel log = logging.getLogger(__name__) @@ -262,3 +262,8 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): ) return lines.strip() + + +def setup(bot: Bot) -> None: + """Load the TalentPool cog.""" + bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/moderation/watchchannels/watchchannel.py b/bot/cogs/moderation/watchchannels/watchchannel.py deleted file mode 100644 index 044077350..000000000 --- a/bot/cogs/moderation/watchchannels/watchchannel.py +++ /dev/null @@ -1,348 +0,0 @@ -import asyncio -import logging -import re -import textwrap -from abc import abstractmethod -from collections import defaultdict, deque -from dataclasses import dataclass -from typing import Optional - -import dateutil.parser -import discord -from discord import Color, DMChannel, Embed, HTTPException, Message, errors -from discord.ext.commands import Cog, Context - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons -from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - -URL_RE = re.compile(r"(https?://[^\s]+)") - - -@dataclass -class MessageHistory: - """Represents a watch channel's message history.""" - - last_author: Optional[int] = None - last_channel: Optional[int] = None - message_count: int = 0 - - -class WatchChannel(metaclass=CogABCMeta): - """ABC with functionality for relaying users' messages to a certain channel.""" - - @abstractmethod - def __init__( - self, - bot: Bot, - destination: int, - webhook_id: int, - api_endpoint: str, - api_default_params: dict, - logger: logging.Logger - ) -> None: - self.bot = bot - - self.destination = destination # E.g., Channels.big_brother_logs - self.webhook_id = webhook_id # E.g., Webhooks.big_brother - self.api_endpoint = api_endpoint # E.g., 'bot/infractions' - self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} - self.log = logger # Logger of the child cog for a correct name in the logs - - self._consume_task = None - self.watched_users = defaultdict(dict) - self.message_queue = defaultdict(lambda: defaultdict(deque)) - self.consumption_queue = {} - self.retries = 5 - self.retry_delay = 10 - self.channel = None - self.webhook = None - self.message_history = MessageHistory() - - self._start = self.bot.loop.create_task(self.start_watchchannel()) - - @property - def modlog(self) -> ModLog: - """Provides access to the ModLog cog for alert purposes.""" - return self.bot.get_cog("ModLog") - - @property - def consuming_messages(self) -> bool: - """Checks if a consumption task is currently running.""" - if self._consume_task is None: - return False - - if self._consume_task.done(): - exc = self._consume_task.exception() - if exc: - self.log.exception( - "The message queue consume task has failed with:", - exc_info=exc - ) - return False - - return True - - async def start_watchchannel(self) -> None: - """Starts the watch channel by getting the channel, webhook, and user cache ready.""" - await self.bot.wait_until_guild_available() - - try: - self.channel = await self.bot.fetch_channel(self.destination) - except HTTPException: - self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - if self.channel is None or self.webhook is None: - self.log.error("Failed to start the watch channel; unloading the cog.") - - message = textwrap.dedent( - f""" - An error occurred while loading the text channel or webhook. - - TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} - Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} - - The Cog has been unloaded. - """ - ) - - await self.modlog.send_log_message( - title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", - text=message, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - if not await self.fetch_user_cache(): - await self.modlog.send_log_message( - title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", - text="Could not retrieve the list of watched users from the API and messages will not be relayed.", - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - async def fetch_user_cache(self) -> bool: - """ - Fetches watched users from the API and updates the watched user cache accordingly. - - This function returns `True` if the update succeeded. - """ - try: - data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) - except ResponseCodeError as err: - self.log.exception("Failed to fetch the watched users from the API", exc_info=err) - return False - - self.watched_users = defaultdict(dict) - - for entry in data: - user_id = entry.pop('user') - self.watched_users[user_id] = entry - - return True - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Queues up messages sent by watched users.""" - if msg.author.id in self.watched_users: - if not self.consuming_messages: - self._consume_task = self.bot.loop.create_task(self.consume_messages()) - - self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") - self.message_queue[msg.author.id][msg.channel.id].append(msg) - - async def consume_messages(self, delay_consumption: bool = True) -> None: - """Consumes the message queues to log watched users' messages.""" - if delay_consumption: - self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") - await asyncio.sleep(BigBrotherConfig.log_delay) - - self.log.trace("Started consuming the message queue") - - # If the previous consumption Task failed, first consume the existing comsumption_queue - if not self.consumption_queue: - self.consumption_queue = self.message_queue.copy() - self.message_queue.clear() - - for user_channel_queues in self.consumption_queue.values(): - for channel_queue in user_channel_queues.values(): - while channel_queue: - msg = channel_queue.popleft() - - self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") - await self.relay_message(msg) - - self.consumption_queue.clear() - - if self.message_queue: - self.log.trace("Channel queue not empty: Continuing consuming queues") - self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) - else: - self.log.trace("Done consuming messages.") - - async def webhook_send( - self, - content: Optional[str] = None, - username: Optional[str] = None, - avatar_url: Optional[str] = None, - embed: Optional[Embed] = None, - ) -> None: - """Sends a message to the webhook with the specified kwargs.""" - username = messages.sub_clyde(username) - try: - await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send a message to the webhook", - exc_info=exc - ) - - async def relay_message(self, msg: Message) -> None: - """Relays the message to the relevant watch channel.""" - limit = BigBrotherConfig.header_message_limit - - if ( - msg.author.id != self.message_history.last_author - or msg.channel.id != self.message_history.last_channel - or self.message_history.message_count >= limit - ): - self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) - - await self.send_header(msg) - - cleaned_content = msg.clean_content - - if cleaned_content: - # Put all non-media URLs in a code block to prevent embeds - media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} - for url in URL_RE.findall(cleaned_content): - if url not in media_urls: - cleaned_content = cleaned_content.replace(url, f"`{url}`") - await self.webhook_send( - cleaned_content, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - - if msg.attachments: - try: - await messages.send_attachments(msg, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await self.webhook_send( - embed=e, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send an attachment to the webhook", - exc_info=exc - ) - - self.message_history.message_count += 1 - - async def send_header(self, msg: Message) -> None: - """Sends a header embed with information about the relayed messages to the watch channel.""" - user_id = msg.author.id - - guild = self.bot.get_guild(GuildConfig.id) - actor = guild.get_member(self.watched_users[user_id]['actor']) - actor = actor.display_name if actor else self.watched_users[user_id]['actor'] - - inserted_at = self.watched_users[user_id]['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - - reason = self.watched_users[user_id]['reason'] - - if isinstance(msg.channel, DMChannel): - # If a watched user DMs the bot there won't be a channel name or jump URL - # This could technically include a GroupChannel but bot's can't be in those - message_jump = "via DM" - else: - message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" - - footer = f"Added {time_delta} by {actor} | Reason: {reason}" - embed = Embed(description=f"{msg.author.mention} {message_jump}") - embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) - - await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) - - async def list_watched_users( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Gives an overview of the watched user list for this channel. - - The optional kwarg `oldest_first` orders the list by oldest entry. - - The optional kwarg `update_cache` specifies whether the cache should - be refreshed by polling the API. - """ - if update_cache: - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") - update_cache = False - - lines = [] - for user_id, user_data in self.watched_users.items(): - inserted_at = user_data['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - lines.append(f"• <@{user_id}> (added {time_delta})") - - if oldest_first: - lines.reverse() - - lines = lines or ("There's nothing here yet.",) - - embed = Embed( - title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", - color=Color.blue() - ) - await LinePaginator.paginate(lines, ctx, embed, empty=False) - - @staticmethod - def _get_time_delta(time_string: str) -> str: - """Returns the time in human-readable time delta format.""" - date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) - time_delta = time_since(date_time, precision="minutes", max_units=1) - - return time_delta - - def _remove_user(self, user_id: int) -> None: - """Removes a user from a watch channel.""" - self.watched_users.pop(user_id, None) - self.message_queue.pop(user_id, None) - self.consumption_queue.pop(user_id, None) - - def cog_unload(self) -> None: - """Takes care of unloading the cog and canceling the consumption task.""" - self.log.trace("Unloading the cog") - if self._consume_task and not self._consume_task.done(): - self._consume_task.cancel() - try: - self._consume_task.result() - except asyncio.CancelledError as e: - self.log.exception( - "The consume task was canceled. Messages may be lost.", - exc_info=e - ) diff --git a/tests/bot/cogs/backend/sync/test_base.py b/tests/bot/cogs/backend/sync/test_base.py index 0d0a8299d..3009aacb6 100644 --- a/tests/bot/cogs/backend/sync/test_base.py +++ b/tests/bot/cogs/backend/sync/test_base.py @@ -6,7 +6,7 @@ import discord from bot import constants from bot.api import ResponseCodeError -from bot.cogs.backend.sync.syncers import Syncer, _Diff +from bot.cogs.backend.sync._syncers import Syncer, _Diff from tests import helpers diff --git a/tests/bot/cogs/backend/sync/test_cog.py b/tests/bot/cogs/backend/sync/test_cog.py index 199747051..e40552817 100644 --- a/tests/bot/cogs/backend/sync/test_cog.py +++ b/tests/bot/cogs/backend/sync/test_cog.py @@ -6,7 +6,8 @@ import discord from bot import constants from bot.api import ResponseCodeError from bot.cogs.backend import sync -from bot.cogs.backend.sync.syncers import Syncer +from bot.cogs.backend.sync._cog import Sync +from bot.cogs.backend.sync._syncers import Syncer from tests import helpers from tests.base import CommandTestCase @@ -29,19 +30,19 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): self.bot = helpers.MockBot() self.role_syncer_patcher = mock.patch( - "bot.cogs.backend.sync.syncers.RoleSyncer", + "bot.cogs.backend.sync._syncers.RoleSyncer", autospec=Syncer, spec_set=True ) self.user_syncer_patcher = mock.patch( - "bot.cogs.backend.sync.syncers.UserSyncer", + "bot.cogs.backend.sync._syncers.UserSyncer", autospec=Syncer, spec_set=True ) self.RoleSyncer = self.role_syncer_patcher.start() self.UserSyncer = self.user_syncer_patcher.start() - self.cog = sync.Sync(self.bot) + self.cog = Sync(self.bot) def tearDown(self): self.role_syncer_patcher.stop() @@ -59,7 +60,7 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) + @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) def test_sync_cog_init(self, sync_guild): """Should instantiate syncers and run a sync for the guild.""" # Reset because a Sync cog was already instantiated in setUp. @@ -70,7 +71,7 @@ class SyncCogTests(SyncCogTestCase): mock_sync_guild_coro = mock.MagicMock() sync_guild.return_value = mock_sync_guild_coro - sync.Sync(self.bot) + Sync(self.bot) self.RoleSyncer.assert_called_once_with(self.bot) self.UserSyncer.assert_called_once_with(self.bot) @@ -131,7 +132,7 @@ class SyncCogListenerTests(SyncCogTestCase): super().setUp() self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - self.guild_id_patcher = mock.patch("bot.cogs.backend.sync.cog.constants.Guild.id", 5) + self.guild_id_patcher = mock.patch("bot.cogs.backend.sync._cog.constants.Guild.id", 5) self.guild_id = self.guild_id_patcher.start() self.guild = helpers.MockGuild(id=self.guild_id) diff --git a/tests/bot/cogs/backend/sync/test_roles.py b/tests/bot/cogs/backend/sync/test_roles.py index cc2e51c7f..99d682ede 100644 --- a/tests/bot/cogs/backend/sync/test_roles.py +++ b/tests/bot/cogs/backend/sync/test_roles.py @@ -3,7 +3,7 @@ from unittest import mock import discord -from bot.cogs.backend.sync.syncers import RoleSyncer, _Diff, _Role +from bot.cogs.backend.sync._syncers import RoleSyncer, _Diff, _Role from tests import helpers diff --git a/tests/bot/cogs/backend/sync/test_users.py b/tests/bot/cogs/backend/sync/test_users.py index 490ea9e06..51dcbe48a 100644 --- a/tests/bot/cogs/backend/sync/test_users.py +++ b/tests/bot/cogs/backend/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from bot.cogs.backend.sync.syncers import UserSyncer, _Diff, _User +from bot.cogs.backend.sync._syncers import UserSyncer, _Diff, _User from tests import helpers diff --git a/tests/bot/cogs/moderation/infraction/test_infractions.py b/tests/bot/cogs/moderation/infraction/test_infractions.py index a79042557..2df61d431 100644 --- a/tests/bot/cogs/moderation/infraction/test_infractions.py +++ b/tests/bot/cogs/moderation/infraction/test_infractions.py @@ -17,8 +17,8 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.guild = MockGuild(id=4567) self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) - @patch("bot.cogs.moderation.infraction.utils.get_active_infraction") - @patch("bot.cogs.moderation.infraction.utils.post_infraction") + @patch("bot.cogs.moderation.infraction._utils.get_active_infraction") + @patch("bot.cogs.moderation.infraction._utils.post_infraction") async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): """Should truncate reason for `ctx.guild.ban`.""" get_active_mock.return_value = None @@ -39,7 +39,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value ) - @patch("bot.cogs.moderation.infraction.utils.post_infraction") + @patch("bot.cogs.moderation.infraction._utils.post_infraction") async def test_apply_kick_reason_truncation(self, post_infraction_mock): """Should truncate reason for `Member.kick`.""" post_infraction_mock.return_value = {"foo": "bar"} -- cgit v1.2.3 From aaee0f86e99f8dfdc454c52516fbdf7f0030168a Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 12 Aug 2020 23:07:30 -0700 Subject: Fix ModLog imports Bunch of modules still rely on importing the cog directly from the moderation package. --- bot/cogs/filters/antispam.py | 2 +- bot/cogs/filters/filtering.py | 2 +- bot/cogs/filters/token_remover.py | 2 +- bot/cogs/moderation/defcon.py | 2 +- bot/cogs/moderation/verification.py | 2 +- bot/cogs/moderation/watchchannels/_watchchannel.py | 2 +- bot/cogs/utils/clean.py | 2 +- tests/bot/cogs/filters/test_token_remover.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/bot/cogs/filters/antispam.py b/bot/cogs/filters/antispam.py index 0bcca578d..d2dccea06 100644 --- a/bot/cogs/filters/antispam.py +++ b/bot/cogs/filters/antispam.py @@ -11,7 +11,7 @@ from discord.ext.commands import Cog from bot import rules from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import ( AntiSpam as AntiSpamConfig, Channels, Colours, DEBUG_MODE, Event, Filter, diff --git a/bot/cogs/filters/filtering.py b/bot/cogs/filters/filtering.py index 93cc1c655..556b466ef 100644 --- a/bot/cogs/filters/filtering.py +++ b/bot/cogs/filters/filtering.py @@ -12,7 +12,7 @@ from discord.ext.commands import Cog from discord.utils import escape_markdown from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import ( Channels, Colours, Filter, Icons, URLs diff --git a/bot/cogs/filters/token_remover.py b/bot/cogs/filters/token_remover.py index ef979f222..8eace07b6 100644 --- a/bot/cogs/filters/token_remover.py +++ b/bot/cogs/filters/token_remover.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog from bot import utils from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import Channels, Colours, Event, Icons log = logging.getLogger(__name__) diff --git a/bot/cogs/moderation/defcon.py b/bot/cogs/moderation/defcon.py index 4c0ad5914..e78435a7d 100644 --- a/bot/cogs/moderation/defcon.py +++ b/bot/cogs/moderation/defcon.py @@ -9,7 +9,7 @@ from discord import Colour, Embed, Member from discord.ext.commands import Cog, Context, group from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles from bot.decorators import with_role diff --git a/bot/cogs/moderation/verification.py b/bot/cogs/moderation/verification.py index ae156cf70..ba95ab5e4 100644 --- a/bot/cogs/moderation/verification.py +++ b/bot/cogs/moderation/verification.py @@ -6,7 +6,7 @@ from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.decorators import in_whitelist, without_role from bot.utils.checks import InWhitelistCheckFailure, without_role_check diff --git a/bot/cogs/moderation/watchchannels/_watchchannel.py b/bot/cogs/moderation/watchchannels/_watchchannel.py index 044077350..488ae704d 100644 --- a/bot/cogs/moderation/watchchannels/_watchchannel.py +++ b/bot/cogs/moderation/watchchannels/_watchchannel.py @@ -14,7 +14,7 @@ from discord.ext.commands import Cog, Context from bot.api import ResponseCodeError from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons from bot.pagination import LinePaginator from bot.utils import CogABCMeta, messages diff --git a/bot/cogs/utils/clean.py b/bot/cogs/utils/clean.py index f436e531a..c156ff02e 100644 --- a/bot/cogs/utils/clean.py +++ b/bot/cogs/utils/clean.py @@ -8,7 +8,7 @@ from discord.ext import commands from discord.ext.commands import Cog, Context, group from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import ( Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES ) diff --git a/tests/bot/cogs/filters/test_token_remover.py b/tests/bot/cogs/filters/test_token_remover.py index 5c527ed94..55b284ef9 100644 --- a/tests/bot/cogs/filters/test_token_remover.py +++ b/tests/bot/cogs/filters/test_token_remover.py @@ -8,7 +8,7 @@ from discord import Colour, NotFound from bot import constants from bot.cogs.filters import token_remover from bot.cogs.filters.token_remover import Token, TokenRemover -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from tests.helpers import MockBot, MockMessage, autospec -- cgit v1.2.3 From 1c2b384915f4a7ba070c95c86126746bae2f7279 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 14 Aug 2020 09:59:56 -0700 Subject: Rename "cogs" directory to "exts" The directory contains modules, which are extensions. It only indirectly contains cogs through the extensions. Therefore, a technically more accurate name is "extensions", or "exts" when abbreviated. Furthermore, "exts" is consistent with SeasonalBot. --- bot/__main__.py | 90 +- bot/cogs/__init__.py | 0 bot/cogs/alias.py | 153 ---- bot/cogs/backend/__init__.py | 0 bot/cogs/backend/config_verifier.py | 40 - bot/cogs/backend/error_handler.py | 287 ------- bot/cogs/backend/logging.py | 42 - bot/cogs/backend/sync/__init__.py | 7 - bot/cogs/backend/sync/_cog.py | 180 ---- bot/cogs/backend/sync/_syncers.py | 347 -------- bot/cogs/dm_relay.py | 124 --- bot/cogs/duck_pond.py | 166 ---- bot/cogs/filters/__init__.py | 0 bot/cogs/filters/antimalware.py | 98 --- bot/cogs/filters/antispam.py | 288 ------- bot/cogs/filters/filter_lists.py | 273 ------ bot/cogs/filters/filtering.py | 575 ------------- bot/cogs/filters/security.py | 31 - bot/cogs/filters/token_remover.py | 182 ---- bot/cogs/filters/webhook_remover.py | 84 -- bot/cogs/help_channels.py | 944 --------------------- bot/cogs/info/__init__.py | 0 bot/cogs/info/doc.py | 511 ----------- bot/cogs/info/help.py | 375 -------- bot/cogs/info/information.py | 422 --------- bot/cogs/info/python_news.py | 232 ----- bot/cogs/info/reddit.py | 304 ------- bot/cogs/info/site.py | 146 ---- bot/cogs/info/source.py | 141 --- bot/cogs/info/stats.py | 129 --- bot/cogs/info/tags.py | 277 ------ bot/cogs/info/wolfram.py | 280 ------ bot/cogs/moderation/__init__.py | 0 bot/cogs/moderation/defcon.py | 258 ------ bot/cogs/moderation/incidents.py | 412 --------- bot/cogs/moderation/infraction/__init__.py | 0 bot/cogs/moderation/infraction/_scheduler.py | 463 ---------- bot/cogs/moderation/infraction/_utils.py | 201 ----- bot/cogs/moderation/infraction/infractions.py | 375 -------- bot/cogs/moderation/infraction/management.py | 310 ------- bot/cogs/moderation/infraction/superstarify.py | 244 ------ bot/cogs/moderation/modlog.py | 837 ------------------ bot/cogs/moderation/silence.py | 170 ---- bot/cogs/moderation/slowmode.py | 97 --- bot/cogs/moderation/verification.py | 191 ----- bot/cogs/moderation/watchchannels/__init__.py | 0 bot/cogs/moderation/watchchannels/_watchchannel.py | 348 -------- bot/cogs/moderation/watchchannels/bigbrother.py | 170 ---- bot/cogs/moderation/watchchannels/talentpool.py | 269 ------ bot/cogs/off_topic_names.py | 162 ---- bot/cogs/utils/__init__.py | 0 bot/cogs/utils/bot.py | 385 --------- bot/cogs/utils/clean.py | 272 ------ bot/cogs/utils/eval.py | 202 ----- bot/cogs/utils/extensions.py | 289 ------- bot/cogs/utils/jams.py | 150 ---- bot/cogs/utils/reminders.py | 427 ---------- bot/cogs/utils/snekbox.py | 349 -------- bot/cogs/utils/utils.py | 265 ------ bot/exts/__init__.py | 0 bot/exts/alias.py | 153 ++++ bot/exts/backend/__init__.py | 0 bot/exts/backend/config_verifier.py | 40 + bot/exts/backend/error_handler.py | 287 +++++++ bot/exts/backend/logging.py | 42 + bot/exts/backend/sync/__init__.py | 7 + bot/exts/backend/sync/_cog.py | 180 ++++ bot/exts/backend/sync/_syncers.py | 347 ++++++++ bot/exts/dm_relay.py | 124 +++ bot/exts/duck_pond.py | 166 ++++ bot/exts/filters/__init__.py | 0 bot/exts/filters/antimalware.py | 98 +++ bot/exts/filters/antispam.py | 288 +++++++ bot/exts/filters/filter_lists.py | 273 ++++++ bot/exts/filters/filtering.py | 575 +++++++++++++ bot/exts/filters/security.py | 31 + bot/exts/filters/token_remover.py | 182 ++++ bot/exts/filters/webhook_remover.py | 84 ++ bot/exts/help_channels.py | 944 +++++++++++++++++++++ bot/exts/info/__init__.py | 0 bot/exts/info/doc.py | 511 +++++++++++ bot/exts/info/help.py | 375 ++++++++ bot/exts/info/information.py | 422 +++++++++ bot/exts/info/python_news.py | 232 +++++ bot/exts/info/reddit.py | 304 +++++++ bot/exts/info/site.py | 146 ++++ bot/exts/info/source.py | 141 +++ bot/exts/info/stats.py | 129 +++ bot/exts/info/tags.py | 277 ++++++ bot/exts/info/wolfram.py | 280 ++++++ bot/exts/moderation/__init__.py | 0 bot/exts/moderation/defcon.py | 258 ++++++ bot/exts/moderation/incidents.py | 412 +++++++++ bot/exts/moderation/infraction/__init__.py | 0 bot/exts/moderation/infraction/_scheduler.py | 463 ++++++++++ bot/exts/moderation/infraction/_utils.py | 201 +++++ bot/exts/moderation/infraction/infractions.py | 375 ++++++++ bot/exts/moderation/infraction/management.py | 310 +++++++ bot/exts/moderation/infraction/superstarify.py | 244 ++++++ bot/exts/moderation/modlog.py | 837 ++++++++++++++++++ bot/exts/moderation/silence.py | 170 ++++ bot/exts/moderation/slowmode.py | 97 +++ bot/exts/moderation/verification.py | 191 +++++ bot/exts/moderation/watchchannels/__init__.py | 0 bot/exts/moderation/watchchannels/_watchchannel.py | 348 ++++++++ bot/exts/moderation/watchchannels/bigbrother.py | 170 ++++ bot/exts/moderation/watchchannels/talentpool.py | 269 ++++++ bot/exts/off_topic_names.py | 162 ++++ bot/exts/utils/__init__.py | 0 bot/exts/utils/bot.py | 385 +++++++++ bot/exts/utils/clean.py | 272 ++++++ bot/exts/utils/eval.py | 202 +++++ bot/exts/utils/extensions.py | 289 +++++++ bot/exts/utils/jams.py | 150 ++++ bot/exts/utils/reminders.py | 427 ++++++++++ bot/exts/utils/snekbox.py | 349 ++++++++ bot/exts/utils/utils.py | 265 ++++++ tests/bot/cogs/__init__.py | 0 tests/bot/cogs/backend/__init__.py | 0 tests/bot/cogs/backend/sync/__init__.py | 0 tests/bot/cogs/backend/sync/test_base.py | 404 --------- tests/bot/cogs/backend/sync/test_cog.py | 416 --------- tests/bot/cogs/backend/sync/test_roles.py | 157 ---- tests/bot/cogs/backend/sync/test_users.py | 158 ---- tests/bot/cogs/backend/test_logging.py | 32 - tests/bot/cogs/filters/__init__.py | 0 tests/bot/cogs/filters/test_antimalware.py | 165 ---- tests/bot/cogs/filters/test_antispam.py | 35 - tests/bot/cogs/filters/test_security.py | 54 -- tests/bot/cogs/filters/test_token_remover.py | 310 ------- tests/bot/cogs/info/__init__.py | 0 tests/bot/cogs/info/test_information.py | 584 ------------- tests/bot/cogs/moderation/__init__.py | 0 tests/bot/cogs/moderation/infraction/__init__.py | 0 .../cogs/moderation/infraction/test_infractions.py | 55 -- tests/bot/cogs/moderation/test_incidents.py | 770 ----------------- tests/bot/cogs/moderation/test_modlog.py | 29 - tests/bot/cogs/moderation/test_silence.py | 261 ------ tests/bot/cogs/moderation/test_slowmode.py | 111 --- tests/bot/cogs/test_cogs.py | 80 -- tests/bot/cogs/test_duck_pond.py | 548 ------------ tests/bot/cogs/utils/__init__.py | 0 tests/bot/cogs/utils/test_jams.py | 173 ---- tests/bot/cogs/utils/test_snekbox.py | 409 --------- tests/bot/exts/__init__.py | 0 tests/bot/exts/backend/__init__.py | 0 tests/bot/exts/backend/sync/__init__.py | 0 tests/bot/exts/backend/sync/test_base.py | 404 +++++++++ tests/bot/exts/backend/sync/test_cog.py | 416 +++++++++ tests/bot/exts/backend/sync/test_roles.py | 157 ++++ tests/bot/exts/backend/sync/test_users.py | 158 ++++ tests/bot/exts/backend/test_logging.py | 32 + tests/bot/exts/filters/__init__.py | 0 tests/bot/exts/filters/test_antimalware.py | 165 ++++ tests/bot/exts/filters/test_antispam.py | 35 + tests/bot/exts/filters/test_security.py | 54 ++ tests/bot/exts/filters/test_token_remover.py | 310 +++++++ tests/bot/exts/info/__init__.py | 0 tests/bot/exts/info/test_information.py | 584 +++++++++++++ tests/bot/exts/moderation/__init__.py | 0 tests/bot/exts/moderation/infraction/__init__.py | 0 .../exts/moderation/infraction/test_infractions.py | 55 ++ tests/bot/exts/moderation/test_incidents.py | 770 +++++++++++++++++ tests/bot/exts/moderation/test_modlog.py | 29 + tests/bot/exts/moderation/test_silence.py | 261 ++++++ tests/bot/exts/moderation/test_slowmode.py | 111 +++ tests/bot/exts/test_cogs.py | 81 ++ tests/bot/exts/test_duck_pond.py | 548 ++++++++++++ tests/bot/exts/utils/__init__.py | 0 tests/bot/exts/utils/test_jams.py | 173 ++++ tests/bot/exts/utils/test_snekbox.py | 409 +++++++++ 171 files changed, 18281 insertions(+), 18280 deletions(-) delete mode 100644 bot/cogs/__init__.py delete mode 100644 bot/cogs/alias.py delete mode 100644 bot/cogs/backend/__init__.py delete mode 100644 bot/cogs/backend/config_verifier.py delete mode 100644 bot/cogs/backend/error_handler.py delete mode 100644 bot/cogs/backend/logging.py delete mode 100644 bot/cogs/backend/sync/__init__.py delete mode 100644 bot/cogs/backend/sync/_cog.py delete mode 100644 bot/cogs/backend/sync/_syncers.py delete mode 100644 bot/cogs/dm_relay.py delete mode 100644 bot/cogs/duck_pond.py delete mode 100644 bot/cogs/filters/__init__.py delete mode 100644 bot/cogs/filters/antimalware.py delete mode 100644 bot/cogs/filters/antispam.py delete mode 100644 bot/cogs/filters/filter_lists.py delete mode 100644 bot/cogs/filters/filtering.py delete mode 100644 bot/cogs/filters/security.py delete mode 100644 bot/cogs/filters/token_remover.py delete mode 100644 bot/cogs/filters/webhook_remover.py delete mode 100644 bot/cogs/help_channels.py delete mode 100644 bot/cogs/info/__init__.py delete mode 100644 bot/cogs/info/doc.py delete mode 100644 bot/cogs/info/help.py delete mode 100644 bot/cogs/info/information.py delete mode 100644 bot/cogs/info/python_news.py delete mode 100644 bot/cogs/info/reddit.py delete mode 100644 bot/cogs/info/site.py delete mode 100644 bot/cogs/info/source.py delete mode 100644 bot/cogs/info/stats.py delete mode 100644 bot/cogs/info/tags.py delete mode 100644 bot/cogs/info/wolfram.py delete mode 100644 bot/cogs/moderation/__init__.py delete mode 100644 bot/cogs/moderation/defcon.py delete mode 100644 bot/cogs/moderation/incidents.py delete mode 100644 bot/cogs/moderation/infraction/__init__.py delete mode 100644 bot/cogs/moderation/infraction/_scheduler.py delete mode 100644 bot/cogs/moderation/infraction/_utils.py delete mode 100644 bot/cogs/moderation/infraction/infractions.py delete mode 100644 bot/cogs/moderation/infraction/management.py delete mode 100644 bot/cogs/moderation/infraction/superstarify.py delete mode 100644 bot/cogs/moderation/modlog.py delete mode 100644 bot/cogs/moderation/silence.py delete mode 100644 bot/cogs/moderation/slowmode.py delete mode 100644 bot/cogs/moderation/verification.py delete mode 100644 bot/cogs/moderation/watchchannels/__init__.py delete mode 100644 bot/cogs/moderation/watchchannels/_watchchannel.py delete mode 100644 bot/cogs/moderation/watchchannels/bigbrother.py delete mode 100644 bot/cogs/moderation/watchchannels/talentpool.py delete mode 100644 bot/cogs/off_topic_names.py delete mode 100644 bot/cogs/utils/__init__.py delete mode 100644 bot/cogs/utils/bot.py delete mode 100644 bot/cogs/utils/clean.py delete mode 100644 bot/cogs/utils/eval.py delete mode 100644 bot/cogs/utils/extensions.py delete mode 100644 bot/cogs/utils/jams.py delete mode 100644 bot/cogs/utils/reminders.py delete mode 100644 bot/cogs/utils/snekbox.py delete mode 100644 bot/cogs/utils/utils.py create mode 100644 bot/exts/__init__.py create mode 100644 bot/exts/alias.py create mode 100644 bot/exts/backend/__init__.py create mode 100644 bot/exts/backend/config_verifier.py create mode 100644 bot/exts/backend/error_handler.py create mode 100644 bot/exts/backend/logging.py create mode 100644 bot/exts/backend/sync/__init__.py create mode 100644 bot/exts/backend/sync/_cog.py create mode 100644 bot/exts/backend/sync/_syncers.py create mode 100644 bot/exts/dm_relay.py create mode 100644 bot/exts/duck_pond.py create mode 100644 bot/exts/filters/__init__.py create mode 100644 bot/exts/filters/antimalware.py create mode 100644 bot/exts/filters/antispam.py create mode 100644 bot/exts/filters/filter_lists.py create mode 100644 bot/exts/filters/filtering.py create mode 100644 bot/exts/filters/security.py create mode 100644 bot/exts/filters/token_remover.py create mode 100644 bot/exts/filters/webhook_remover.py create mode 100644 bot/exts/help_channels.py create mode 100644 bot/exts/info/__init__.py create mode 100644 bot/exts/info/doc.py create mode 100644 bot/exts/info/help.py create mode 100644 bot/exts/info/information.py create mode 100644 bot/exts/info/python_news.py create mode 100644 bot/exts/info/reddit.py create mode 100644 bot/exts/info/site.py create mode 100644 bot/exts/info/source.py create mode 100644 bot/exts/info/stats.py create mode 100644 bot/exts/info/tags.py create mode 100644 bot/exts/info/wolfram.py create mode 100644 bot/exts/moderation/__init__.py create mode 100644 bot/exts/moderation/defcon.py create mode 100644 bot/exts/moderation/incidents.py create mode 100644 bot/exts/moderation/infraction/__init__.py create mode 100644 bot/exts/moderation/infraction/_scheduler.py create mode 100644 bot/exts/moderation/infraction/_utils.py create mode 100644 bot/exts/moderation/infraction/infractions.py create mode 100644 bot/exts/moderation/infraction/management.py create mode 100644 bot/exts/moderation/infraction/superstarify.py create mode 100644 bot/exts/moderation/modlog.py create mode 100644 bot/exts/moderation/silence.py create mode 100644 bot/exts/moderation/slowmode.py create mode 100644 bot/exts/moderation/verification.py create mode 100644 bot/exts/moderation/watchchannels/__init__.py create mode 100644 bot/exts/moderation/watchchannels/_watchchannel.py create mode 100644 bot/exts/moderation/watchchannels/bigbrother.py create mode 100644 bot/exts/moderation/watchchannels/talentpool.py create mode 100644 bot/exts/off_topic_names.py create mode 100644 bot/exts/utils/__init__.py create mode 100644 bot/exts/utils/bot.py create mode 100644 bot/exts/utils/clean.py create mode 100644 bot/exts/utils/eval.py create mode 100644 bot/exts/utils/extensions.py create mode 100644 bot/exts/utils/jams.py create mode 100644 bot/exts/utils/reminders.py create mode 100644 bot/exts/utils/snekbox.py create mode 100644 bot/exts/utils/utils.py delete mode 100644 tests/bot/cogs/__init__.py delete mode 100644 tests/bot/cogs/backend/__init__.py delete mode 100644 tests/bot/cogs/backend/sync/__init__.py delete mode 100644 tests/bot/cogs/backend/sync/test_base.py delete mode 100644 tests/bot/cogs/backend/sync/test_cog.py delete mode 100644 tests/bot/cogs/backend/sync/test_roles.py delete mode 100644 tests/bot/cogs/backend/sync/test_users.py delete mode 100644 tests/bot/cogs/backend/test_logging.py delete mode 100644 tests/bot/cogs/filters/__init__.py delete mode 100644 tests/bot/cogs/filters/test_antimalware.py delete mode 100644 tests/bot/cogs/filters/test_antispam.py delete mode 100644 tests/bot/cogs/filters/test_security.py delete mode 100644 tests/bot/cogs/filters/test_token_remover.py delete mode 100644 tests/bot/cogs/info/__init__.py delete mode 100644 tests/bot/cogs/info/test_information.py delete mode 100644 tests/bot/cogs/moderation/__init__.py delete mode 100644 tests/bot/cogs/moderation/infraction/__init__.py delete mode 100644 tests/bot/cogs/moderation/infraction/test_infractions.py delete mode 100644 tests/bot/cogs/moderation/test_incidents.py delete mode 100644 tests/bot/cogs/moderation/test_modlog.py delete mode 100644 tests/bot/cogs/moderation/test_silence.py delete mode 100644 tests/bot/cogs/moderation/test_slowmode.py delete mode 100644 tests/bot/cogs/test_cogs.py delete mode 100644 tests/bot/cogs/test_duck_pond.py delete mode 100644 tests/bot/cogs/utils/__init__.py delete mode 100644 tests/bot/cogs/utils/test_jams.py delete mode 100644 tests/bot/cogs/utils/test_snekbox.py create mode 100644 tests/bot/exts/__init__.py create mode 100644 tests/bot/exts/backend/__init__.py create mode 100644 tests/bot/exts/backend/sync/__init__.py create mode 100644 tests/bot/exts/backend/sync/test_base.py create mode 100644 tests/bot/exts/backend/sync/test_cog.py create mode 100644 tests/bot/exts/backend/sync/test_roles.py create mode 100644 tests/bot/exts/backend/sync/test_users.py create mode 100644 tests/bot/exts/backend/test_logging.py create mode 100644 tests/bot/exts/filters/__init__.py create mode 100644 tests/bot/exts/filters/test_antimalware.py create mode 100644 tests/bot/exts/filters/test_antispam.py create mode 100644 tests/bot/exts/filters/test_security.py create mode 100644 tests/bot/exts/filters/test_token_remover.py create mode 100644 tests/bot/exts/info/__init__.py create mode 100644 tests/bot/exts/info/test_information.py create mode 100644 tests/bot/exts/moderation/__init__.py create mode 100644 tests/bot/exts/moderation/infraction/__init__.py create mode 100644 tests/bot/exts/moderation/infraction/test_infractions.py create mode 100644 tests/bot/exts/moderation/test_incidents.py create mode 100644 tests/bot/exts/moderation/test_modlog.py create mode 100644 tests/bot/exts/moderation/test_silence.py create mode 100644 tests/bot/exts/moderation/test_slowmode.py create mode 100644 tests/bot/exts/test_cogs.py create mode 100644 tests/bot/exts/test_duck_pond.py create mode 100644 tests/bot/exts/utils/__init__.py create mode 100644 tests/bot/exts/utils/test_jams.py create mode 100644 tests/bot/exts/utils/test_snekbox.py (limited to 'tests') diff --git a/bot/__main__.py b/bot/__main__.py index 4b0f6dfe4..555847357 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -34,67 +34,67 @@ bot = Bot( ) # Backend -bot.load_extension("bot.cogs.backend.config_verifier") -bot.load_extension("bot.cogs.backend.error_handler") -bot.load_extension("bot.cogs.backend.logging") -bot.load_extension("bot.cogs.backend.sync") +bot.load_extension("bot.exts.backend.config_verifier") +bot.load_extension("bot.exts.backend.error_handler") +bot.load_extension("bot.exts.backend.logging") +bot.load_extension("bot.exts.backend.sync") # Filters -bot.load_extension("bot.cogs.filters.antimalware") -bot.load_extension("bot.cogs.filters.antispam") -bot.load_extension("bot.cogs.filters.filter_lists") -bot.load_extension("bot.cogs.filters.filtering") -bot.load_extension("bot.cogs.filters.security") -bot.load_extension("bot.cogs.filters.token_remover") -bot.load_extension("bot.cogs.filters.webhook_remover") +bot.load_extension("bot.exts.filters.antimalware") +bot.load_extension("bot.exts.filters.antispam") +bot.load_extension("bot.exts.filters.filter_lists") +bot.load_extension("bot.exts.filters.filtering") +bot.load_extension("bot.exts.filters.security") +bot.load_extension("bot.exts.filters.token_remover") +bot.load_extension("bot.exts.filters.webhook_remover") # Info -bot.load_extension("bot.cogs.info.doc") -bot.load_extension("bot.cogs.info.help") -bot.load_extension("bot.cogs.info.information") -bot.load_extension("bot.cogs.info.python_news") -bot.load_extension("bot.cogs.info.reddit") -bot.load_extension("bot.cogs.info.site") -bot.load_extension("bot.cogs.info.source") -bot.load_extension("bot.cogs.info.stats") -bot.load_extension("bot.cogs.info.tags") -bot.load_extension("bot.cogs.info.wolfram") +bot.load_extension("bot.exts.info.doc") +bot.load_extension("bot.exts.info.help") +bot.load_extension("bot.exts.info.information") +bot.load_extension("bot.exts.info.python_news") +bot.load_extension("bot.exts.info.reddit") +bot.load_extension("bot.exts.info.site") +bot.load_extension("bot.exts.info.source") +bot.load_extension("bot.exts.info.stats") +bot.load_extension("bot.exts.info.tags") +bot.load_extension("bot.exts.info.wolfram") # Moderation -bot.load_extension("bot.cogs.moderation.defcon") -bot.load_extension("bot.cogs.moderation.incidents") -bot.load_extension("bot.cogs.moderation.modlog") -bot.load_extension("bot.cogs.moderation.silence") -bot.load_extension("bot.cogs.moderation.slowmode") -bot.load_extension("bot.cogs.moderation.verification") +bot.load_extension("bot.exts.moderation.defcon") +bot.load_extension("bot.exts.moderation.incidents") +bot.load_extension("bot.exts.moderation.modlog") +bot.load_extension("bot.exts.moderation.silence") +bot.load_extension("bot.exts.moderation.slowmode") +bot.load_extension("bot.exts.moderation.verification") # Moderation - Infraction -bot.load_extension("bot.cogs.moderation.infraction.infractions") -bot.load_extension("bot.cogs.moderation.infraction.management") -bot.load_extension("bot.cogs.moderation.infraction.superstarify") +bot.load_extension("bot.exts.moderation.infraction.infractions") +bot.load_extension("bot.exts.moderation.infraction.management") +bot.load_extension("bot.exts.moderation.infraction.superstarify") # Moderation - Watchchannels -bot.load_extension("bot.cogs.moderation.watchchannels.bigbrother") -bot.load_extension("bot.cogs.moderation.watchchannels.talentpool") +bot.load_extension("bot.exts.moderation.watchchannels.bigbrother") +bot.load_extension("bot.exts.moderation.watchchannels.talentpool") # Utils -bot.load_extension("bot.cogs.utils.bot") -bot.load_extension("bot.cogs.utils.clean") -bot.load_extension("bot.cogs.utils.eval") -bot.load_extension("bot.cogs.utils.extensions") -bot.load_extension("bot.cogs.utils.jams") -bot.load_extension("bot.cogs.utils.reminders") -bot.load_extension("bot.cogs.utils.snekbox") -bot.load_extension("bot.cogs.utils.utils") +bot.load_extension("bot.exts.utils.bot") +bot.load_extension("bot.exts.utils.clean") +bot.load_extension("bot.exts.utils.eval") +bot.load_extension("bot.exts.utils.extensions") +bot.load_extension("bot.exts.utils.jams") +bot.load_extension("bot.exts.utils.reminders") +bot.load_extension("bot.exts.utils.snekbox") +bot.load_extension("bot.exts.utils.utils") # Misc -bot.load_extension("bot.cogs.alias") -bot.load_extension("bot.cogs.dm_relay") -bot.load_extension("bot.cogs.duck_pond") -bot.load_extension("bot.cogs.off_topic_names") +bot.load_extension("bot.exts.alias") +bot.load_extension("bot.exts.dm_relay") +bot.load_extension("bot.exts.duck_pond") +bot.load_extension("bot.exts.off_topic_names") if constants.HelpChannels.enable: - bot.load_extension("bot.cogs.help_channels") + bot.load_extension("bot.exts.help_channels") # Apply `message_edited_at` patch if discord.py did not yet release a bug fix. if not hasattr(discord.message.Message, '_handle_edited_timestamp'): diff --git a/bot/cogs/__init__.py b/bot/cogs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py deleted file mode 100644 index 3c5a35c24..000000000 --- a/bot/cogs/alias.py +++ /dev/null @@ -1,153 +0,0 @@ -import inspect -import logging - -from discord import Colour, Embed -from discord.ext.commands import ( - Cog, Command, Context, Greedy, - clean_content, command, group, -) - -from bot.bot import Bot -from bot.cogs.utils.extensions import Extension -from bot.converters import FetchedMember, TagNameConverter -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - - -class Alias (Cog): - """Aliases for commonly used commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: - """Invokes a command with args and kwargs.""" - log.debug(f"{cmd_name} was invoked through an alias") - cmd = self.bot.get_command(cmd_name) - if not cmd: - return log.info(f'Did not find command "{cmd_name}" to invoke.') - elif not await cmd.can_run(ctx): - return log.info( - f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' - ) - - await ctx.invoke(cmd, *args, **kwargs) - - @command(name='aliases') - async def aliases_command(self, ctx: Context) -> None: - """Show configured aliases on the bot.""" - embed = Embed( - title='Configured aliases', - colour=Colour.blue() - ) - await LinePaginator.paginate( - ( - f"• `{ctx.prefix}{value.name}` " - f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" - for name, value in inspect.getmembers(self) - if isinstance(value, Command) and name.endswith('_alias') - ), - ctx, embed, empty=False, max_lines=20 - ) - - @command(name="resources", aliases=("resource",), hidden=True) - async def site_resources_alias(self, ctx: Context) -> None: - """Alias for invoking site resources.""" - await self.invoke(ctx, "site resources") - - @command(name="tools", hidden=True) - async def site_tools_alias(self, ctx: Context) -> None: - """Alias for invoking site tools.""" - await self.invoke(ctx, "site tools") - - @command(name="watch", hidden=True) - async def bigbrother_watch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking bigbrother watch [user] [reason].""" - await self.invoke(ctx, "bigbrother watch", user, reason=reason) - - @command(name="unwatch", hidden=True) - async def bigbrother_unwatch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking bigbrother unwatch [user] [reason].""" - await self.invoke(ctx, "bigbrother unwatch", user, reason=reason) - - @command(name="home", hidden=True) - async def site_home_alias(self, ctx: Context) -> None: - """Alias for invoking site home.""" - await self.invoke(ctx, "site home") - - @command(name="faq", hidden=True) - async def site_faq_alias(self, ctx: Context) -> None: - """Alias for invoking site faq.""" - await self.invoke(ctx, "site faq") - - @command(name="rules", aliases=("rule",), hidden=True) - async def site_rules_alias(self, ctx: Context, rules: Greedy[int], *_: str) -> None: - """Alias for invoking site rules.""" - await self.invoke(ctx, "site rules", *rules) - - @command(name="reload", hidden=True) - async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: - """Alias for invoking extensions reload [extensions...].""" - await self.invoke(ctx, "extensions reload", *extensions) - - @command(name="defon", hidden=True) - async def defcon_enable_alias(self, ctx: Context) -> None: - """Alias for invoking defcon enable.""" - await self.invoke(ctx, "defcon enable") - - @command(name="defoff", hidden=True) - async def defcon_disable_alias(self, ctx: Context) -> None: - """Alias for invoking defcon disable.""" - await self.invoke(ctx, "defcon disable") - - @command(name="exception", hidden=True) - async def tags_get_traceback_alias(self, ctx: Context) -> None: - """Alias for invoking tags get traceback.""" - await self.invoke(ctx, "tags get", tag_name="traceback") - - @group(name="get", - aliases=("show", "g"), - hidden=True, - invoke_without_command=True) - async def get_group_alias(self, ctx: Context) -> None: - """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" - pass - - @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) - async def tags_get_alias( - self, ctx: Context, *, tag_name: TagNameConverter = None - ) -> None: - """ - Alias for invoking tags get [tag_name]. - - tag_name: str - tag to be viewed. - """ - await self.invoke(ctx, "tags get", tag_name=tag_name) - - @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) - async def docs_get_alias( - self, ctx: Context, symbol: clean_content = None - ) -> None: - """Alias for invoking docs get [symbol].""" - await self.invoke(ctx, "docs get", symbol) - - @command(name="nominate", hidden=True) - async def nomination_add_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking talentpool add [user] [reason].""" - await self.invoke(ctx, "talentpool add", user, reason=reason) - - @command(name="unnominate", hidden=True) - async def nomination_end_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking nomination end [user] [reason].""" - await self.invoke(ctx, "nomination end", user, reason=reason) - - @command(name="nominees", hidden=True) - async def nominees_alias(self, ctx: Context) -> None: - """Alias for invoking tp watched.""" - await self.invoke(ctx, "talentpool watched") - - -def setup(bot: Bot) -> None: - """Load the Alias cog.""" - bot.add_cog(Alias(bot)) diff --git a/bot/cogs/backend/__init__.py b/bot/cogs/backend/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/backend/config_verifier.py b/bot/cogs/backend/config_verifier.py deleted file mode 100644 index d72c6c22e..000000000 --- a/bot/cogs/backend/config_verifier.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging - -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot - - -log = logging.getLogger(__name__) - - -class ConfigVerifier(Cog): - """Verify config on startup.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) - - async def verify_channels(self) -> None: - """ - Verify channels. - - If any channels in config aren't present in server, log them in a warning. - """ - await self.bot.wait_until_guild_available() - server = self.bot.get_guild(constants.Guild.id) - - server_channel_ids = {channel.id for channel in server.channels} - invalid_channels = [ - channel_name for channel_name, channel_id in constants.Channels - if channel_id not in server_channel_ids - ] - - if invalid_channels: - log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") - - -def setup(bot: Bot) -> None: - """Load the ConfigVerifier cog.""" - bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/backend/error_handler.py b/bot/cogs/backend/error_handler.py deleted file mode 100644 index f9d4de638..000000000 --- a/bot/cogs/backend/error_handler.py +++ /dev/null @@ -1,287 +0,0 @@ -import contextlib -import logging -import typing as t - -from discord import Embed -from discord.ext.commands import Cog, Context, errors -from sentry_sdk import push_scope - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, Colours -from bot.converters import TagNameConverter -from bot.utils.checks import InWhitelistCheckFailure - -log = logging.getLogger(__name__) - - -class ErrorHandler(Cog): - """Handles errors emitted from commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - def _get_error_embed(self, title: str, body: str) -> Embed: - """Return an embed that contains the exception.""" - return Embed( - title=title, - colour=Colours.soft_red, - description=body - ) - - @Cog.listener() - async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: - """ - Provide generic command error handling. - - Error handling is deferred to any local error handler, if present. This is done by - checking for the presence of a `handled` attribute on the error. - - Error handling emits a single error message in the invoking context `ctx` and a log message, - prioritised as follows: - - 1. If the name fails to match a command: - * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. - Otherwise if it matches a tag, the tag is invoked - * If CommandNotFound is raised when invoking the tag (determined by the presence of the - `invoked_from_error_handler` attribute), this error is treated as being unexpected - and therefore sends an error message - * Commands in the verification channel are ignored - 2. UserInputError: see `handle_user_input_error` - 3. CheckFailure: see `handle_check_failure` - 4. CommandOnCooldown: send an error message in the invoking context - 5. ResponseCodeError: see `handle_api_error` - 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` - """ - command = ctx.command - - if hasattr(e, "handled"): - log.trace(f"Command {command} had its error already handled locally; ignoring.") - return - - if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): - if await self.try_silence(ctx): - return - if ctx.channel.id != Channels.verification: - # Try to look for a tag with the command's name - await self.try_get_tag(ctx) - return # Exit early to avoid logging. - elif isinstance(e, errors.UserInputError): - await self.handle_user_input_error(ctx, e) - elif isinstance(e, errors.CheckFailure): - await self.handle_check_failure(ctx, e) - elif isinstance(e, errors.CommandOnCooldown): - await ctx.send(e) - elif isinstance(e, errors.CommandInvokeError): - if isinstance(e.original, ResponseCodeError): - await self.handle_api_error(ctx, e.original) - else: - await self.handle_unexpected_error(ctx, e.original) - return # Exit early to avoid logging. - elif not isinstance(e, errors.DisabledCommand): - # ConversionError, MaxConcurrencyReached, ExtensionError - await self.handle_unexpected_error(ctx, e) - return # Exit early to avoid logging. - - log.debug( - f"Command {command} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" - ) - - @staticmethod - def get_help_command(ctx: Context) -> t.Coroutine: - """Return a prepared `help` command invocation coroutine.""" - if ctx.command: - return ctx.send_help(ctx.command) - - return ctx.send_help() - - async def try_silence(self, ctx: Context) -> bool: - """ - Attempt to invoke the silence or unsilence command if invoke with matches a pattern. - - Respecting the checks if: - * invoked with `shh+` silence channel for amount of h's*2 with max of 15. - * invoked with `unshh+` unsilence channel - Return bool depending on success of command. - """ - command = ctx.invoked_with.lower() - silence_command = self.bot.get_command("silence") - ctx.invoked_from_error_handler = True - try: - if not await silence_command.can_run(ctx): - log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") - return False - except errors.CommandError: - log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") - return False - if command.startswith("shh"): - await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) - return True - elif command.startswith("unshh"): - await ctx.invoke(self.bot.get_command("unsilence")) - return True - return False - - async def try_get_tag(self, ctx: Context) -> None: - """ - Attempt to display a tag by interpreting the command name as a tag name. - - The invocation of tags get respects its checks. Any CommandErrors raised will be handled - by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to - the context to prevent infinite recursion in the case of a CommandNotFound exception. - """ - tags_get_command = self.bot.get_command("tags get") - ctx.invoked_from_error_handler = True - - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." - try: - if not await tags_get_command.can_run(ctx): - log.debug(log_msg) - return - except errors.CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - - try: - tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) - except errors.BadArgument: - log.debug( - f"{ctx.author} tried to use an invalid command " - f"and the fallback tag failed validation in TagNameConverter." - ) - else: - with contextlib.suppress(ResponseCodeError): - await ctx.invoke(tags_get_command, tag_name=tag_name) - # Return to not raise the exception - return - - async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: - """ - Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. - - * MissingRequiredArgument: send an error message with arg name and the help command - * TooManyArguments: send an error message and the help command - * BadArgument: send an error message and the help command - * BadUnionArgument: send an error message including the error produced by the last converter - * ArgumentParsingError: send an error message - * Other: send an error message and the help command - """ - prepared_help_command = self.get_help_command(ctx) - - if isinstance(e, errors.MissingRequiredArgument): - embed = self._get_error_embed("Missing required argument", e.param.name) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.missing_required_argument") - elif isinstance(e, errors.TooManyArguments): - embed = self._get_error_embed("Too many arguments", str(e)) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.too_many_arguments") - elif isinstance(e, errors.BadArgument): - embed = self._get_error_embed("Bad argument", str(e)) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.bad_argument") - elif isinstance(e, errors.BadUnionArgument): - embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") - await ctx.send(embed=embed) - self.bot.stats.incr("errors.bad_union_argument") - elif isinstance(e, errors.ArgumentParsingError): - embed = self._get_error_embed("Argument parsing error", str(e)) - await ctx.send(embed=embed) - self.bot.stats.incr("errors.argument_parsing_error") - else: - embed = self._get_error_embed( - "Input error", - "Something about your input seems off. Check the arguments and try again." - ) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.other_user_input_error") - - @staticmethod - async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: - """ - Send an error message in `ctx` for certain types of CheckFailure. - - The following types are handled: - - * BotMissingPermissions - * BotMissingRole - * BotMissingAnyRole - * NoPrivateMessage - * InWhitelistCheckFailure - """ - bot_missing_errors = ( - errors.BotMissingPermissions, - errors.BotMissingRole, - errors.BotMissingAnyRole - ) - - if isinstance(e, bot_missing_errors): - ctx.bot.stats.incr("errors.bot_permission_error") - await ctx.send( - "Sorry, it looks like I don't have the permissions or roles I need to do that." - ) - elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): - ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") - await ctx.send(e) - - @staticmethod - async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: - """Send an error message in `ctx` for ResponseCodeError and log it.""" - if e.status == 404: - await ctx.send("There does not seem to be anything matching your query.") - log.debug(f"API responded with 404 for command {ctx.command}") - ctx.bot.stats.incr("errors.api_error_404") - elif e.status == 400: - content = await e.response.json() - log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) - await ctx.send("According to the API, your request is malformed.") - ctx.bot.stats.incr("errors.api_error_400") - elif 500 <= e.status < 600: - await ctx.send("Sorry, there seems to be an internal issue with the API.") - log.warning(f"API responded with {e.status} for command {ctx.command}") - ctx.bot.stats.incr("errors.api_internal_server_error") - else: - await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") - log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") - ctx.bot.stats.incr(f"errors.api_error_{e.status}") - - @staticmethod - async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: - """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" - await ctx.send( - f"Sorry, an unexpected error occurred. Please let us know!\n\n" - f"```{e.__class__.__name__}: {e}```" - ) - - ctx.bot.stats.incr("errors.unexpected") - - with push_scope() as scope: - scope.user = { - "id": ctx.author.id, - "username": str(ctx.author) - } - - scope.set_tag("command", ctx.command.qualified_name) - scope.set_tag("message_id", ctx.message.id) - scope.set_tag("channel_id", ctx.channel.id) - - scope.set_extra("full_message", ctx.message.content) - - if ctx.guild is not None: - scope.set_extra( - "jump_to", - f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" - ) - - log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) - - -def setup(bot: Bot) -> None: - """Load the ErrorHandler cog.""" - bot.add_cog(ErrorHandler(bot)) diff --git a/bot/cogs/backend/logging.py b/bot/cogs/backend/logging.py deleted file mode 100644 index 94fa2b139..000000000 --- a/bot/cogs/backend/logging.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging - -from discord import Embed -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, DEBUG_MODE - - -log = logging.getLogger(__name__) - - -class Logging(Cog): - """Debug logging module.""" - - def __init__(self, bot: Bot): - self.bot = bot - - self.bot.loop.create_task(self.startup_greeting()) - - async def startup_greeting(self) -> None: - """Announce our presence to the configured devlog channel.""" - await self.bot.wait_until_guild_available() - log.info("Bot connected!") - - embed = Embed(description="Connected!") - embed.set_author( - name="Python Bot", - url="https://github.com/python-discord/bot", - icon_url=( - "https://raw.githubusercontent.com/" - "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" - ) - ) - - if not DEBUG_MODE: - await self.bot.get_channel(Channels.dev_log).send(embed=embed) - - -def setup(bot: Bot) -> None: - """Load the Logging cog.""" - bot.add_cog(Logging(bot)) diff --git a/bot/cogs/backend/sync/__init__.py b/bot/cogs/backend/sync/__init__.py deleted file mode 100644 index 2541beaa8..000000000 --- a/bot/cogs/backend/sync/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from bot.bot import Bot - - -def setup(bot: Bot) -> None: - """Load the Sync cog.""" - from ._cog import Sync - bot.add_cog(Sync(bot)) diff --git a/bot/cogs/backend/sync/_cog.py b/bot/cogs/backend/sync/_cog.py deleted file mode 100644 index b6068f328..000000000 --- a/bot/cogs/backend/sync/_cog.py +++ /dev/null @@ -1,180 +0,0 @@ -import logging -from typing import Any, Dict - -from discord import Member, Role, User -from discord.ext import commands -from discord.ext.commands import Cog, Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from . import _syncers - -log = logging.getLogger(__name__) - - -class Sync(Cog): - """Captures relevant events and sends them to the site.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.role_syncer = _syncers.RoleSyncer(self.bot) - self.user_syncer = _syncers.UserSyncer(self.bot) - - self.bot.loop.create_task(self.sync_guild()) - - async def sync_guild(self) -> None: - """Syncs the roles/users of the guild with the database.""" - await self.bot.wait_until_guild_available() - - guild = self.bot.get_guild(constants.Guild.id) - if guild is None: - return - - for syncer in (self.role_syncer, self.user_syncer): - await syncer.sync(guild) - - async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: - """Send a PATCH request to partially update a user in the database.""" - try: - await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) - except ResponseCodeError as e: - if e.response.status != 404: - raise - if not ignore_404: - log.warning("Unable to update user, got 404. Assuming race condition from join event.") - - @Cog.listener() - async def on_guild_role_create(self, role: Role) -> None: - """Adds newly create role to the database table over the API.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.post( - 'bot/roles', - json={ - 'colour': role.colour.value, - 'id': role.id, - 'name': role.name, - 'permissions': role.permissions.value, - 'position': role.position, - } - ) - - @Cog.listener() - async def on_guild_role_delete(self, role: Role) -> None: - """Deletes role from the database when it's deleted from the guild.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - @Cog.listener() - async def on_guild_role_update(self, before: Role, after: Role) -> None: - """Syncs role with the database if any of the stored attributes were updated.""" - if after.guild.id != constants.Guild.id: - return - - was_updated = ( - before.name != after.name - or before.colour != after.colour - or before.permissions != after.permissions - or before.position != after.position - ) - - if was_updated: - await self.bot.api_client.put( - f'bot/roles/{after.id}', - json={ - 'colour': after.colour.value, - 'id': after.id, - 'name': after.name, - 'permissions': after.permissions.value, - 'position': after.position, - } - ) - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """ - Adds a new user or updates existing user to the database when a member joins the guild. - - If the joining member is a user that is already known to the database (i.e., a user that - previously left), it will update the user's information. If the user is not yet known by - the database, the user is added. - """ - if member.guild.id != constants.Guild.id: - return - - packed = { - 'discriminator': int(member.discriminator), - 'id': member.id, - 'in_guild': True, - 'name': member.name, - 'roles': sorted(role.id for role in member.roles) - } - - got_error = False - - try: - # First try an update of the user to set the `in_guild` field and other - # fields that may have changed since the last time we've seen them. - await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) - - except ResponseCodeError as e: - # If we didn't get 404, something else broke - propagate it up. - if e.response.status != 404: - raise - - got_error = True # yikes - - if got_error: - # If we got `404`, the user is new. Create them. - await self.bot.api_client.post('bot/users', json=packed) - - @Cog.listener() - async def on_member_remove(self, member: Member) -> None: - """Set the in_guild field to False when a member leaves the guild.""" - if member.guild.id != constants.Guild.id: - return - - await self.patch_user(member.id, json={"in_guild": False}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Update the roles of the member in the database if a change is detected.""" - if after.guild.id != constants.Guild.id: - return - - if before.roles != after.roles: - updated_information = {"roles": sorted(role.id for role in after.roles)} - await self.patch_user(after.id, json=updated_information) - - @Cog.listener() - async def on_user_update(self, before: User, after: User) -> None: - """Update the user information in the database if a relevant change is detected.""" - attrs = ("name", "discriminator") - if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): - updated_information = { - "name": after.name, - "discriminator": int(after.discriminator), - } - # A 404 likely means the user is in another guild. - await self.patch_user(after.id, json=updated_information, ignore_404=True) - - @commands.group(name='sync') - @commands.has_permissions(administrator=True) - async def sync_group(self, ctx: Context) -> None: - """Run synchronizations between the bot and site manually.""" - - @sync_group.command(name='roles') - @commands.has_permissions(administrator=True) - async def sync_roles_command(self, ctx: Context) -> None: - """Manually synchronise the guild's roles with the roles on the site.""" - await self.role_syncer.sync(ctx.guild, ctx) - - @sync_group.command(name='users') - @commands.has_permissions(administrator=True) - async def sync_users_command(self, ctx: Context) -> None: - """Manually synchronise the guild's users with the users on the site.""" - await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/backend/sync/_syncers.py b/bot/cogs/backend/sync/_syncers.py deleted file mode 100644 index f7ba811bc..000000000 --- a/bot/cogs/backend/sync/_syncers.py +++ /dev/null @@ -1,347 +0,0 @@ -import abc -import asyncio -import logging -import typing as t -from collections import namedtuple -from functools import partial - -import discord -from discord import Guild, HTTPException, Member, Message, Reaction, User -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot - -log = logging.getLogger(__name__) - -# These objects are declared as namedtuples because tuples are hashable, -# something that we make use of when diffing site roles against guild roles. -_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) -_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) - - -class Syncer(abc.ABC): - """Base class for synchronising the database with objects in the Discord cache.""" - - _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " - _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - @property - @abc.abstractmethod - def name(self) -> str: - """The name of the syncer; used in output messages and logging.""" - raise NotImplementedError # pragma: no cover - - async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: - """ - Send a prompt to confirm or abort a sync using reactions and return the sent message. - - If a message is given, it is edited to display the prompt and reactions. Otherwise, a new - message is sent to the dev-core channel and mentions the core developers role. If the - channel cannot be retrieved, return None. - """ - log.trace(f"Sending {self.name} sync confirmation prompt.") - - msg_content = ( - f'Possible cache issue while syncing {self.name}s. ' - f'More than {constants.Sync.max_diff} {self.name}s were changed. ' - f'React to confirm or abort the sync.' - ) - - # Send to core developers if it's an automatic sync. - if not message: - log.trace("Message not provided for confirmation; creating a new one in dev-core.") - channel = self.bot.get_channel(constants.Channels.dev_core) - - if not channel: - log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") - try: - channel = await self.bot.fetch_channel(constants.Channels.dev_core) - except HTTPException: - log.exception( - f"Failed to fetch channel for sending sync confirmation prompt; " - f"aborting {self.name} sync." - ) - return None - - allowed_roles = [discord.Object(constants.Roles.core_developers)] - message = await channel.send( - f"{self._CORE_DEV_MENTION}{msg_content}", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) - ) - else: - await message.edit(content=msg_content) - - # Add the initial reactions. - log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") - for emoji in self._REACTION_EMOJIS: - await message.add_reaction(emoji) - - return message - - def _reaction_check( - self, - author: Member, - message: Message, - reaction: Reaction, - user: t.Union[Member, User] - ) -> bool: - """ - Return True if the `reaction` is a valid confirmation or abort reaction on `message`. - - If the `author` of the prompt is a bot, then a reaction by any core developer will be - considered valid. Otherwise, the author of the reaction (`user`) will have to be the - `author` of the prompt. - """ - # For automatic syncs, check for the core dev role instead of an exact author - has_role = any(constants.Roles.core_developers == role.id for role in user.roles) - return ( - reaction.message.id == message.id - and not user.bot - and (has_role if author.bot else user == author) - and str(reaction.emoji) in self._REACTION_EMOJIS - ) - - async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: - """ - Wait for a confirmation reaction by `author` on `message` and return True if confirmed. - - Uses the `_reaction_check` function to determine if a reaction is valid. - - If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. - To acknowledge the reaction (or lack thereof), `message` will be edited. - """ - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - reaction = None - try: - log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") - reaction, _ = await self.bot.wait_for( - 'reaction_add', - check=partial(self._reaction_check, author, message), - timeout=constants.Sync.confirm_timeout - ) - except asyncio.TimeoutError: - # reaction will remain none thus sync will be aborted in the finally block below. - log.debug(f"The {self.name} syncer confirmation prompt timed out.") - - if str(reaction) == constants.Emojis.check_mark: - log.trace(f"The {self.name} syncer was confirmed.") - await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') - return True - else: - log.info(f"The {self.name} syncer was aborted or timed out!") - await message.edit( - content=f':warning: {mention}{self.name} sync aborted or timed out!' - ) - return False - - @abc.abstractmethod - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference between the cache of `guild` and the database.""" - raise NotImplementedError # pragma: no cover - - @abc.abstractmethod - async def _sync(self, diff: _Diff) -> None: - """Perform the API calls for synchronisation.""" - raise NotImplementedError # pragma: no cover - - async def _get_confirmation_result( - self, - diff_size: int, - author: Member, - message: t.Optional[Message] = None - ) -> t.Tuple[bool, t.Optional[Message]]: - """ - Prompt for confirmation and return a tuple of the result and the prompt message. - - `diff_size` is the size of the diff of the sync. If it is greater than - `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the - sync and the `message` is an extant message to edit to display the prompt. - - If confirmed or no confirmation was needed, the result is True. The returned message will - either be the given `message` or a new one which was created when sending the prompt. - """ - log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") - if diff_size > constants.Sync.max_diff: - message = await self._send_prompt(message) - if not message: - return False, None # Couldn't get channel. - - confirmed = await self._wait_for_confirmation(author, message) - if not confirmed: - return False, message # Sync aborted. - - return True, message - - async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: - """ - Synchronise the database with the cache of `guild`. - - If the differences between the cache and the database are greater than - `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core - channel. The confirmation can be optionally redirect to `ctx` instead. - """ - log.info(f"Starting {self.name} syncer.") - - message = None - author = self.bot.user - if ctx: - message = await ctx.send(f"📊 Synchronising {self.name}s.") - author = ctx.author - - diff = await self._get_diff(guild) - diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict - totals = {k: len(v) for k, v in diff_dict.items() if v is not None} - diff_size = sum(totals.values()) - - confirmed, message = await self._get_confirmation_result(diff_size, author, message) - if not confirmed: - return - - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - try: - await self._sync(diff) - except ResponseCodeError as e: - log.exception(f"{self.name} syncer failed!") - - # Don't show response text because it's probably some really long HTML. - results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" - content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" - else: - results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) - log.info(f"{self.name} syncer finished: {results}.") - content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" - - if message: - await message.edit(content=content) - - -class RoleSyncer(Syncer): - """Synchronise the database with roles in the cache.""" - - name = "role" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of roles between the cache of `guild` and the database.""" - log.trace("Getting the diff for roles.") - roles = await self.bot.api_client.get('bot/roles') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_roles = {_Role(**role_dict) for role_dict in roles} - guild_roles = { - _Role( - id=role.id, - name=role.name, - colour=role.colour.value, - permissions=role.permissions.value, - position=role.position, - ) - for role in guild.roles - } - - guild_role_ids = {role.id for role in guild_roles} - api_role_ids = {role.id for role in db_roles} - new_role_ids = guild_role_ids - api_role_ids - deleted_role_ids = api_role_ids - guild_role_ids - - # New roles are those which are on the cached guild but not on the - # DB guild, going by the role ID. We need to send them in for creation. - roles_to_create = {role for role in guild_roles if role.id in new_role_ids} - roles_to_update = guild_roles - db_roles - roles_to_create - roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} - - return _Diff(roles_to_create, roles_to_update, roles_to_delete) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the role cache of `guild`.""" - log.trace("Syncing created roles...") - for role in diff.created: - await self.bot.api_client.post('bot/roles', json=role._asdict()) - - log.trace("Syncing updated roles...") - for role in diff.updated: - await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) - - log.trace("Syncing deleted roles...") - for role in diff.deleted: - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - -class UserSyncer(Syncer): - """Synchronise the database with users in the cache.""" - - name = "user" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of users between the cache of `guild` and the database.""" - log.trace("Getting the diff for users.") - users = await self.bot.api_client.get('bot/users') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_users = { - user_dict['id']: _User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in users - } - guild_users = { - member.id: _User( - id=member.id, - name=member.name, - discriminator=int(member.discriminator), - roles=tuple(sorted(role.id for role in member.roles)), - in_guild=True - ) - for member in guild.members - } - - users_to_create = set() - users_to_update = set() - - for db_user in db_users.values(): - guild_user = guild_users.get(db_user.id) - if guild_user is not None: - if db_user != guild_user: - users_to_update.add(guild_user) - - elif db_user.in_guild: - # The user is known in the DB but not the guild, and the - # DB currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - new_api_user = db_user._replace(in_guild=False) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(db_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) - - return _Diff(users_to_create, users_to_update, None) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the user cache of `guild`.""" - log.trace("Syncing created users...") - for user in diff.created: - await self.bot.api_client.post('bot/users', json=user._asdict()) - - log.trace("Syncing updated users...") - for user in diff.updated: - await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/dm_relay.py b/bot/cogs/dm_relay.py deleted file mode 100644 index 0d8f340b4..000000000 --- a/bot/cogs/dm_relay.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging -from typing import Optional - -import discord -from discord import Color -from discord.ext import commands -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot -from bot.converters import UserMentionOrID -from bot.utils import RedisCache -from bot.utils.checks import in_whitelist_check, with_role_check -from bot.utils.messages import send_attachments -from bot.utils.webhooks import send_webhook - -log = logging.getLogger(__name__) - - -class DMRelay(Cog): - """Relay direct messages to and from the bot.""" - - # RedisCache[str, t.Union[discord.User.id, discord.Member.id]] - dm_cache = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_id = constants.Webhooks.dm_log - self.webhook = None - self.bot.loop.create_task(self.fetch_webhook()) - - @commands.command(aliases=("reply",)) - async def send_dm(self, ctx: commands.Context, member: Optional[UserMentionOrID], *, message: str) -> None: - """ - Allows you to send a DM to a user from the bot. - - If `member` is not provided, it will send to the last user who DM'd the bot. - - This feature should be used extremely sparingly. Use ModMail if you need to have a serious - conversation with a user. This is just for responding to extraordinary DMs, having a little - fun with users, and telling people they are DMing the wrong bot. - - NOTE: This feature will be removed if it is overused. - """ - if not member: - user_id = await self.dm_cache.get("last_user") - member = ctx.guild.get_member(user_id) if user_id else None - - # If we still don't have a Member at this point, give up - if not member: - log.debug("This bot has never gotten a DM, or the RedisCache has been cleared.") - await ctx.message.add_reaction("❌") - return - - try: - await member.send(message) - except discord.errors.Forbidden: - log.debug("User has disabled DMs.") - await ctx.message.add_reaction("❌") - else: - await ctx.message.add_reaction("✅") - self.bot.stats.incr("dm_relay.dm_sent") - - async def fetch_webhook(self) -> None: - """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_guild_available() - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - @Cog.listener() - async def on_message(self, message: discord.Message) -> None: - """Relays the message's content and attachments to the dm_log channel.""" - # Only relay DMs from humans - if message.author.bot or message.guild or self.webhook is None: - return - - if message.clean_content: - await send_webhook( - webhook=self.webhook, - content=message.clean_content, - username=f"{message.author.display_name} ({message.author.id})", - avatar_url=message.author.avatar_url - ) - await self.dm_cache.set("last_user", message.author.id) - self.bot.stats.incr("dm_relay.dm_received") - - # Handle any attachments - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (discord.errors.Forbidden, discord.errors.NotFound): - e = discord.Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await send_webhook( - webhook=self.webhook, - embed=e, - username=f"{message.author.display_name} ({message.author.id})", - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception("Failed to send an attachment to the webhook") - - def cog_check(self, ctx: commands.Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - checks = [ - with_role_check(ctx, *constants.MODERATION_ROLES), - in_whitelist_check( - ctx, - channels=[constants.Channels.dm_log], - redirect=None, - fail_silently=True, - ) - ] - return all(checks) - - -def setup(bot: Bot) -> None: - """Load the DMRelay cog.""" - bot.add_cog(DMRelay(bot)) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py deleted file mode 100644 index 7021069fa..000000000 --- a/bot/cogs/duck_pond.py +++ /dev/null @@ -1,166 +0,0 @@ -import logging -from typing import Union - -import discord -from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot -from bot.utils.messages import send_attachments -from bot.utils.webhooks import send_webhook - -log = logging.getLogger(__name__) - - -class DuckPond(Cog): - """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_id = constants.Webhooks.duck_pond - self.webhook = None - self.bot.loop.create_task(self.fetch_webhook()) - - async def fetch_webhook(self) -> None: - """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_guild_available() - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - @staticmethod - def is_staff(member: Union[User, Member]) -> bool: - """Check if a specific member or user is staff.""" - if hasattr(member, "roles"): - for role in member.roles: - if role.id in constants.STAFF_ROLES: - return True - return False - - async def has_green_checkmark(self, message: Message) -> bool: - """Check if the message has a green checkmark reaction.""" - for reaction in message.reactions: - if reaction.emoji == "✅": - async for user in reaction.users(): - if user == self.bot.user: - return True - return False - - async def count_ducks(self, message: Message) -> int: - """ - Count the number of ducks in the reactions of a specific message. - - Only counts ducks added by staff members. - """ - duck_count = 0 - duck_reactors = [] - - for reaction in message.reactions: - async for user in reaction.users(): - - # Is the user a staff member and not already counted as reactor? - if not self.is_staff(user) or user.id in duck_reactors: - continue - - # Is the emoji a duck? - if hasattr(reaction.emoji, "id"): - if reaction.emoji.id in constants.DuckPond.custom_emojis: - duck_count += 1 - duck_reactors.append(user.id) - elif isinstance(reaction.emoji, str): - if reaction.emoji == "🦆": - duck_count += 1 - duck_reactors.append(user.id) - return duck_count - - async def relay_message(self, message: Message) -> None: - """Relays the message's content and attachments to the duck pond channel.""" - if message.clean_content: - await send_webhook( - webhook=self.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await send_webhook( - webhook=self.webhook, - embed=e, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception("Failed to send an attachment to the webhook") - - await message.add_reaction("✅") - - @staticmethod - def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: - """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" - if payload.emoji.is_custom_emoji(): - if payload.emoji.id in constants.DuckPond.custom_emojis: - return True - elif payload.emoji.name == "🦆": - return True - - return False - - @Cog.listener() - async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: - """ - Determine if a message should be sent to the duck pond. - - This will count the number of duck reactions on the message, and if this amount meets the - amount of ducks specified in the config under duck_pond/threshold, it will - send the message off to the duck pond. - """ - # Is the emoji in the reaction a duck? - if not self._payload_has_duckpond_emoji(payload): - return - - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - message = await channel.fetch_message(payload.message_id) - member = discord.utils.get(message.guild.members, id=payload.user_id) - - # Is the member a human and a staff member? - if not self.is_staff(member) or member.bot: - return - - # Does the message already have a green checkmark? - if await self.has_green_checkmark(message): - return - - # Time to count our ducks! - duck_count = await self.count_ducks(message) - - # If we've got more than the required amount of ducks, send the message to the duck_pond. - if duck_count >= constants.DuckPond.threshold: - await self.relay_message(message) - - @Cog.listener() - async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: - """Ensure that people don't remove the green checkmark from duck ponded messages.""" - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - - # Prevent the green checkmark from being removed - if payload.emoji.name == "✅": - message = await channel.fetch_message(payload.message_id) - duck_count = await self.count_ducks(message) - if duck_count >= constants.DuckPond.threshold: - await message.add_reaction("✅") - - -def setup(bot: Bot) -> None: - """Load the DuckPond cog.""" - bot.add_cog(DuckPond(bot)) diff --git a/bot/cogs/filters/__init__.py b/bot/cogs/filters/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/filters/antimalware.py b/bot/cogs/filters/antimalware.py deleted file mode 100644 index c76bd2c60..000000000 --- a/bot/cogs/filters/antimalware.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging -import typing as t -from os.path import splitext - -from discord import Embed, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, STAFF_ROLES, URLs - -log = logging.getLogger(__name__) - -PY_EMBED_DESCRIPTION = ( - "It looks like you tried to attach a Python file - " - f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" -) - -TXT_EMBED_DESCRIPTION = ( - "**Uh-oh!** It looks like your message got zapped by our spam filter. " - "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" - "• If you attempted to send a message longer than 2000 characters, try shortening your message " - "to fit within the character limit or use a pasting service (see below) \n\n" - "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " - "{cmd_channel_mention} for more information) or use a pasting service like: " - f"\n\n{URLs.site_schema}{URLs.site_paste}" -) - -DISALLOWED_EMBED_DESCRIPTION = ( - "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " - "We currently allow the following file types: **{joined_whitelist}**.\n\n" - "Feel free to ask in {meta_channel_mention} if you think this is a mistake." -) - - -class AntiMalware(Cog): - """Delete messages which contain attachments with non-whitelisted file extensions.""" - - def __init__(self, bot: Bot): - self.bot = bot - - def _get_whitelisted_file_formats(self) -> list: - """Get the file formats currently on the whitelist.""" - return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() - - def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: - """Get an iterable containing all the disallowed extensions of attachments.""" - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} - extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) - return extensions_blocked - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Identify messages with prohibited attachments.""" - # Return when message don't have attachment and don't moderate DMs - if not message.attachments or not message.guild: - return - - # Check if user is staff, if is, return - # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance - if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): - return - - embed = Embed() - extensions_blocked = self._get_disallowed_extensions(message) - blocked_extensions_str = ', '.join(extensions_blocked) - if ".py" in extensions_blocked: - # Short-circuit on *.py files to provide a pastebin link - embed.description = PY_EMBED_DESCRIPTION - elif ".txt" in extensions_blocked: - # Work around Discord AutoConversion of messages longer than 2000 chars to .txt - cmd_channel = self.bot.get_channel(Channels.bot_commands) - embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) - elif extensions_blocked: - meta_channel = self.bot.get_channel(Channels.meta) - embed.description = DISALLOWED_EMBED_DESCRIPTION.format( - joined_whitelist=', '.join(self._get_whitelisted_file_formats()), - blocked_extensions_str=blocked_extensions_str, - meta_channel_mention=meta_channel.mention, - ) - - if embed.description: - log.info( - f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", - extra={"attachment_list": [attachment.filename for attachment in message.attachments]} - ) - - await message.channel.send(f"Hey {message.author.mention}!", embed=embed) - - # Delete the offending message: - try: - await message.delete() - except NotFound: - log.info(f"Tried to delete message `{message.id}`, but message could not be found.") - - -def setup(bot: Bot) -> None: - """Load the AntiMalware cog.""" - bot.add_cog(AntiMalware(bot)) diff --git a/bot/cogs/filters/antispam.py b/bot/cogs/filters/antispam.py deleted file mode 100644 index d2dccea06..000000000 --- a/bot/cogs/filters/antispam.py +++ /dev/null @@ -1,288 +0,0 @@ -import asyncio -import logging -from collections.abc import Mapping -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from operator import itemgetter -from typing import Dict, Iterable, List, Set - -from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Cog - -from bot import rules -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import ( - AntiSpam as AntiSpamConfig, Channels, - Colours, DEBUG_MODE, Event, Filter, - Guild as GuildConfig, Icons, - STAFF_ROLES, -) -from bot.converters import Duration -from bot.utils.messages import send_attachments - - -log = logging.getLogger(__name__) - -RULE_FUNCTION_MAPPING = { - 'attachments': rules.apply_attachments, - 'burst': rules.apply_burst, - 'burst_shared': rules.apply_burst_shared, - 'chars': rules.apply_chars, - 'discord_emojis': rules.apply_discord_emojis, - 'duplicates': rules.apply_duplicates, - 'links': rules.apply_links, - 'mentions': rules.apply_mentions, - 'newlines': rules.apply_newlines, - 'role_mentions': rules.apply_role_mentions -} - - -@dataclass -class DeletionContext: - """Represents a Deletion Context for a single spam event.""" - - channel: TextChannel - members: Dict[int, Member] = field(default_factory=dict) - rules: Set[str] = field(default_factory=set) - messages: Dict[int, Message] = field(default_factory=dict) - attachments: List[List[str]] = field(default_factory=list) - - async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: - """Adds new rule violation events to the deletion context.""" - self.rules.add(rule_name) - - for member in members: - if member.id not in self.members: - self.members[member.id] = member - - for message in messages: - if message.id not in self.messages: - self.messages[message.id] = message - - # Re-upload attachments - destination = message.guild.get_channel(Channels.attachment_log) - urls = await send_attachments(message, destination, link_large=False) - self.attachments.append(urls) - - async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: - """Method that takes care of uploading the queue and posting modlog alert.""" - triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values()) - - mod_alert_message = ( - f"**Triggered by:** {triggered_by_users}\n" - f"**Channel:** {self.channel.mention}\n" - f"**Rules:** {', '.join(rule for rule in self.rules)}\n" - ) - - # For multiple messages or those with excessive newlines, use the logs API - if len(self.messages) > 1 or 'newlines' in self.rules: - url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) - mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" - else: - mod_alert_message += "Message:\n" - [message] = self.messages.values() - content = message.clean_content - remaining_chars = 2040 - len(mod_alert_message) - - if len(content) > remaining_chars: - content = content[:remaining_chars] + "..." - - mod_alert_message += f"{content}" - - *_, last_message = self.messages.values() - await modlog.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title="Spam detected!", - text=mod_alert_message, - thumbnail=last_message.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=AntiSpamConfig.ping_everyone - ) - - -class AntiSpam(Cog): - """Cog that controls our anti-spam measures.""" - - def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: - self.bot = bot - self.validation_errors = validation_errors - role_id = AntiSpamConfig.punishment['role_id'] - self.muted_role = Object(role_id) - self.expiration_date_converter = Duration() - - self.message_deletion_queue = dict() - - self.bot.loop.create_task(self.alert_on_validation_error()) - - @property - def mod_log(self) -> ModLog: - """Allows for easy access of the ModLog cog.""" - return self.bot.get_cog("ModLog") - - async def alert_on_validation_error(self) -> None: - """Unloads the cog and alerts admins if configuration validation failed.""" - await self.bot.wait_until_guild_available() - if self.validation_errors: - body = "**The following errors were encountered:**\n" - body += "\n".join(f"- {error}" for error in self.validation_errors.values()) - body += "\n\n**The cog has been unloaded.**" - - await self.mod_log.send_log_message( - title="Error: AntiSpam configuration validation failed!", - text=body, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Colour.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Applies the antispam rules to each received message.""" - if ( - not message.guild - or message.guild.id != GuildConfig.id - or message.author.bot - or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) - or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) - ): - return - - # Fetch the rule configuration with the highest rule interval. - max_interval_config = max( - AntiSpamConfig.rules.values(), - key=itemgetter('interval') - ) - max_interval = max_interval_config['interval'] - - # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. - earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) - relevant_messages = [ - msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) - if not msg.author.bot - ] - - for rule_name in AntiSpamConfig.rules: - rule_config = AntiSpamConfig.rules[rule_name] - rule_function = RULE_FUNCTION_MAPPING[rule_name] - - # Create a list of messages that were sent in the interval that the rule cares about. - latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) - messages_for_rule = [ - msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp - ] - result = await rule_function(message, messages_for_rule, rule_config) - - # If the rule returns `None`, that means the message didn't violate it. - # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` - # which contains the reason for why the message violated the rule and - # an iterable of all members that violated the rule. - if result is not None: - self.bot.stats.incr(f"mod_alerts.{rule_name}") - reason, members, relevant_messages = result - full_reason = f"`{rule_name}` rule: {reason}" - - # If there's no spam event going on for this channel, start a new Message Deletion Context - channel = message.channel - if channel.id not in self.message_deletion_queue: - log.trace(f"Creating queue for channel `{channel.id}`") - self.message_deletion_queue[message.channel.id] = DeletionContext(channel) - self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) - - # Add the relevant of this trigger to the Deletion Context - await self.message_deletion_queue[message.channel.id].add( - rule_name=rule_name, - members=members, - messages=relevant_messages - ) - - for member in members: - - # Fire it off as a background task to ensure - # that the sleep doesn't block further tasks - self.bot.loop.create_task( - self.punish(message, member, full_reason) - ) - - await self.maybe_delete_messages(channel, relevant_messages) - break - - async def punish(self, msg: Message, member: Member, reason: str) -> None: - """Punishes the given member for triggering an antispam rule.""" - if not any(role.id == self.muted_role.id for role in member.roles): - remove_role_after = AntiSpamConfig.punishment['remove_after'] - - # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes - context = await self.bot.get_context(msg) - context.author = self.bot.user - context.message.author = self.bot.user - - # Since we're going to invoke the tempmute command directly, we need to manually call the converter. - dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S") - await context.invoke( - self.bot.get_command('tempmute'), - member, - dt_remove_role_after, - reason=reason - ) - - async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: - """Cleans the messages if cleaning is configured.""" - if AntiSpamConfig.clean_offending: - # If we have more than one message, we can use bulk delete. - if len(messages) > 1: - message_ids = [message.id for message in messages] - self.mod_log.ignore(Event.message_delete, *message_ids) - await channel.delete_messages(messages) - - # Otherwise, the bulk delete endpoint will throw up. - # Delete the message directly instead. - else: - self.mod_log.ignore(Event.message_delete, messages[0].id) - try: - await messages[0].delete() - except NotFound: - log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") - - async def _process_deletion_context(self, context_id: int) -> None: - """Processes the Deletion Context queue.""" - log.trace("Sleeping before processing message deletion queue.") - await asyncio.sleep(10) - - if context_id not in self.message_deletion_queue: - log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") - return - - deletion_context = self.message_deletion_queue.pop(context_id) - await deletion_context.upload_messages(self.bot.user.id, self.mod_log) - - -def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: - """Validates the antispam configs.""" - validation_errors = {} - for name, config in rules_.items(): - if name not in RULE_FUNCTION_MAPPING: - log.error( - f"Unrecognized antispam rule `{name}`. " - f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" - ) - validation_errors[name] = f"`{name}` is not recognized as an antispam rule." - continue - for required_key in ('interval', 'max'): - if required_key not in config: - log.error( - f"`{required_key}` is required but was not " - f"set in rule `{name}`'s configuration." - ) - validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" - return validation_errors - - -def setup(bot: Bot) -> None: - """Validate the AntiSpam configs and load the AntiSpam cog.""" - validation_errors = validate_config() - bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/cogs/filters/filter_lists.py b/bot/cogs/filters/filter_lists.py deleted file mode 100644 index c15adc461..000000000 --- a/bot/cogs/filters/filter_lists.py +++ /dev/null @@ -1,273 +0,0 @@ -import logging -from typing import Optional - -from discord import Colour, Embed -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.converters import ValidDiscordServerInvite, ValidFilterListType -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - - -class FilterLists(Cog): - """Commands for blacklisting and whitelisting things.""" - - methods_with_filterlist_types = [ - "allow_add", - "allow_delete", - "allow_get", - "deny_add", - "deny_delete", - "deny_get", - ] - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.bot.loop.create_task(self._amend_docstrings()) - - async def _amend_docstrings(self) -> None: - """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" - await self.bot.wait_until_guild_available() - - # Add valid filterlist types to the docstrings - valid_types = await ValidFilterListType.get_valid_types(self.bot) - valid_types = [f"`{type_.lower()}`" for type_ in valid_types] - - for method_name in self.methods_with_filterlist_types: - command = getattr(self, method_name) - command.help = ( - f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." - ) - - async def _add_data( - self, - ctx: Context, - allowed: bool, - list_type: ValidFilterListType, - content: str, - comment: Optional[str] = None, - ) -> None: - """Add an item to a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we gotta validate it. - if list_type == "GUILD_INVITE": - guild_data = await self._validate_guild_invite(ctx, content) - content = guild_data.get("id") - - # Unless the user has specified another comment, let's - # use the server name as the comment so that the list - # of guild IDs will be more easily readable when we - # display it. - if not comment: - comment = guild_data.get("name") - - # If it's a file format, let's make sure it has a leading dot. - elif list_type == "FILE_FORMAT" and not content.startswith("."): - content = f".{content}" - - # Try to add the item to the database - log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") - payload = { - "allowed": allowed, - "type": list_type, - "content": content, - "comment": comment, - } - - try: - item = await self.bot.api_client.post( - "bot/filter-lists", - json=payload - ) - except ResponseCodeError as e: - if e.status == 400: - await ctx.message.add_reaction("❌") - log.debug( - f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " - "probably because the request violated the UniqueConstraint." - ) - raise BadArgument( - f"Unable to add the item to the {allow_type}. " - "The item probably already exists. Keep in mind that a " - "blacklist and a whitelist for the same item cannot co-exist, " - "and we do not permit any duplicates." - ) - raise - - # Insert the item into the cache - self.bot.insert_item_into_filter_list_cache(item) - await ctx.message.add_reaction("✅") - - async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we need to convert it. - if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): - guild_data = await self._validate_guild_invite(ctx, content) - content = guild_data.get("id") - - # If it's a file format, let's make sure it has a leading dot. - elif list_type == "FILE_FORMAT" and not content.startswith("."): - content = f".{content}" - - # Find the content and delete it. - log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") - item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) - - if item is not None: - try: - await self.bot.api_client.delete( - f"bot/filter-lists/{item['id']}" - ) - del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] - await ctx.message.add_reaction("✅") - except ResponseCodeError as e: - log.debug( - f"{ctx.author} tried to delete an item with the id {item['id']}, but " - f"the API raised an unexpected error: {e}" - ) - await ctx.message.add_reaction("❌") - else: - await ctx.message.add_reaction("❌") - - async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: - """Paginate and display all items in a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] - - # Build a list of lines we want to show in the paginator - lines = [] - for content, metadata in result.items(): - line = f"• `{content}`" - - if comment := metadata.get("comment"): - line += f" - {comment}" - - lines.append(line) - lines = sorted(lines) - - # Build the embed - list_type_plural = list_type.lower().replace("_", " ").title() + "s" - embed = Embed( - title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", - colour=Colour.blue() - ) - log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") - - if result: - await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - await ctx.message.add_reaction("❌") - - async def _sync_data(self, ctx: Context) -> None: - """Syncs the filterlists with the API.""" - try: - log.trace("Attempting to sync FilterList cache with data from the API.") - await self.bot.cache_filter_list_data() - await ctx.message.add_reaction("✅") - except ResponseCodeError as e: - log.debug( - f"{ctx.author} tried to sync FilterList cache data but " - f"the API raised an unexpected error: {e}" - ) - await ctx.message.add_reaction("❌") - - @staticmethod - async def _validate_guild_invite(ctx: Context, invite: str) -> dict: - """ - Validates a guild invite, and returns the guild info as a dict. - - Will raise a BadArgument if the guild invite is invalid. - """ - log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") - validator = ValidDiscordServerInvite() - guild_data = await validator.convert(ctx, invite) - - # If we make it this far without raising a BadArgument, the invite is - # valid. Let's return a dict of guild information. - log.trace(f"{invite} validated as server invite. Converting to ID.") - return guild_data - - @group(aliases=("allowlist", "allow", "al", "wl")) - async def whitelist(self, ctx: Context) -> None: - """Group for whitelisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @group(aliases=("denylist", "deny", "bl", "dl")) - async def blacklist(self, ctx: Context) -> None: - """Group for blacklisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @whitelist.command(name="add", aliases=("a", "set")) - async def allow_add( - self, - ctx: Context, - list_type: ValidFilterListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified allowlist.""" - await self._add_data(ctx, True, list_type, content, comment) - - @blacklist.command(name="add", aliases=("a", "set")) - async def deny_add( - self, - ctx: Context, - list_type: ValidFilterListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified denylist.""" - await self._add_data(ctx, False, list_type, content, comment) - - @whitelist.command(name="remove", aliases=("delete", "rm",)) - async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from the specified allowlist.""" - await self._delete_data(ctx, True, list_type, content) - - @blacklist.command(name="remove", aliases=("delete", "rm",)) - async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from the specified denylist.""" - await self._delete_data(ctx, False, list_type, content) - - @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: - """Get the contents of a specified allowlist.""" - await self._list_all_data(ctx, True, list_type) - - @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: - """Get the contents of a specified denylist.""" - await self._list_all_data(ctx, False, list_type) - - @whitelist.command(name="sync", aliases=("s",)) - async def allow_sync(self, ctx: Context) -> None: - """Syncs both allowlists and denylists with the API.""" - await self._sync_data(ctx) - - @blacklist.command(name="sync", aliases=("s",)) - async def deny_sync(self, ctx: Context) -> None: - """Syncs both allowlists and denylists with the API.""" - await self._sync_data(ctx) - - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the FilterLists cog.""" - bot.add_cog(FilterLists(bot)) diff --git a/bot/cogs/filters/filtering.py b/bot/cogs/filters/filtering.py deleted file mode 100644 index 556b466ef..000000000 --- a/bot/cogs/filters/filtering.py +++ /dev/null @@ -1,575 +0,0 @@ -import asyncio -import logging -import re -from datetime import datetime, timedelta -from typing import List, Mapping, Optional, Tuple, Union - -import dateutil -import discord.errors -from dateutil.relativedelta import relativedelta -from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel -from discord.ext.commands import Cog -from discord.utils import escape_markdown - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import ( - Channels, Colours, - Filter, Icons, URLs -) -from bot.utils.redis_cache import RedisCache -from bot.utils.regex import INVITE_RE -from bot.utils.scheduling import Scheduler - -log = logging.getLogger(__name__) - -# Regular expressions -SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) -URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) -ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") - -# Other constants. -DAYS_BETWEEN_ALERTS = 3 -OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) - - -class Filtering(Cog): - """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" - - # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent - name_alerts = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - self.name_lock = asyncio.Lock() - - staff_mistake_str = "If you believe this was a mistake, please let staff know!" - self.filters = { - "filter_zalgo": { - "enabled": Filter.filter_zalgo, - "function": self._has_zalgo, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_zalgo, - "notification_msg": ( - "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " - f"{staff_mistake_str}" - ), - "schedule_deletion": False - }, - "filter_invites": { - "enabled": Filter.filter_invites, - "function": self._has_invites, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_invites, - "notification_msg": ( - f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" - r"Our server rules can be found here: " - ), - "schedule_deletion": False - }, - "filter_domains": { - "enabled": Filter.filter_domains, - "function": self._has_urls, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_domains, - "notification_msg": ( - f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" - ), - "schedule_deletion": False - }, - "watch_regex": { - "enabled": Filter.watch_regex, - "function": self._has_watch_regex_match, - "type": "watchlist", - "content_only": True, - "schedule_deletion": True - }, - "watch_rich_embeds": { - "enabled": Filter.watch_rich_embeds, - "function": self._has_rich_embed, - "type": "watchlist", - "content_only": False, - "schedule_deletion": False - } - } - - self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: - """Fetch items from the filter_list_cache.""" - return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() - - @staticmethod - def _expand_spoilers(text: str) -> str: - """Return a string containing all interpretations of a spoilered message.""" - split_text = SPOILER_RE.split(text) - return ''.join( - split_text[0::2] + split_text[1::2] + split_text - ) - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Invoke message filter for new messages.""" - await self._filter_message(msg) - - # Ignore webhook messages. - if msg.webhook_id is None: - await self.check_bad_words_in_name(msg.author) - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """ - Invoke message filter for message edits. - - If there have been multiple edits, calculate the time delta from the previous edit. - """ - if not before.edited_at: - delta = relativedelta(after.edited_at, before.created_at).microseconds - else: - delta = relativedelta(after.edited_at, before.edited_at).microseconds - await self._filter_message(after, delta) - - def get_name_matches(self, name: str) -> List[re.Match]: - """Check bad words from passed string (name). Return list of matches.""" - matches = [] - watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) - for pattern in watchlist_patterns: - if match := re.search(pattern, name, flags=re.IGNORECASE): - matches.append(match) - return matches - - async def check_send_alert(self, member: Member) -> bool: - """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" - if last_alert := await self.name_alerts.get(member.id): - last_alert = datetime.utcfromtimestamp(last_alert) - if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: - log.trace(f"Last alert was too recent for {member}'s nickname.") - return False - - return True - - async def check_bad_words_in_name(self, member: Member) -> None: - """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" - # Use lock to avoid race conditions - async with self.name_lock: - # Check whether the users display name contains any words in our blacklist - matches = self.get_name_matches(member.display_name) - - if not matches or not await self.check_send_alert(member): - return - - log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") - - log_string = ( - f"**User:** {member.mention} (`{member.id}`)\n" - f"**Display Name:** {member.display_name}\n" - f"**Bad Matches:** {', '.join(match.group() for match in matches)}" - ) - - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colours.soft_red, - title="Username filtering alert", - text=log_string, - channel_id=Channels.mod_alerts, - thumbnail=member.avatar_url - ) - - # Update time when alert sent - await self.name_alerts.set(member.id, datetime.utcnow().timestamp()) - - async def filter_eval(self, result: str, msg: Message) -> bool: - """ - Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. - - Also requires the original message, to check whether to filter and for mod logs. - Returns whether a filter was triggered or not. - """ - filter_triggered = False - # Should we filter this message? - if self._check_filter(msg): - for filter_name, _filter in self.filters.items(): - # Is this specific filter enabled in the config? - # We also do not need to worry about filters that take the full message, - # since all we have is an arbitrary string. - if _filter["enabled"] and _filter["content_only"]: - match = await _filter["function"](result) - - if match: - # If this is a filter (not a watchlist), we set the variable so we know - # that it has been triggered - if _filter["type"] == "filter": - filter_triggered = True - - # We do not have to check against DM channels since !eval cannot be used there. - channel_str = f"in {msg.channel.mention}" - - message_content, additional_embeds, additional_embeds_msg = self._add_stats( - filter_name, match, result - ) - - message = ( - f"The {filter_name} {_filter['type']} was triggered " - f"by **{msg.author}** " - f"(`{msg.author.id}`) {channel_str} using !eval with " - f"[the following message]({msg.jump_url}):\n\n" - f"{message_content}" - ) - - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title=f"{_filter['type'].title()} triggered!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone, - additional_embeds=additional_embeds, - additional_embeds_msg=additional_embeds_msg - ) - - break # We don't want multiple filters to trigger - - return filter_triggered - - async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: - """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" - # Should we filter this message? - if self._check_filter(msg): - for filter_name, _filter in self.filters.items(): - # Is this specific filter enabled in the config? - if _filter["enabled"]: - # Double trigger check for the embeds filter - if filter_name == "watch_rich_embeds": - # If the edit delta is less than 0.001 seconds, then we're probably dealing - # with a double filter trigger. - if delta is not None and delta < 100: - continue - - # Does the filter only need the message content or the full message? - if _filter["content_only"]: - match = await _filter["function"](msg.content) - else: - match = await _filter["function"](msg) - - if match: - is_private = msg.channel.type is discord.ChannelType.private - - # If this is a filter (not a watchlist) and not in a DM, delete the message. - if _filter["type"] == "filter" and not is_private: - try: - # Embeds (can?) trigger both the `on_message` and `on_message_edit` - # event handlers, triggering filtering twice for the same message. - # - # If `on_message`-triggered filtering already deleted the message - # then `on_message_edit`-triggered filtering will raise exception - # since the message no longer exists. - # - # In addition, to avoid sending two notifications to the user, the - # logs, and mod_alert, we return if the message no longer exists. - await msg.delete() - except discord.errors.NotFound: - return - - # Notify the user if the filter specifies - if _filter["user_notification"]: - await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) - - # If the message is classed as offensive, we store it in the site db and - # it will be deleted it after one week. - if _filter["schedule_deletion"] and not is_private: - delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() - data = { - 'id': msg.id, - 'channel_id': msg.channel.id, - 'delete_date': delete_date - } - - await self.bot.api_client.post('bot/offensive-messages', json=data) - self.schedule_msg_delete(data) - log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") - - if is_private: - channel_str = "via DM" - else: - channel_str = f"in {msg.channel.mention}" - - message_content, additional_embeds, additional_embeds_msg = self._add_stats( - filter_name, match, msg.content - ) - - message = ( - f"The {filter_name} {_filter['type']} was triggered " - f"by **{msg.author}** " - f"(`{msg.author.id}`) {channel_str} with [the " - f"following message]({msg.jump_url}):\n\n" - f"{message_content}" - ) - - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title=f"{_filter['type'].title()} triggered!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone if not is_private else False, - additional_embeds=additional_embeds, - additional_embeds_msg=additional_embeds_msg - ) - - break # We don't want multiple filters to trigger - - def _add_stats(self, name: str, match: Union[re.Match, dict, bool, List[discord.Embed]], content: str) -> Tuple[ - str, Optional[List[discord.Embed]], Optional[str] - ]: - """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" - # Word and match stats for watch_regex - if name == "watch_regex": - surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] - message_content = ( - f"**Match:** '{match[0]}'\n" - f"**Location:** '...{escape_markdown(surroundings)}...'\n" - f"\n**Original Message:**\n{escape_markdown(content)}" - ) - else: # Use original content - message_content = content - - additional_embeds = None - additional_embeds_msg = None - - self.bot.stats.incr(f"filters.{name}") - - # The function returns True for invalid invites. - # They have no data so additional embeds can't be created for them. - if name == "filter_invites" and match is not True: - additional_embeds = [] - for _, data in match.items(): - embed = discord.Embed(description=( - f"**Members:**\n{data['members']}\n" - f"**Active:**\n{data['active']}" - )) - embed.set_author(name=data["name"]) - embed.set_thumbnail(url=data["icon"]) - embed.set_footer(text=f"Guild ID: {data['id']}") - additional_embeds.append(embed) - additional_embeds_msg = "For the following guild(s):" - - elif name == "watch_rich_embeds": - additional_embeds = match - additional_embeds_msg = "With the following embed(s):" - - return message_content, additional_embeds, additional_embeds_msg - - @staticmethod - def _check_filter(msg: Message) -> bool: - """Check whitelists to see if we should filter this message.""" - role_whitelisted = False - - if type(msg.author) is Member: # Only Member has roles, not User. - for role in msg.author.roles: - if role.id in Filter.role_whitelist: - role_whitelisted = True - - return ( - msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist - and not role_whitelisted # Role not in whitelist - and not msg.author.bot # Author not a bot - ) - - async def _has_watch_regex_match(self, text: str) -> Union[bool, re.Match]: - """ - Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. - - `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is - matched as-is. Spoilers are expanded, if any, and URLs are ignored. - """ - if SPOILER_RE.search(text): - text = self._expand_spoilers(text) - - # Make sure it's not a URL - if URL_RE.search(text): - return False - - watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) - for pattern in watchlist_patterns: - match = re.search(pattern, text, flags=re.IGNORECASE) - if match: - return match - - async def _has_urls(self, text: str) -> bool: - """Returns True if the text contains one of the blacklisted URLs from the config file.""" - if not URL_RE.search(text): - return False - - text = text.lower() - domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) - - for url in domain_blacklist: - if url.lower() in text: - return True - - return False - - @staticmethod - async def _has_zalgo(text: str) -> bool: - """ - Returns True if the text contains zalgo characters. - - Zalgo range is \u0300 – \u036F and \u0489. - """ - return bool(ZALGO_RE.search(text)) - - async def _has_invites(self, text: str) -> Union[dict, bool]: - """ - Checks if there's any invites in the text content that aren't in the guild whitelist. - - If any are detected, a dictionary of invite data is returned, with a key per invite. - If none are detected, False is returned. - - Attempts to catch some of common ways to try to cheat the system. - """ - # Remove backslashes to prevent escape character aroundfuckery like - # discord\.gg/gdudes-pony-farm - text = text.replace("\\", "") - - invites = INVITE_RE.findall(text) - invite_data = dict() - for invite in invites: - if invite in invite_data: - continue - - response = await self.bot.http_session.get( - f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} - ) - response = await response.json() - guild = response.get("guild") - if guild is None: - # Lack of a "guild" key in the JSON response indicates either an group DM invite, an - # expired invite, or an invalid invite. The API does not currently differentiate - # between invalid and expired invites - return True - - guild_id = guild.get("id") - guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) - guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) - - # Is this invite allowed? - guild_partnered_or_verified = ( - 'PARTNERED' in guild.get("features", []) - or 'VERIFIED' in guild.get("features", []) - ) - invite_not_allowed = ( - guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted. - or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted. - and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered. - ) - - if invite_not_allowed: - guild_icon_hash = guild["icon"] - guild_icon = ( - "https://cdn.discordapp.com/icons/" - f"{guild_id}/{guild_icon_hash}.png?size=512" - ) - - invite_data[invite] = { - "name": guild["name"], - "id": guild['id'], - "icon": guild_icon, - "members": response["approximate_member_count"], - "active": response["approximate_presence_count"] - } - - return invite_data if invite_data else False - - @staticmethod - async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: - """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" - if msg.embeds: - for embed in msg.embeds: - if embed.type == "rich": - urls = URL_RE.findall(msg.content) - if not embed.url or embed.url not in urls: - # If `embed.url` does not exist or if `embed.url` is not part of the content - # of the message, it's unlikely to be an auto-generated embed by Discord. - return msg.embeds - else: - log.trace( - "Found a rich embed sent by a regular user account, " - "but it was likely just an automatic URL embed." - ) - return False - return False - - async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: - """ - Notify filtered_member about a moderation action with the reason str. - - First attempts to DM the user, fall back to in-channel notification if user has DMs disabled - """ - try: - await filtered_member.send(reason) - except discord.errors.Forbidden: - await channel.send(f"{filtered_member.mention} {reason}") - - def schedule_msg_delete(self, msg: dict) -> None: - """Delete an offensive message once its deletion date is reached.""" - delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) - - async def reschedule_offensive_msg_deletion(self) -> None: - """Get all the pending message deletion from the API and reschedule them.""" - await self.bot.wait_until_ready() - response = await self.bot.api_client.get('bot/offensive-messages',) - - now = datetime.utcnow() - - for msg in response: - delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - - if delete_at < now: - await self.delete_offensive_msg(msg) - else: - self.schedule_msg_delete(msg) - - async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: - """Delete an offensive message, and then delete it from the db.""" - try: - channel = self.bot.get_channel(msg['channel_id']) - if channel: - msg_obj = await channel.fetch_message(msg['id']) - await msg_obj.delete() - except NotFound: - log.info( - f"Tried to delete message {msg['id']}, but the message can't be found " - f"(it has been probably already deleted)." - ) - except HTTPException as e: - log.warning(f"Failed to delete message {msg['id']}: status {e.status}") - - await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') - log.info(f"Deleted the offensive message with id {msg['id']}.") - - -def setup(bot: Bot) -> None: - """Load the Filtering cog.""" - bot.add_cog(Filtering(bot)) diff --git a/bot/cogs/filters/security.py b/bot/cogs/filters/security.py deleted file mode 100644 index c680c5e27..000000000 --- a/bot/cogs/filters/security.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging - -from discord.ext.commands import Cog, Context, NoPrivateMessage - -from bot.bot import Bot - -log = logging.getLogger(__name__) - - -class Security(Cog): - """Security-related helpers.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.bot.check(self.check_not_bot) # Global commands check - no bots can run any commands at all - self.bot.check(self.check_on_guild) # Global commands check - commands can't be run in a DM - - def check_not_bot(self, ctx: Context) -> bool: - """Check if the context is a bot user.""" - return not ctx.author.bot - - def check_on_guild(self, ctx: Context) -> bool: - """Check if the context is in a guild.""" - if ctx.guild is None: - raise NoPrivateMessage("This command cannot be used in private messages.") - return True - - -def setup(bot: Bot) -> None: - """Load the Security cog.""" - bot.add_cog(Security(bot)) diff --git a/bot/cogs/filters/token_remover.py b/bot/cogs/filters/token_remover.py deleted file mode 100644 index 8eace07b6..000000000 --- a/bot/cogs/filters/token_remover.py +++ /dev/null @@ -1,182 +0,0 @@ -import base64 -import binascii -import logging -import re -import typing as t - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot import utils -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Channels, Colours, Event, Icons - -log = logging.getLogger(__name__) - -LOG_MESSAGE = ( - "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, " - "token was `{user_id}.{timestamp}.{hmac}`" -) -DELETION_MESSAGE_TEMPLATE = ( - "Hey {mention}! I noticed you posted a seemingly valid Discord API " - "token in your message and have removed your message. " - "This means that your token has been **compromised**. " - "Please change your token **immediately** at: " - "\n\n" - "Feel free to re-post it with the token removed. " - "If you believe this was a mistake, please let us know!" -) -DISCORD_EPOCH = 1_420_070_400 -TOKEN_EPOCH = 1_293_840_000 - -# Three parts delimited by dots: user ID, creation timestamp, HMAC. -# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string. -# Each part only matches base64 URL-safe characters. -# Padding has never been observed, but the padding character '=' is matched just in case. -TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII) - - -class Token(t.NamedTuple): - """A Discord Bot token.""" - - user_id: str - timestamp: str - hmac: str - - -class TokenRemover(Cog): - """Scans messages for potential discord.py bot tokens and removes them.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """ - Check each message for a string that matches Discord's token pattern. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - # Ignore DMs; can't delete messages in there anyway. - if not msg.guild or msg.author.bot: - return - - found_token = self.find_token_in_message(msg) - if found_token: - await self.take_action(msg, found_token) - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """ - Check each edit for a string that matches Discord's token pattern. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - await self.on_message(after) - - async def take_action(self, msg: Message, found_token: Token) -> None: - """Remove the `msg` containing the `found_token` and send a mod log message.""" - self.mod_log.ignore(Event.message_delete, msg.id) - - try: - await msg.delete() - except NotFound: - log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") - return - - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - log_message = self.format_log_message(msg, found_token) - log.debug(log_message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=log_message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) - - self.bot.stats.incr("tokens.removed_tokens") - - @staticmethod - def format_log_message(msg: Message, token: Token) -> str: - """Return the log message to send for `token` being censored in `msg`.""" - return LOG_MESSAGE.format( - author=msg.author, - author_id=msg.author.id, - channel=msg.channel.mention, - user_id=token.user_id, - timestamp=token.timestamp, - hmac='x' * len(token.hmac), - ) - - @classmethod - def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: - """Return a seemingly valid token found in `msg` or `None` if no token is found.""" - # Use finditer rather than search to guard against method calls prematurely returning the - # token check (e.g. `message.channel.send` also matches our token pattern) - for match in TOKEN_RE.finditer(msg.content): - token = Token(*match.groups()) - if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp): - # Short-circuit on first match - return token - - # No matching substring - return - - @staticmethod - def is_valid_user_id(b64_content: str) -> bool: - """ - Check potential token to see if it contains a valid Discord user ID. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - b64_content = utils.pad_base64(b64_content) - - try: - decoded_bytes = base64.urlsafe_b64decode(b64_content) - string = decoded_bytes.decode('utf-8') - - # isdigit on its own would match a lot of other Unicode characters, hence the isascii. - return string.isascii() and string.isdigit() - except (binascii.Error, ValueError): - return False - - @staticmethod - def is_valid_timestamp(b64_content: str) -> bool: - """ - Return True if `b64_content` decodes to a valid timestamp. - - If the timestamp is greater than the Discord epoch, it's probably valid. - See: https://i.imgur.com/7WdehGn.png - """ - b64_content = utils.pad_base64(b64_content) - - try: - decoded_bytes = base64.urlsafe_b64decode(b64_content) - timestamp = int.from_bytes(decoded_bytes, byteorder="big") - except (binascii.Error, ValueError) as e: - log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") - return False - - # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound - # is not checked. - if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH: - return True - else: - log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch") - return False - - -def setup(bot: Bot) -> None: - """Load the TokenRemover cog.""" - bot.add_cog(TokenRemover(bot)) diff --git a/bot/cogs/filters/webhook_remover.py b/bot/cogs/filters/webhook_remover.py deleted file mode 100644 index 5812da87c..000000000 --- a/bot/cogs/filters/webhook_remover.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging -import re - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Channels, Colours, Event, Icons - -WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) - -ALERT_MESSAGE_TEMPLATE = ( - "{user}, looks like you posted a Discord webhook URL. Therefore, your " - "message has been removed. Your webhook may have been **compromised** so " - "please re-create the webhook **immediately**. If you believe this was " - "mistake, please let us know." -) - -log = logging.getLogger(__name__) - - -class WebhookRemover(Cog): - """Scan messages to detect Discord webhooks links.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get current instance of `ModLog`.""" - return self.bot.get_cog("ModLog") - - async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: - """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" - # Don't log this, due internal delete, not by user. Will make different entry. - self.mod_log.ignore(Event.message_delete, msg.id) - - try: - await msg.delete() - except NotFound: - log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") - return - - await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) - - message = ( - f"{msg.author} (`{msg.author.id}`) posted a Discord webhook URL " - f"to #{msg.channel}. Webhook URL was `{redacted_url}`" - ) - log.debug(message) - - # Send entry to moderation alerts. - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Discord webhook URL removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts - ) - - self.bot.stats.incr("tokens.removed_webhooks") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Check if a Discord webhook URL is in `message`.""" - # Ignore DMs; can't delete messages in there anyway. - if not msg.guild or msg.author.bot: - return - - matches = WEBHOOK_URL_RE.search(msg.content) - if matches: - await self.delete_and_respond(msg, matches[1] + "xxx") - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """Check if a Discord webhook URL is in the edited message `after`.""" - await self.on_message(after) - - -def setup(bot: Bot) -> None: - """Load `WebhookRemover` cog.""" - bot.add_cog(WebhookRemover(bot)) diff --git a/bot/cogs/help_channels.py b/bot/cogs/help_channels.py deleted file mode 100644 index 57094751e..000000000 --- a/bot/cogs/help_channels.py +++ /dev/null @@ -1,944 +0,0 @@ -import asyncio -import json -import logging -import random -import typing as t -from collections import deque -from datetime import datetime, timedelta, timezone -from pathlib import Path - -import discord -import discord.abc -from discord.ext import commands - -from bot import constants -from bot.bot import Bot -from bot.utils import RedisCache -from bot.utils.checks import with_role_check -from bot.utils.scheduling import Scheduler - -log = logging.getLogger(__name__) - -ASKING_GUIDE_URL = "https://pythondiscord.com/pages/asking-good-questions/" -MAX_CHANNELS_PER_CATEGORY = 50 -EXCLUDED_CHANNELS = (constants.Channels.how_to_get_help, constants.Channels.cooldown) - -HELP_CHANNEL_TOPIC = """ -This is a Python help channel. You can claim your own help channel in the Python Help: Available category. -""" - -AVAILABLE_MSG = f""" -This help channel is now **available**, which means that you can claim it by simply typing your \ -question into it. Once claimed, the channel will move into the **Python Help: Occupied** category, \ -and will be yours until it has been inactive for {constants.HelpChannels.idle_minutes} minutes or \ -is closed manually with `!close`. When that happens, it will be set to **dormant** and moved into \ -the **Help: Dormant** category. - -Try to write the best question you can by providing a detailed description and telling us what \ -you've tried already. For more information on asking a good question, \ -check out our guide on [asking good questions]({ASKING_GUIDE_URL}). -""" - -DORMANT_MSG = f""" -This help channel has been marked as **dormant**, and has been moved into the **Help: Dormant** \ -category at the bottom of the channel list. It is no longer possible to send messages in this \ -channel until it becomes available again. - -If your question wasn't answered yet, you can claim a new help channel from the \ -**Help: Available** category by simply asking your question again. Consider rephrasing the \ -question to maximize your chance of getting a good answer. If you're not sure how, have a look \ -through our guide for [asking a good question]({ASKING_GUIDE_URL}). -""" - -CoroutineFunc = t.Callable[..., t.Coroutine] - - -class HelpChannels(commands.Cog): - """ - Manage the help channel system of the guild. - - The system is based on a 3-category system: - - Available Category - - * Contains channels which are ready to be occupied by someone who needs help - * Will always contain `constants.HelpChannels.max_available` channels; refilled automatically - from the pool of dormant channels - * Prioritise using the channels which have been dormant for the longest amount of time - * If there are no more dormant channels, the bot will automatically create a new one - * If there are no dormant channels to move, helpers will be notified (see `notify()`) - * When a channel becomes available, the dormant embed will be edited to show `AVAILABLE_MSG` - * User can only claim a channel at an interval `constants.HelpChannels.claim_minutes` - * To keep track of cooldowns, user which claimed a channel will have a temporary role - - In Use Category - - * Contains all channels which are occupied by someone needing help - * Channel moves to dormant category after `constants.HelpChannels.idle_minutes` of being idle - * Command can prematurely mark a channel as dormant - * Channel claimant is allowed to use the command - * Allowed roles for the command are configurable with `constants.HelpChannels.cmd_whitelist` - * When a channel becomes dormant, an embed with `DORMANT_MSG` will be sent - - Dormant Category - - * Contains channels which aren't in use - * Channels are used to refill the Available category - - Help channels are named after the chemical elements in `bot/resources/elements.json`. - """ - - # This cache tracks which channels are claimed by which members. - # RedisCache[discord.TextChannel.id, t.Union[discord.User.id, discord.Member.id]] - help_channel_claimants = RedisCache() - - # This cache maps a help channel to whether it has had any - # activity other than the original claimant. True being no other - # activity and False being other activity. - # RedisCache[discord.TextChannel.id, bool] - unanswered = RedisCache() - - # This dictionary maps a help channel to the time it was claimed - # RedisCache[discord.TextChannel.id, UtcPosixTimestamp] - claim_times = RedisCache() - - # This cache maps a help channel to original question message in same channel. - # RedisCache[discord.TextChannel.id, discord.Message.id] - question_messages = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - # Categories - self.available_category: discord.CategoryChannel = None - self.in_use_category: discord.CategoryChannel = None - self.dormant_category: discord.CategoryChannel = None - - # Queues - self.channel_queue: asyncio.Queue[discord.TextChannel] = None - self.name_queue: t.Deque[str] = None - - self.name_positions = self.get_names() - self.last_notification: t.Optional[datetime] = None - - # Asyncio stuff - self.queue_tasks: t.List[asyncio.Task] = [] - self.ready = asyncio.Event() - self.on_message_lock = asyncio.Lock() - self.init_task = self.bot.loop.create_task(self.init_cog()) - - def cog_unload(self) -> None: - """Cancel the init task and scheduled tasks when the cog unloads.""" - log.trace("Cog unload: cancelling the init_cog task") - self.init_task.cancel() - - log.trace("Cog unload: cancelling the channel queue tasks") - for task in self.queue_tasks: - task.cancel() - - self.scheduler.cancel_all() - - def create_channel_queue(self) -> asyncio.Queue: - """ - Return a queue of dormant channels to use for getting the next available channel. - - The channels are added to the queue in a random order. - """ - log.trace("Creating the channel queue.") - - channels = list(self.get_category_channels(self.dormant_category)) - random.shuffle(channels) - - log.trace("Populating the channel queue with channels.") - queue = asyncio.Queue() - for channel in channels: - queue.put_nowait(channel) - - return queue - - async def create_dormant(self) -> t.Optional[discord.TextChannel]: - """ - Create and return a new channel in the Dormant category. - - The new channel will sync its permission overwrites with the category. - - Return None if no more channel names are available. - """ - log.trace("Getting a name for a new dormant channel.") - - try: - name = self.name_queue.popleft() - except IndexError: - log.debug("No more names available for new dormant channels.") - return None - - log.debug(f"Creating a new dormant channel named {name}.") - return await self.dormant_category.create_text_channel(name, topic=HELP_CHANNEL_TOPIC) - - def create_name_queue(self) -> deque: - """Return a queue of element names to use for creating new channels.""" - log.trace("Creating the chemical element name queue.") - - used_names = self.get_used_names() - - log.trace("Determining the available names.") - available_names = (name for name in self.name_positions if name not in used_names) - - log.trace("Populating the name queue with names.") - return deque(available_names) - - async def dormant_check(self, ctx: commands.Context) -> bool: - """Return True if the user is the help channel claimant or passes the role check.""" - if await self.help_channel_claimants.get(ctx.channel.id) == ctx.author.id: - log.trace(f"{ctx.author} is the help channel claimant, passing the check for dormant.") - self.bot.stats.incr("help.dormant_invoke.claimant") - return True - - log.trace(f"{ctx.author} is not the help channel claimant, checking roles.") - role_check = with_role_check(ctx, *constants.HelpChannels.cmd_whitelist) - - if role_check: - self.bot.stats.incr("help.dormant_invoke.staff") - - return role_check - - @commands.command(name="close", aliases=["dormant", "solved"], enabled=False) - async def close_command(self, ctx: commands.Context) -> None: - """ - Make the current in-use help channel dormant. - - Make the channel dormant if the user passes the `dormant_check`, - delete the message that invoked this, - and reset the send permissions cooldown for the user who started the session. - """ - log.trace("close command invoked; checking if the channel is in-use.") - if ctx.channel.category == self.in_use_category: - if await self.dormant_check(ctx): - await self.remove_cooldown_role(ctx.author) - - # Ignore missing task when cooldown has passed but the channel still isn't dormant. - if ctx.author.id in self.scheduler: - self.scheduler.cancel(ctx.author.id) - - await self.move_to_dormant(ctx.channel, "command") - self.scheduler.cancel(ctx.channel.id) - else: - log.debug(f"{ctx.author} invoked command 'dormant' outside an in-use help channel") - - async def get_available_candidate(self) -> discord.TextChannel: - """ - Return a dormant channel to turn into an available channel. - - If no channel is available, wait indefinitely until one becomes available. - """ - log.trace("Getting an available channel candidate.") - - try: - channel = self.channel_queue.get_nowait() - except asyncio.QueueEmpty: - log.info("No candidate channels in the queue; creating a new channel.") - channel = await self.create_dormant() - - if not channel: - log.info("Couldn't create a candidate channel; waiting to get one from the queue.") - await self.notify() - channel = await self.wait_for_dormant_channel() - - return channel - - @staticmethod - def get_clean_channel_name(channel: discord.TextChannel) -> str: - """Return a clean channel name without status emojis prefix.""" - prefix = constants.HelpChannels.name_prefix - try: - # Try to remove the status prefix using the index of the channel prefix - name = channel.name[channel.name.index(prefix):] - log.trace(f"The clean name for `{channel}` is `{name}`") - except ValueError: - # If, for some reason, the channel name does not contain "help-" fall back gracefully - log.info(f"Can't get clean name because `{channel}` isn't prefixed by `{prefix}`.") - name = channel.name - - return name - - @staticmethod - def is_excluded_channel(channel: discord.abc.GuildChannel) -> bool: - """Check if a channel should be excluded from the help channel system.""" - return not isinstance(channel, discord.TextChannel) or channel.id in EXCLUDED_CHANNELS - - def get_category_channels(self, category: discord.CategoryChannel) -> t.Iterable[discord.TextChannel]: - """Yield the text channels of the `category` in an unsorted manner.""" - log.trace(f"Getting text channels in the category '{category}' ({category.id}).") - - # This is faster than using category.channels because the latter sorts them. - for channel in self.bot.get_guild(constants.Guild.id).channels: - if channel.category_id == category.id and not self.is_excluded_channel(channel): - yield channel - - async def get_in_use_time(self, channel_id: int) -> t.Optional[timedelta]: - """Return the duration `channel_id` has been in use. Return None if it's not in use.""" - log.trace(f"Calculating in use time for channel {channel_id}.") - - claimed_timestamp = await self.claim_times.get(channel_id) - if claimed_timestamp: - claimed = datetime.utcfromtimestamp(claimed_timestamp) - return datetime.utcnow() - claimed - - @staticmethod - def get_names() -> t.List[str]: - """ - Return a truncated list of prefixed element names. - - The amount of names is configured with `HelpChannels.max_total_channels`. - The prefix is configured with `HelpChannels.name_prefix`. - """ - count = constants.HelpChannels.max_total_channels - prefix = constants.HelpChannels.name_prefix - - log.trace(f"Getting the first {count} element names from JSON.") - - with Path("bot/resources/elements.json").open(encoding="utf-8") as elements_file: - all_names = json.load(elements_file) - - if prefix: - return [prefix + name for name in all_names[:count]] - else: - return all_names[:count] - - def get_used_names(self) -> t.Set[str]: - """Return channel names which are already being used.""" - log.trace("Getting channel names which are already being used.") - - names = set() - for cat in (self.available_category, self.in_use_category, self.dormant_category): - for channel in self.get_category_channels(cat): - names.add(self.get_clean_channel_name(channel)) - - if len(names) > MAX_CHANNELS_PER_CATEGORY: - log.warning( - f"Too many help channels ({len(names)}) already exist! " - f"Discord only supports {MAX_CHANNELS_PER_CATEGORY} in a category." - ) - - log.trace(f"Got {len(names)} used names: {names}") - return names - - @classmethod - async def get_idle_time(cls, channel: discord.TextChannel) -> t.Optional[int]: - """ - Return the time elapsed, in seconds, since the last message sent in the `channel`. - - Return None if the channel has no messages. - """ - log.trace(f"Getting the idle time for #{channel} ({channel.id}).") - - msg = await cls.get_last_message(channel) - if not msg: - log.debug(f"No idle time available; #{channel} ({channel.id}) has no messages.") - return None - - idle_time = (datetime.utcnow() - msg.created_at).seconds - - log.trace(f"#{channel} ({channel.id}) has been idle for {idle_time} seconds.") - return idle_time - - @staticmethod - async def get_last_message(channel: discord.TextChannel) -> t.Optional[discord.Message]: - """Return the last message sent in the channel or None if no messages exist.""" - log.trace(f"Getting the last message in #{channel} ({channel.id}).") - - try: - return await channel.history(limit=1).next() # noqa: B305 - except discord.NoMoreItems: - log.debug(f"No last message available; #{channel} ({channel.id}) has no messages.") - return None - - async def init_available(self) -> None: - """Initialise the Available category with channels.""" - log.trace("Initialising the Available category with channels.") - - channels = list(self.get_category_channels(self.available_category)) - missing = constants.HelpChannels.max_available - len(channels) - - # If we've got less than `max_available` channel available, we should add some. - if missing > 0: - log.trace(f"Moving {missing} missing channels to the Available category.") - for _ in range(missing): - await self.move_to_available() - - # If for some reason we have more than `max_available` channels available, - # we should move the superfluous ones over to dormant. - elif missing < 0: - log.trace(f"Moving {abs(missing)} superfluous available channels over to the Dormant category.") - for channel in channels[:abs(missing)]: - await self.move_to_dormant(channel, "auto") - - async def init_categories(self) -> None: - """Get the help category objects. Remove the cog if retrieval fails.""" - log.trace("Getting the CategoryChannel objects for the help categories.") - - try: - self.available_category = await self.try_get_channel( - constants.Categories.help_available - ) - self.in_use_category = await self.try_get_channel(constants.Categories.help_in_use) - self.dormant_category = await self.try_get_channel(constants.Categories.help_dormant) - except discord.HTTPException: - log.exception("Failed to get a category; cog will be removed") - self.bot.remove_cog(self.qualified_name) - - async def init_cog(self) -> None: - """Initialise the help channel system.""" - log.trace("Waiting for the guild to be available before initialisation.") - await self.bot.wait_until_guild_available() - - log.trace("Initialising the cog.") - await self.init_categories() - await self.check_cooldowns() - - self.channel_queue = self.create_channel_queue() - self.name_queue = self.create_name_queue() - - log.trace("Moving or rescheduling in-use channels.") - for channel in self.get_category_channels(self.in_use_category): - await self.move_idle_channel(channel, has_task=False) - - # Prevent the command from being used until ready. - # The ready event wasn't used because channels could change categories between the time - # the command is invoked and the cog is ready (e.g. if move_idle_channel wasn't called yet). - # This may confuse users. So would potentially long delays for the cog to become ready. - self.close_command.enabled = True - - await self.init_available() - - log.info("Cog is ready!") - self.ready.set() - - self.report_stats() - - def report_stats(self) -> None: - """Report the channel count stats.""" - total_in_use = sum(1 for _ in self.get_category_channels(self.in_use_category)) - total_available = sum(1 for _ in self.get_category_channels(self.available_category)) - total_dormant = sum(1 for _ in self.get_category_channels(self.dormant_category)) - - self.bot.stats.gauge("help.total.in_use", total_in_use) - self.bot.stats.gauge("help.total.available", total_available) - self.bot.stats.gauge("help.total.dormant", total_dormant) - - @staticmethod - def is_claimant(member: discord.Member) -> bool: - """Return True if `member` has the 'Help Cooldown' role.""" - return any(constants.Roles.help_cooldown == role.id for role in member.roles) - - def match_bot_embed(self, message: t.Optional[discord.Message], description: str) -> bool: - """Return `True` if the bot's `message`'s embed description matches `description`.""" - if not message or not message.embeds: - return False - - bot_msg_desc = message.embeds[0].description - if bot_msg_desc is discord.Embed.Empty: - log.trace("Last message was a bot embed but it was empty.") - return False - return message.author == self.bot.user and bot_msg_desc.strip() == description.strip() - - @staticmethod - def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: - """Return True if `channel` is within a category with `category_id`.""" - actual_category = getattr(channel, "category", None) - return actual_category is not None and actual_category.id == category_id - - async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None: - """ - Make the `channel` dormant if idle or schedule the move if still active. - - If `has_task` is True and rescheduling is required, the extant task to make the channel - dormant will first be cancelled. - """ - log.trace(f"Handling in-use channel #{channel} ({channel.id}).") - - if not await self.is_empty(channel): - idle_seconds = constants.HelpChannels.idle_minutes * 60 - else: - idle_seconds = constants.HelpChannels.deleted_idle_minutes * 60 - - time_elapsed = await self.get_idle_time(channel) - - if time_elapsed is None or time_elapsed >= idle_seconds: - log.info( - f"#{channel} ({channel.id}) is idle longer than {idle_seconds} seconds " - f"and will be made dormant." - ) - - await self.move_to_dormant(channel, "auto") - else: - # Cancel the existing task, if any. - if has_task: - self.scheduler.cancel(channel.id) - - delay = idle_seconds - time_elapsed - log.info( - f"#{channel} ({channel.id}) is still active; " - f"scheduling it to be moved after {delay} seconds." - ) - - self.scheduler.schedule_later(delay, channel.id, self.move_idle_channel(channel)) - - async def move_to_bottom_position(self, channel: discord.TextChannel, category_id: int, **options) -> None: - """ - Move the `channel` to the bottom position of `category` and edit channel attributes. - - To ensure "stable sorting", we use the `bulk_channel_update` endpoint and provide the current - positions of the other channels in the category as-is. This should make sure that the channel - really ends up at the bottom of the category. - - If `options` are provided, the channel will be edited after the move is completed. This is the - same order of operations that `discord.TextChannel.edit` uses. For information on available - options, see the documention on `discord.TextChannel.edit`. While possible, position-related - options should be avoided, as it may interfere with the category move we perform. - """ - # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. - category = await self.try_get_channel(category_id) - - payload = [{"id": c.id, "position": c.position} for c in category.channels] - - # Calculate the bottom position based on the current highest position in the category. If the - # category is currently empty, we simply use the current position of the channel to avoid making - # unnecessary changes to positions in the guild. - bottom_position = payload[-1]["position"] + 1 if payload else channel.position - - payload.append( - { - "id": channel.id, - "position": bottom_position, - "parent_id": category.id, - "lock_permissions": True, - } - ) - - # We use d.py's method to ensure our request is processed by d.py's rate limit manager - await self.bot.http.bulk_channel_update(category.guild.id, payload) - - # Now that the channel is moved, we can edit the other attributes - if options: - await channel.edit(**options) - - async def move_to_available(self) -> None: - """Make a channel available.""" - log.trace("Making a channel available.") - - channel = await self.get_available_candidate() - log.info(f"Making #{channel} ({channel.id}) available.") - - await self.send_available_message(channel) - - log.trace(f"Moving #{channel} ({channel.id}) to the Available category.") - - await self.move_to_bottom_position( - channel=channel, - category_id=constants.Categories.help_available, - ) - - self.report_stats() - - async def move_to_dormant(self, channel: discord.TextChannel, caller: str) -> None: - """ - Make the `channel` dormant. - - A caller argument is provided for metrics. - """ - log.info(f"Moving #{channel} ({channel.id}) to the Dormant category.") - - await self.help_channel_claimants.delete(channel.id) - await self.move_to_bottom_position( - channel=channel, - category_id=constants.Categories.help_dormant, - ) - - self.bot.stats.incr(f"help.dormant_calls.{caller}") - - in_use_time = await self.get_in_use_time(channel.id) - if in_use_time: - self.bot.stats.timing("help.in_use_time", in_use_time) - - unanswered = await self.unanswered.get(channel.id) - if unanswered: - self.bot.stats.incr("help.sessions.unanswered") - elif unanswered is not None: - self.bot.stats.incr("help.sessions.answered") - - log.trace(f"Position of #{channel} ({channel.id}) is actually {channel.position}.") - log.trace(f"Sending dormant message for #{channel} ({channel.id}).") - embed = discord.Embed(description=DORMANT_MSG) - await channel.send(embed=embed) - - await self.unpin(channel) - - log.trace(f"Pushing #{channel} ({channel.id}) into the channel queue.") - self.channel_queue.put_nowait(channel) - self.report_stats() - - async def move_to_in_use(self, channel: discord.TextChannel) -> None: - """Make a channel in-use and schedule it to be made dormant.""" - log.info(f"Moving #{channel} ({channel.id}) to the In Use category.") - - await self.move_to_bottom_position( - channel=channel, - category_id=constants.Categories.help_in_use, - ) - - timeout = constants.HelpChannels.idle_minutes * 60 - - log.trace(f"Scheduling #{channel} ({channel.id}) to become dormant in {timeout} sec.") - self.scheduler.schedule_later(timeout, channel.id, self.move_idle_channel(channel)) - self.report_stats() - - async def notify(self) -> None: - """ - Send a message notifying about a lack of available help channels. - - Configuration: - - * `HelpChannels.notify` - toggle notifications - * `HelpChannels.notify_channel` - destination channel for notifications - * `HelpChannels.notify_minutes` - minimum interval between notifications - * `HelpChannels.notify_roles` - roles mentioned in notifications - """ - if not constants.HelpChannels.notify: - return - - log.trace("Notifying about lack of channels.") - - if self.last_notification: - elapsed = (datetime.utcnow() - self.last_notification).seconds - minimum_interval = constants.HelpChannels.notify_minutes * 60 - should_send = elapsed >= minimum_interval - else: - should_send = True - - if not should_send: - log.trace("Notification not sent because it's too recent since the previous one.") - return - - try: - log.trace("Sending notification message.") - - channel = self.bot.get_channel(constants.HelpChannels.notify_channel) - mentions = " ".join(f"<@&{role}>" for role in constants.HelpChannels.notify_roles) - allowed_roles = [discord.Object(id_) for id_ in constants.HelpChannels.notify_roles] - - message = await channel.send( - f"{mentions} A new available help channel is needed but there " - f"are no more dormant ones. Consider freeing up some in-use channels manually by " - f"using the `{constants.Bot.prefix}dormant` command within the channels.", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) - ) - - self.bot.stats.incr("help.out_of_channel_alerts") - - self.last_notification = message.created_at - except Exception: - # Handle it here cause this feature isn't critical for the functionality of the system. - log.exception("Failed to send notification about lack of dormant channels!") - - async def check_for_answer(self, message: discord.Message) -> None: - """Checks for whether new content in a help channel comes from non-claimants.""" - channel = message.channel - - # Confirm the channel is an in use help channel - if self.is_in_category(channel, constants.Categories.help_in_use): - log.trace(f"Checking if #{channel} ({channel.id}) has been answered.") - - # Check if there is an entry in unanswered - if await self.unanswered.contains(channel.id): - claimant_id = await self.help_channel_claimants.get(channel.id) - if not claimant_id: - # The mapping for this channel doesn't exist, we can't do anything. - return - - # Check the message did not come from the claimant - if claimant_id != message.author.id: - # Mark the channel as answered - await self.unanswered.set(channel.id, False) - - @commands.Cog.listener() - async def on_message(self, message: discord.Message) -> None: - """Move an available channel to the In Use category and replace it with a dormant one.""" - if message.author.bot: - return # Ignore messages sent by bots. - - channel = message.channel - - await self.check_for_answer(message) - - if not self.is_in_category(channel, constants.Categories.help_available) or self.is_excluded_channel(channel): - return # Ignore messages outside the Available category or in excluded channels. - - log.trace("Waiting for the cog to be ready before processing messages.") - await self.ready.wait() - - log.trace("Acquiring lock to prevent a channel from being processed twice...") - async with self.on_message_lock: - log.trace(f"on_message lock acquired for {message.id}.") - - if not self.is_in_category(channel, constants.Categories.help_available): - log.debug( - f"Message {message.id} will not make #{channel} ({channel.id}) in-use " - f"because another message in the channel already triggered that." - ) - return - - log.info(f"Channel #{channel} was claimed by `{message.author.id}`.") - await self.move_to_in_use(channel) - await self.revoke_send_permissions(message.author) - - await self.pin(message) - - # Add user with channel for dormant check. - await self.help_channel_claimants.set(channel.id, message.author.id) - - self.bot.stats.incr("help.claimed") - - # Must use a timezone-aware datetime to ensure a correct POSIX timestamp. - timestamp = datetime.now(timezone.utc).timestamp() - await self.claim_times.set(channel.id, timestamp) - - await self.unanswered.set(channel.id, True) - - log.trace(f"Releasing on_message lock for {message.id}.") - - # Move a dormant channel to the Available category to fill in the gap. - # This is done last and outside the lock because it may wait indefinitely for a channel to - # be put in the queue. - await self.move_to_available() - - @commands.Cog.listener() - async def on_message_delete(self, msg: discord.Message) -> None: - """ - Reschedule an in-use channel to become dormant sooner if the channel is empty. - - The new time for the dormant task is configured with `HelpChannels.deleted_idle_minutes`. - """ - if not self.is_in_category(msg.channel, constants.Categories.help_in_use): - return - - if not await self.is_empty(msg.channel): - return - - log.info(f"Claimant of #{msg.channel} ({msg.author}) deleted message, channel is empty now. Rescheduling task.") - - # Cancel existing dormant task before scheduling new. - self.scheduler.cancel(msg.channel.id) - - delay = constants.HelpChannels.deleted_idle_minutes * 60 - self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel)) - - async def is_empty(self, channel: discord.TextChannel) -> bool: - """Return True if there's an AVAILABLE_MSG and the messages leading up are bot messages.""" - log.trace(f"Checking if #{channel} ({channel.id}) is empty.") - - # A limit of 100 results in a single API call. - # If AVAILABLE_MSG isn't found within 100 messages, then assume the channel is not empty. - # Not gonna do an extensive search for it cause it's too expensive. - async for msg in channel.history(limit=100): - if not msg.author.bot: - log.trace(f"#{channel} ({channel.id}) has a non-bot message.") - return False - - if self.match_bot_embed(msg, AVAILABLE_MSG): - log.trace(f"#{channel} ({channel.id}) has the available message embed.") - return True - - return False - - async def check_cooldowns(self) -> None: - """Remove expired cooldowns and re-schedule active ones.""" - log.trace("Checking all cooldowns to remove or re-schedule them.") - guild = self.bot.get_guild(constants.Guild.id) - cooldown = constants.HelpChannels.claim_minutes * 60 - - for channel_id, member_id in await self.help_channel_claimants.items(): - member = guild.get_member(member_id) - if not member: - continue # Member probably left the guild. - - in_use_time = await self.get_in_use_time(channel_id) - - if not in_use_time or in_use_time.seconds > cooldown: - # Remove the role if no claim time could be retrieved or if the cooldown expired. - # Since the channel is in the claimants cache, it is definitely strange for a time - # to not exist. However, it isn't a reason to keep the user stuck with a cooldown. - await self.remove_cooldown_role(member) - else: - # The member is still on a cooldown; re-schedule it for the remaining time. - delay = cooldown - in_use_time.seconds - self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) - - async def add_cooldown_role(self, member: discord.Member) -> None: - """Add the help cooldown role to `member`.""" - log.trace(f"Adding cooldown role for {member} ({member.id}).") - await self._change_cooldown_role(member, member.add_roles) - - async def remove_cooldown_role(self, member: discord.Member) -> None: - """Remove the help cooldown role from `member`.""" - log.trace(f"Removing cooldown role for {member} ({member.id}).") - await self._change_cooldown_role(member, member.remove_roles) - - async def _change_cooldown_role(self, member: discord.Member, coro_func: CoroutineFunc) -> None: - """ - Change `member`'s cooldown role via awaiting `coro_func` and handle errors. - - `coro_func` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. - """ - guild = self.bot.get_guild(constants.Guild.id) - role = guild.get_role(constants.Roles.help_cooldown) - if role is None: - log.warning(f"Help cooldown role ({constants.Roles.help_cooldown}) could not be found!") - return - - try: - await coro_func(role) - except discord.NotFound: - log.debug(f"Failed to change role for {member} ({member.id}): member not found") - except discord.Forbidden: - log.debug( - f"Forbidden to change role for {member} ({member.id}); " - f"possibly due to role hierarchy" - ) - except discord.HTTPException as e: - log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") - - async def revoke_send_permissions(self, member: discord.Member) -> None: - """ - Disallow `member` to send messages in the Available category for a certain time. - - The time until permissions are reinstated can be configured with - `HelpChannels.claim_minutes`. - """ - log.trace( - f"Revoking {member}'s ({member.id}) send message permissions in the Available category." - ) - - await self.add_cooldown_role(member) - - # Cancel the existing task, if any. - # Would mean the user somehow bypassed the lack of permissions (e.g. user is guild owner). - if member.id in self.scheduler: - self.scheduler.cancel(member.id) - - delay = constants.HelpChannels.claim_minutes * 60 - self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) - - async def send_available_message(self, channel: discord.TextChannel) -> None: - """Send the available message by editing a dormant message or sending a new message.""" - channel_info = f"#{channel} ({channel.id})" - log.trace(f"Sending available message in {channel_info}.") - - embed = discord.Embed(description=AVAILABLE_MSG) - - msg = await self.get_last_message(channel) - if self.match_bot_embed(msg, DORMANT_MSG): - log.trace(f"Found dormant message {msg.id} in {channel_info}; editing it.") - await msg.edit(embed=embed) - else: - log.trace(f"Dormant message not found in {channel_info}; sending a new message.") - await channel.send(embed=embed) - - async def try_get_channel(self, channel_id: int) -> discord.abc.GuildChannel: - """Attempt to get or fetch a channel and return it.""" - log.trace(f"Getting the channel {channel_id}.") - - channel = self.bot.get_channel(channel_id) - if not channel: - log.debug(f"Channel {channel_id} is not in cache; fetching from API.") - channel = await self.bot.fetch_channel(channel_id) - - log.trace(f"Channel #{channel} ({channel_id}) retrieved.") - return channel - - async def pin_wrapper(self, msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool: - """ - Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. - - Return True if successful and False otherwise. - """ - channel_str = f"#{channel} ({channel.id})" - if pin: - func = self.bot.http.pin_message - verb = "pin" - else: - func = self.bot.http.unpin_message - verb = "unpin" - - try: - await func(channel.id, msg_id) - except discord.HTTPException as e: - if e.code == 10008: - log.debug(f"Message {msg_id} in {channel_str} doesn't exist; can't {verb}.") - else: - log.exception( - f"Error {verb}ning message {msg_id} in {channel_str}: {e.status} ({e.code})" - ) - return False - else: - log.trace(f"{verb.capitalize()}ned message {msg_id} in {channel_str}.") - return True - - async def pin(self, message: discord.Message) -> None: - """Pin an initial question `message` and store it in a cache.""" - if await self.pin_wrapper(message.id, message.channel, pin=True): - await self.question_messages.set(message.channel.id, message.id) - - async def unpin(self, channel: discord.TextChannel) -> None: - """Unpin the initial question message sent in `channel`.""" - msg_id = await self.question_messages.pop(channel.id) - if msg_id is None: - log.debug(f"#{channel} ({channel.id}) doesn't have a message pinned.") - else: - await self.pin_wrapper(msg_id, channel, pin=False) - - async def wait_for_dormant_channel(self) -> discord.TextChannel: - """Wait for a dormant channel to become available in the queue and return it.""" - log.trace("Waiting for a dormant channel.") - - task = asyncio.create_task(self.channel_queue.get()) - self.queue_tasks.append(task) - channel = await task - - log.trace(f"Channel #{channel} ({channel.id}) finally retrieved from the queue.") - self.queue_tasks.remove(task) - - return channel - - -def validate_config() -> None: - """Raise a ValueError if the cog's config is invalid.""" - log.trace("Validating config.") - total = constants.HelpChannels.max_total_channels - available = constants.HelpChannels.max_available - - if total == 0 or available == 0: - raise ValueError("max_total_channels and max_available and must be greater than 0.") - - if total < available: - raise ValueError( - f"max_total_channels ({total}) must be greater than or equal to max_available " - f"({available})." - ) - - if total > MAX_CHANNELS_PER_CATEGORY: - raise ValueError( - f"max_total_channels ({total}) must be less than or equal to " - f"{MAX_CHANNELS_PER_CATEGORY} due to Discord's limit on channels per category." - ) - - -def setup(bot: Bot) -> None: - """Load the HelpChannels cog.""" - try: - validate_config() - except ValueError as e: - log.error(f"HelpChannels cog will not be loaded due to misconfiguration: {e}") - else: - bot.add_cog(HelpChannels(bot)) diff --git a/bot/cogs/info/__init__.py b/bot/cogs/info/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/info/doc.py b/bot/cogs/info/doc.py deleted file mode 100644 index 204cffb37..000000000 --- a/bot/cogs/info/doc.py +++ /dev/null @@ -1,511 +0,0 @@ -import asyncio -import functools -import logging -import re -import textwrap -from collections import OrderedDict -from contextlib import suppress -from types import SimpleNamespace -from typing import Any, Callable, Optional, Tuple - -import discord -from bs4 import BeautifulSoup -from bs4.element import PageElement, Tag -from discord.errors import NotFound -from discord.ext import commands -from markdownify import MarkdownConverter -from requests import ConnectTimeout, ConnectionError, HTTPError -from sphinx.ext import intersphinx -from urllib3.exceptions import ProtocolError - -from bot.bot import Bot -from bot.constants import MODERATION_ROLES, RedirectOutput -from bot.converters import ValidPythonIdentifier, ValidURL -from bot.decorators import with_role -from bot.pagination import LinePaginator - - -log = logging.getLogger(__name__) -logging.getLogger('urllib3').setLevel(logging.WARNING) - -# Since Intersphinx is intended to be used with Sphinx, -# we need to mock its configuration. -SPHINX_MOCK_APP = SimpleNamespace( - config=SimpleNamespace( - intersphinx_timeout=3, - tls_verify=True, - user_agent="python3:python-discord/bot:1.0.0" - ) -) - -NO_OVERRIDE_GROUPS = ( - "2to3fixer", - "token", - "label", - "pdbcommand", - "term", -) -NO_OVERRIDE_PACKAGES = ( - "python", -) - -SEARCH_END_TAG_ATTRS = ( - "data", - "function", - "class", - "exception", - "seealso", - "section", - "rubric", - "sphinxsidebar", -) -UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") -WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") - -FAILED_REQUEST_RETRY_AMOUNT = 3 -NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay - - -def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: - """ - LRU cache implementation for coroutines. - - Once the cache exceeds the maximum size, keys are deleted in FIFO order. - - An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. - """ - # Assign the cache to the function itself so we can clear it from outside. - async_cache.cache = OrderedDict() - - def decorator(function: Callable) -> Callable: - """Define the async_cache decorator.""" - @functools.wraps(function) - async def wrapper(*args) -> Any: - """Decorator wrapper for the caching logic.""" - key = ':'.join(args[arg_offset:]) - - value = async_cache.cache.get(key) - if value is None: - if len(async_cache.cache) > max_size: - async_cache.cache.popitem(last=False) - - async_cache.cache[key] = await function(*args) - return async_cache.cache[key] - return wrapper - return decorator - - -class DocMarkdownConverter(MarkdownConverter): - """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" - - def convert_code(self, el: PageElement, text: str) -> str: - """Undo `markdownify`s underscore escaping.""" - return f"`{text}`".replace('\\', '') - - def convert_pre(self, el: PageElement, text: str) -> str: - """Wrap any codeblocks in `py` for syntax highlighting.""" - code = ''.join(el.strings) - return f"```py\n{code}```" - - -def markdownify(html: str) -> DocMarkdownConverter: - """Create a DocMarkdownConverter object from the input html.""" - return DocMarkdownConverter(bullets='•').convert(html) - - -class InventoryURL(commands.Converter): - """ - Represents an Intersphinx inventory URL. - - This converter checks whether intersphinx accepts the given inventory URL, and raises - `BadArgument` if that is not the case. - - Otherwise, it simply passes through the given URL. - """ - - @staticmethod - async def convert(ctx: commands.Context, url: str) -> str: - """Convert url to Intersphinx inventory URL.""" - try: - intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) - except AttributeError: - raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") - except ConnectionError: - if url.startswith('https'): - raise commands.BadArgument( - f"Cannot establish a connection to `{url}`. Does it support HTTPS?" - ) - raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") - except ValueError: - raise commands.BadArgument( - f"Failed to read Intersphinx inventory from URL `{url}`. " - "Are you sure that it's a valid inventory file?" - ) - return url - - -class Doc(commands.Cog): - """A set of commands for querying & displaying documentation.""" - - def __init__(self, bot: Bot): - self.base_urls = {} - self.bot = bot - self.inventories = {} - self.renamed_symbols = set() - - self.bot.loop.create_task(self.init_refresh_inventory()) - - async def init_refresh_inventory(self) -> None: - """Refresh documentation inventory on cog initialization.""" - await self.bot.wait_until_guild_available() - await self.refresh_inventory() - - async def update_single( - self, package_name: str, base_url: str, inventory_url: str - ) -> None: - """ - Rebuild the inventory for a single package. - - Where: - * `package_name` is the package name to use, appears in the log - * `base_url` is the root documentation URL for the specified package, used to build - absolute paths that link to specific symbols - * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running - `intersphinx.fetch_inventory` in an executor on the bot's event loop - """ - self.base_urls[package_name] = base_url - - package = await self._fetch_inventory(inventory_url) - if not package: - return None - - for group, value in package.items(): - for symbol, (package_name, _version, relative_doc_url, _) in value.items(): - absolute_doc_url = base_url + relative_doc_url - - if symbol in self.inventories: - group_name = group.split(":")[1] - symbol_base_url = self.inventories[symbol].split("/", 3)[2] - if ( - group_name in NO_OVERRIDE_GROUPS - or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) - ): - - symbol = f"{group_name}.{symbol}" - # If renamed `symbol` already exists, add library name in front to differentiate between them. - if symbol in self.renamed_symbols: - # Split `package_name` because of packages like Pillow that have spaces in them. - symbol = f"{package_name.split()[0]}.{symbol}" - - self.inventories[symbol] = absolute_doc_url - self.renamed_symbols.add(symbol) - continue - - self.inventories[symbol] = absolute_doc_url - - log.trace(f"Fetched inventory for {package_name}.") - - async def refresh_inventory(self) -> None: - """Refresh internal documentation inventory.""" - log.debug("Refreshing documentation inventory...") - - # Clear the old base URLS and inventories to ensure - # that we start from a fresh local dataset. - # Also, reset the cache used for fetching documentation. - self.base_urls.clear() - self.inventories.clear() - self.renamed_symbols.clear() - async_cache.cache = OrderedDict() - - # Run all coroutines concurrently - since each of them performs a HTTP - # request, this speeds up fetching the inventory data heavily. - coros = [ - self.update_single( - package["package"], package["base_url"], package["inventory_url"] - ) for package in await self.bot.api_client.get('bot/documentation-links') - ] - await asyncio.gather(*coros) - - async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: - """ - Given a Python symbol, return its signature and description. - - The first tuple element is the signature of the given symbol as a markup-free string, and - the second tuple element is the description of the given symbol with HTML markup included. - - If the given symbol is a module, returns a tuple `(None, str)` - else if the symbol could not be found, returns `None`. - """ - url = self.inventories.get(symbol) - if url is None: - return None - - async with self.bot.http_session.get(url) as response: - html = await response.text(encoding='utf-8') - - # Find the signature header and parse the relevant parts. - symbol_id = url.split('#')[-1] - soup = BeautifulSoup(html, 'lxml') - symbol_heading = soup.find(id=symbol_id) - search_html = str(soup) - - if symbol_heading is None: - return None - - if symbol_id == f"module-{symbol}": - # Get page content from the module headerlink to the - # first tag that has its class in `SEARCH_END_TAG_ATTRS` - start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) - if start_tag is None: - return [], "" - - end_tag = start_tag.find_next(self._match_end_tag) - if end_tag is None: - return [], "" - - description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) - description_end_index = search_html.find(str(end_tag)) - description = search_html[description_start_index:description_end_index] - signatures = None - - else: - signatures = [] - description = str(symbol_heading.find_next_sibling("dd")) - description_pos = search_html.find(description) - # Get text of up to 3 signatures, remove unwanted symbols - for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): - signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) - if signature and search_html.find(str(element)) < description_pos: - signatures.append(signature) - - return signatures, description.replace('¶', '') - - @async_cache(arg_offset=1) - async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: - """ - Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. - - If the symbol is known, an Embed with documentation about it is returned. - """ - scraped_html = await self.get_symbol_html(symbol) - if scraped_html is None: - return None - - signatures = scraped_html[0] - permalink = self.inventories[symbol] - description = markdownify(scraped_html[1]) - - # Truncate the description of the embed to the last occurrence - # of a double newline (interpreted as a paragraph) before index 1000. - if len(description) > 1000: - shortened = description[:1000] - description_cutoff = shortened.rfind('\n\n', 100) - if description_cutoff == -1: - # Search the shortened version for cutoff points in decreasing desirability, - # cutoff at 1000 if none are found. - for string in (". ", ", ", ",", " "): - description_cutoff = shortened.rfind(string) - if description_cutoff != -1: - break - else: - description_cutoff = 1000 - description = description[:description_cutoff] - - # If there is an incomplete code block, cut it out - if description.count("```") % 2: - codeblock_start = description.rfind('```py') - description = description[:codeblock_start].rstrip() - description += f"... [read more]({permalink})" - - description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) - if signatures is None: - # If symbol is a module, don't show signature. - embed_description = description - - elif not signatures: - # It's some "meta-page", for example: - # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views - embed_description = "This appears to be a generic page not tied to a specific symbol." - - else: - embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) - embed_description += f"\n{description}" - - embed = discord.Embed( - title=f'`{symbol}`', - url=permalink, - description=embed_description - ) - # Show all symbols with the same name that were renamed in the footer. - embed.set_footer( - text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) - ) - return embed - - @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) - async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: - """Lookup documentation for Python symbols.""" - await ctx.invoke(self.get_command, symbol) - - @docs_group.command(name='get', aliases=('g',)) - async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: - """ - Return a documentation embed for a given symbol. - - If no symbol is given, return a list of all available inventories. - - Examples: - !docs - !docs aiohttp - !docs aiohttp.ClientSession - !docs get aiohttp.ClientSession - """ - if symbol is None: - inventory_embed = discord.Embed( - title=f"All inventories (`{len(self.base_urls)}` total)", - colour=discord.Colour.blue() - ) - - lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) - if self.base_urls: - await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) - - else: - inventory_embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=inventory_embed) - - else: - # Fetching documentation for a symbol (at least for the first time, since - # caching is used) takes quite some time, so let's send typing to indicate - # that we got the command, but are still working on it. - async with ctx.typing(): - doc_embed = await self.get_symbol_embed(symbol) - - if doc_embed is None: - error_embed = discord.Embed( - description=f"Sorry, I could not find any documentation for `{symbol}`.", - colour=discord.Colour.red() - ) - error_message = await ctx.send(embed=error_embed) - with suppress(NotFound): - await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) - await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) - else: - await ctx.send(embed=doc_embed) - - @docs_group.command(name='set', aliases=('s',)) - @with_role(*MODERATION_ROLES) - async def set_command( - self, ctx: commands.Context, package_name: ValidPythonIdentifier, - base_url: ValidURL, inventory_url: InventoryURL - ) -> None: - """ - Adds a new documentation metadata object to the site's database. - - The database will update the object, should an existing item with the specified `package_name` already exist. - - Example: - !docs set \ - python \ - https://docs.python.org/3/ \ - https://docs.python.org/3/objects.inv - """ - body = { - 'package': package_name, - 'base_url': base_url, - 'inventory_url': inventory_url - } - await self.bot.api_client.post('bot/documentation-links', json=body) - - log.info( - f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" - f"Package name: {package_name}\n" - f"Base url: {base_url}\n" - f"Inventory URL: {inventory_url}" - ) - - # Rebuilding the inventory can take some time, so lets send out a - # typing event to show that the Bot is still working. - async with ctx.typing(): - await self.refresh_inventory() - await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") - - @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) - @with_role(*MODERATION_ROLES) - async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: - """ - Removes the specified package from the database. - - Examples: - !docs delete aiohttp - """ - await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') - - async with ctx.typing(): - # Rebuild the inventory to ensure that everything - # that was from this package is properly deleted. - await self.refresh_inventory() - await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") - - @docs_group.command(name="refresh", aliases=("rfsh", "r")) - @with_role(*MODERATION_ROLES) - async def refresh_command(self, ctx: commands.Context) -> None: - """Refresh inventories and send differences to channel.""" - old_inventories = set(self.base_urls) - with ctx.typing(): - await self.refresh_inventory() - # Get differences of added and removed inventories - added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) - if added: - added = f"+ {added}" - - removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) - if removed: - removed = f"- {removed}" - - embed = discord.Embed( - title="Inventories refreshed", - description=f"```diff\n{added}\n{removed}```" if added or removed else "" - ) - await ctx.send(embed=embed) - - async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: - """Get and return inventory from `inventory_url`. If fetching fails, return None.""" - fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) - for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): - try: - package = await self.bot.loop.run_in_executor(None, fetch_func) - except ConnectTimeout: - log.error( - f"Fetching of inventory {inventory_url} timed out," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" - ) - except ProtocolError: - log.error( - f"Connection lost while fetching inventory {inventory_url}," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" - ) - except HTTPError as e: - log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") - return None - except ConnectionError: - log.error(f"Couldn't establish connection to inventory {inventory_url}.") - return None - else: - return package - log.error(f"Fetching of inventory {inventory_url} failed.") - return None - - @staticmethod - def _match_end_tag(tag: Tag) -> bool: - """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" - for attr in SEARCH_END_TAG_ATTRS: - if attr in tag.get("class", ()): - return True - - return tag.name == "table" - - -def setup(bot: Bot) -> None: - """Load the Doc cog.""" - bot.add_cog(Doc(bot)) diff --git a/bot/cogs/info/help.py b/bot/cogs/info/help.py deleted file mode 100644 index 3d1d6fd10..000000000 --- a/bot/cogs/info/help.py +++ /dev/null @@ -1,375 +0,0 @@ -import itertools -import logging -from asyncio import TimeoutError -from collections import namedtuple -from contextlib import suppress -from typing import List, Union - -from discord import Colour, Embed, Member, Message, NotFound, Reaction, User -from discord.ext.commands import Bot, Cog, Command, Context, Group, HelpCommand -from fuzzywuzzy import fuzz, process -from fuzzywuzzy.utils import full_process - -from bot import constants -from bot.constants import Channels, Emojis, STAFF_ROLES -from bot.decorators import redirect_output -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -COMMANDS_PER_PAGE = 8 -DELETE_EMOJI = Emojis.trashcan -PREFIX = constants.Bot.prefix - -Category = namedtuple("Category", ["name", "description", "cogs"]) - - -async def help_cleanup(bot: Bot, author: Member, message: Message) -> None: - """ - Runs the cleanup for the help command. - - Adds the :trashcan: reaction that, when clicked, will delete the help message. - After a 300 second timeout, the reaction will be removed. - """ - def check(reaction: Reaction, user: User) -> bool: - """Checks the reaction is :trashcan:, the author is original author and messages are the same.""" - return str(reaction) == DELETE_EMOJI and user.id == author.id and reaction.message.id == message.id - - await message.add_reaction(DELETE_EMOJI) - - with suppress(NotFound): - try: - await bot.wait_for("reaction_add", check=check, timeout=300) - await message.delete() - except TimeoutError: - await message.remove_reaction(DELETE_EMOJI, bot.user) - - -class HelpQueryNotFound(ValueError): - """ - Raised when a HelpSession Query doesn't match a command or cog. - - Contains the custom attribute of ``possible_matches``. - - Instances of this object contain a dictionary of any command(s) that were close to matching the - query, where keys are the possible matched command names and values are the likeness match scores. - """ - - def __init__(self, arg: str, possible_matches: dict = None): - super().__init__(arg) - self.possible_matches = possible_matches - - -class CustomHelpCommand(HelpCommand): - """ - An interactive instance for the bot help command. - - Cogs can be grouped into custom categories. All cogs with the same category will be displayed - under a single category name in the help output. Custom categories are defined inside the cogs - as a class attribute named `category`. A description can also be specified with the attribute - `category_description`. If a description is not found in at least one cog, the default will be - the regular description (class docstring) of the first cog found in the category. - """ - - def __init__(self): - super().__init__(command_attrs={"help": "Shows help for bot commands"}) - - @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) - async def command_callback(self, ctx: Context, *, command: str = None) -> None: - """Attempts to match the provided query with a valid command or cog.""" - # the only reason we need to tamper with this is because d.py does not support "categories", - # so we need to deal with them ourselves. - - bot = ctx.bot - - if command is None: - # quick and easy, send bot help if command is none - mapping = self.get_bot_mapping() - await self.send_bot_help(mapping) - return - - cog_matches = [] - description = None - for cog in bot.cogs.values(): - if hasattr(cog, "category") and cog.category == command: - cog_matches.append(cog) - if hasattr(cog, "category_description"): - description = cog.category_description - - if cog_matches: - category = Category(name=command, description=description, cogs=cog_matches) - await self.send_category_help(category) - return - - # it's either a cog, group, command or subcommand; let the parent class deal with it - await super().command_callback(ctx, command=command) - - async def get_all_help_choices(self) -> set: - """ - Get all the possible options for getting help in the bot. - - This will only display commands the author has permission to run. - - These include: - - Category names - - Cog names - - Group command names (and aliases) - - Command names (and aliases) - - Subcommand names (with parent group and aliases for subcommand, but not including aliases for group) - - Options and choices are case sensitive. - """ - # first get all commands including subcommands and full command name aliases - choices = set() - for command in await self.filter_commands(self.context.bot.walk_commands()): - # the the command or group name - choices.add(str(command)) - - if isinstance(command, Command): - # all aliases if it's just a command - choices.update(command.aliases) - else: - # otherwise we need to add the parent name in - choices.update(f"{command.full_parent_name} {alias}" for alias in command.aliases) - - # all cog names - choices.update(self.context.bot.cogs) - - # all category names - choices.update(cog.category for cog in self.context.bot.cogs.values() if hasattr(cog, "category")) - return choices - - async def command_not_found(self, string: str) -> "HelpQueryNotFound": - """ - Handles when a query does not match a valid command, group, cog or category. - - Will return an instance of the `HelpQueryNotFound` exception with the error message and possible matches. - """ - choices = await self.get_all_help_choices() - - # Run fuzzywuzzy's processor beforehand, and avoid matching if processed string is empty - # This avoids fuzzywuzzy from raising a warning on inputs with only non-alphanumeric characters - if (processed := full_process(string)): - result = process.extractBests(processed, choices, scorer=fuzz.ratio, score_cutoff=60, processor=None) - else: - result = [] - - return HelpQueryNotFound(f'Query "{string}" not found.', dict(result)) - - async def subcommand_not_found(self, command: Command, string: str) -> "HelpQueryNotFound": - """ - Redirects the error to `command_not_found`. - - `command_not_found` deals with searching and getting best choices for both commands and subcommands. - """ - return await self.command_not_found(f"{command.qualified_name} {string}") - - async def send_error_message(self, error: HelpQueryNotFound) -> None: - """Send the error message to the channel.""" - embed = Embed(colour=Colour.red(), title=str(error)) - - if getattr(error, "possible_matches", None): - matches = "\n".join(f"`{match}`" for match in error.possible_matches) - embed.description = f"**Did you mean:**\n{matches}" - - await self.context.send(embed=embed) - - async def command_formatting(self, command: Command) -> Embed: - """ - Takes a command and turns it into an embed. - - It will add an author, command signature + help, aliases and a note if the user can't run the command. - """ - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - parent = command.full_parent_name - - name = str(command) if not parent else f"{parent} {command.name}" - command_details = f"**```{PREFIX}{name} {command.signature}```**\n" - - # show command aliases - aliases = ", ".join(f"`{alias}`" if not parent else f"`{parent} {alias}`" for alias in command.aliases) - if aliases: - command_details += f"**Can also use:** {aliases}\n\n" - - # check if the user is allowed to run this command - if not await command.can_run(self.context): - command_details += "***You cannot run this command.***\n\n" - - command_details += f"*{command.help or 'No details provided.'}*\n" - embed.description = command_details - - return embed - - async def send_command_help(self, command: Command) -> None: - """Send help for a single command.""" - embed = await self.command_formatting(command) - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - @staticmethod - def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: - """ - Formats the prefix, command name and signature, and short doc for an iterable of commands. - - return_as_list is helpful for passing these command details into the paginator as a list of command details. - """ - details = [] - for command in commands_: - signature = f" {command.signature}" if command.signature else "" - details.append( - f"\n**`{PREFIX}{command.qualified_name}{signature}`**\n*{command.short_doc or 'No details provided'}*" - ) - if return_as_list: - return details - else: - return "".join(details) - - async def send_group_help(self, group: Group) -> None: - """Sends help for a group command.""" - subcommands = group.commands - - if len(subcommands) == 0: - # no subcommands, just treat it like a regular command - await self.send_command_help(group) - return - - # remove commands that the user can't run and are hidden, and sort by name - commands_ = await self.filter_commands(subcommands, sort=True) - - embed = await self.command_formatting(group) - - command_details = self.get_commands_brief_details(commands_) - if command_details: - embed.description += f"\n**Subcommands:**\n{command_details}" - - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - async def send_cog_help(self, cog: Cog) -> None: - """Send help for a cog.""" - # sort commands by name, and remove any the user cant run or are hidden. - commands_ = await self.filter_commands(cog.get_commands(), sort=True) - - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - embed.description = f"**{cog.qualified_name}**\n*{cog.description}*" - - command_details = self.get_commands_brief_details(commands_) - if command_details: - embed.description += f"\n\n**Commands:**\n{command_details}" - - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - @staticmethod - def _category_key(command: Command) -> str: - """ - Returns a cog name of a given command for use as a key for `sorted` and `groupby`. - - A zero width space is used as a prefix for results with no cogs to force them last in ordering. - """ - if command.cog: - with suppress(AttributeError): - if command.cog.category: - return f"**{command.cog.category}**" - return f"**{command.cog_name}**" - else: - return "**\u200bNo Category:**" - - async def send_category_help(self, category: Category) -> None: - """ - Sends help for a bot category. - - This sends a brief help for all commands in all cogs registered to the category. - """ - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - all_commands = [] - for cog in category.cogs: - all_commands.extend(cog.get_commands()) - - filtered_commands = await self.filter_commands(all_commands, sort=True) - - command_detail_lines = self.get_commands_brief_details(filtered_commands, return_as_list=True) - description = f"**{category.name}**\n*{category.description}*" - - if command_detail_lines: - description += "\n\n**Commands:**" - - await LinePaginator.paginate( - command_detail_lines, - self.context, - embed, - prefix=description, - max_lines=COMMANDS_PER_PAGE, - max_size=2000, - ) - - async def send_bot_help(self, mapping: dict) -> None: - """Sends help for all bot commands and cogs.""" - bot = self.context.bot - - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - filter_commands = await self.filter_commands(bot.commands, sort=True, key=self._category_key) - - cog_or_category_pages = [] - - for cog_or_category, _commands in itertools.groupby(filter_commands, key=self._category_key): - sorted_commands = sorted(_commands, key=lambda c: c.name) - - if len(sorted_commands) == 0: - continue - - command_detail_lines = self.get_commands_brief_details(sorted_commands, return_as_list=True) - - # Split cogs or categories which have too many commands to fit in one page. - # The length of commands is included for later use when aggregating into pages for the paginator. - for index in range(0, len(sorted_commands), COMMANDS_PER_PAGE): - truncated_lines = command_detail_lines[index:index + COMMANDS_PER_PAGE] - joined_lines = "".join(truncated_lines) - cog_or_category_pages.append((f"**{cog_or_category}**{joined_lines}", len(truncated_lines))) - - pages = [] - counter = 0 - page = "" - for page_details, length in cog_or_category_pages: - counter += length - if counter > COMMANDS_PER_PAGE: - # force a new page on paginator even if it falls short of the max pages - # since we still want to group categories/cogs. - counter = length - pages.append(page) - page = f"{page_details}\n\n" - else: - page += f"{page_details}\n\n" - - if page: - # add any remaining command help that didn't get added in the last iteration above. - pages.append(page) - - await LinePaginator.paginate(pages, self.context, embed=embed, max_lines=1, max_size=2000) - - -class Help(Cog): - """Custom Embed Pagination Help feature.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.old_help_command = bot.help_command - bot.help_command = CustomHelpCommand() - bot.help_command.cog = self - - def cog_unload(self) -> None: - """Reset the help command when the cog is unloaded.""" - self.bot.help_command = self.old_help_command - - -def setup(bot: Bot) -> None: - """Load the Help cog.""" - bot.add_cog(Help(bot)) - log.info("Cog loaded: Help") diff --git a/bot/cogs/info/information.py b/bot/cogs/info/information.py deleted file mode 100644 index 8982196d1..000000000 --- a/bot/cogs/info/information.py +++ /dev/null @@ -1,422 +0,0 @@ -import colorsys -import logging -import pprint -import textwrap -from collections import Counter, defaultdict -from string import Template -from typing import Any, Mapping, Optional, Union - -from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils -from discord.abc import GuildChannel -from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group -from discord.utils import escape_markdown - -from bot import constants -from bot.bot import Bot -from bot.decorators import in_whitelist, with_role -from bot.pagination import LinePaginator -from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - - -class Information(Cog): - """A cog with commands for generating embeds with server info, such as server stats and user info.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @staticmethod - def role_can_read(channel: GuildChannel, role: Role) -> bool: - """Return True if `role` can read messages in `channel`.""" - overwrites = channel.overwrites_for(role) - return overwrites.read_messages is True - - def get_staff_channel_count(self, guild: Guild) -> int: - """ - Get the number of channels that are staff-only. - - We need to know two things about a channel: - - Does the @everyone role have explicit read deny permissions? - - Do staff roles have explicit read allow permissions? - - If the answer to both of these questions is yes, it's a staff channel. - """ - channel_ids = set() - for channel in guild.channels: - if channel.type is ChannelType.category: - continue - - everyone_can_read = self.role_can_read(channel, guild.default_role) - - for role in constants.STAFF_ROLES: - role_can_read = self.role_can_read(channel, guild.get_role(role)) - if role_can_read and not everyone_can_read: - channel_ids.add(channel.id) - break - - return len(channel_ids) - - @staticmethod - def get_channel_type_counts(guild: Guild) -> str: - """Return the total amounts of the various types of channels in `guild`.""" - channel_counter = Counter(c.type for c in guild.channels) - channel_type_list = [] - for channel, count in channel_counter.items(): - channel_type = str(channel).title() - channel_type_list.append(f"{channel_type} channels: {count}") - - channel_type_list = sorted(channel_type_list) - return "\n".join(channel_type_list) - - @with_role(*constants.MODERATION_ROLES) - @command(name="roles") - async def roles_info(self, ctx: Context) -> None: - """Returns a list of all roles and their corresponding IDs.""" - # Sort the roles alphabetically and remove the @everyone role - roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) - - # Build a list - role_list = [] - for role in roles: - role_list.append(f"`{role.id}` - {role.mention}") - - # Build an embed - embed = Embed( - title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", - colour=Colour.blurple() - ) - - await LinePaginator.paginate(role_list, ctx, embed, empty=False) - - @with_role(*constants.MODERATION_ROLES) - @command(name="role") - async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: - """ - Return information on a role or list of roles. - - To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. - """ - parsed_roles = [] - failed_roles = [] - - for role_name in roles: - if isinstance(role_name, Role): - # Role conversion has already succeeded - parsed_roles.append(role_name) - continue - - role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) - - if not role: - failed_roles.append(role_name) - continue - - parsed_roles.append(role) - - if failed_roles: - await ctx.send(f":x: Could not retrieve the following roles: {', '.join(failed_roles)}") - - for role in parsed_roles: - h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) - - embed = Embed( - title=f"{role.name} info", - colour=role.colour, - ) - embed.add_field(name="ID", value=role.id, inline=True) - embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) - embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) - embed.add_field(name="Member count", value=len(role.members), inline=True) - embed.add_field(name="Position", value=role.position) - embed.add_field(name="Permission code", value=role.permissions.value, inline=True) - - await ctx.send(embed=embed) - - @command(name="server", aliases=["server_info", "guild", "guild_info"]) - async def server_info(self, ctx: Context) -> None: - """Returns an embed full of server information.""" - created = time_since(ctx.guild.created_at, precision="days") - features = ", ".join(ctx.guild.features) - region = ctx.guild.region - - roles = len(ctx.guild.roles) - member_count = ctx.guild.member_count - channel_counts = self.get_channel_type_counts(ctx.guild) - - # How many of each user status? - statuses = Counter(member.status for member in ctx.guild.members) - embed = Embed(colour=Colour.blurple()) - - # How many staff members and staff channels do we have? - staff_member_count = len(ctx.guild.get_role(constants.Roles.helpers).members) - staff_channel_count = self.get_staff_channel_count(ctx.guild) - - # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the - # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting - # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts - # after the dedent is made. - embed.description = Template( - textwrap.dedent(f""" - **Server information** - Created: {created} - Voice region: {region} - Features: {features} - - **Channel counts** - $channel_counts - Staff channels: {staff_channel_count} - - **Member counts** - Members: {member_count:,} - Staff members: {staff_member_count} - Roles: {roles} - - **Member statuses** - {constants.Emojis.status_online} {statuses[Status.online]:,} - {constants.Emojis.status_idle} {statuses[Status.idle]:,} - {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} - {constants.Emojis.status_offline} {statuses[Status.offline]:,} - """) - ).substitute({"channel_counts": channel_counts}) - embed.set_thumbnail(url=ctx.guild.icon_url) - - await ctx.send(embed=embed) - - @command(name="user", aliases=["user_info", "member", "member_info"]) - async def user_info(self, ctx: Context, user: Member = None) -> None: - """Returns info about a user.""" - if user is None: - user = ctx.author - - # Do a role check if this is being executed on someone other than the caller - elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): - await ctx.send("You may not use this command on users other than yourself.") - return - - # Non-staff may only do this in #bot-commands - if not with_role_check(ctx, *constants.STAFF_ROLES): - if not ctx.channel.id == constants.Channels.bot_commands: - raise InWhitelistCheckFailure(constants.Channels.bot_commands) - - embed = await self.create_user_embed(ctx, user) - - await ctx.send(embed=embed) - - async def create_user_embed(self, ctx: Context, user: Member) -> Embed: - """Creates an embed containing information on the `user`.""" - created = time_since(user.created_at, max_units=3) - - # Custom status - custom_status = '' - for activity in user.activities: - # Check activity.state for None value if user has a custom status set - # This guards against a custom status with an emoji but no text, which will cause - # escape_markdown to raise an exception - # This can be reworked after a move to d.py 1.3.0+, which adds a CustomActivity class - if activity.name == 'Custom Status' and activity.state: - state = escape_markdown(activity.state) - custom_status = f'Status: {state}\n' - - name = str(user) - if user.nick: - name = f"{user.nick} ({name})" - - joined = time_since(user.joined_at, max_units=3) - roles = ", ".join(role.mention for role in user.roles[1:]) - - description = [ - textwrap.dedent(f""" - **User Information** - Created: {created} - Profile: {user.mention} - ID: {user.id} - {custom_status} - **Member Information** - Joined: {joined} - Roles: {roles or None} - """).strip() - ] - - # Show more verbose output in moderation channels for infractions and nominations - if ctx.channel.id in constants.MODERATION_CHANNELS: - description.append(await self.expanded_user_infraction_counts(user)) - description.append(await self.user_nomination_counts(user)) - else: - description.append(await self.basic_user_infraction_counts(user)) - - # Let's build the embed now - embed = Embed( - title=name, - description="\n\n".join(description) - ) - - embed.set_thumbnail(url=user.avatar_url_as(static_format="png")) - embed.colour = user.top_role.colour if roles else Colour.blurple() - - return embed - - async def basic_user_infraction_counts(self, member: Member) -> str: - """Gets the total and active infraction counts for the given `member`.""" - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'hidden': 'False', - 'user__id': str(member.id) - } - ) - - total_infractions = len(infractions) - active_infractions = sum(infraction['active'] for infraction in infractions) - - infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}" - - return infraction_output - - async def expanded_user_infraction_counts(self, member: Member) -> str: - """ - Gets expanded infraction counts for the given `member`. - - The counts will be split by infraction type and the number of active infractions for each type will indicated - in the output as well. - """ - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'user__id': str(member.id) - } - ) - - infraction_output = ["**Infractions**"] - if not infractions: - infraction_output.append("This user has never received an infraction.") - else: - # Count infractions split by `type` and `active` status for this user - infraction_types = set() - infraction_counter = defaultdict(int) - for infraction in infractions: - infraction_type = infraction["type"] - infraction_active = 'active' if infraction["active"] else 'inactive' - - infraction_types.add(infraction_type) - infraction_counter[f"{infraction_active} {infraction_type}"] += 1 - - # Format the output of the infraction counts - for infraction_type in sorted(infraction_types): - active_count = infraction_counter[f"active {infraction_type}"] - total_count = active_count + infraction_counter[f"inactive {infraction_type}"] - - line = f"{infraction_type.capitalize()}s: {total_count}" - if active_count: - line += f" ({active_count} active)" - - infraction_output.append(line) - - return "\n".join(infraction_output) - - async def user_nomination_counts(self, member: Member) -> str: - """Gets the active and historical nomination counts for the given `member`.""" - nominations = await self.bot.api_client.get( - 'bot/nominations', - params={ - 'user__id': str(member.id) - } - ) - - output = ["**Nominations**"] - - if not nominations: - output.append("This user has never been nominated.") - else: - count = len(nominations) - is_currently_nominated = any(nomination["active"] for nomination in nominations) - nomination_noun = "nomination" if count == 1 else "nominations" - - if is_currently_nominated: - output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).") - else: - output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.") - - return "\n".join(output) - - def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str: - """Format a mapping to be readable to a human.""" - # sorting is technically superfluous but nice if you want to look for a specific field - fields = sorted(mapping.items(), key=lambda item: item[0]) - - if field_width is None: - field_width = len(max(mapping.keys(), key=len)) - - out = '' - - for key, val in fields: - if isinstance(val, dict): - # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries - inner_width = int(field_width * 1.6) - val = '\n' + self.format_fields(val, field_width=inner_width) - - elif isinstance(val, str): - # split up text since it might be long - text = textwrap.fill(val, width=100, replace_whitespace=False) - - # indent it, I guess you could do this with `wrap` and `join` but this is nicer - val = textwrap.indent(text, ' ' * (field_width + len(': '))) - - # the first line is already indented so we `str.lstrip` it - val = val.lstrip() - - if key == 'color': - # makes the base 10 representation of a hex number readable to humans - val = hex(val) - - out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width) - - # remove trailing whitespace - return out.rstrip() - - @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) - @group(invoke_without_command=True) - @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) - async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: - """Shows information about the raw API response.""" - # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling - # doing this extra request is also much easier than trying to convert everything back into a dictionary again - raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) - - paginator = Paginator() - - def add_content(title: str, content: str) -> None: - paginator.add_line(f'== {title} ==\n') - # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution. - # we hope it's not close to 2000 - paginator.add_line(content.replace('```', '`` `')) - paginator.close_page() - - if message.content: - add_content('Raw message', message.content) - - transformer = pprint.pformat if json else self.format_fields - for field_name in ('embeds', 'attachments'): - data = raw_data[field_name] - - if not data: - continue - - total = len(data) - for current, item in enumerate(data, start=1): - title = f'Raw {field_name} ({current}/{total})' - add_content(title, transformer(item)) - - for page in paginator.pages: - await ctx.send(page) - - @raw.command() - async def json(self, ctx: Context, message: Message) -> None: - """Shows information about the raw API response in a copy-pasteable Python format.""" - await ctx.invoke(self.raw, message=message, json=True) - - -def setup(bot: Bot) -> None: - """Load the Information cog.""" - bot.add_cog(Information(bot)) diff --git a/bot/cogs/info/python_news.py b/bot/cogs/info/python_news.py deleted file mode 100644 index 0ab5738a4..000000000 --- a/bot/cogs/info/python_news.py +++ /dev/null @@ -1,232 +0,0 @@ -import logging -import typing as t -from datetime import date, datetime - -import discord -import feedparser -from bs4 import BeautifulSoup -from discord.ext.commands import Cog -from discord.ext.tasks import loop - -from bot import constants -from bot.bot import Bot -from bot.utils.webhooks import send_webhook - -PEPS_RSS_URL = "https://www.python.org/dev/peps/peps.rss/" - -RECENT_THREADS_TEMPLATE = "https://mail.python.org/archives/list/{name}@python.org/recent-threads" -THREAD_TEMPLATE_URL = "https://mail.python.org/archives/api/list/{name}@python.org/thread/{id}/" -MAILMAN_PROFILE_URL = "https://mail.python.org/archives/users/{id}/" -THREAD_URL = "https://mail.python.org/archives/list/{list}@python.org/thread/{id}/" - -AVATAR_URL = "https://www.python.org/static/opengraph-icon-200x200.png" - -log = logging.getLogger(__name__) - - -class PythonNews(Cog): - """Post new PEPs and Python News to `#python-news`.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_names = {} - self.webhook: t.Optional[discord.Webhook] = None - - self.bot.loop.create_task(self.get_webhook_names()) - self.bot.loop.create_task(self.get_webhook_and_channel()) - - async def start_tasks(self) -> None: - """Start the tasks for fetching new PEPs and mailing list messages.""" - self.fetch_new_media.start() - - @loop(minutes=20) - async def fetch_new_media(self) -> None: - """Fetch new mailing list messages and then new PEPs.""" - await self.post_maillist_news() - await self.post_pep_news() - - async def sync_maillists(self) -> None: - """Sync currently in-use maillists with API.""" - # Wait until guild is available to avoid running before everything is ready - await self.bot.wait_until_guild_available() - - response = await self.bot.api_client.get("bot/bot-settings/news") - for mail in constants.PythonNews.mail_lists: - if mail not in response["data"]: - response["data"][mail] = [] - - # Because we are handling PEPs differently, we don't include it to mail lists - if "pep" not in response["data"]: - response["data"]["pep"] = [] - - await self.bot.api_client.put("bot/bot-settings/news", json=response) - - async def get_webhook_names(self) -> None: - """Get webhook author names from maillist API.""" - await self.bot.wait_until_guild_available() - - async with self.bot.http_session.get("https://mail.python.org/archives/api/lists") as resp: - lists = await resp.json() - - for mail in lists: - if mail["name"].split("@")[0] in constants.PythonNews.mail_lists: - self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] - - async def post_pep_news(self) -> None: - """Fetch new PEPs and when they don't have announcement in #python-news, create it.""" - # Wait until everything is ready and http_session available - await self.bot.wait_until_guild_available() - await self.sync_maillists() - - async with self.bot.http_session.get(PEPS_RSS_URL) as resp: - data = feedparser.parse(await resp.text("utf-8")) - - news_listing = await self.bot.api_client.get("bot/bot-settings/news") - payload = news_listing.copy() - pep_numbers = news_listing["data"]["pep"] - - # Reverse entries to send oldest first - data["entries"].reverse() - for new in data["entries"]: - try: - new_datetime = datetime.strptime(new["published"], "%a, %d %b %Y %X %Z") - except ValueError: - log.warning(f"Wrong datetime format passed in PEP new: {new['published']}") - continue - pep_nr = new["title"].split(":")[0].split()[1] - if ( - pep_nr in pep_numbers - or new_datetime.date() < date.today() - ): - continue - - # Build an embed and send a webhook - embed = discord.Embed( - title=new["title"], - description=new["summary"], - timestamp=new_datetime, - url=new["link"], - colour=constants.Colours.soft_green - ) - embed.set_footer(text=data["feed"]["title"], icon_url=AVATAR_URL) - msg = await send_webhook( - webhook=self.webhook, - username=data["feed"]["title"], - embed=embed, - avatar_url=AVATAR_URL, - wait=True, - ) - payload["data"]["pep"].append(pep_nr) - - # Increase overall PEP new stat - self.bot.stats.incr("python_news.posted.pep") - - if msg.channel.is_news(): - log.trace("Publishing PEP annnouncement because it was in a news channel") - await msg.publish() - - # Apply new sent news to DB to avoid duplicate sending - await self.bot.api_client.put("bot/bot-settings/news", json=payload) - - async def post_maillist_news(self) -> None: - """Send new maillist threads to #python-news that is listed in configuration.""" - await self.bot.wait_until_guild_available() - await self.sync_maillists() - existing_news = await self.bot.api_client.get("bot/bot-settings/news") - payload = existing_news.copy() - - for maillist in constants.PythonNews.mail_lists: - async with self.bot.http_session.get(RECENT_THREADS_TEMPLATE.format(name=maillist)) as resp: - recents = BeautifulSoup(await resp.text(), features="lxml") - - # When a

element is present in the response then the mailing list - # has not had any activity during the current month, so therefore it - # can be ignored. - if recents.p: - continue - - for thread in recents.html.body.div.find_all("a", href=True): - # We want only these threads that have identifiers - if "latest" in thread["href"]: - continue - - thread_information, email_information = await self.get_thread_and_first_mail( - maillist, thread["href"].split("/")[-2] - ) - - try: - new_date = datetime.strptime(email_information["date"], "%Y-%m-%dT%X%z") - except ValueError: - log.warning(f"Invalid datetime from Thread email: {email_information['date']}") - continue - - if ( - thread_information["thread_id"] in existing_news["data"][maillist] - or 'Re: ' in thread_information["subject"] - or new_date.date() < date.today() - ): - continue - - content = email_information["content"] - link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) - - # Build an embed and send a message to the webhook - embed = discord.Embed( - title=thread_information["subject"], - description=content[:500] + f"... [continue reading]({link})" if len(content) > 500 else content, - timestamp=new_date, - url=link, - colour=constants.Colours.soft_green - ) - embed.set_author( - name=f"{email_information['sender_name']} ({email_information['sender']['address']})", - url=MAILMAN_PROFILE_URL.format(id=email_information["sender"]["mailman_id"]), - ) - embed.set_footer( - text=f"Posted to {self.webhook_names[maillist]}", - icon_url=AVATAR_URL, - ) - msg = await send_webhook( - webhook=self.webhook, - username=self.webhook_names[maillist], - embed=embed, - avatar_url=AVATAR_URL, - wait=True, - ) - payload["data"][maillist].append(thread_information["thread_id"]) - - # Increase this specific maillist counter in stats - self.bot.stats.incr(f"python_news.posted.{maillist.replace('-', '_')}") - - if msg.channel.is_news(): - log.trace("Publishing mailing list message because it was in a news channel") - await msg.publish() - - await self.bot.api_client.put("bot/bot-settings/news", json=payload) - - async def get_thread_and_first_mail(self, maillist: str, thread_identifier: str) -> t.Tuple[t.Any, t.Any]: - """Get mail thread and first mail from mail.python.org based on `maillist` and `thread_identifier`.""" - async with self.bot.http_session.get( - THREAD_TEMPLATE_URL.format(name=maillist, id=thread_identifier) - ) as resp: - thread_information = await resp.json() - - async with self.bot.http_session.get(thread_information["starting_email"]) as resp: - email_information = await resp.json() - return thread_information, email_information - - async def get_webhook_and_channel(self) -> None: - """Storage #python-news channel Webhook and `TextChannel` to `News.webhook` and `channel`.""" - await self.bot.wait_until_guild_available() - self.webhook = await self.bot.fetch_webhook(constants.PythonNews.webhook) - - await self.start_tasks() - - def cog_unload(self) -> None: - """Stop news posting tasks on cog unload.""" - self.fetch_new_media.cancel() - - -def setup(bot: Bot) -> None: - """Add `News` cog.""" - bot.add_cog(PythonNews(bot)) diff --git a/bot/cogs/info/reddit.py b/bot/cogs/info/reddit.py deleted file mode 100644 index d853ab2ea..000000000 --- a/bot/cogs/info/reddit.py +++ /dev/null @@ -1,304 +0,0 @@ -import asyncio -import logging -import random -import textwrap -from collections import namedtuple -from datetime import datetime, timedelta -from typing import List - -from aiohttp import BasicAuth, ClientError -from discord import Colour, Embed, TextChannel -from discord.ext.commands import Cog, Context, group -from discord.ext.tasks import loop - -from bot.bot import Bot -from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks -from bot.converters import Subreddit -from bot.decorators import with_role -from bot.pagination import LinePaginator -from bot.utils.messages import sub_clyde - -log = logging.getLogger(__name__) - -AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) - - -class Reddit(Cog): - """Track subreddit posts and show detailed statistics about them.""" - - HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} - URL = "https://www.reddit.com" - OAUTH_URL = "https://oauth.reddit.com" - MAX_RETRIES = 3 - - def __init__(self, bot: Bot): - self.bot = bot - - self.webhook = None - self.access_token = None - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - - bot.loop.create_task(self.init_reddit_ready()) - self.auto_poster_loop.start() - - def cog_unload(self) -> None: - """Stop the loop task and revoke the access token when the cog is unloaded.""" - self.auto_poster_loop.cancel() - if self.access_token and self.access_token.expires_at > datetime.utcnow(): - asyncio.create_task(self.revoke_access_token()) - - async def init_reddit_ready(self) -> None: - """Sets the reddit webhook when the cog is loaded.""" - await self.bot.wait_until_guild_available() - if not self.webhook: - self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) - - @property - def channel(self) -> TextChannel: - """Get the #reddit channel object from the bot's cache.""" - return self.bot.get_channel(Channels.reddit) - - async def get_access_token(self) -> None: - """ - Get a Reddit API OAuth2 access token and assign it to self.access_token. - - A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded and a ClientError raised if retrieval was still unsuccessful. - """ - for i in range(1, self.MAX_RETRIES + 1): - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "grant_type": "client_credentials", - "duration": "temporary" - } - ) - - if response.status == 200 and response.content_type == "application/json": - content = await response.json() - expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - self.access_token = AccessToken( - token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(seconds=expiration) - ) - - log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") - return - else: - log.debug( - f"Failed to get an access token: " - f"status {response.status} & content type {response.content_type}; " - f"retrying ({i}/{self.MAX_RETRIES})" - ) - - await asyncio.sleep(3) - - self.bot.remove_cog(self.qualified_name) - raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") - - async def revoke_access_token(self) -> None: - """ - Revoke the OAuth2 access token for the Reddit API. - - For security reasons, it's good practice to revoke the token when it's no longer being used. - """ - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/revoke_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "token": self.access_token.token, - "token_type_hint": "access_token" - } - ) - - if response.status == 204 and response.content_type == "application/json": - self.access_token = None - else: - log.warning(f"Unable to revoke access token: status {response.status}.") - - async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: - """A helper method to fetch a certain amount of Reddit posts at a given route.""" - # Reddit's JSON responses only provide 25 posts at most. - if not 25 >= amount > 0: - raise ValueError("Invalid amount of subreddit posts requested.") - - # Renew the token if necessary. - if not self.access_token or self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() - - url = f"{self.OAUTH_URL}/{route}" - for _ in range(self.MAX_RETRIES): - response = await self.bot.http_session.get( - url=url, - headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, - params=params - ) - if response.status == 200 and response.content_type == 'application/json': - # Got appropriate response - process and return. - content = await response.json() - posts = content["data"]["children"] - return posts[:amount] - - await asyncio.sleep(3) - - log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") - return list() # Failed to get appropriate response within allowed number of retries. - - async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: - """ - Get the top amount of posts for a given subreddit within a specified timeframe. - - A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top - weekly posts. - - The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. - """ - embed = Embed(description="") - - posts = await self.fetch_posts( - route=f"{subreddit}/top", - amount=amount, - params={"t": time} - ) - - if not posts: - embed.title = random.choice(ERROR_REPLIES) - embed.colour = Colour.red() - embed.description = ( - "Sorry! We couldn't find any posts from that subreddit. " - "If this problem persists, please let us know." - ) - - return embed - - for post in posts: - data = post["data"] - - text = data["selftext"] - if text: - text = textwrap.shorten(text, width=128, placeholder="...") - text += "\n" # Add newline to separate embed info - - ups = data["ups"] - comments = data["num_comments"] - author = data["author"] - - title = textwrap.shorten(data["title"], width=64, placeholder="...") - link = self.URL + data["permalink"] - - embed.description += ( - f"**[{title}]({link})**\n" - f"{text}" - f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" - ) - - embed.colour = Colour.blurple() - return embed - - @loop() - async def auto_poster_loop(self) -> None: - """Post the top 5 posts daily, and the top 5 posts weekly.""" - # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter - now = datetime.utcnow() - tomorrow = now + timedelta(days=1) - midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) - seconds_until = (midnight_tomorrow - now).total_seconds() - - await asyncio.sleep(seconds_until) - - await self.bot.wait_until_guild_available() - if not self.webhook: - await self.bot.fetch_webhook(Webhooks.reddit) - - if datetime.utcnow().weekday() == 0: - await self.top_weekly_posts() - # if it's a monday send the top weekly posts - - for subreddit in RedditConfig.subreddits: - top_posts = await self.get_top_posts(subreddit=subreddit, time="day") - username = sub_clyde(f"{subreddit} Top Daily Posts") - message = await self.webhook.send(username=username, embed=top_posts, wait=True) - - if message.channel.is_news(): - await message.publish() - - async def top_weekly_posts(self) -> None: - """Post a summary of the top posts.""" - for subreddit in RedditConfig.subreddits: - # Send and pin the new weekly posts. - top_posts = await self.get_top_posts(subreddit=subreddit, time="week") - username = sub_clyde(f"{subreddit} Top Weekly Posts") - message = await self.webhook.send(wait=True, username=username, embed=top_posts) - - if subreddit.lower() == "r/python": - if not self.channel: - log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") - return - - # Remove the oldest pins so that only 12 remain at most. - pins = await self.channel.pins() - - while len(pins) >= 12: - await pins[-1].unpin() - del pins[-1] - - await message.pin() - - if message.channel.is_news(): - await message.publish() - - @group(name="reddit", invoke_without_command=True) - async def reddit_group(self, ctx: Context) -> None: - """View the top posts from various subreddits.""" - await ctx.send_help(ctx.command) - - @reddit_group.command(name="top") - async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of all time from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="all") - - await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - - @reddit_group.command(name="daily") - async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of today from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="day") - - await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - - @reddit_group.command(name="weekly") - async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of this week from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="week") - - await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - - @with_role(*STAFF_ROLES) - @reddit_group.command(name="subreddits", aliases=("subs",)) - async def subreddits_command(self, ctx: Context) -> None: - """Send a paginated embed of all the subreddits we're relaying.""" - embed = Embed() - embed.title = "Relayed subreddits." - embed.colour = Colour.blurple() - - await LinePaginator.paginate( - RedditConfig.subreddits, - ctx, embed, - footer_text="Use the reddit commands along with these to view their posts.", - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Reddit cog.""" - if not RedditConfig.secret or not RedditConfig.client_id: - log.error("Credentials not provided, cog not loaded.") - return - bot.add_cog(Reddit(bot)) diff --git a/bot/cogs/info/site.py b/bot/cogs/info/site.py deleted file mode 100644 index ac29daa1d..000000000 --- a/bot/cogs/info/site.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging - -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import URLs -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" - - -class Site(Cog): - """Commands for linking to different parts of the site.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="site", aliases=("s",), invoke_without_command=True) - async def site_group(self, ctx: Context) -> None: - """Commands for getting info about our website.""" - await ctx.send_help(ctx.command) - - @site_group.command(name="home", aliases=("about",)) - async def site_main(self, ctx: Context) -> None: - """Info about the website itself.""" - url = f"{URLs.site_schema}{URLs.site}/" - - embed = Embed(title="Python Discord website") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - f"[Our official website]({url}) is an open-source community project " - "created with Python and Django. It contains information about the server " - "itself, lets you sign up for upcoming events, has its own wiki, contains " - "a list of valuable learning resources, and much more." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="resources") - async def site_resources(self, ctx: Context) -> None: - """Info about the site's Resources page.""" - learning_url = f"{PAGES_URL}/resources" - - embed = Embed(title="Resources") - embed.set_footer(text=f"{learning_url}") - embed.colour = Colour.blurple() - embed.description = ( - f"The [Resources page]({learning_url}) on our website contains a " - "list of hand-selected learning resources that we regularly recommend " - f"to both beginners and experts." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="tools") - async def site_tools(self, ctx: Context) -> None: - """Info about the site's Tools page.""" - tools_url = f"{PAGES_URL}/resources/tools" - - embed = Embed(title="Tools") - embed.set_footer(text=f"{tools_url}") - embed.colour = Colour.blurple() - embed.description = ( - f"The [Tools page]({tools_url}) on our website contains a " - f"couple of the most popular tools for programming in Python." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="help") - async def site_help(self, ctx: Context) -> None: - """Info about the site's Getting Help page.""" - url = f"{PAGES_URL}/resources/guides/asking-good-questions" - - embed = Embed(title="Asking Good Questions") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - "Asking the right question about something that's new to you can sometimes be tricky. " - f"To help with this, we've created a [guide to asking good questions]({url}) on our website. " - "It contains everything you need to get the very best help from our community." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="faq") - async def site_faq(self, ctx: Context) -> None: - """Info about the site's FAQ page.""" - url = f"{PAGES_URL}/frequently-asked-questions" - - embed = Embed(title="FAQ") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - "As the largest Python community on Discord, we get hundreds of questions every day. " - "Many of these questions have been asked before. We've compiled a list of the most " - "frequently asked questions along with their answers, which can be found on " - f"our [FAQ page]({url})." - ) - - await ctx.send(embed=embed) - - @site_group.command(aliases=['r', 'rule'], name='rules') - async def site_rules(self, ctx: Context, *rules: int) -> None: - """Provides a link to all rules or, if specified, displays specific rule(s).""" - rules_embed = Embed(title='Rules', color=Colour.blurple()) - rules_embed.url = f"{PAGES_URL}/rules" - - if not rules: - # Rules were not submitted. Return the default description. - rules_embed.description = ( - "The rules and guidelines that apply to this community can be found on" - f" our [rules page]({PAGES_URL}/rules). We expect" - " all members of the community to have read and understood these." - ) - - await ctx.send(embed=rules_embed) - return - - full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) - invalid_indices = tuple( - pick - for pick in rules - if pick < 1 or pick > len(full_rules) - ) - - if invalid_indices: - indices = ', '.join(map(str, invalid_indices)) - await ctx.send(f":x: Invalid rule indices: {indices}") - return - - for rule in rules: - self.bot.stats.incr(f"rule_uses.{rule}") - - final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) - - await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) - - -def setup(bot: Bot) -> None: - """Load the Site cog.""" - bot.add_cog(Site(bot)) diff --git a/bot/cogs/info/source.py b/bot/cogs/info/source.py deleted file mode 100644 index 205e0ba81..000000000 --- a/bot/cogs/info/source.py +++ /dev/null @@ -1,141 +0,0 @@ -import inspect -from pathlib import Path -from typing import Optional, Tuple, Union - -from discord import Embed -from discord.ext import commands - -from bot.bot import Bot -from bot.constants import URLs - -SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] - - -class SourceConverter(commands.Converter): - """Convert an argument into a help command, tag, command, or cog.""" - - async def convert(self, ctx: commands.Context, argument: str) -> SourceType: - """Convert argument into source object.""" - if argument.lower().startswith("help"): - return ctx.bot.help_command - - cog = ctx.bot.get_cog(argument) - if cog: - return cog - - cmd = ctx.bot.get_command(argument) - if cmd: - return cmd - - tags_cog = ctx.bot.get_cog("Tags") - show_tag = True - - if not tags_cog: - show_tag = False - elif argument.lower() in tags_cog._cache: - return argument.lower() - - raise commands.BadArgument( - f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog." - ) - - -class BotSource(commands.Cog): - """Displays information about the bot's source code.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @commands.command(name="source", aliases=("src",)) - async def source_command(self, ctx: commands.Context, *, source_item: SourceConverter = None) -> None: - """Display information and a GitHub link to the source code of a command, tag, or cog.""" - if not source_item: - embed = Embed(title="Bot's GitHub Repository") - embed.add_field(name="Repository", value=f"[Go to GitHub]({URLs.github_bot_repo})") - embed.set_thumbnail(url="https://avatars1.githubusercontent.com/u/9919") - await ctx.send(embed=embed) - return - - embed = await self.build_embed(source_item) - await ctx.send(embed=embed) - - def get_source_link(self, source_item: SourceType) -> Tuple[str, str, Optional[int]]: - """ - Build GitHub link of source item, return this link, file location and first line number. - - Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). - """ - if isinstance(source_item, commands.Command): - if source_item.cog_name == "Alias": - cmd_name = source_item.callback.__name__.replace("_alias", "") - cmd = self.bot.get_command(cmd_name.replace("_", " ")) - src = cmd.callback.__code__ - filename = src.co_filename - else: - src = source_item.callback.__code__ - filename = src.co_filename - elif isinstance(source_item, str): - tags_cog = self.bot.get_cog("Tags") - filename = tags_cog._cache[source_item]["location"] - else: - src = type(source_item) - try: - filename = inspect.getsourcefile(src) - except TypeError: - raise commands.BadArgument("Cannot get source for a dynamically-created object.") - - if not isinstance(source_item, str): - try: - lines, first_line_no = inspect.getsourcelines(src) - except OSError: - raise commands.BadArgument("Cannot get source for a dynamically-created object.") - - lines_extension = f"#L{first_line_no}-L{first_line_no+len(lines)-1}" - else: - first_line_no = None - lines_extension = "" - - # Handle tag file location differently than others to avoid errors in some cases - if not first_line_no: - file_location = Path(filename).relative_to("/bot/") - else: - file_location = Path(filename).relative_to(Path.cwd()).as_posix() - - url = f"{URLs.github_bot_repo}/blob/master/{file_location}{lines_extension}" - - return url, file_location, first_line_no or None - - async def build_embed(self, source_object: SourceType) -> Optional[Embed]: - """Build embed based on source object.""" - url, location, first_line = self.get_source_link(source_object) - - if isinstance(source_object, commands.HelpCommand): - title = "Help Command" - description = source_object.__doc__.splitlines()[1] - elif isinstance(source_object, commands.Command): - if source_object.cog_name == "Alias": - cmd_name = source_object.callback.__name__.replace("_alias", "") - cmd = self.bot.get_command(cmd_name.replace("_", " ")) - description = cmd.short_doc - else: - description = source_object.short_doc - - title = f"Command: {source_object.qualified_name}" - elif isinstance(source_object, str): - title = f"Tag: {source_object}" - description = "" - else: - title = f"Cog: {source_object.qualified_name}" - description = source_object.description.splitlines()[0] - - embed = Embed(title=title, description=description) - embed.add_field(name="Source Code", value=f"[Go to GitHub]({url})") - line_text = f":{first_line}" if first_line else "" - embed.set_footer(text=f"{location}{line_text}") - - return embed - - -def setup(bot: Bot) -> None: - """Load the BotSource cog.""" - bot.add_cog(BotSource(bot)) diff --git a/bot/cogs/info/stats.py b/bot/cogs/info/stats.py deleted file mode 100644 index d42f55466..000000000 --- a/bot/cogs/info/stats.py +++ /dev/null @@ -1,129 +0,0 @@ -import string -from datetime import datetime - -from discord import Member, Message, Status -from discord.ext.commands import Cog, Context -from discord.ext.tasks import loop - -from bot.bot import Bot -from bot.constants import Categories, Channels, Guild, Stats as StatConf - - -CHANNEL_NAME_OVERRIDES = { - Channels.off_topic_0: "off_topic_0", - Channels.off_topic_1: "off_topic_1", - Channels.off_topic_2: "off_topic_2", - Channels.staff_lounge: "staff_lounge" -} - -ALLOWED_CHARS = string.ascii_letters + string.digits + "_" - - -class Stats(Cog): - """A cog which provides a way to hook onto Discord events and forward to stats.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.last_presence_update = None - self.update_guild_boost.start() - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Report message events in the server to statsd.""" - if message.guild is None: - return - - if message.guild.id != Guild.id: - return - - cat = getattr(message.channel, "category", None) - if cat is not None and cat.id == Categories.modmail: - if message.channel.id != Channels.incidents: - # Do not report modmail channels to stats, there are too many - # of them for interesting statistics to be drawn out of this. - return - - reformatted_name = message.channel.name.replace('-', '_') - - if CHANNEL_NAME_OVERRIDES.get(message.channel.id): - reformatted_name = CHANNEL_NAME_OVERRIDES.get(message.channel.id) - - reformatted_name = "".join(char for char in reformatted_name if char in ALLOWED_CHARS) - - stat_name = f"channels.{reformatted_name}" - self.bot.stats.incr(stat_name) - - # Increment the total message count - self.bot.stats.incr("messages") - - @Cog.listener() - async def on_command_completion(self, ctx: Context) -> None: - """Report completed commands to statsd.""" - command_name = ctx.command.qualified_name.replace(" ", "_") - - self.bot.stats.incr(f"commands.{command_name}") - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Update member count stat on member join.""" - if member.guild.id != Guild.id: - return - - self.bot.stats.gauge("guild.total_members", len(member.guild.members)) - - @Cog.listener() - async def on_member_leave(self, member: Member) -> None: - """Update member count stat on member leave.""" - if member.guild.id != Guild.id: - return - - self.bot.stats.gauge("guild.total_members", len(member.guild.members)) - - @Cog.listener() - async def on_member_update(self, _before: Member, after: Member) -> None: - """Update presence estimates on member update.""" - if after.guild.id != Guild.id: - return - - if self.last_presence_update: - if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: - return - - self.last_presence_update = datetime.now() - - online = 0 - idle = 0 - dnd = 0 - offline = 0 - - for member in after.guild.members: - if member.status is Status.online: - online += 1 - elif member.status is Status.dnd: - dnd += 1 - elif member.status is Status.idle: - idle += 1 - elif member.status is Status.offline: - offline += 1 - - self.bot.stats.gauge("guild.status.online", online) - self.bot.stats.gauge("guild.status.idle", idle) - self.bot.stats.gauge("guild.status.do_not_disturb", dnd) - self.bot.stats.gauge("guild.status.offline", offline) - - @loop(hours=1) - async def update_guild_boost(self) -> None: - """Post the server boost level and tier every hour.""" - await self.bot.wait_until_guild_available() - g = self.bot.get_guild(Guild.id) - self.bot.stats.gauge("boost.amount", g.premium_subscription_count) - self.bot.stats.gauge("boost.tier", g.premium_tier) - - def cog_unload(self) -> None: - """Stop the boost statistic task on unload of the Cog.""" - self.update_guild_boost.stop() - - -def setup(bot: Bot) -> None: - """Load the stats cog.""" - bot.add_cog(Stats(bot)) diff --git a/bot/cogs/info/tags.py b/bot/cogs/info/tags.py deleted file mode 100644 index 3d76c5c08..000000000 --- a/bot/cogs/info/tags.py +++ /dev/null @@ -1,277 +0,0 @@ -import logging -import re -import time -from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot import constants -from bot.bot import Bot -from bot.converters import TagNameConverter -from bot.pagination import LinePaginator -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -TEST_CHANNELS = ( - constants.Channels.bot_commands, - constants.Channels.helpers -) - -REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) -FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." - - -class Tags(Cog): - """Save new tags and fetch existing tags.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.tag_cooldowns = {} - self._cache = self.get_tags() - - @staticmethod - def get_tags() -> dict: - """Get all tags.""" - cache = {} - - base_path = Path("bot", "resources", "tags") - for file in base_path.glob("**/*"): - if file.is_file(): - tag_title = file.stem - tag = { - "title": tag_title, - "embed": { - "description": file.read_text(encoding="utf8"), - }, - "restricted_to": "developers", - "location": f"/bot/{file}" - } - - # Convert to a list to allow negative indexing. - parents = list(file.relative_to(base_path).parents) - if len(parents) > 1: - # -1 would be '.' hence -2 is used as the index. - tag["restricted_to"] = parents[-2].name - - cache[tag_title] = tag - - return cache - - @staticmethod - def check_accessibility(user: Member, tag: dict) -> bool: - """Check if user can access a tag.""" - return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] - - @staticmethod - def _fuzzy_search(search: str, target: str) -> float: - """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" - current, index = 0, 0 - _search = REGEX_NON_ALPHABET.sub('', search.lower()) - _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) - _target = next(_targets) - try: - while True: - while index < len(_target) and _search[current] == _target[index]: - current += 1 - index += 1 - index, _target = 0, next(_targets) - except (StopIteration, IndexError): - pass - return current / len(_search) * 100 - - def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: - """Return a list of suggested tags.""" - scores: Dict[str, int] = { - tag_title: Tags._fuzzy_search(tag_name, tag['title']) - for tag_title, tag in self._cache.items() - } - - thresholds = thresholds or [100, 90, 80, 70, 60] - - for threshold in thresholds: - suggestions = [ - self._cache[tag_title] - for tag_title, matching_score in scores.items() - if matching_score >= threshold - ] - if suggestions: - return suggestions - - return [] - - def _get_tag(self, tag_name: str) -> list: - """Get a specific tag.""" - found = [self._cache.get(tag_name.lower(), None)] - if not found[0]: - return self._get_suggestions(tag_name) - return found - - def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: - """ - Search for tags via contents. - - `predicate` will be the built-in any, all, or a custom callable. Must return a bool. - """ - keywords_processed: List[str] = [] - for keyword in keywords.split(','): - keyword_sanitized = keyword.strip().casefold() - if not keyword_sanitized: - # this happens when there are leading / trailing / consecutive comma. - continue - keywords_processed.append(keyword_sanitized) - - if not keywords_processed: - # after sanitizing, we can end up with an empty list, for example when keywords is ',' - # in that case, we simply want to search for such keywords directly instead. - keywords_processed = [keywords] - - matching_tags = [] - for tag in self._cache.values(): - matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) - if self.check_accessibility(user, tag) and check(matches): - matching_tags.append(tag) - - return matching_tags - - async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: - """Send the result of matching tags to user.""" - if not matching_tags: - pass - elif len(matching_tags) == 1: - await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) - else: - is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 - embed = Embed( - title=f"Here are the tags containing the given keyword{'s' * is_plural}:", - description='\n'.join(tag['title'] for tag in matching_tags[:10]) - ) - await LinePaginator.paginate( - sorted(f"**»** {tag['title']}" for tag in matching_tags), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - - @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) - async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: - """Show all known tags, a single tag, or run a subcommand.""" - await ctx.invoke(self.get_command, tag_name=tag_name) - - @tags_group.group(name='search', invoke_without_command=True) - async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Only search for tags that has ALL the keywords. - """ - matching_tags = self._get_tags_via_content(all, keywords, ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @search_tag_content.command(name='any') - async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Search for tags that has ANY of the keywords. - """ - matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @tags_group.command(name='get', aliases=('show', 'g')) - async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: - """Get a specified tag, or a list of all tags if no tag is specified.""" - - def _command_on_cooldown(tag_name: str) -> bool: - """ - Check if the command is currently on cooldown, on a per-tag, per-channel basis. - - The cooldown duration is set in constants.py. - """ - now = time.time() - - cooldown_conditions = ( - tag_name - and tag_name in self.tag_cooldowns - and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags - and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id - ) - - if cooldown_conditions: - return True - return False - - if _command_on_cooldown(tag_name): - time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] - time_left = constants.Cooldowns.tags - time_elapsed - log.info( - f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " - f"Cooldown ends in {time_left:.1f} seconds." - ) - return - - if tag_name is not None: - temp_founds = self._get_tag(tag_name) - - founds = [] - - for found_tag in temp_founds: - if self.check_accessibility(ctx.author, found_tag): - founds.append(found_tag) - - if len(founds) == 1: - tag = founds[0] - if ctx.channel.id not in TEST_CHANNELS: - self.tag_cooldowns[tag_name] = { - "time": time.time(), - "channel": ctx.channel.id - } - - self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") - - await wait_for_deletion( - await ctx.send(embed=Embed.from_dict(tag['embed'])), - [ctx.author.id], - client=self.bot - ) - elif founds and len(tag_name) >= 3: - await wait_for_deletion( - await ctx.send( - embed=Embed( - title='Did you mean ...', - description='\n'.join(tag['title'] for tag in founds[:10]) - ) - ), - [ctx.author.id], - client=self.bot - ) - - else: - tags = self._cache.values() - if not tags: - await ctx.send(embed=Embed( - description="**There are no tags in the database!**", - colour=Colour.red() - )) - else: - embed: Embed = Embed(title="**Current tags**") - await LinePaginator.paginate( - sorted( - f"**»** {tag['title']}" for tag in tags - if self.check_accessibility(ctx.author, tag) - ), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Tags cog.""" - bot.add_cog(Tags(bot)) diff --git a/bot/cogs/info/wolfram.py b/bot/cogs/info/wolfram.py deleted file mode 100644 index e6cae3bb8..000000000 --- a/bot/cogs/info/wolfram.py +++ /dev/null @@ -1,280 +0,0 @@ -import logging -from io import BytesIO -from typing import Callable, List, Optional, Tuple -from urllib import parse - -import discord -from dateutil.relativedelta import relativedelta -from discord import Embed -from discord.ext import commands -from discord.ext.commands import BucketType, Cog, Context, check, group - -from bot.bot import Bot -from bot.constants import Colours, STAFF_ROLES, Wolfram -from bot.pagination import ImagePaginator -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -APPID = Wolfram.key -DEFAULT_OUTPUT_FORMAT = "JSON" -QUERY = "http://api.wolframalpha.com/v2/{request}?{data}" -WOLF_IMAGE = "https://www.symbols.com/gi.php?type=1&id=2886&i=1" - -MAX_PODS = 20 - -# Allows for 10 wolfram calls pr user pr day -usercd = commands.CooldownMapping.from_cooldown(Wolfram.user_limit_day, 60*60*24, BucketType.user) - -# Allows for max api requests / days in month per day for the entire guild (Temporary) -guildcd = commands.CooldownMapping.from_cooldown(Wolfram.guild_limit_day, 60*60*24, BucketType.guild) - - -async def send_embed( - ctx: Context, - message_txt: str, - colour: int = Colours.soft_red, - footer: str = None, - img_url: str = None, - f: discord.File = None -) -> None: - """Generate & send a response embed with Wolfram as the author.""" - embed = Embed(colour=colour) - embed.description = message_txt - embed.set_author(name="Wolfram Alpha", - icon_url=WOLF_IMAGE, - url="https://www.wolframalpha.com/") - if footer: - embed.set_footer(text=footer) - - if img_url: - embed.set_image(url=img_url) - - await ctx.send(embed=embed, file=f) - - -def custom_cooldown(*ignore: List[int]) -> Callable: - """ - Implement per-user and per-guild cooldowns for requests to the Wolfram API. - - A list of roles may be provided to ignore the per-user cooldown - """ - async def predicate(ctx: Context) -> bool: - if ctx.invoked_with == 'help': - # if the invoked command is help we don't want to increase the ratelimits since it's not actually - # invoking the command/making a request, so instead just check if the user/guild are on cooldown. - guild_cooldown = not guildcd.get_bucket(ctx.message).get_tokens() == 0 # if guild is on cooldown - if not any(r.id in ignore for r in ctx.author.roles): # check user bucket if user is not ignored - return guild_cooldown and not usercd.get_bucket(ctx.message).get_tokens() == 0 - return guild_cooldown - - user_bucket = usercd.get_bucket(ctx.message) - - if all(role.id not in ignore for role in ctx.author.roles): - user_rate = user_bucket.update_rate_limit() - - if user_rate: - # Can't use api; cause: member limit - delta = relativedelta(seconds=int(user_rate)) - cooldown = humanize_delta(delta) - message = ( - "You've used up your limit for Wolfram|Alpha requests.\n" - f"Cooldown: {cooldown}" - ) - await send_embed(ctx, message) - return False - - guild_bucket = guildcd.get_bucket(ctx.message) - guild_rate = guild_bucket.update_rate_limit() - - # Repr has a token attribute to read requests left - log.debug(guild_bucket) - - if guild_rate: - # Can't use api; cause: guild limit - message = ( - "The max limit of requests for the server has been reached for today.\n" - f"Cooldown: {int(guild_rate)}" - ) - await send_embed(ctx, message) - return False - - return True - return check(predicate) - - -async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tuple]]: - """Get the Wolfram API pod pages for the provided query.""" - async with ctx.channel.typing(): - url_str = parse.urlencode({ - "input": query, - "appid": APPID, - "output": DEFAULT_OUTPUT_FORMAT, - "format": "image,plaintext" - }) - request_url = QUERY.format(request="query", data=url_str) - - async with bot.http_session.get(request_url) as response: - json = await response.json(content_type='text/plain') - - result = json["queryresult"] - - if result["error"]: - # API key not set up correctly - if result["error"]["msg"] == "Invalid appid": - message = "Wolfram API key is invalid or missing." - log.warning( - "API key seems to be missing, or invalid when " - f"processing a wolfram request: {url_str}, Response: {json}" - ) - await send_embed(ctx, message) - return - - message = "Something went wrong internally with your request, please notify staff!" - log.warning(f"Something went wrong getting a response from wolfram: {url_str}, Response: {json}") - await send_embed(ctx, message) - return - - if not result["success"]: - message = f"I couldn't find anything for {query}." - await send_embed(ctx, message) - return - - if not result["numpods"]: - message = "Could not find any results." - await send_embed(ctx, message) - return - - pods = result["pods"] - pages = [] - for pod in pods[:MAX_PODS]: - subs = pod.get("subpods") - - for sub in subs: - title = sub.get("title") or sub.get("plaintext") or sub.get("id", "") - img = sub["img"]["src"] - pages.append((title, img)) - return pages - - -class Wolfram(Cog): - """Commands for interacting with the Wolfram|Alpha API.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_command(self, ctx: Context, *, query: str) -> None: - """Requests all answers on a single image, sends an image of all related pods.""" - url_str = parse.urlencode({ - "i": query, - "appid": APPID, - }) - query = QUERY.format(request="simple", data=url_str) - - # Give feedback that the bot is working. - async with ctx.channel.typing(): - async with self.bot.http_session.get(query) as response: - status = response.status - image_bytes = await response.read() - - f = discord.File(BytesIO(image_bytes), filename="image.png") - image_url = "attachment://image.png" - - if status == 501: - message = "Failed to get response" - footer = "" - color = Colours.soft_red - elif status == 400: - message = "No input found" - footer = "" - color = Colours.soft_red - elif status == 403: - message = "Wolfram API key is invalid or missing." - footer = "" - color = Colours.soft_red - else: - message = "" - footer = "View original for a bigger picture." - color = Colours.soft_orange - - # Sends a "blank" embed if no request is received, unsure how to fix - await send_embed(ctx, message, color, footer=footer, img_url=image_url, f=f) - - @wolfram_command.command(name="page", aliases=("pa", "p")) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_page_command(self, ctx: Context, *, query: str) -> None: - """ - Requests a drawn image of given query. - - Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. - """ - pages = await get_pod_pages(ctx, self.bot, query) - - if not pages: - return - - embed = Embed() - embed.set_author(name="Wolfram Alpha", - icon_url=WOLF_IMAGE, - url="https://www.wolframalpha.com/") - embed.colour = Colours.soft_orange - - await ImagePaginator.paginate(pages, ctx, embed) - - @wolfram_command.command(name="cut", aliases=("c",)) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_cut_command(self, ctx: Context, *, query: str) -> None: - """ - Requests a drawn image of given query. - - Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. - """ - pages = await get_pod_pages(ctx, self.bot, query) - - if not pages: - return - - if len(pages) >= 2: - page = pages[1] - else: - page = pages[0] - - await send_embed(ctx, page[0], colour=Colours.soft_orange, img_url=page[1]) - - @wolfram_command.command(name="short", aliases=("sh", "s")) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_short_command(self, ctx: Context, *, query: str) -> None: - """Requests an answer to a simple question.""" - url_str = parse.urlencode({ - "i": query, - "appid": APPID, - }) - query = QUERY.format(request="result", data=url_str) - - # Give feedback that the bot is working. - async with ctx.channel.typing(): - async with self.bot.http_session.get(query) as response: - status = response.status - response_text = await response.text() - - if status == 501: - message = "Failed to get response" - color = Colours.soft_red - elif status == 400: - message = "No input found" - color = Colours.soft_red - elif response_text == "Error 1: Invalid appid": - message = "Wolfram API key is invalid or missing." - color = Colours.soft_red - else: - message = response_text - color = Colours.soft_orange - - await send_embed(ctx, message, color) - - -def setup(bot: Bot) -> None: - """Load the Wolfram cog.""" - bot.add_cog(Wolfram(bot)) diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/moderation/defcon.py b/bot/cogs/moderation/defcon.py deleted file mode 100644 index e78435a7d..000000000 --- a/bot/cogs/moderation/defcon.py +++ /dev/null @@ -1,258 +0,0 @@ -from __future__ import annotations - -import logging -from collections import namedtuple -from datetime import datetime, timedelta -from enum import Enum - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles -from bot.decorators import with_role - -log = logging.getLogger(__name__) - -REJECTION_MESSAGE = """ -Hi, {user} - Thanks for your interest in our server! - -Due to a current (or detected) cyberattack on our community, we've limited access to the server for new accounts. Since -your account is relatively new, we're unable to provide access to the server at this time. - -Even so, thanks for joining! We're very excited at the possibility of having you here, and we hope that this situation -will be resolved soon. In the meantime, please feel free to peruse the resources on our site at -, and have a nice day! -""" - -BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism" - - -class Action(Enum): - """Defcon Action.""" - - ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template']) - - ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n") - DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "") - UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n") - - -class Defcon(Cog): - """Time-sensitive server defense mechanisms.""" - - days = None # type: timedelta - enabled = False # type: bool - - def __init__(self, bot: Bot): - self.bot = bot - self.channel = None - self.days = timedelta(days=0) - - self.bot.loop.create_task(self.sync_settings()) - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def sync_settings(self) -> None: - """On cog load, try to synchronize DEFCON settings to the API.""" - await self.bot.wait_until_guild_available() - self.channel = await self.bot.fetch_channel(Channels.defcon) - - try: - response = await self.bot.api_client.get('bot/bot-settings/defcon') - data = response['data'] - - except Exception: # Yikes! - log.exception("Unable to get DEFCON settings!") - await self.bot.get_channel(Channels.dev_log).send( - f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" - ) - - else: - if data["enabled"]: - self.enabled = True - self.days = timedelta(days=data["days"]) - log.info(f"DEFCON enabled: {self.days.days} days") - - else: - self.enabled = False - self.days = timedelta(days=0) - log.info("DEFCON disabled") - - await self.update_channel_topic() - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """If DEFCON is enabled, check newly joining users to see if they meet the account age threshold.""" - if self.enabled and self.days.days > 0: - now = datetime.utcnow() - - if now - member.created_at < self.days: - log.info(f"Rejecting user {member}: Account is too new and DEFCON is enabled") - - message_sent = False - - try: - await member.send(REJECTION_MESSAGE.format(user=member.mention)) - - message_sent = True - except Exception: - log.exception(f"Unable to send rejection message to user: {member}") - - await member.kick(reason="DEFCON active, user is too new") - self.bot.stats.incr("defcon.leaves") - - message = ( - f"{member} (`{member.id}`) was denied entry because their account is too new." - ) - - if not message_sent: - message = f"{message}\n\nUnable to send rejection message via DM; they probably have DMs disabled." - - await self.mod_log.send_log_message( - Icons.defcon_denied, Colours.soft_red, "Entry denied", - message, member.avatar_url_as(static_format="png") - ) - - @group(name='defcon', aliases=('dc',), invoke_without_command=True) - @with_role(Roles.admins, Roles.owners) - async def defcon_group(self, ctx: Context) -> None: - """Check the DEFCON status or run a subcommand.""" - await ctx.send_help(ctx.command) - - async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: - """Providing a structured way to do an defcon action.""" - try: - response = await self.bot.api_client.get('bot/bot-settings/defcon') - data = response['data'] - - if "enable_date" in data and action is Action.DISABLED: - enabled = datetime.fromisoformat(data["enable_date"]) - - delta = datetime.now() - enabled - - self.bot.stats.timing("defcon.enabled", delta) - except Exception: - pass - - error = None - try: - await self.bot.api_client.put( - 'bot/bot-settings/defcon', - json={ - 'name': 'defcon', - 'data': { - # TODO: retrieve old days count - 'days': days, - 'enabled': action is not Action.DISABLED, - 'enable_date': datetime.now().isoformat() - } - } - ) - except Exception as err: - log.exception("Unable to update DEFCON settings.") - error = err - finally: - await ctx.send(self.build_defcon_msg(action, error)) - await self.send_defcon_log(action, ctx.author, error) - - self.bot.stats.gauge("defcon.threshold", days) - - @defcon_group.command(name='enable', aliases=('on', 'e')) - @with_role(Roles.admins, Roles.owners) - async def enable_command(self, ctx: Context) -> None: - """ - Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! - - Currently, this just adds an account age requirement. Use !defcon days to set how old an account must be, - in days. - """ - self.enabled = True - await self._defcon_action(ctx, days=0, action=Action.ENABLED) - await self.update_channel_topic() - - @defcon_group.command(name='disable', aliases=('off', 'd')) - @with_role(Roles.admins, Roles.owners) - async def disable_command(self, ctx: Context) -> None: - """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" - self.enabled = False - await self._defcon_action(ctx, days=0, action=Action.DISABLED) - await self.update_channel_topic() - - @defcon_group.command(name='status', aliases=('s',)) - @with_role(Roles.admins, Roles.owners) - async def status_command(self, ctx: Context) -> None: - """Check the current status of DEFCON mode.""" - embed = Embed( - colour=Colour.blurple(), title="DEFCON Status", - description=f"**Enabled:** {self.enabled}\n" - f"**Days:** {self.days.days}" - ) - - await ctx.send(embed=embed) - - @defcon_group.command(name='days') - @with_role(Roles.admins, Roles.owners) - async def days_command(self, ctx: Context, days: int) -> None: - """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" - self.days = timedelta(days=days) - self.enabled = True - await self._defcon_action(ctx, days=days, action=Action.UPDATED) - await self.update_channel_topic() - - async def update_channel_topic(self) -> None: - """Update the #defcon channel topic with the current DEFCON status.""" - if self.enabled: - day_str = "days" if self.days.days > 1 else "day" - new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Enabled, Threshold: {self.days.days} {day_str})" - else: - new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Disabled)" - - self.mod_log.ignore(Event.guild_channel_update, Channels.defcon) - await self.channel.edit(topic=new_topic) - - def build_defcon_msg(self, action: Action, e: Exception = None) -> str: - """Build in-channel response string for DEFCON action.""" - if action is Action.ENABLED: - msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n" - elif action is Action.DISABLED: - msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n" - elif action is Action.UPDATED: - msg = ( - f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} " - f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n" - ) - - if e: - msg += ( - "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" - f"```py\n{e}\n```" - ) - - return msg - - async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None: - """Send log message for DEFCON action.""" - info = action.value - log_msg: str = ( - f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n" - f"{info.template.format(days=self.days.days)}" - ) - status_msg = f"DEFCON {action.name.lower()}" - - if e: - log_msg += ( - "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" - f"```py\n{e}\n```" - ) - - await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg) - - -def setup(bot: Bot) -> None: - """Load the Defcon cog.""" - bot.add_cog(Defcon(bot)) diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py deleted file mode 100644 index e49913552..000000000 --- a/bot/cogs/moderation/incidents.py +++ /dev/null @@ -1,412 +0,0 @@ -import asyncio -import logging -import typing as t -from datetime import datetime -from enum import Enum - -import discord -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, Colours, Emojis, Guild, Webhooks -from bot.utils.messages import sub_clyde - -log = logging.getLogger(__name__) - -# Amount of messages for `crawl_task` to process at most on start-up - limited to 50 -# as in practice, there should never be this many messages, and if there are, -# something has likely gone very wrong -CRAWL_LIMIT = 50 - -# Seconds for `crawl_task` to sleep after adding reactions to a message -CRAWL_SLEEP = 2 - - -class Signal(Enum): - """ - Recognized incident status signals. - - This binds emoji to actions. The bot will only react to emoji linked here. - All other signals are seen as invalid. - """ - - ACTIONED = Emojis.incident_actioned - NOT_ACTIONED = Emojis.incident_unactioned - INVESTIGATING = Emojis.incident_investigating - - -# Reactions from non-mod roles will be removed -ALLOWED_ROLES: t.Set[int] = set(Guild.moderation_roles) - -# Message must have all of these emoji to pass the `has_signals` check -ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} - -# An embed coupled with an optional file to be dispatched -# If the file is not None, the embed attempts to show it in its body -FileEmbed = t.Tuple[discord.Embed, t.Optional[discord.File]] - - -async def download_file(attachment: discord.Attachment) -> t.Optional[discord.File]: - """ - Download & return `attachment` file. - - If the download fails, the reason is logged and None will be returned. - 404 and 403 errors are only logged at debug level. - """ - log.debug(f"Attempting to download attachment: {attachment.filename}") - try: - return await attachment.to_file() - except (discord.NotFound, discord.Forbidden) as exc: - log.debug(f"Failed to download attachment: {exc}") - except Exception: - log.exception("Failed to download attachment") - - -async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> FileEmbed: - """ - Create an embed representation of `incident` for the #incidents-archive channel. - - The name & discriminator of `actioned_by` and `outcome` will be presented in the - embed footer. Additionally, the embed is coloured based on `outcome`. - - The author of `incident` is not shown in the embed. It is assumed that this piece - of information will be relayed in other ways, e.g. webhook username. - - As mentions in embeds do not ping, we do not need to use `incident.clean_content`. - - If `incident` contains attachments, the first attachment will be downloaded and - returned alongside the embed. The embed attempts to display the attachment. - Should the download fail, we fallback on linking the `proxy_url`, which should - remain functional for some time after the original message is deleted. - """ - log.trace(f"Creating embed for {incident.id=}") - - if outcome is Signal.ACTIONED: - colour = Colours.soft_green - footer = f"Actioned by {actioned_by}" - else: - colour = Colours.soft_red - footer = f"Rejected by {actioned_by}" - - embed = discord.Embed( - description=incident.content, - timestamp=datetime.utcnow(), - colour=colour, - ) - embed.set_footer(text=footer, icon_url=actioned_by.avatar_url) - - if incident.attachments: - attachment = incident.attachments[0] # User-sent messages can only contain one attachment - file = await download_file(attachment) - - if file is not None: - embed.set_image(url=f"attachment://{attachment.filename}") # Embed displays the attached file - else: - embed.set_author(name="[Failed to relay attachment]", url=attachment.proxy_url) # Embed links the file - else: - file = None - - return embed, file - - -def is_incident(message: discord.Message) -> bool: - """True if `message` qualifies as an incident, False otherwise.""" - conditions = ( - message.channel.id == Channels.incidents, # Message sent in #incidents - not message.author.bot, # Not by a bot - not message.content.startswith("#"), # Doesn't start with a hash - not message.pinned, # And isn't header - ) - return all(conditions) - - -def own_reactions(message: discord.Message) -> t.Set[str]: - """Get the set of reactions placed on `message` by the bot itself.""" - return {str(reaction.emoji) for reaction in message.reactions if reaction.me} - - -def has_signals(message: discord.Message) -> bool: - """True if `message` already has all `Signal` reactions, False otherwise.""" - return ALL_SIGNALS.issubset(own_reactions(message)) - - -async def add_signals(incident: discord.Message) -> None: - """ - Add `Signal` member emoji to `incident` as reactions. - - If the emoji has already been placed on `incident` by the bot, it will be skipped. - """ - existing_reacts = own_reactions(incident) - - for signal_emoji in Signal: - if signal_emoji.value in existing_reacts: # This would not raise, but it is a superfluous API call - log.trace(f"Skipping emoji as it's already been placed: {signal_emoji}") - else: - log.trace(f"Adding reaction: {signal_emoji}") - await incident.add_reaction(signal_emoji.value) - - -class Incidents(Cog): - """ - Automation for the #incidents channel. - - This cog does not provide a command API, it only reacts to the following events. - - On start-up: - * Crawl #incidents and add missing `Signal` emoji where appropriate - * This is to retro-actively add the available options for messages which - were sent while the bot wasn't listening - * Pinned messages and message starting with # do not qualify as incidents - * See: `crawl_incidents` - - On message: - * Add `Signal` member emoji if message qualifies as an incident - * Ignore messages starting with # - * Use this if verbal communication is necessary - * Each such message must be deleted manually once appropriate - * See: `on_message` - - On reaction: - * Remove reaction if not permitted - * User does not have any of the roles in `ALLOWED_ROLES` - * Used emoji is not a `Signal` member - * If `Signal.ACTIONED` or `Signal.NOT_ACTIONED` were chosen, attempt to - relay the incident message to #incidents-archive - * If relay successful, delete original message - * See: `on_raw_reaction_add` - - Please refer to function docstrings for implementation details. - """ - - def __init__(self, bot: Bot) -> None: - """Prepare `event_lock` and schedule `crawl_task` on start-up.""" - self.bot = bot - - self.event_lock = asyncio.Lock() - self.crawl_task = self.bot.loop.create_task(self.crawl_incidents()) - - async def crawl_incidents(self) -> None: - """ - Crawl #incidents and add missing emoji where necessary. - - This is to catch-up should an incident be reported while the bot wasn't listening. - After adding each reaction, we take a short break to avoid drowning in ratelimits. - - Once this task is scheduled, listeners that change messages should await it. - The crawl assumes that the channel history doesn't change as we go over it. - - Behaviour is configured by: `CRAWL_LIMIT`, `CRAWL_SLEEP`. - """ - await self.bot.wait_until_guild_available() - incidents: discord.TextChannel = self.bot.get_channel(Channels.incidents) - - log.debug(f"Crawling messages in #incidents: {CRAWL_LIMIT=}, {CRAWL_SLEEP=}") - async for message in incidents.history(limit=CRAWL_LIMIT): - - if not is_incident(message): - log.trace(f"Skipping message {message.id}: not an incident") - continue - - if has_signals(message): - log.trace(f"Skipping message {message.id}: already has all signals") - continue - - await add_signals(message) - await asyncio.sleep(CRAWL_SLEEP) - - log.debug("Crawl task finished!") - - async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: - """ - Relay an embed representation of `incident` to the #incidents-archive channel. - - The following pieces of information are relayed: - * Incident message content (as embed description) - * Incident attachment (if image, shown in archive embed) - * Incident author name (as webhook author) - * Incident author avatar (as webhook avatar) - * Resolution signal `outcome` (as embed colour & footer) - * Moderator `actioned_by` (name & discriminator shown in footer) - - If `incident` contains an attachment, we try to add it to the archive embed. There is - no handing of extensions / file types - we simply dispatch the attachment file with the - webhook, and try to display it in the embed. Testing indicates that if the attachment - cannot be displayed (e.g. a text file), it's invisible in the embed, with no error. - - Return True if the relay finishes successfully. If anything goes wrong, meaning - not all information was relayed, return False. This signals that the original - message is not safe to be deleted, as we will lose some information. - """ - log.debug(f"Archiving incident: {incident.id} (outcome: {outcome}, actioned by: {actioned_by})") - embed, attachment_file = await make_embed(incident, outcome, actioned_by) - - try: - webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive) - await webhook.send( - embed=embed, - username=sub_clyde(incident.author.name), - avatar_url=incident.author.avatar_url, - file=attachment_file, - ) - except Exception: - log.exception(f"Failed to archive incident {incident.id} to #incidents-archive") - return False - else: - log.trace("Message archived successfully!") - return True - - def make_confirmation_task(self, incident: discord.Message, timeout: int = 5) -> asyncio.Task: - """ - Create a task to wait `timeout` seconds for `incident` to be deleted. - - If `timeout` passes, this will raise `asyncio.TimeoutError`, signaling that we haven't - been able to confirm that the message was deleted. - """ - log.trace(f"Confirmation task will wait {timeout=} seconds for {incident.id=} to be deleted") - - def check(payload: discord.RawReactionActionEvent) -> bool: - return payload.message_id == incident.id - - coroutine = self.bot.wait_for(event="raw_message_delete", check=check, timeout=timeout) - return self.bot.loop.create_task(coroutine) - - async def process_event(self, reaction: str, incident: discord.Message, member: discord.Member) -> None: - """ - Process a `reaction_add` event in #incidents. - - First, we check that the reaction is a recognized `Signal` member, and that it was sent by - a permitted user (at least one role in `ALLOWED_ROLES`). If not, the reaction is removed. - - If the reaction was either `Signal.ACTIONED` or `Signal.NOT_ACTIONED`, we attempt to relay - the report to #incidents-archive. If successful, the original message is deleted. - - We do not release `event_lock` until we receive the corresponding `message_delete` event. - This ensures that if there is a racing event awaiting the lock, it will fail to find the - message, and will abort. There is a `timeout` to ensure that this doesn't hold the lock - forever should something go wrong. - """ - members_roles: t.Set[int] = {role.id for role in member.roles} - if not members_roles & ALLOWED_ROLES: # Intersection is truthy on at least 1 common element - log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals") - await incident.remove_reaction(reaction, member) - return - - try: - signal = Signal(reaction) - except ValueError: - log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") - await incident.remove_reaction(reaction, member) - return - - log.trace(f"Received signal: {signal}") - - if signal not in (Signal.ACTIONED, Signal.NOT_ACTIONED): - log.debug("Reaction was valid, but no action is currently defined for it") - return - - relay_successful = await self.archive(incident, signal, actioned_by=member) - if not relay_successful: - log.trace("Original message will not be deleted as we failed to relay it to the archive") - return - - timeout = 5 # Seconds - confirmation_task = self.make_confirmation_task(incident, timeout) - - log.trace("Deleting original message") - await incident.delete() - - log.trace(f"Awaiting deletion confirmation: {timeout=} seconds") - try: - await confirmation_task - except asyncio.TimeoutError: - log.warning(f"Did not receive incident deletion confirmation within {timeout} seconds!") - else: - log.trace("Deletion was confirmed") - - async def resolve_message(self, message_id: int) -> t.Optional[discord.Message]: - """ - Get `discord.Message` for `message_id` from cache, or API. - - We first look into the local cache to see if the message is present. - - If not, we try to fetch the message from the API. This is necessary for messages - which were sent before the bot's current session. - - In an edge-case, it is also possible that the message was already deleted, and - the API will respond with a 404. In such a case, None will be returned. - This signals that the event for `message_id` should be ignored. - """ - await self.bot.wait_until_guild_available() # First make sure that the cache is ready - log.trace(f"Resolving message for: {message_id=}") - message: t.Optional[discord.Message] = self.bot._connection._get_message(message_id) - - if message is not None: - log.trace("Message was found in cache") - return message - - log.trace("Message not found, attempting to fetch") - try: - message = await self.bot.get_channel(Channels.incidents).fetch_message(message_id) - except discord.NotFound: - log.trace("Message doesn't exist, it was likely already relayed") - except Exception: - log.exception(f"Failed to fetch message {message_id}!") - else: - log.trace("Message fetched successfully!") - return message - - @Cog.listener() - async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None: - """ - Pre-process `payload` and pass it to `process_event` if appropriate. - - We abort instantly if `payload` doesn't relate to a message sent in #incidents, - or if it was sent by a bot. - - If `payload` relates to a message in #incidents, we first ensure that `crawl_task` has - finished, to make sure we don't mutate channel state as we're crawling it. - - Next, we acquire `event_lock` - to prevent racing, events are processed one at a time. - - Once we have the lock, the `discord.Message` object for this event must be resolved. - If the lock was previously held by an event which successfully relayed the incident, - this will fail and we abort the current event. - - Finally, with both the lock and the `discord.Message` instance in our hands, we delegate - to `process_event` to handle the event. - - The justification for using a raw listener is the need to receive events for messages - which were not cached in the current session. As a result, a certain amount of - complexity is introduced, but at the moment this doesn't appear to be avoidable. - """ - if payload.channel_id != Channels.incidents or payload.member.bot: - return - - log.trace(f"Received reaction add event in #incidents, waiting for crawler: {self.crawl_task.done()=}") - await self.crawl_task - - log.trace(f"Acquiring event lock: {self.event_lock.locked()=}") - async with self.event_lock: - message = await self.resolve_message(payload.message_id) - - if message is None: - log.debug("Listener will abort as related message does not exist!") - return - - if not is_incident(message): - log.debug("Ignoring event for a non-incident message") - return - - await self.process_event(str(payload.emoji), message, payload.member) - log.trace("Releasing event lock") - - @Cog.listener() - async def on_message(self, message: discord.Message) -> None: - """Pass `message` to `add_signals` if and only if it satisfies `is_incident`.""" - if is_incident(message): - await add_signals(message) - - -def setup(bot: Bot) -> None: - """Load the Incidents cog.""" - bot.add_cog(Incidents(bot)) diff --git a/bot/cogs/moderation/infraction/__init__.py b/bot/cogs/moderation/infraction/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/moderation/infraction/_scheduler.py b/bot/cogs/moderation/infraction/_scheduler.py deleted file mode 100644 index 33944a8db..000000000 --- a/bot/cogs/moderation/infraction/_scheduler.py +++ /dev/null @@ -1,463 +0,0 @@ -import logging -import textwrap -import typing as t -from abc import abstractmethod -from datetime import datetime -from gettext import ngettext - -import dateutil.parser -import discord -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Colours, STAFF_CHANNELS -from bot.utils import time -from bot.utils.scheduling import Scheduler -from . import _utils -from ._utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class InfractionScheduler: - """Handles the application, pardoning, and expiration of infractions.""" - - def __init__(self, bot: Bot, supported_infractions: t.Container[str]): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - @property - def mod_log(self) -> ModLog: - """Get the currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: - """Schedule expiration for previous infractions.""" - await self.bot.wait_until_guild_available() - - log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") - - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={'active': 'true'} - ) - for infraction in infractions: - if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_expiration(infraction) - - async def reapply_infraction( - self, - infraction: _utils.Infraction, - apply_coro: t.Optional[t.Awaitable] - ) -> None: - """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" - # Calculate the time remaining, in seconds, for the mute. - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - delta = (expiry - datetime.utcnow()).total_seconds() - - # Mark as inactive if less than a minute remains. - if delta < 60: - log.info( - "Infraction will be deactivated instead of re-applied " - "because less than 1 minute remains." - ) - await self.deactivate_infraction(infraction) - return - - # Allowing mod log since this is a passive action that should be logged. - await apply_coro - log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") - - async def apply_infraction( - self, - ctx: Context, - infraction: _utils.Infraction, - user: UserSnowflake, - action_coro: t.Optional[t.Awaitable] = None - ) -> None: - """Apply an infraction to the user, log the infraction, and optionally notify the user.""" - infr_type = infraction["type"] - icon = _utils.INFRACTION_ICONS[infr_type][0] - reason = infraction["reason"] - expiry = time.format_infraction_with_duration(infraction["expires_at"]) - id_ = infraction['id'] - - log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") - - # Default values for the confirmation message and mod log. - confirm_msg = ":ok_hand: applied" - - # Specifying an expiry for a note or warning makes no sense. - if infr_type in ("note", "warning"): - expiry_msg = "" - else: - expiry_msg = f" until {expiry}" if expiry else " permanently" - - dm_result = "" - dm_log_text = "" - expiry_log_text = f"\nExpires: {expiry}" if expiry else "" - log_title = "applied" - log_content = None - failed = False - - # DM the user about the infraction if it's not a shadow/hidden infraction. - # This needs to happen before we apply the infraction, as the bot cannot - # send DMs to user that it doesn't share a guild with. If we were to - # apply kick/ban infractions first, this would mean that we'd make it - # impossible for us to deliver a DM. See python-discord/bot#982. - if not infraction["hidden"]: - dm_result = f"{constants.Emojis.failmail} " - dm_log_text = "\nDM: **Failed**" - - # Sometimes user is a discord.Object; make it a proper user. - try: - if not isinstance(user, (discord.Member, discord.User)): - user = await self.bot.fetch_user(user.id) - except discord.HTTPException as e: - log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") - else: - # Accordingly display whether the user was successfully notified via DM. - if await _utils.notify_infraction(user, infr_type, expiry, reason, icon): - dm_result = ":incoming_envelope: " - dm_log_text = "\nDM: Sent" - - end_msg = "" - if infraction["actor"] == self.bot.user.id: - log.trace( - f"Infraction #{id_} actor is bot; including the reason in the confirmation message." - ) - if reason: - end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" - elif ctx.channel.id not in STAFF_CHANNELS: - log.trace( - f"Infraction #{id_} context is not in a staff channel; omitting infraction count." - ) - else: - log.trace(f"Fetching total infraction count for {user}.") - - infractions = await self.bot.api_client.get( - "bot/infractions", - params={"user__id": str(user.id)} - ) - total = len(infractions) - end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" - - # Execute the necessary actions to apply the infraction on Discord. - if action_coro: - log.trace(f"Awaiting the infraction #{id_} application action coroutine.") - try: - await action_coro - if expiry: - # Schedule the expiration of the infraction. - self.schedule_expiration(infraction) - except discord.HTTPException as e: - # Accordingly display that applying the infraction failed. - confirm_msg = ":x: failed to apply" - expiry_msg = "" - log_content = ctx.author.mention - log_title = "failed to apply" - - log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" - if isinstance(e, discord.Forbidden): - log.warning(f"{log_msg}: bot lacks permissions.") - else: - log.exception(log_msg) - failed = True - - if failed: - log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") - try: - await self.bot.api_client.delete(f"bot/infractions/{id_}") - except ResponseCodeError as e: - confirm_msg += " and failed to delete" - log_title += " and failed to delete" - log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") - infr_message = "" - else: - infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" - - # Send a confirmation message to the invoking context. - log.trace(f"Sending infraction #{id_} confirmation message.") - await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") - - # Send a log message to the mod log. - log.trace(f"Sending apply mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=icon, - colour=Colours.soft_red, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {user.mention} (`{user.id}`) - Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} - Reason: {reason} - """), - content=log_content, - footer=f"ID {infraction['id']}" - ) - - log.info(f"Applied {infr_type} infraction #{id_} to {user}.") - - async def pardon_infraction( - self, - ctx: Context, - infr_type: str, - user: UserSnowflake, - send_msg: bool = True - ) -> None: - """ - Prematurely end an infraction for a user and log the action in the mod log. - - If `send_msg` is True, then a pardoning confirmation message will be sent to - the context channel. Otherwise, no such message will be sent. - """ - log.trace(f"Pardoning {infr_type} infraction for {user}.") - - # Check the current active infraction - log.trace(f"Fetching active {infr_type} infractions for {user}.") - response = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': user.id - } - ) - - if not response: - log.debug(f"No active {infr_type} infraction found for {user}.") - await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") - return - - # Deactivate the infraction and cancel its scheduled expiration task. - log_text = await self.deactivate_infraction(response[0], send_log=False) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["Actor"] = str(ctx.message.author) - log_content = None - id_ = response[0]['id'] - footer = f"ID: {id_}" - - # If multiple active infractions were found, mark them as inactive in the database - # and cancel their expiration tasks. - if len(response) > 1: - log.info( - f"Found more than one active {infr_type} infraction for user {user.id}; " - "deactivating the extra active infractions too." - ) - - footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" - - log_note = f"Found multiple **active** {infr_type} infractions in the database." - if "Note" in log_text: - log_text["Note"] = f" {log_note}" - else: - log_text["Note"] = log_note - - # deactivate_infraction() is not called again because: - # 1. Discord cannot store multiple active bans or assign multiples of the same role - # 2. It would send a pardon DM for each active infraction, which is redundant - for infraction in response[1:]: - id_ = infraction['id'] - try: - # Mark infraction as inactive in the database. - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError: - log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") - # This is simpler and cleaner than trying to concatenate all the errors. - log_text["Failure"] = "See bot's logs for details." - - # Cancel pending expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Accordingly display whether the user was successfully notified via DM. - dm_emoji = "" - if log_text.get("DM") == "Sent": - dm_emoji = ":incoming_envelope: " - elif "DM" in log_text: - dm_emoji = f"{constants.Emojis.failmail} " - - # Accordingly display whether the pardon failed. - if "Failure" in log_text: - confirm_msg = ":x: failed to pardon" - log_title = "pardon failed" - log_content = ctx.author.mention - - log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") - else: - confirm_msg = ":ok_hand: pardoned" - log_title = "pardoned" - - log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") - - # Send a confirmation message to the invoking context. - if send_msg: - log.trace(f"Sending infraction #{id_} pardon confirmation message.") - await ctx.send( - f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " - f"{log_text.get('Failure', '')}" - ) - - # Move reason to end of entry to avoid cutting out some keys - log_text["Reason"] = log_text.pop("Reason") - - # Send a log message to the mod log. - await self.mod_log.send_log_message( - icon_url=_utils.INFRACTION_ICONS[infr_type][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=footer, - content=log_content, - ) - - async def deactivate_infraction( - self, - infraction: _utils.Infraction, - send_log: bool = True - ) -> t.Dict[str, str]: - """ - Deactivate an active infraction and return a dictionary of lines to send in a mod log. - - The infraction is removed from Discord, marked as inactive in the database, and has its - expiration task cancelled. If `send_log` is True, a mod log is sent for the - deactivation of the infraction. - - Infractions of unsupported types will raise a ValueError. - """ - guild = self.bot.get_guild(constants.Guild.id) - mod_role = guild.get_role(constants.Roles.moderators) - user_id = infraction["user"] - actor = infraction["actor"] - type_ = infraction["type"] - id_ = infraction["id"] - inserted_at = infraction["inserted_at"] - expiry = infraction["expires_at"] - - log.info(f"Marking infraction #{id_} as inactive (expired).") - - expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None - created = time.format_infraction_with_duration(inserted_at, expiry) - - log_content = None - log_text = { - "Member": f"<@{user_id}>", - "Actor": str(self.bot.get_user(actor) or actor), - "Reason": infraction["reason"], - "Created": created, - } - - try: - log.trace("Awaiting the pardon action coroutine.") - returned_log = await self._pardon_action(infraction) - - if returned_log is not None: - log_text = {**log_text, **returned_log} # Merge the logs together - else: - raise ValueError( - f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" - ) - except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") - log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" - log_content = mod_role.mention - except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." - log_content = mod_role.mention - - # Check if the user is currently being watched by Big Brother. - try: - log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") - - active_watch = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "watch", - "user__id": user_id - } - ) - - log_text["Watching"] = "Yes" if active_watch else "No" - except ResponseCodeError: - log.exception(f"Failed to fetch watch status for user {user_id}") - log_text["Watching"] = "Unknown - failed to fetch watch status." - - try: - # Mark infraction as inactive in the database. - log.trace(f"Marking infraction #{id_} as inactive in the database.") - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_line = f"API request failed with code {e.status}." - log_content = mod_role.mention - - # Append to an existing failure message if possible - if "Failure" in log_text: - log_text["Failure"] += f" {log_line}" - else: - log_text["Failure"] = log_line - - # Cancel the expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Send a log message to the mod log. - if send_log: - log_title = "expiration failed" if "Failure" in log_text else "expired" - - user = self.bot.get_user(user_id) - avatar = user.avatar_url_as(static_format="png") if user else None - - # Move reason to end so when reason is too long, this is not gonna cut out required items. - log_text["Reason"] = log_text.pop("Reason") - - log.trace(f"Sending deactivation mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=_utils.INFRACTION_ICONS[type_][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {type_}", - thumbnail=avatar, - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=f"ID: {id_}", - content=log_content, - ) - - return log_text - - @abstractmethod - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - raise NotImplementedError - - def schedule_expiration(self, infraction: _utils.Infraction) -> None: - """ - Marks an infraction expired after the delay from time of scheduling to time of expiration. - - At the time of expiration, the infraction is marked as inactive on the website and the - expiration task is cancelled. - """ - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/infraction/_utils.py b/bot/cogs/moderation/infraction/_utils.py deleted file mode 100644 index fb55287b6..000000000 --- a/bot/cogs/moderation/infraction/_utils.py +++ /dev/null @@ -1,201 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext.commands import Context - -from bot.api import ResponseCodeError -from bot.constants import Colours, Icons - -log = logging.getLogger(__name__) - -# apply icon, pardon icon -INFRACTION_ICONS = { - "ban": (Icons.user_ban, Icons.user_unban), - "kick": (Icons.sign_out, None), - "mute": (Icons.user_mute, Icons.user_unmute), - "note": (Icons.user_warn, None), - "superstar": (Icons.superstarify, Icons.unsuperstarify), - "warning": (Icons.user_warn, None), -} -RULES_URL = "https://pythondiscord.com/pages/rules" -APPEALABLE_INFRACTIONS = ("ban", "mute") - -# Type aliases -UserObject = t.Union[discord.Member, discord.User] -UserSnowflake = t.Union[UserObject, discord.Object] -Infraction = t.Dict[str, t.Union[str, int, bool]] - - -async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: - """ - Create a new user in the database. - - Used when an infraction needs to be applied on a user absent in the guild. - """ - log.trace(f"Attempting to add user {user.id} to the database.") - - if not isinstance(user, (discord.Member, discord.User)): - log.debug("The user being added to the DB is not a Member or User object.") - - payload = { - 'discriminator': int(getattr(user, 'discriminator', 0)), - 'id': user.id, - 'in_guild': False, - 'name': getattr(user, 'name', 'Name unknown'), - 'roles': [] - } - - try: - response = await ctx.bot.api_client.post('bot/users', json=payload) - log.info(f"User {user.id} added to the DB.") - return response - except ResponseCodeError as e: - log.error(f"Failed to add user {user.id} to the DB. {e}") - await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") - - -async def post_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - reason: str, - expires_at: datetime = None, - hidden: bool = False, - active: bool = True -) -> t.Optional[dict]: - """Posts an infraction to the API.""" - log.trace(f"Posting {infr_type} infraction for {user} to the API.") - - payload = { - "actor": ctx.message.author.id, - "hidden": hidden, - "reason": reason, - "type": infr_type, - "user": user.id, - "active": active - } - if expires_at: - payload['expires_at'] = expires_at.isoformat() - - # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. - for should_post_user in (True, False): - try: - response = await ctx.bot.api_client.post('bot/infractions', json=payload) - return response - except ResponseCodeError as e: - if e.status == 400 and 'user' in e.response_json: - # Only one attempt to add the user to the database, not two: - if not should_post_user or await post_user(ctx, user) is None: - return - else: - log.exception(f"Unexpected error while adding an infraction for {user}:") - await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") - return - - -async def get_active_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - send_msg: bool = True -) -> t.Optional[dict]: - """ - Retrieves an active infraction of the given type for the user. - - If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, - then a message for the moderator will be sent to the context channel letting them know. - Otherwise, no message will be sent. - """ - log.trace(f"Checking if {user} has active infractions of type {infr_type}.") - - active_infractions = await ctx.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': str(user.id) - } - ) - if active_infractions: - # Checks to see if the moderator should be told there is an active infraction - if send_msg: - log.trace(f"{user} has active infractions of type {infr_type}.") - await ctx.send( - f":x: According to my records, this user already has a {infr_type} infraction. " - f"See infraction **#{active_infractions[0]['id']}**." - ) - return active_infractions[0] - else: - log.trace(f"{user} does not have active infractions of type {infr_type}.") - - -async def notify_infraction( - user: UserObject, - infr_type: str, - expires_at: t.Optional[str] = None, - reason: t.Optional[str] = None, - icon_url: str = Icons.token_removed -) -> bool: - """DM a user about their new infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their {infr_type} infraction.") - - text = textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """) - - embed = discord.Embed( - description=textwrap.shorten(text, width=2048, placeholder="..."), - colour=Colours.soft_red - ) - - embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) - embed.title = f"Please review our rules over at {RULES_URL}" - embed.url = RULES_URL - - if infr_type in APPEALABLE_INFRACTIONS: - embed.set_footer( - text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" - ) - - return await send_private_embed(user, embed) - - -async def notify_pardon( - user: UserObject, - title: str, - content: str, - icon_url: str = Icons.user_verified -) -> bool: - """DM a user about their pardoned infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their pardoned infraction.") - - embed = discord.Embed( - description=content, - colour=Colours.soft_green - ) - - embed.set_author(name=title, icon_url=icon_url) - - return await send_private_embed(user, embed) - - -async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: - """ - A helper method for sending an embed to a user's DMs. - - Returns a boolean indicator of DM success. - """ - try: - await user.send(embed=embed) - return True - except (discord.HTTPException, discord.Forbidden, discord.NotFound): - log.debug( - f"Infraction-related information could not be sent to user {user} ({user.id}). " - "The user either could not be retrieved or probably disabled their DMs." - ) - return False diff --git a/bot/cogs/moderation/infraction/infractions.py b/bot/cogs/moderation/infraction/infractions.py deleted file mode 100644 index cb459b447..000000000 --- a/bot/cogs/moderation/infraction/infractions.py +++ /dev/null @@ -1,375 +0,0 @@ -import logging -import textwrap -import typing as t - -import discord -from discord import Member -from discord.ext import commands -from discord.ext.commands import Context, command - -from bot import constants -from bot.bot import Bot -from bot.constants import Event -from bot.converters import Expiry, FetchedMember -from bot.decorators import respect_role_hierarchy -from bot.utils.checks import with_role_check -from . import _utils -from ._scheduler import InfractionScheduler -from ._utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class Infractions(InfractionScheduler, commands.Cog): - """Apply and pardon infractions on users for moderation purposes.""" - - category = "Moderation" - category_description = "Server moderation tools." - - def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) - - self.category = "Moderation" - self._muted_role = discord.Object(constants.Roles.muted) - - @commands.Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Reapply active mute infractions for returning members.""" - active_mutes = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "mute", - "user__id": member.id - } - ) - - if active_mutes: - reason = f"Re-applying active mute: {active_mutes[0]['id']}" - action = member.add_roles(self._muted_role, reason=reason) - - await self.reapply_infraction(active_mutes[0], action) - - # region: Permanent infractions - - @command() - async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Warn a user for the given reason.""" - infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False) - if infraction is None: - return - - await self.apply_infraction(ctx, infraction, user) - - @command() - async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Kick a user for the given reason.""" - await self.apply_kick(ctx, user, reason) - - @command() - async def ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Permanently ban a user for the given reason and stop watching them with Big Brother.""" - await self.apply_ban(ctx, user, reason) - - # endregion - # region: Temporary infractions - - @command(aliases=["mute"]) - async def tempmute(self, ctx: Context, user: Member, duration: Expiry, *, reason: t.Optional[str] = None) -> None: - """ - Temporarily mute a user for the given reason and duration. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_mute(ctx, user, reason, expires_at=duration) - - @command() - async def tempban( - self, - ctx: Context, - user: FetchedMember, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily ban a user for the given reason and duration. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_ban(ctx, user, reason, expires_at=duration) - - # endregion - # region: Permanent shadow infractions - - @command(hidden=True) - async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Create a private note for a user with the given reason without notifying the user.""" - infraction = await _utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) - if infraction is None: - return - - await self.apply_infraction(ctx, infraction, user) - - @command(hidden=True, aliases=['shadowkick', 'skick']) - async def shadow_kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Kick a user for the given reason without notifying the user.""" - await self.apply_kick(ctx, user, reason, hidden=True) - - @command(hidden=True, aliases=['shadowban', 'sban']) - async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Permanently ban a user for the given reason without notifying the user.""" - await self.apply_ban(ctx, user, reason, hidden=True) - - # endregion - # region: Temporary shadow infractions - - @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) - async def shadow_tempmute( - self, ctx: Context, - user: Member, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily mute a user for the given reason and duration without notifying the user. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True) - - @command(hidden=True, aliases=["shadowtempban, stempban"]) - async def shadow_tempban( - self, - ctx: Context, - user: FetchedMember, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily ban a user for the given reason and duration without notifying the user. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) - - # endregion - # region: Remove infractions (un- commands) - - @command() - async def unmute(self, ctx: Context, user: FetchedMember) -> None: - """Prematurely end the active mute infraction for the user.""" - await self.pardon_infraction(ctx, "mute", user) - - @command() - async def unban(self, ctx: Context, user: FetchedMember) -> None: - """Prematurely end the active ban infraction for the user.""" - await self.pardon_infraction(ctx, "ban", user) - - # endregion - # region: Base apply functions - - async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: - """Apply a mute infraction with kwargs passed to `post_infraction`.""" - if await _utils.get_active_infraction(ctx, user, "mute"): - return - - infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_update, user.id) - - async def action() -> None: - await user.add_roles(self._muted_role, reason=reason) - - log.trace(f"Attempting to kick {user} from voice because they've been muted.") - await user.move_to(None, reason=reason) - - await self.apply_infraction(ctx, infraction, user, action()) - - @respect_role_hierarchy() - async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: - """Apply a kick infraction with kwargs passed to `post_infraction`.""" - infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_remove, user.id) - - if reason: - reason = textwrap.shorten(reason, width=512, placeholder="...") - - action = user.kick(reason=reason) - await self.apply_infraction(ctx, infraction, user, action) - - @respect_role_hierarchy() - async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: - """ - Apply a ban infraction with kwargs passed to `post_infraction`. - - Will also remove the banned user from the Big Brother watch list if applicable. - """ - # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active - is_temporary = kwargs.get("expires_at") is not None - active_infraction = await _utils.get_active_infraction(ctx, user, "ban", is_temporary) - - if active_infraction: - if is_temporary: - log.trace("Tempban ignored as it cannot overwrite an active ban.") - return - - if active_infraction.get('expires_at') is None: - log.trace("Permaban already exists, notify.") - await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).") - return - - log.trace("Old tempban is being replaced by new permaban.") - await self.pardon_infraction(ctx, "ban", user, is_temporary) - - infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_remove, user.id) - - if reason: - reason = textwrap.shorten(reason, width=512, placeholder="...") - - action = ctx.guild.ban(user, reason=reason, delete_message_days=0) - await self.apply_infraction(ctx, infraction, user, action) - - if infraction.get('expires_at') is not None: - log.trace(f"Ban isn't permanent; user {user} won't be unwatched by Big Brother.") - return - - bb_cog = self.bot.get_cog("Big Brother") - if not bb_cog: - log.error(f"Big Brother cog not loaded; perma-banned user {user} won't be unwatched.") - return - - log.trace(f"Big Brother cog loaded; attempting to unwatch perma-banned user {user}.") - - bb_reason = "User has been permanently banned from the server. Automatically removed." - await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) - - # endregion - # region: Base pardon functions - - async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's muted role, DM them a notification, and return a log dict.""" - user = guild.get_member(user_id) - log_text = {} - - if user: - # Remove the muted role. - self.mod_log.ignore(Event.member_update, user.id) - await user.remove_roles(self._muted_role, reason=reason) - - # DM the user about the expiration. - notified = await _utils.notify_pardon( - user=user, - title="You have been unmuted", - content="You may now send messages in the server.", - icon_url=_utils.INFRACTION_ICONS["mute"][1] - ) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["DM"] = "Sent" if notified else "**Failed**" - else: - log.info(f"Failed to unmute user {user_id}: user not found") - log_text["Failure"] = "User was not found in the guild." - - return log_text - - async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's ban on the Discord guild and return a log dict.""" - user = discord.Object(user_id) - log_text = {} - - self.mod_log.ignore(Event.member_unban, user_id) - - try: - await guild.unban(user, reason=reason) - except discord.NotFound: - log.info(f"Failed to unban user {user_id}: no active ban found on Discord") - log_text["Note"] = "No active ban found on Discord." - - return log_text - - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - guild = self.bot.get_guild(constants.Guild.id) - user_id = infraction["user"] - reason = f"Infraction #{infraction['id']} expired or was pardoned." - - if infraction["type"] == "mute": - return await self.pardon_mute(user_id, guild, reason) - elif infraction["type"] == "ban": - return await self.pardon_ban(user_id, guild, reason) - - # endregion - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Send a notification to the invoking context on a Union failure.""" - if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters or discord.Member in error.converters: - await ctx.send(str(error.errors[0])) - error.handled = True - - -def setup(bot: Bot) -> None: - """Load the Infractions cog.""" - bot.add_cog(Infractions(bot)) diff --git a/bot/cogs/moderation/infraction/management.py b/bot/cogs/moderation/infraction/management.py deleted file mode 100644 index 9e7ae8113..000000000 --- a/bot/cogs/moderation/infraction/management.py +++ /dev/null @@ -1,310 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext import commands -from discord.ext.commands import Context - -from bot import constants -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user -from bot.pagination import LinePaginator -from bot.utils import time -from bot.utils.checks import in_whitelist_check, with_role_check -from . import _utils -from .infractions import Infractions - -log = logging.getLogger(__name__) - - -class ModManagement(commands.Cog): - """Management of infractions.""" - - category = "Moderation" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @property - def infractions_cog(self) -> Infractions: - """Get currently loaded Infractions cog instance.""" - return self.bot.get_cog("Infractions") - - # region: Edit infraction commands - - @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) - async def infraction_group(self, ctx: Context) -> None: - """Infraction manipulation commands.""" - await ctx.send_help(ctx.command) - - @infraction_group.command(name='edit') - async def infraction_edit( - self, - ctx: Context, - infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821 - duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821 - *, - reason: str = None - ) -> None: - """ - Edit the duration and/or the reason of an infraction. - - Durations are relative to the time of updating and should be appended with a unit of time. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction - authored by the command invoker should be edited. - - Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 - timestamp can be provided for the duration. - """ - if duration is None and reason is None: - # Unlike UserInputError, the error handler will show a specified message for BadArgument - raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") - - # Retrieve the previous infraction for its information. - if isinstance(infraction_id, str): - params = { - "actor__id": ctx.author.id, - "ordering": "-inserted_at" - } - infractions = await self.bot.api_client.get("bot/infractions", params=params) - - if infractions: - old_infraction = infractions[0] - infraction_id = old_infraction["id"] - else: - await ctx.send( - ":x: Couldn't find most recent infraction; you have never given an infraction." - ) - return - else: - old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") - - request_data = {} - confirm_messages = [] - log_text = "" - - if duration is not None and not old_infraction['active']: - if reason is None: - await ctx.send(":x: Cannot edit the expiration of an expired infraction.") - return - confirm_messages.append("expiry unchanged (infraction already expired)") - elif isinstance(duration, str): - request_data['expires_at'] = None - confirm_messages.append("marked as permanent") - elif duration is not None: - request_data['expires_at'] = duration.isoformat() - expiry = time.format_infraction_with_duration(request_data['expires_at']) - confirm_messages.append(f"set to expire on {expiry}") - else: - confirm_messages.append("expiry unchanged") - - if reason: - request_data['reason'] = reason - confirm_messages.append("set a new reason") - log_text += f""" - Previous reason: {old_infraction['reason']} - New reason: {reason} - """.rstrip() - else: - confirm_messages.append("reason unchanged") - - # Update the infraction - new_infraction = await self.bot.api_client.patch( - f'bot/infractions/{infraction_id}', - json=request_data, - ) - - # Re-schedule infraction if the expiration has been updated - if 'expires_at' in request_data: - # A scheduled task should only exist if the old infraction wasn't permanent - if old_infraction['expires_at']: - self.infractions_cog.scheduler.cancel(new_infraction['id']) - - # If the infraction was not marked as permanent, schedule a new expiration task - if request_data['expires_at']: - self.infractions_cog.schedule_expiration(new_infraction) - - log_text += f""" - Previous expiry: {old_infraction['expires_at'] or "Permanent"} - New expiry: {new_infraction['expires_at'] or "Permanent"} - """.rstrip() - - changes = ' & '.join(confirm_messages) - await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") - - # Get information about the infraction's user - user_id = new_infraction['user'] - user = ctx.guild.get_member(user_id) - - if user: - user_text = f"{user.mention} (`{user.id}`)" - thumbnail = user.avatar_url_as(static_format="png") - else: - user_text = f"`{user_id}`" - thumbnail = None - - # The infraction's actor - actor_id = new_infraction['actor'] - actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`" - - await self.mod_log.send_log_message( - icon_url=constants.Icons.pencil, - colour=discord.Colour.blurple(), - title="Infraction edited", - thumbnail=thumbnail, - text=textwrap.dedent(f""" - Member: {user_text} - Actor: {actor} - Edited by: {ctx.message.author}{log_text} - """) - ) - - # endregion - # region: Search infractions - - @infraction_group.group(name="search", invoke_without_command=True) - async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None: - """Searches for infractions in the database.""" - if isinstance(query, discord.User): - await ctx.invoke(self.search_user, query) - else: - await ctx.invoke(self.search_reason, query) - - @infraction_search_group.command(name="user", aliases=("member", "id")) - async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: - """Search for infractions by member.""" - infraction_list = await self.bot.api_client.get( - 'bot/infractions', - params={'user__id': str(user.id)} - ) - embed = discord.Embed( - title=f"Infractions for {user} ({len(infraction_list)} total)", - colour=discord.Colour.orange() - ) - await self.send_infraction_list(ctx, embed, infraction_list) - - @infraction_search_group.command(name="reason", aliases=("match", "regex", "re")) - async def search_reason(self, ctx: Context, reason: str) -> None: - """Search for infractions by their reason. Use Re2 for matching.""" - infraction_list = await self.bot.api_client.get( - 'bot/infractions', - params={'search': reason} - ) - embed = discord.Embed( - title=f"Infractions matching `{reason}` ({len(infraction_list)} total)", - colour=discord.Colour.orange() - ) - await self.send_infraction_list(ctx, embed, infraction_list) - - # endregion - # region: Utility functions - - async def send_infraction_list( - self, - ctx: Context, - embed: discord.Embed, - infractions: t.Iterable[_utils.Infraction] - ) -> None: - """Send a paginated embed of infractions for the specified user.""" - if not infractions: - await ctx.send(":warning: No infractions could be found for that query.") - return - - lines = tuple( - self.infraction_to_string(infraction) - for infraction in infractions - ) - - await LinePaginator.paginate( - lines, - ctx=ctx, - embed=embed, - empty=True, - max_lines=3, - max_size=1000 - ) - - def infraction_to_string(self, infraction: _utils.Infraction) -> str: - """Convert the infraction object to a string representation.""" - actor_id = infraction["actor"] - guild = self.bot.get_guild(constants.Guild.id) - actor = guild.get_member(actor_id) - active = infraction["active"] - user_id = infraction["user"] - hidden = infraction["hidden"] - created = time.format_infraction(infraction["inserted_at"]) - - if active: - remaining = time.until_expiration(infraction["expires_at"]) or "Expired" - else: - remaining = "Inactive" - - if infraction["expires_at"] is None: - expires = "*Permanent*" - else: - date_from = datetime.strptime(created, time.INFRACTION_FORMAT) - expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) - - lines = textwrap.dedent(f""" - {"**===============**" if active else "==============="} - Status: {"__**Active**__" if active else "Inactive"} - User: {self.bot.get_user(user_id)} (`{user_id}`) - Type: **{infraction["type"]}** - Shadow: {hidden} - Created: {created} - Expires: {expires} - Remaining: {remaining} - Actor: {actor.mention if actor else actor_id} - ID: `{infraction["id"]}` - Reason: {infraction["reason"] or "*None*"} - {"**===============**" if active else "==============="} - """) - - return lines.strip() - - # endregion - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators inside moderator channels to invoke the commands in this cog.""" - checks = [ - with_role_check(ctx, *constants.MODERATION_ROLES), - in_whitelist_check( - ctx, - channels=constants.MODERATION_CHANNELS, - categories=[constants.Categories.modmail], - redirect=None, - fail_silently=True, - ) - ] - return all(checks) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Send a notification to the invoking context on a Union failure.""" - if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters: - await ctx.send(str(error.errors[0])) - error.handled = True - - -def setup(bot: Bot) -> None: - """Load the ModManagement cog.""" - bot.add_cog(ModManagement(bot)) diff --git a/bot/cogs/moderation/infraction/superstarify.py b/bot/cogs/moderation/infraction/superstarify.py deleted file mode 100644 index 7dc5b4691..000000000 --- a/bot/cogs/moderation/infraction/superstarify.py +++ /dev/null @@ -1,244 +0,0 @@ -import json -import logging -import random -import textwrap -import typing as t -from pathlib import Path - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, command - -from bot import constants -from bot.bot import Bot -from bot.converters import Expiry -from bot.utils.checks import with_role_check -from bot.utils.time import format_infraction -from . import _utils -from ._scheduler import InfractionScheduler - -log = logging.getLogger(__name__) -NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" - -with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file: - STAR_NAMES = json.load(stars_file) - - -class Superstarify(InfractionScheduler, Cog): - """A set of commands to moderate terrible nicknames.""" - - def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"superstar"}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Revert nickname edits if the user has an active superstarify infraction.""" - if before.display_name == after.display_name: - return # User didn't change their nickname. Abort! - - log.trace( - f"{before} ({before.display_name}) is trying to change their nickname to " - f"{after.display_name}. Checking if the user is in superstar-prison..." - ) - - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "superstar", - "user__id": str(before.id) - } - ) - - if not active_superstarifies: - log.trace(f"{before} has no active superstar infractions.") - return - - infraction = active_superstarifies[0] - forced_nick = self.get_nick(infraction["id"], before.id) - if after.display_name == forced_nick: - return # Nick change was triggered by this event. Ignore. - - log.info( - f"{after.display_name} ({after.id}) tried to escape superstar prison. " - f"Changing the nick back to {before.display_name}." - ) - await after.edit( - nick=forced_nick, - reason=f"Superstarified member tried to escape the prison: {infraction['id']}" - ) - - notified = await _utils.notify_infraction( - user=after, - infr_type="Superstarify", - expires_at=format_infraction(infraction["expires_at"]), - reason=( - "You have tried to change your nickname on the **Python Discord** server " - f"from **{before.display_name}** to **{after.display_name}**, but as you " - "are currently in superstar-prison, you do not have permission to do so." - ), - icon_url=_utils.INFRACTION_ICONS["superstar"][0] - ) - - if not notified: - log.info("Failed to DM user about why they cannot change their nickname.") - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Reapply active superstar infractions for returning members.""" - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "superstar", - "user__id": member.id - } - ) - - if active_superstarifies: - infraction = active_superstarifies[0] - action = member.edit( - nick=self.get_nick(infraction["id"], member.id), - reason=f"Superstarified member tried to escape the prison: {infraction['id']}" - ) - - await self.reapply_infraction(infraction, action) - - @command(name="superstarify", aliases=("force_nick", "star")) - async def superstarify( - self, - ctx: Context, - member: Member, - duration: Expiry, - *, - reason: str = None, - ) -> None: - """ - Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - - An optional reason can be provided. If no reason is given, the original name will be shown - in a generated reason. - """ - if await _utils.get_active_infraction(ctx, member, "superstar"): - return - - # Post the infraction to the API - reason = reason or f"old nick: {member.display_name}" - infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) - id_ = infraction["id"] - - old_nick = member.display_name - forced_nick = self.get_nick(id_, member.id) - expiry_str = format_infraction(infraction["expires_at"]) - - # Apply the infraction and schedule the expiration task. - log.debug(f"Changing nickname of {member} to {forced_nick}.") - self.mod_log.ignore(constants.Event.member_update, member.id) - await member.edit(nick=forced_nick, reason=reason) - self.schedule_expiration(infraction) - - # Send a DM to the user to notify them of their new infraction. - await _utils.notify_infraction( - user=member, - infr_type="Superstarify", - expires_at=expiry_str, - icon_url=_utils.INFRACTION_ICONS["superstar"][0], - reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." - ) - - # Send an embed with the infraction information to the invoking context. - log.trace(f"Sending superstar #{id_} embed.") - embed = Embed( - title="Congratulations!", - colour=constants.Colours.soft_orange, - description=( - f"Your previous nickname, **{old_nick}**, " - f"was so bad that we have decided to change it. " - f"Your new nickname will be **{forced_nick}**.\n\n" - f"You will be unable to change your nickname until **{expiry_str}**.\n\n" - "If you're confused by this, please read our " - f"[official nickname policy]({NICKNAME_POLICY_URL})." - ) - ) - await ctx.send(embed=embed) - - # Log to the mod log channel. - log.trace(f"Sending apply mod log for superstar #{id_}.") - await self.mod_log.send_log_message( - icon_url=_utils.INFRACTION_ICONS["superstar"][0], - colour=Colour.gold(), - title="Member achieved superstardom", - thumbnail=member.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {member.mention} (`{member.id}`) - Actor: {ctx.message.author} - Expires: {expiry_str} - Old nickname: `{old_nick}` - New nickname: `{forced_nick}` - Reason: {reason} - """), - footer=f"ID {id_}" - ) - - @command(name="unsuperstarify", aliases=("release_nick", "unstar")) - async def unsuperstarify(self, ctx: Context, member: Member) -> None: - """Remove the superstarify infraction and allow the user to change their nickname.""" - await self.pardon_infraction(ctx, "superstar", member) - - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """Pardon a superstar infraction and return a log dict.""" - if infraction["type"] != "superstar": - return - - guild = self.bot.get_guild(constants.Guild.id) - user = guild.get_member(infraction["user"]) - - # Don't bother sending a notification if the user left the guild. - if not user: - log.debug( - "User left the guild and therefore won't be notified about superstar " - f"{infraction['id']} pardon." - ) - return {} - - # DM the user about the expiration. - notified = await _utils.notify_pardon( - user=user, - title="You are no longer superstarified", - content="You may now change your nickname on the server.", - icon_url=_utils.INFRACTION_ICONS["superstar"][1] - ) - - return { - "Member": f"{user.mention}(`{user.id}`)", - "DM": "Sent" if notified else "**Failed**" - } - - @staticmethod - def get_nick(infraction_id: int, member_id: int) -> str: - """Randomly select a nickname from the Superstarify nickname list.""" - log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") - - rng = random.Random(str(infraction_id) + str(member_id)) - return rng.choice(STAR_NAMES) - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the Superstarify cog.""" - bot.add_cog(Superstarify(bot)) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py deleted file mode 100644 index c86f04b9d..000000000 --- a/bot/cogs/moderation/modlog.py +++ /dev/null @@ -1,837 +0,0 @@ -import asyncio -import difflib -import itertools -import logging -import typing as t -from datetime import datetime -from itertools import zip_longest - -import discord -from dateutil.relativedelta import relativedelta -from deepdiff import DeepDiff -from discord import Colour -from discord.abc import GuildChannel -from discord.ext.commands import Cog, Context -from discord.utils import escape_markdown - -from bot.bot import Bot -from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -GUILD_CHANNEL = t.Union[discord.CategoryChannel, discord.TextChannel, discord.VoiceChannel] - -CHANNEL_CHANGES_UNSUPPORTED = ("permissions",) -CHANNEL_CHANGES_SUPPRESSED = ("_overwrites", "position") -ROLE_CHANGES_UNSUPPORTED = ("colour", "permissions") - -VOICE_STATE_ATTRIBUTES = { - "channel.name": "Channel", - "self_stream": "Streaming", - "self_video": "Broadcasting", -} - - -class ModLog(Cog, name="ModLog"): - """Logging for server events and staff actions.""" - - def __init__(self, bot: Bot): - self.bot = bot - self._ignored = {event: [] for event in Event} - - self._cached_deletes = [] - self._cached_edits = [] - - async def upload_log( - self, - messages: t.Iterable[discord.Message], - actor_id: int, - attachments: t.Iterable[t.List[str]] = None - ) -> str: - """Upload message logs to the database and return a URL to a page for viewing the logs.""" - if attachments is None: - attachments = [] - - response = await self.bot.api_client.post( - 'bot/deleted-messages', - json={ - 'actor': actor_id, - 'creation': datetime.utcnow().isoformat(), - 'deletedmessage_set': [ - { - 'id': message.id, - 'author': message.author.id, - 'channel_id': message.channel.id, - 'content': message.content, - 'embeds': [embed.to_dict() for embed in message.embeds], - 'attachments': attachment, - } - for message, attachment in zip_longest(messages, attachments, fillvalue=[]) - ] - } - ) - - return f"{URLs.site_logs_view}/{response['id']}" - - def ignore(self, event: Event, *items: int) -> None: - """Add event to ignored events to suppress log emission.""" - for item in items: - if item not in self._ignored[event]: - self._ignored[event].append(item) - - async def send_log_message( - self, - icon_url: t.Optional[str], - colour: t.Union[discord.Colour, int], - title: t.Optional[str], - text: str, - thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, - channel_id: int = Channels.mod_log, - ping_everyone: bool = False, - files: t.Optional[t.List[discord.File]] = None, - content: t.Optional[str] = None, - additional_embeds: t.Optional[t.List[discord.Embed]] = None, - additional_embeds_msg: t.Optional[str] = None, - timestamp_override: t.Optional[datetime] = None, - footer: t.Optional[str] = None, - ) -> Context: - """Generate log embed and send to logging channel.""" - # Truncate string directly here to avoid removing newlines - embed = discord.Embed( - description=text[:2045] + "..." if len(text) > 2048 else text - ) - - if title and icon_url: - embed.set_author(name=title, icon_url=icon_url) - - embed.colour = colour - embed.timestamp = timestamp_override or datetime.utcnow() - - if footer: - embed.set_footer(text=footer) - - if thumbnail: - embed.set_thumbnail(url=thumbnail) - - if ping_everyone: - if content: - content = f"@everyone\n{content}" - else: - content = "@everyone" - - channel = self.bot.get_channel(channel_id) - log_message = await channel.send( - content=content, - embed=embed, - files=files, - allowed_mentions=discord.AllowedMentions(everyone=True) - ) - - if additional_embeds: - if additional_embeds_msg: - await channel.send(additional_embeds_msg) - for additional_embed in additional_embeds: - await channel.send(embed=additional_embed) - - return await self.bot.get_context(log_message) # Optionally return for use with antispam - - @Cog.listener() - async def on_guild_channel_create(self, channel: GUILD_CHANNEL) -> None: - """Log channel create event to mod log.""" - if channel.guild.id != GuildConstant.id: - return - - if isinstance(channel, discord.CategoryChannel): - title = "Category created" - message = f"{channel.name} (`{channel.id}`)" - elif isinstance(channel, discord.VoiceChannel): - title = "Voice channel created" - - if channel.category: - message = f"{channel.category}/{channel.name} (`{channel.id}`)" - else: - message = f"{channel.name} (`{channel.id}`)" - else: - title = "Text channel created" - - if channel.category: - message = f"{channel.category}/{channel.name} (`{channel.id}`)" - else: - message = f"{channel.name} (`{channel.id}`)" - - await self.send_log_message(Icons.hash_green, Colours.soft_green, title, message) - - @Cog.listener() - async def on_guild_channel_delete(self, channel: GUILD_CHANNEL) -> None: - """Log channel delete event to mod log.""" - if channel.guild.id != GuildConstant.id: - return - - if isinstance(channel, discord.CategoryChannel): - title = "Category deleted" - elif isinstance(channel, discord.VoiceChannel): - title = "Voice channel deleted" - else: - title = "Text channel deleted" - - if channel.category and not isinstance(channel, discord.CategoryChannel): - message = f"{channel.category}/{channel.name} (`{channel.id}`)" - else: - message = f"{channel.name} (`{channel.id}`)" - - await self.send_log_message( - Icons.hash_red, Colours.soft_red, - title, message - ) - - @Cog.listener() - async def on_guild_channel_update(self, before: GUILD_CHANNEL, after: GuildChannel) -> None: - """Log channel update event to mod log.""" - if before.guild.id != GuildConstant.id: - return - - if before.id in self._ignored[Event.guild_channel_update]: - self._ignored[Event.guild_channel_update].remove(before.id) - return - - # Two channel updates are sent for a single edit: 1 for topic and 1 for category change. - # TODO: remove once support is added for ignoring multiple occurrences for the same channel. - help_categories = (Categories.help_available, Categories.help_dormant, Categories.help_in_use) - if after.category and after.category.id in help_categories: - return - - diff = DeepDiff(before, after) - changes = [] - done = [] - - diff_values = diff.get("values_changed", {}) - diff_values.update(diff.get("type_changes", {})) - - for key, value in diff_values.items(): - if not key: # Not sure why, but it happens - continue - - key = key[5:] # Remove "root." prefix - - if "[" in key: - key = key.split("[", 1)[0] - - if "." in key: - key = key.split(".", 1)[0] - - if key in done or key in CHANNEL_CHANGES_SUPPRESSED: - continue - - if key in CHANNEL_CHANGES_UNSUPPORTED: - changes.append(f"**{key.title()}** updated") - else: - new = value["new_value"] - old = value["old_value"] - - # Discord does not treat consecutive backticks ("``") as an empty inline code block, so the markdown - # formatting is broken when `new` and/or `old` are empty values. "None" is used for these cases so - # formatting is preserved. - changes.append(f"**{key.title()}:** `{old or 'None'}` **→** `{new or 'None'}`") - - done.append(key) - - if not changes: - return - - message = "" - - for item in sorted(changes): - message += f"{Emojis.bullet} {item}\n" - - if after.category: - message = f"**{after.category}/#{after.name} (`{after.id}`)**\n{message}" - else: - message = f"**#{after.name}** (`{after.id}`)\n{message}" - - await self.send_log_message( - Icons.hash_blurple, Colour.blurple(), - "Channel updated", message - ) - - @Cog.listener() - async def on_guild_role_create(self, role: discord.Role) -> None: - """Log role create event to mod log.""" - if role.guild.id != GuildConstant.id: - return - - await self.send_log_message( - Icons.crown_green, Colours.soft_green, - "Role created", f"`{role.id}`" - ) - - @Cog.listener() - async def on_guild_role_delete(self, role: discord.Role) -> None: - """Log role delete event to mod log.""" - if role.guild.id != GuildConstant.id: - return - - await self.send_log_message( - Icons.crown_red, Colours.soft_red, - "Role removed", f"{role.name} (`{role.id}`)" - ) - - @Cog.listener() - async def on_guild_role_update(self, before: discord.Role, after: discord.Role) -> None: - """Log role update event to mod log.""" - if before.guild.id != GuildConstant.id: - return - - diff = DeepDiff(before, after) - changes = [] - done = [] - - diff_values = diff.get("values_changed", {}) - diff_values.update(diff.get("type_changes", {})) - - for key, value in diff_values.items(): - if not key: # Not sure why, but it happens - continue - - key = key[5:] # Remove "root." prefix - - if "[" in key: - key = key.split("[", 1)[0] - - if "." in key: - key = key.split(".", 1)[0] - - if key in done or key == "color": - continue - - if key in ROLE_CHANGES_UNSUPPORTED: - changes.append(f"**{key.title()}** updated") - else: - new = value["new_value"] - old = value["old_value"] - - changes.append(f"**{key.title()}:** `{old}` **→** `{new}`") - - done.append(key) - - if not changes: - return - - message = "" - - for item in sorted(changes): - message += f"{Emojis.bullet} {item}\n" - - message = f"**{after.name}** (`{after.id}`)\n{message}" - - await self.send_log_message( - Icons.crown_blurple, Colour.blurple(), - "Role updated", message - ) - - @Cog.listener() - async def on_guild_update(self, before: discord.Guild, after: discord.Guild) -> None: - """Log guild update event to mod log.""" - if before.id != GuildConstant.id: - return - - diff = DeepDiff(before, after) - changes = [] - done = [] - - diff_values = diff.get("values_changed", {}) - diff_values.update(diff.get("type_changes", {})) - - for key, value in diff_values.items(): - if not key: # Not sure why, but it happens - continue - - key = key[5:] # Remove "root." prefix - - if "[" in key: - key = key.split("[", 1)[0] - - if "." in key: - key = key.split(".", 1)[0] - - if key in done: - continue - - new = value["new_value"] - old = value["old_value"] - - changes.append(f"**{key.title()}:** `{old}` **→** `{new}`") - - done.append(key) - - if not changes: - return - - message = "" - - for item in sorted(changes): - message += f"{Emojis.bullet} {item}\n" - - message = f"**{after.name}** (`{after.id}`)\n{message}" - - await self.send_log_message( - Icons.guild_update, Colour.blurple(), - "Guild updated", message, - thumbnail=after.icon_url_as(format="png") - ) - - @Cog.listener() - async def on_member_ban(self, guild: discord.Guild, member: discord.Member) -> None: - """Log ban event to user log.""" - if guild.id != GuildConstant.id: - return - - if member.id in self._ignored[Event.member_ban]: - self._ignored[Event.member_ban].remove(member.id) - return - - await self.send_log_message( - Icons.user_ban, Colours.soft_red, - "User banned", f"{member} (`{member.id}`)", - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.user_log - ) - - @Cog.listener() - async def on_member_join(self, member: discord.Member) -> None: - """Log member join event to user log.""" - if member.guild.id != GuildConstant.id: - return - - member_str = escape_markdown(str(member)) - message = f"{member_str} (`{member.id}`)" - now = datetime.utcnow() - difference = abs(relativedelta(now, member.created_at)) - - message += "\n\n**Account age:** " + humanize_delta(difference) - - if difference.days < 1 and difference.months < 1 and difference.years < 1: # New user account! - message = f"{Emojis.new} {message}" - - await self.send_log_message( - Icons.sign_in, Colours.soft_green, - "User joined", message, - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.user_log - ) - - @Cog.listener() - async def on_member_remove(self, member: discord.Member) -> None: - """Log member leave event to user log.""" - if member.guild.id != GuildConstant.id: - return - - if member.id in self._ignored[Event.member_remove]: - self._ignored[Event.member_remove].remove(member.id) - return - - member_str = escape_markdown(str(member)) - await self.send_log_message( - Icons.sign_out, Colours.soft_red, - "User left", f"{member_str} (`{member.id}`)", - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.user_log - ) - - @Cog.listener() - async def on_member_unban(self, guild: discord.Guild, member: discord.User) -> None: - """Log member unban event to mod log.""" - if guild.id != GuildConstant.id: - return - - if member.id in self._ignored[Event.member_unban]: - self._ignored[Event.member_unban].remove(member.id) - return - - member_str = escape_markdown(str(member)) - await self.send_log_message( - Icons.user_unban, Colour.blurple(), - "User unbanned", f"{member_str} (`{member.id}`)", - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.mod_log - ) - - @staticmethod - def get_role_diff(before: t.List[discord.Role], after: t.List[discord.Role]) -> t.List[str]: - """Return a list of strings describing the roles added and removed.""" - changes = [] - before_roles = set(before) - after_roles = set(after) - - for role in (before_roles - after_roles): - changes.append(f"**Role removed:** {role.name} (`{role.id}`)") - - for role in (after_roles - before_roles): - changes.append(f"**Role added:** {role.name} (`{role.id}`)") - - return changes - - @Cog.listener() - async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: - """Log member update event to user log.""" - if before.guild.id != GuildConstant.id: - return - - if before.id in self._ignored[Event.member_update]: - self._ignored[Event.member_update].remove(before.id) - return - - changes = self.get_role_diff(before.roles, after.roles) - - # The regex is a simple way to exclude all sequence and mapping types. - diff = DeepDiff(before, after, exclude_regex_paths=r".*\[.*") - - # A type change seems to always take precedent over a value change. Furthermore, it will - # include the value change along with the type change anyway. Therefore, it's OK to - # "overwrite" values_changed; in practice there will never even be anything to overwrite. - diff_values = {**diff.get("values_changed", {}), **diff.get("type_changes", {})} - - for attr, value in diff_values.items(): - if not attr: # Not sure why, but it happens. - continue - - attr = attr[5:] # Remove "root." prefix. - attr = attr.replace("_", " ").replace(".", " ").capitalize() - - new = value.get("new_value") - old = value.get("old_value") - - changes.append(f"**{attr}:** `{old}` **→** `{new}`") - - if not changes: - return - - message = "" - - for item in sorted(changes): - message += f"{Emojis.bullet} {item}\n" - - member_str = escape_markdown(str(after)) - message = f"**{member_str}** (`{after.id}`)\n{message}" - - await self.send_log_message( - icon_url=Icons.user_update, - colour=Colour.blurple(), - title="Member updated", - text=message, - thumbnail=after.avatar_url_as(static_format="png"), - channel_id=Channels.user_log - ) - - @Cog.listener() - async def on_message_delete(self, message: discord.Message) -> None: - """Log message delete event to message change log.""" - channel = message.channel - author = message.author - - # Ignore DMs. - if not message.guild: - return - - if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist: - return - - self._cached_deletes.append(message.id) - - if message.id in self._ignored[Event.message_delete]: - self._ignored[Event.message_delete].remove(message.id) - return - - if author.bot: - return - - author_str = escape_markdown(str(author)) - if channel.category: - response = ( - f"**Author:** {author_str} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - ) - else: - response = ( - f"**Author:** {author_str} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - ) - - if message.attachments: - # Prepend the message metadata with the number of attachments - response = f"**Attachments:** {len(message.attachments)}\n" + response - - # Shorten the message content if necessary - content = message.clean_content - remaining_chars = 2040 - len(response) - - if len(content) > remaining_chars: - botlog_url = await self.upload_log(messages=[message], actor_id=message.author.id) - ending = f"\n\nMessage truncated, [full message here]({botlog_url})." - truncation_point = remaining_chars - len(ending) - content = f"{content[:truncation_point]}...{ending}" - - response += f"{content}" - - await self.send_log_message( - Icons.message_delete, Colours.soft_red, - "Message deleted", - response, - channel_id=Channels.message_log - ) - - @Cog.listener() - async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: - """Log raw message delete event to message change log.""" - if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.modlog_blacklist: - return - - await asyncio.sleep(1) # Wait here in case the normal event was fired - - if event.message_id in self._cached_deletes: - # It was in the cache and the normal event was fired, so we can just ignore it - self._cached_deletes.remove(event.message_id) - return - - if event.message_id in self._ignored[Event.message_delete]: - self._ignored[Event.message_delete].remove(event.message_id) - return - - channel = self.bot.get_channel(event.channel_id) - - if channel.category: - response = ( - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{event.message_id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - else: - response = ( - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{event.message_id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - - await self.send_log_message( - Icons.message_delete, Colours.soft_red, - "Message deleted", - response, - channel_id=Channels.message_log - ) - - @Cog.listener() - async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: - """Log message edit event to message change log.""" - if ( - not msg_before.guild - or msg_before.guild.id != GuildConstant.id - or msg_before.channel.id in GuildConstant.modlog_blacklist - or msg_before.author.bot - ): - return - - self._cached_edits.append(msg_before.id) - - if msg_before.content == msg_after.content: - return - - author = msg_before.author - author_str = escape_markdown(str(author)) - - channel = msg_before.channel - channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" - - # Getting the difference per words and group them by type - add, remove, same - # Note that this is intended grouping without sorting - diff = difflib.ndiff(msg_before.clean_content.split(), msg_after.clean_content.split()) - diff_groups = tuple( - (diff_type, tuple(s[2:] for s in diff_words)) - for diff_type, diff_words in itertools.groupby(diff, key=lambda s: s[0]) - ) - - content_before: t.List[str] = [] - content_after: t.List[str] = [] - - for index, (diff_type, words) in enumerate(diff_groups): - sub = ' '.join(words) - if diff_type == '-': - content_before.append(f"[{sub}](http://o.hi)") - elif diff_type == '+': - content_after.append(f"[{sub}](http://o.hi)") - elif diff_type == ' ': - if len(words) > 2: - sub = ( - f"{words[0] if index > 0 else ''}" - " ... " - f"{words[-1] if index < len(diff_groups) - 1 else ''}" - ) - content_before.append(sub) - content_after.append(sub) - - response = ( - f"**Author:** {author_str} (`{author.id}`)\n" - f"**Channel:** {channel_name} (`{channel.id}`)\n" - f"**Message ID:** `{msg_before.id}`\n" - "\n" - f"**Before**:\n{' '.join(content_before)}\n" - f"**After**:\n{' '.join(content_after)}\n" - "\n" - f"[Jump to message]({msg_after.jump_url})" - ) - - if msg_before.edited_at: - # Message was previously edited, to assist with self-bot detection, use the edited_at - # datetime as the baseline and create a human-readable delta between this edit event - # and the last time the message was edited - timestamp = msg_before.edited_at - delta = humanize_delta(relativedelta(msg_after.edited_at, msg_before.edited_at)) - footer = f"Last edited {delta} ago" - else: - # Message was not previously edited, use the created_at datetime as the baseline, no - # delta calculation needed - timestamp = msg_before.created_at - footer = None - - await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited", response, - channel_id=Channels.message_log, timestamp_override=timestamp, footer=footer - ) - - @Cog.listener() - async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: - """Log raw message edit event to message change log.""" - try: - channel = self.bot.get_channel(int(event.data["channel_id"])) - message = await channel.fetch_message(event.message_id) - except discord.NotFound: # Was deleted before we got the event - return - - if ( - not message.guild - or message.guild.id != GuildConstant.id - or message.channel.id in GuildConstant.modlog_blacklist - or message.author.bot - ): - return - - await asyncio.sleep(1) # Wait here in case the normal event was fired - - if event.message_id in self._cached_edits: - # It was in the cache and the normal event was fired, so we can just ignore it - self._cached_edits.remove(event.message_id) - return - - author = message.author - channel = message.channel - channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" - - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel_name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel_name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - f"{message.clean_content}" - ) - - await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited (Before)", - before_response, channel_id=Channels.message_log - ) - - await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited (After)", - after_response, channel_id=Channels.message_log - ) - - @Cog.listener() - async def on_voice_state_update( - self, - member: discord.Member, - before: discord.VoiceState, - after: discord.VoiceState - ) -> None: - """Log member voice state changes to the voice log channel.""" - if ( - member.guild.id != GuildConstant.id - or (before.channel and before.channel.id in GuildConstant.modlog_blacklist) - ): - return - - if member.id in self._ignored[Event.voice_state_update]: - self._ignored[Event.voice_state_update].remove(member.id) - return - - # Exclude all channel attributes except the name. - diff = DeepDiff( - before, - after, - exclude_paths=("root.session_id", "root.afk"), - exclude_regex_paths=r"root\.channel\.(?!name)", - ) - - # A type change seems to always take precedent over a value change. Furthermore, it will - # include the value change along with the type change anyway. Therefore, it's OK to - # "overwrite" values_changed; in practice there will never even be anything to overwrite. - diff_values = {**diff.get("values_changed", {}), **diff.get("type_changes", {})} - - icon = Icons.voice_state_blue - colour = Colour.blurple() - changes = [] - - for attr, values in diff_values.items(): - if not attr: # Not sure why, but it happens. - continue - - old = values["old_value"] - new = values["new_value"] - - attr = attr[5:] # Remove "root." prefix. - attr = VOICE_STATE_ATTRIBUTES.get(attr, attr.replace("_", " ").capitalize()) - - changes.append(f"**{attr}:** `{old}` **→** `{new}`") - - # Set the embed icon and colour depending on which attribute changed. - if any(name in attr for name in ("Channel", "deaf", "mute")): - if new is None or new is True: - # Left a channel or was muted/deafened. - icon = Icons.voice_state_red - colour = Colours.soft_red - elif old is None or old is True: - # Joined a channel or was unmuted/undeafened. - icon = Icons.voice_state_green - colour = Colours.soft_green - - if not changes: - return - - member_str = escape_markdown(str(member)) - message = "\n".join(f"{Emojis.bullet} {item}" for item in sorted(changes)) - message = f"**{member_str}** (`{member.id}`)\n{message}" - - await self.send_log_message( - icon_url=icon, - colour=colour, - title="Voice state updated", - text=message, - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.voice_log - ) - - -def setup(bot: Bot) -> None: - """Load the ModLog cog.""" - bot.add_cog(ModLog(bot)) diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py deleted file mode 100644 index 4af87c724..000000000 --- a/bot/cogs/moderation/silence.py +++ /dev/null @@ -1,170 +0,0 @@ -import asyncio -import logging -from contextlib import suppress -from typing import Optional - -from discord import TextChannel -from discord.ext import commands, tasks -from discord.ext.commands import Context - -from bot.bot import Bot -from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles -from bot.converters import HushDurationConverter -from bot.utils.checks import with_role_check -from bot.utils.scheduling import Scheduler - -log = logging.getLogger(__name__) - - -class SilenceNotifier(tasks.Loop): - """Loop notifier for posting notices to `alert_channel` containing added channels.""" - - def __init__(self, alert_channel: TextChannel): - super().__init__(self._notifier, seconds=1, minutes=0, hours=0, count=None, reconnect=True, loop=None) - self._silenced_channels = {} - self._alert_channel = alert_channel - - def add_channel(self, channel: TextChannel) -> None: - """Add channel to `_silenced_channels` and start loop if not launched.""" - if not self._silenced_channels: - self.start() - log.info("Starting notifier loop.") - self._silenced_channels[channel] = self._current_loop - - def remove_channel(self, channel: TextChannel) -> None: - """Remove channel from `_silenced_channels` and stop loop if no channels remain.""" - with suppress(KeyError): - del self._silenced_channels[channel] - if not self._silenced_channels: - self.stop() - log.info("Stopping notifier loop.") - - async def _notifier(self) -> None: - """Post notice of `_silenced_channels` with their silenced duration to `_alert_channel` periodically.""" - # Wait for 15 minutes between notices with pause at start of loop. - if self._current_loop and not self._current_loop/60 % 15: - log.debug( - f"Sending notice with channels: " - f"{', '.join(f'#{channel} ({channel.id})' for channel in self._silenced_channels)}." - ) - channels_text = ', '.join( - f"{channel.mention} for {(self._current_loop-start)//60} min" - for channel, start in self._silenced_channels.items() - ) - await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") - - -class Silence(commands.Cog): - """Commands for stopping channel messages for `verified` role in a channel.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - self.muted_channels = set() - - self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars()) - self._get_instance_vars_event = asyncio.Event() - - async def _get_instance_vars(self) -> None: - """Get instance variables after they're available to get from the guild.""" - await self.bot.wait_until_guild_available() - guild = self.bot.get_guild(Guild.id) - self._verified_role = guild.get_role(Roles.verified) - self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) - self._mod_log_channel = self.bot.get_channel(Channels.mod_log) - self.notifier = SilenceNotifier(self._mod_log_channel) - self._get_instance_vars_event.set() - - @commands.command(aliases=("hush",)) - async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None: - """ - Silence the current channel for `duration` minutes or `forever`. - - Duration is capped at 15 minutes, passing forever makes the silence indefinite. - Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start. - """ - await self._get_instance_vars_event.wait() - log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.") - if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration): - await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.") - return - if duration is None: - await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.") - return - - await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") - - self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) - - @commands.command(aliases=("unhush",)) - async def unsilence(self, ctx: Context) -> None: - """ - Unsilence the current channel. - - If the channel was silenced indefinitely, notifications for the channel will stop. - """ - await self._get_instance_vars_event.wait() - log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") - if not await self._unsilence(ctx.channel): - await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.") - else: - await ctx.send(f"{Emojis.check_mark} unsilenced current channel.") - - async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: - """ - Silence `channel` for `self._verified_role`. - - If `persistent` is `True` add `channel` to notifier. - `duration` is only used for logging; if None is passed `persistent` should be True to not log None. - Return `True` if channel permissions were changed, `False` otherwise. - """ - current_overwrite = channel.overwrites_for(self._verified_role) - if current_overwrite.send_messages is False: - log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.") - return False - await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False)) - self.muted_channels.add(channel) - if persistent: - log.info(f"Silenced #{channel} ({channel.id}) indefinitely.") - self.notifier.add_channel(channel) - return True - - log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).") - return True - - async def _unsilence(self, channel: TextChannel) -> bool: - """ - Unsilence `channel`. - - Check if `channel` is silenced through a `PermissionOverwrite`, - if it is unsilence it and remove it from the notifier. - Return `True` if channel permissions were changed, `False` otherwise. - """ - current_overwrite = channel.overwrites_for(self._verified_role) - if current_overwrite.send_messages is False: - await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) - log.info(f"Unsilenced channel #{channel} ({channel.id}).") - self.scheduler.cancel(channel.id) - self.notifier.remove_channel(channel) - self.muted_channels.discard(channel) - return True - log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") - return False - - def cog_unload(self) -> None: - """Send alert with silenced channels and cancel scheduled tasks on unload.""" - self.scheduler.cancel_all() - if self.muted_channels: - channels_string = ''.join(channel.mention for channel in self.muted_channels) - message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}" - asyncio.create_task(self._mod_alerts_channel.send(message)) - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the Silence cog.""" - bot.add_cog(Silence(bot)) diff --git a/bot/cogs/moderation/slowmode.py b/bot/cogs/moderation/slowmode.py deleted file mode 100644 index 1d055afac..000000000 --- a/bot/cogs/moderation/slowmode.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -from datetime import datetime -from typing import Optional - -from dateutil.relativedelta import relativedelta -from discord import TextChannel -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import Emojis, MODERATION_ROLES -from bot.converters import DurationDelta -from bot.decorators import with_role_check -from bot.utils import time - -log = logging.getLogger(__name__) - -SLOWMODE_MAX_DELAY = 21600 # seconds - - -class Slowmode(Cog): - """Commands for getting and setting slowmode delays of text channels.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - @group(name='slowmode', aliases=['sm'], invoke_without_command=True) - async def slowmode_group(self, ctx: Context) -> None: - """Get or set the slowmode delay for the text channel this was invoked in or a given text channel.""" - await ctx.send_help(ctx.command) - - @slowmode_group.command(name='get', aliases=['g']) - async def get_slowmode(self, ctx: Context, channel: Optional[TextChannel]) -> None: - """Get the slowmode delay for a text channel.""" - # Use the channel this command was invoked in if one was not given - if channel is None: - channel = ctx.channel - - delay = relativedelta(seconds=channel.slowmode_delay) - humanized_delay = time.humanize_delta(delay) - - await ctx.send(f'The slowmode delay for {channel.mention} is {humanized_delay}.') - - @slowmode_group.command(name='set', aliases=['s']) - async def set_slowmode(self, ctx: Context, channel: Optional[TextChannel], delay: DurationDelta) -> None: - """Set the slowmode delay for a text channel.""" - # Use the channel this command was invoked in if one was not given - if channel is None: - channel = ctx.channel - - # Convert `dateutil.relativedelta.relativedelta` to `datetime.timedelta` - # Must do this to get the delta in a particular unit of time - utcnow = datetime.utcnow() - slowmode_delay = (utcnow + delay - utcnow).total_seconds() - - humanized_delay = time.humanize_delta(delay) - - # Ensure the delay is within discord's limits - if slowmode_delay <= SLOWMODE_MAX_DELAY: - log.info(f'{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.') - - await channel.edit(slowmode_delay=slowmode_delay) - await ctx.send( - f'{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}.' - ) - - else: - log.info( - f'{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, ' - 'which is not between 0 and 6 hours.' - ) - - await ctx.send( - f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.' - ) - - @slowmode_group.command(name='reset', aliases=['r']) - async def reset_slowmode(self, ctx: Context, channel: Optional[TextChannel]) -> None: - """Reset the slowmode delay for a text channel to 0 seconds.""" - # Use the channel this command was invoked in if one was not given - if channel is None: - channel = ctx.channel - - log.info(f'{ctx.author} reset the slowmode delay for #{channel} to 0 seconds.') - - await channel.edit(slowmode_delay=0) - await ctx.send( - f'{Emojis.check_mark} The slowmode delay for {channel.mention} has been reset to 0 seconds.' - ) - - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the Slowmode cog.""" - bot.add_cog(Slowmode(bot)) diff --git a/bot/cogs/moderation/verification.py b/bot/cogs/moderation/verification.py deleted file mode 100644 index ba95ab5e4..000000000 --- a/bot/cogs/moderation/verification.py +++ /dev/null @@ -1,191 +0,0 @@ -import logging -from contextlib import suppress - -from discord import Colour, Forbidden, Message, NotFound, Object -from discord.ext.commands import Cog, Context, command - -from bot import constants -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.decorators import in_whitelist, without_role -from bot.utils.checks import InWhitelistCheckFailure, without_role_check - -log = logging.getLogger(__name__) - -WELCOME_MESSAGE = f""" -Hello! Welcome to the server, and thanks for verifying yourself! - -For your records, these are the documents you accepted: - -`1)` Our rules, here: -`2)` Our privacy policy, here: - you can find information on how to have \ -your information removed here as well. - -Feel free to review them at any point! - -Additionally, if you'd like to receive notifications for the announcements \ -we post in <#{constants.Channels.announcements}> -from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ -to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. - -If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ -<#{constants.Channels.bot_commands}>. -""" - -BOT_MESSAGE_DELETE_DELAY = 10 - - -class Verification(Cog): - """User verification and role self-management.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Check new message event for messages to the checkpoint channel & process.""" - if message.channel.id != constants.Channels.verification: - return # Only listen for #checkpoint messages - - if message.author.bot: - # They're a bot, delete their message after the delay. - await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) - return - - # if a user mentions a role or guild member - # alert the mods in mod-alerts channel - if message.mentions or message.role_mentions: - log.debug( - f"{message.author} mentioned one or more users " - f"and/or roles in {message.channel.name}" - ) - - embed_text = ( - f"{message.author.mention} sent a message in " - f"{message.channel.mention} that contained user and/or role mentions." - f"\n\n**Original message:**\n>>> {message.content}" - ) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=constants.Icons.filtering, - colour=Colour(constants.Colours.soft_red), - title=f"User/Role mentioned in {message.channel.name}", - text=embed_text, - thumbnail=message.author.avatar_url_as(static_format="png"), - channel_id=constants.Channels.mod_alerts, - ) - - ctx: Context = await self.bot.get_context(message) - if ctx.command is not None and ctx.command.name == "accept": - return - - if any(r.id == constants.Roles.verified for r in ctx.author.roles): - log.info( - f"{ctx.author} posted '{ctx.message.content}' " - "in the verification channel, but is already verified." - ) - return - - log.debug( - f"{ctx.author} posted '{ctx.message.content}' in the verification " - "channel. We are providing instructions how to verify." - ) - await ctx.send( - f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " - f"and gain access to the rest of the server.", - delete_after=20 - ) - - log.trace(f"Deleting the message posted by {ctx.author}") - with suppress(NotFound): - await ctx.message.delete() - - @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) - @without_role(constants.Roles.verified) - @in_whitelist(channels=(constants.Channels.verification,)) - async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Accept our rules and gain access to the rest of the server.""" - log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") - await ctx.author.add_roles(Object(constants.Roles.verified), reason="Accepted the rules") - try: - await ctx.author.send(WELCOME_MESSAGE) - except Forbidden: - log.info(f"Sending welcome message failed for {ctx.author}.") - finally: - log.trace(f"Deleting accept message by {ctx.author}.") - with suppress(NotFound): - self.mod_log.ignore(constants.Event.message_delete, ctx.message.id) - await ctx.message.delete() - - @command(name='subscribe') - @in_whitelist(channels=(constants.Channels.bot_commands,)) - async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Subscribe to announcement notifications by assigning yourself the role.""" - has_role = False - - for role in ctx.author.roles: - if role.id == constants.Roles.announcements: - has_role = True - break - - if has_role: - await ctx.send(f"{ctx.author.mention} You're already subscribed!") - return - - log.debug(f"{ctx.author} called !subscribe. Assigning the 'Announcements' role.") - await ctx.author.add_roles(Object(constants.Roles.announcements), reason="Subscribed to announcements") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - await ctx.send( - f"{ctx.author.mention} Subscribed to <#{constants.Channels.announcements}> notifications.", - ) - - @command(name='unsubscribe') - @in_whitelist(channels=(constants.Channels.bot_commands,)) - async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Unsubscribe from announcement notifications by removing the role from yourself.""" - has_role = False - - for role in ctx.author.roles: - if role.id == constants.Roles.announcements: - has_role = True - break - - if not has_role: - await ctx.send(f"{ctx.author.mention} You're already unsubscribed!") - return - - log.debug(f"{ctx.author} called !unsubscribe. Removing the 'Announcements' role.") - await ctx.author.remove_roles(Object(constants.Roles.announcements), reason="Unsubscribed from announcements") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - await ctx.send( - f"{ctx.author.mention} Unsubscribed from <#{constants.Channels.announcements}> notifications." - ) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Check for & ignore any InWhitelistCheckFailure.""" - if isinstance(error, InWhitelistCheckFailure): - error.handled = True - - @staticmethod - def bot_check(ctx: Context) -> bool: - """Block any command within the verification channel that is not !accept.""" - if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): - return ctx.command.name == "accept" - else: - return True - - -def setup(bot: Bot) -> None: - """Load the Verification cog.""" - bot.add_cog(Verification(bot)) diff --git a/bot/cogs/moderation/watchchannels/__init__.py b/bot/cogs/moderation/watchchannels/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/moderation/watchchannels/_watchchannel.py b/bot/cogs/moderation/watchchannels/_watchchannel.py deleted file mode 100644 index 488ae704d..000000000 --- a/bot/cogs/moderation/watchchannels/_watchchannel.py +++ /dev/null @@ -1,348 +0,0 @@ -import asyncio -import logging -import re -import textwrap -from abc import abstractmethod -from collections import defaultdict, deque -from dataclasses import dataclass -from typing import Optional - -import dateutil.parser -import discord -from discord import Color, DMChannel, Embed, HTTPException, Message, errors -from discord.ext.commands import Cog, Context - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons -from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - -URL_RE = re.compile(r"(https?://[^\s]+)") - - -@dataclass -class MessageHistory: - """Represents a watch channel's message history.""" - - last_author: Optional[int] = None - last_channel: Optional[int] = None - message_count: int = 0 - - -class WatchChannel(metaclass=CogABCMeta): - """ABC with functionality for relaying users' messages to a certain channel.""" - - @abstractmethod - def __init__( - self, - bot: Bot, - destination: int, - webhook_id: int, - api_endpoint: str, - api_default_params: dict, - logger: logging.Logger - ) -> None: - self.bot = bot - - self.destination = destination # E.g., Channels.big_brother_logs - self.webhook_id = webhook_id # E.g., Webhooks.big_brother - self.api_endpoint = api_endpoint # E.g., 'bot/infractions' - self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} - self.log = logger # Logger of the child cog for a correct name in the logs - - self._consume_task = None - self.watched_users = defaultdict(dict) - self.message_queue = defaultdict(lambda: defaultdict(deque)) - self.consumption_queue = {} - self.retries = 5 - self.retry_delay = 10 - self.channel = None - self.webhook = None - self.message_history = MessageHistory() - - self._start = self.bot.loop.create_task(self.start_watchchannel()) - - @property - def modlog(self) -> ModLog: - """Provides access to the ModLog cog for alert purposes.""" - return self.bot.get_cog("ModLog") - - @property - def consuming_messages(self) -> bool: - """Checks if a consumption task is currently running.""" - if self._consume_task is None: - return False - - if self._consume_task.done(): - exc = self._consume_task.exception() - if exc: - self.log.exception( - "The message queue consume task has failed with:", - exc_info=exc - ) - return False - - return True - - async def start_watchchannel(self) -> None: - """Starts the watch channel by getting the channel, webhook, and user cache ready.""" - await self.bot.wait_until_guild_available() - - try: - self.channel = await self.bot.fetch_channel(self.destination) - except HTTPException: - self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - if self.channel is None or self.webhook is None: - self.log.error("Failed to start the watch channel; unloading the cog.") - - message = textwrap.dedent( - f""" - An error occurred while loading the text channel or webhook. - - TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} - Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} - - The Cog has been unloaded. - """ - ) - - await self.modlog.send_log_message( - title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", - text=message, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - if not await self.fetch_user_cache(): - await self.modlog.send_log_message( - title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", - text="Could not retrieve the list of watched users from the API and messages will not be relayed.", - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - async def fetch_user_cache(self) -> bool: - """ - Fetches watched users from the API and updates the watched user cache accordingly. - - This function returns `True` if the update succeeded. - """ - try: - data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) - except ResponseCodeError as err: - self.log.exception("Failed to fetch the watched users from the API", exc_info=err) - return False - - self.watched_users = defaultdict(dict) - - for entry in data: - user_id = entry.pop('user') - self.watched_users[user_id] = entry - - return True - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Queues up messages sent by watched users.""" - if msg.author.id in self.watched_users: - if not self.consuming_messages: - self._consume_task = self.bot.loop.create_task(self.consume_messages()) - - self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") - self.message_queue[msg.author.id][msg.channel.id].append(msg) - - async def consume_messages(self, delay_consumption: bool = True) -> None: - """Consumes the message queues to log watched users' messages.""" - if delay_consumption: - self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") - await asyncio.sleep(BigBrotherConfig.log_delay) - - self.log.trace("Started consuming the message queue") - - # If the previous consumption Task failed, first consume the existing comsumption_queue - if not self.consumption_queue: - self.consumption_queue = self.message_queue.copy() - self.message_queue.clear() - - for user_channel_queues in self.consumption_queue.values(): - for channel_queue in user_channel_queues.values(): - while channel_queue: - msg = channel_queue.popleft() - - self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") - await self.relay_message(msg) - - self.consumption_queue.clear() - - if self.message_queue: - self.log.trace("Channel queue not empty: Continuing consuming queues") - self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) - else: - self.log.trace("Done consuming messages.") - - async def webhook_send( - self, - content: Optional[str] = None, - username: Optional[str] = None, - avatar_url: Optional[str] = None, - embed: Optional[Embed] = None, - ) -> None: - """Sends a message to the webhook with the specified kwargs.""" - username = messages.sub_clyde(username) - try: - await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send a message to the webhook", - exc_info=exc - ) - - async def relay_message(self, msg: Message) -> None: - """Relays the message to the relevant watch channel.""" - limit = BigBrotherConfig.header_message_limit - - if ( - msg.author.id != self.message_history.last_author - or msg.channel.id != self.message_history.last_channel - or self.message_history.message_count >= limit - ): - self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) - - await self.send_header(msg) - - cleaned_content = msg.clean_content - - if cleaned_content: - # Put all non-media URLs in a code block to prevent embeds - media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} - for url in URL_RE.findall(cleaned_content): - if url not in media_urls: - cleaned_content = cleaned_content.replace(url, f"`{url}`") - await self.webhook_send( - cleaned_content, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - - if msg.attachments: - try: - await messages.send_attachments(msg, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await self.webhook_send( - embed=e, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send an attachment to the webhook", - exc_info=exc - ) - - self.message_history.message_count += 1 - - async def send_header(self, msg: Message) -> None: - """Sends a header embed with information about the relayed messages to the watch channel.""" - user_id = msg.author.id - - guild = self.bot.get_guild(GuildConfig.id) - actor = guild.get_member(self.watched_users[user_id]['actor']) - actor = actor.display_name if actor else self.watched_users[user_id]['actor'] - - inserted_at = self.watched_users[user_id]['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - - reason = self.watched_users[user_id]['reason'] - - if isinstance(msg.channel, DMChannel): - # If a watched user DMs the bot there won't be a channel name or jump URL - # This could technically include a GroupChannel but bot's can't be in those - message_jump = "via DM" - else: - message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" - - footer = f"Added {time_delta} by {actor} | Reason: {reason}" - embed = Embed(description=f"{msg.author.mention} {message_jump}") - embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) - - await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) - - async def list_watched_users( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Gives an overview of the watched user list for this channel. - - The optional kwarg `oldest_first` orders the list by oldest entry. - - The optional kwarg `update_cache` specifies whether the cache should - be refreshed by polling the API. - """ - if update_cache: - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") - update_cache = False - - lines = [] - for user_id, user_data in self.watched_users.items(): - inserted_at = user_data['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - lines.append(f"• <@{user_id}> (added {time_delta})") - - if oldest_first: - lines.reverse() - - lines = lines or ("There's nothing here yet.",) - - embed = Embed( - title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", - color=Color.blue() - ) - await LinePaginator.paginate(lines, ctx, embed, empty=False) - - @staticmethod - def _get_time_delta(time_string: str) -> str: - """Returns the time in human-readable time delta format.""" - date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) - time_delta = time_since(date_time, precision="minutes", max_units=1) - - return time_delta - - def _remove_user(self, user_id: int) -> None: - """Removes a user from a watch channel.""" - self.watched_users.pop(user_id, None) - self.message_queue.pop(user_id, None) - self.consumption_queue.pop(user_id, None) - - def cog_unload(self) -> None: - """Takes care of unloading the cog and canceling the consumption task.""" - self.log.trace("Unloading the cog") - if self._consume_task and not self._consume_task.done(): - self._consume_task.cancel() - try: - self._consume_task.result() - except asyncio.CancelledError as e: - self.log.exception( - "The consume task was canceled. Messages may be lost.", - exc_info=e - ) diff --git a/bot/cogs/moderation/watchchannels/bigbrother.py b/bot/cogs/moderation/watchchannels/bigbrother.py deleted file mode 100644 index 7db34bcf2..000000000 --- a/bot/cogs/moderation/watchchannels/bigbrother.py +++ /dev/null @@ -1,170 +0,0 @@ -import logging -import textwrap -from collections import ChainMap - -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation.infraction._utils import post_infraction -from bot.constants import Channels, MODERATION_ROLES, Webhooks -from bot.converters import FetchedMember -from bot.decorators import with_role -from ._watchchannel import WatchChannel - -log = logging.getLogger(__name__) - - -class BigBrother(WatchChannel, Cog, name="Big Brother"): - """Monitors users by relaying their messages to a watch channel to assist with moderation.""" - - def __init__(self, bot: Bot) -> None: - super().__init__( - bot, - destination=Channels.big_brother_logs, - webhook_id=Webhooks.big_brother, - api_endpoint='bot/infractions', - api_default_params={'active': 'true', 'type': 'watch', 'ordering': '-inserted_at'}, - logger=log - ) - - @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def bigbrother_group(self, ctx: Context) -> None: - """Monitors users by relaying their messages to the Big Brother watch channel.""" - await ctx.send_help(ctx.command) - - @bigbrother_group.command(name='watched', aliases=('all', 'list')) - @with_role(*MODERATION_ROLES) - async def watched_command( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Shows the users that are currently being monitored by Big Brother. - - The optional kwarg `oldest_first` can be used to order the list by oldest watched. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) - - @bigbrother_group.command(name='oldest') - @with_role(*MODERATION_ROLES) - async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: - """ - Shows Big Brother monitored users ordered by oldest watched. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) - - @bigbrother_group.command(name='watch', aliases=('w',)) - @with_role(*MODERATION_ROLES) - async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Relay messages sent by the given `user` to the `#big-brother` channel. - - A `reason` for adding the user to Big Brother is required and will be displayed - in the header when relaying messages of this user to the watchchannel. - """ - await self.apply_watch(ctx, user, reason) - - @bigbrother_group.command(name='unwatch', aliases=('uw',)) - @with_role(*MODERATION_ROLES) - async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Stop relaying messages by the given `user`.""" - await self.apply_unwatch(ctx, user, reason) - - async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: - """ - Add `user` to watched users and apply a watch infraction with `reason`. - - A message indicating the result of the operation is sent to `ctx`. - The message will include `user`'s previous watch infraction history, if it exists. - """ - if user.bot: - await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") - return - - if not await self.fetch_user_cache(): - await ctx.send(f":x: Updating the user cache failed, can't watch user {user}") - return - - if user.id in self.watched_users: - await ctx.send(f":x: {user} is already being watched.") - return - - response = await post_infraction(ctx, user, 'watch', reason, hidden=True, active=True) - - if response is not None: - self.watched_users[user.id] = response - msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother." - - history = await self.bot.api_client.get( - self.api_endpoint, - params={ - "user__id": str(user.id), - "active": "false", - 'type': 'watch', - 'ordering': '-inserted_at' - } - ) - - if len(history) > 1: - total = f"({len(history) // 2} previous infractions in total)" - end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") - start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" - msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" - else: - msg = ":x: Failed to post the infraction: response was empty." - - await ctx.send(msg) - - async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: - """ - Remove `user` from watched users and mark their infraction as inactive with `reason`. - - If `send_message` is True, a message indicating the result of the operation is sent to - `ctx`. - """ - active_watches = await self.bot.api_client.get( - self.api_endpoint, - params=ChainMap( - self.api_default_params, - {"user__id": str(user.id)} - ) - ) - if active_watches: - log.trace("Active watches for user found. Attempting to remove.") - [infraction] = active_watches - - await self.bot.api_client.patch( - f"{self.api_endpoint}/{infraction['id']}", - json={'active': False} - ) - - await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False) - - self._remove_user(user.id) - - if not send_message: # Prevents a message being sent to the channel if part of a permanent ban - log.debug(f"Perma-banned user {user} was unwatched.") - return - log.trace("User is not banned. Sending message to channel") - message = f":white_check_mark: Messages sent by {user} will no longer be relayed." - - else: - log.trace("No active watches found for user.") - if not send_message: # Prevents a message being sent to the channel if part of a permanent ban - log.debug(f"{user} was not on the watch list; no removal necessary.") - return - log.trace("User is not perma banned. Send the error message.") - message = ":x: The specified user is currently not being watched." - - await ctx.send(message) - - -def setup(bot: Bot) -> None: - """Load the BigBrother cog.""" - bot.add_cog(BigBrother(bot)) diff --git a/bot/cogs/moderation/watchchannels/talentpool.py b/bot/cogs/moderation/watchchannels/talentpool.py deleted file mode 100644 index 2972f56e1..000000000 --- a/bot/cogs/moderation/watchchannels/talentpool.py +++ /dev/null @@ -1,269 +0,0 @@ -import logging -import textwrap -from collections import ChainMap - -from discord import Color, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks -from bot.converters import FetchedMember -from bot.decorators import with_role -from bot.pagination import LinePaginator -from bot.utils import time -from ._watchchannel import WatchChannel - -log = logging.getLogger(__name__) - - -class TalentPool(WatchChannel, Cog, name="Talentpool"): - """Relays messages of helper candidates to a watch channel to observe them.""" - - def __init__(self, bot: Bot) -> None: - super().__init__( - bot, - destination=Channels.talent_pool, - webhook_id=Webhooks.talent_pool, - api_endpoint='bot/nominations', - api_default_params={'active': 'true', 'ordering': '-inserted_at'}, - logger=log, - ) - - @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def nomination_group(self, ctx: Context) -> None: - """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" - await ctx.send_help(ctx.command) - - @nomination_group.command(name='watched', aliases=('all', 'list')) - @with_role(*MODERATION_ROLES) - async def watched_command( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Shows the users that are currently being monitored in the talent pool. - - The optional kwarg `oldest_first` can be used to order the list by oldest nomination. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) - - @nomination_group.command(name='oldest') - @with_role(*MODERATION_ROLES) - async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: - """ - Shows talent pool monitored users ordered by oldest nomination. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) - - @nomination_group.command(name='watch', aliases=('w', 'add', 'a')) - @with_role(*STAFF_ROLES) - async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Relay messages sent by the given `user` to the `#talent-pool` channel. - - A `reason` for adding the user to the talent pool is required and will be displayed - in the header when relaying messages of this user to the channel. - """ - if user.bot: - await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") - return - - if isinstance(user, Member) and any(role.id in STAFF_ROLES for role in user.roles): - await ctx.send(":x: Nominating staff members, eh? Here's a cookie :cookie:") - return - - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update the user cache; can't add {user}") - return - - if user.id in self.watched_users: - await ctx.send(f":x: {user} is already being watched in the talent pool") - return - - # Manual request with `raise_for_status` as False because we want the actual response - session = self.bot.api_client.session - url = self.bot.api_client._url_for(self.api_endpoint) - kwargs = { - 'json': { - 'actor': ctx.author.id, - 'reason': reason, - 'user': user.id - }, - 'raise_for_status': False, - } - async with session.post(url, **kwargs) as resp: - response_data = await resp.json() - - if resp.status == 400 and response_data.get('user', False): - await ctx.send(":x: The specified user can't be found in the database tables") - return - else: - resp.raise_for_status() - - self.watched_users[user.id] = response_data - msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel" - - history = await self.bot.api_client.get( - self.api_endpoint, - params={ - "user__id": str(user.id), - "active": "false", - "ordering": "-inserted_at" - } - ) - - if history: - total = f"({len(history)} previous nominations in total)" - start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" - end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" - msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" - - await ctx.send(msg) - - @nomination_group.command(name='history', aliases=('info', 'search')) - @with_role(*MODERATION_ROLES) - async def history_command(self, ctx: Context, user: FetchedMember) -> None: - """Shows the specified user's nomination history.""" - result = await self.bot.api_client.get( - self.api_endpoint, - params={ - 'user__id': str(user.id), - 'ordering': "-active,-inserted_at" - } - ) - if not result: - await ctx.send(":warning: This user has never been nominated") - return - - embed = Embed( - title=f"Nominations for {user.display_name} `({user.id})`", - color=Color.blue() - ) - lines = [self._nomination_to_string(nomination) for nomination in result] - await LinePaginator.paginate( - lines, - ctx=ctx, - embed=embed, - empty=True, - max_lines=3, - max_size=1000 - ) - - @nomination_group.command(name='unwatch', aliases=('end', )) - @with_role(*MODERATION_ROLES) - async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Ends the active nomination of the specified user with the given reason. - - Providing a `reason` is required. - """ - active_nomination = await self.bot.api_client.get( - self.api_endpoint, - params=ChainMap( - self.api_default_params, - {"user__id": str(user.id)} - ) - ) - - if not active_nomination: - await ctx.send(":x: The specified user does not have an active nomination") - return - - [nomination] = active_nomination - await self.bot.api_client.patch( - f"{self.api_endpoint}/{nomination['id']}", - json={'end_reason': reason, 'active': False} - ) - await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") - self._remove_user(user.id) - - @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def nomination_edit_group(self, ctx: Context) -> None: - """Commands to edit nominations.""" - await ctx.send_help(ctx.command) - - @nomination_edit_group.command(name='reason') - @with_role(*MODERATION_ROLES) - async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None: - """ - Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. - - If the nomination is active, the reason for nominating the user will be edited; - If the nomination is no longer active, the reason for ending the nomination will be edited instead. - """ - try: - nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") - except ResponseCodeError as e: - if e.response.status == 404: - self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") - await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") - return - else: - raise - - field = "reason" if nomination["active"] else "end_reason" - - self.log.trace(f"Changing {field} for nomination with id {nomination_id} to {reason}") - - await self.bot.api_client.patch( - f"{self.api_endpoint}/{nomination_id}", - json={field: reason} - ) - - await ctx.send(f":white_check_mark: Updated the {field} of the nomination!") - - def _nomination_to_string(self, nomination_object: dict) -> str: - """Creates a string representation of a nomination.""" - guild = self.bot.get_guild(Guild.id) - - actor_id = nomination_object["actor"] - actor = guild.get_member(actor_id) - - active = nomination_object["active"] - log.debug(active) - log.debug(type(nomination_object["inserted_at"])) - - start_date = time.format_infraction(nomination_object["inserted_at"]) - if active: - lines = textwrap.dedent( - f""" - =============== - Status: **Active** - Date: {start_date} - Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} - Nomination ID: `{nomination_object["id"]}` - =============== - """ - ) - else: - end_date = time.format_infraction(nomination_object["ended_at"]) - lines = textwrap.dedent( - f""" - =============== - Status: Inactive - Date: {start_date} - Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} - - End date: {end_date} - Unwatch reason: {nomination_object["end_reason"]} - Nomination ID: `{nomination_object["id"]}` - =============== - """ - ) - - return lines.strip() - - -def setup(bot: Bot) -> None: - """Load the TalentPool cog.""" - bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py deleted file mode 100644 index ce95450e0..000000000 --- a/bot/cogs/off_topic_names.py +++ /dev/null @@ -1,162 +0,0 @@ -import asyncio -import difflib -import logging -from datetime import datetime, timedelta - -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES -from bot.converters import OffTopicName -from bot.decorators import with_role -from bot.pagination import LinePaginator - -CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) -log = logging.getLogger(__name__) - - -async def update_names(bot: Bot) -> None: - """Background updater task that performs the daily channel name update.""" - while True: - # Since we truncate the compute timedelta to seconds, we add one second to ensure - # we go past midnight in the `seconds_to_sleep` set below. - today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0) - next_midnight = today_at_midnight + timedelta(days=1) - seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1 - await asyncio.sleep(seconds_to_sleep) - - try: - channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get( - 'bot/off-topic-channel-names', params={'random_items': 3} - ) - except ResponseCodeError as e: - log.error(f"Failed to get new off topic channel names: code {e.response.status}") - continue - channel_0, channel_1, channel_2 = (bot.get_channel(channel_id) for channel_id in CHANNELS) - - await channel_0.edit(name=f'ot0-{channel_0_name}') - await channel_1.edit(name=f'ot1-{channel_1_name}') - await channel_2.edit(name=f'ot2-{channel_2_name}') - log.debug( - "Updated off-topic channel names to" - f" {channel_0_name}, {channel_1_name} and {channel_2_name}" - ) - - -class OffTopicNames(Cog): - """Commands related to managing the off-topic category channel names.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.updater_task = None - - self.bot.loop.create_task(self.init_offtopic_updater()) - - def cog_unload(self) -> None: - """Cancel any running updater tasks on cog unload.""" - if self.updater_task is not None: - self.updater_task.cancel() - - async def init_offtopic_updater(self) -> None: - """Start off-topic channel updating event loop if it hasn't already started.""" - await self.bot.wait_until_guild_available() - if self.updater_task is None: - coro = update_names(self.bot) - self.updater_task = self.bot.loop.create_task(coro) - - @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def otname_group(self, ctx: Context) -> None: - """Add or list items from the off-topic channel name rotation.""" - await ctx.send_help(ctx.command) - - @otname_group.command(name='add', aliases=('a',)) - @with_role(*MODERATION_ROLES) - async def add_command(self, ctx: Context, *, name: OffTopicName) -> None: - """ - Adds a new off-topic name to the rotation. - - The name is not added if it is too similar to an existing name. - """ - existing_names = await self.bot.api_client.get('bot/off-topic-channel-names') - close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) - - if close_match: - match = close_match[0] - log.info( - f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'" - ) - await ctx.send( - f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. " - "Use `!otn forceadd` to override this check." - ) - else: - await self._add_name(ctx, name) - - @otname_group.command(name='forceadd', aliases=('fa',)) - @with_role(*MODERATION_ROLES) - async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None: - """Forcefully adds a new off-topic name to the rotation.""" - await self._add_name(ctx, name) - - async def _add_name(self, ctx: Context, name: str) -> None: - """Adds an off-topic channel name to the site storage.""" - await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name}) - - log.info(f"{ctx.author} added the off-topic channel name '{name}'") - await ctx.send(f":ok_hand: Added `{name}` to the names list.") - - @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) - @with_role(*MODERATION_ROLES) - async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None: - """Removes a off-topic name from the rotation.""" - await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') - - log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") - await ctx.send(f":ok_hand: Removed `{name}` from the names list.") - - @otname_group.command(name='list', aliases=('l',)) - @with_role(*MODERATION_ROLES) - async def list_command(self, ctx: Context) -> None: - """ - Lists all currently known off-topic channel names in a paginator. - - Restricted to Moderator and above to not spoil the surprise. - """ - result = await self.bot.api_client.get('bot/off-topic-channel-names') - lines = sorted(f"• {name}" for name in result) - embed = Embed( - title=f"Known off-topic names (`{len(result)}` total)", - colour=Colour.blue() - ) - if result: - await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - - @otname_group.command(name='search', aliases=('s',)) - @with_role(*MODERATION_ROLES) - async def search_command(self, ctx: Context, *, query: OffTopicName) -> None: - """Search for an off-topic name.""" - result = await self.bot.api_client.get('bot/off-topic-channel-names') - in_matches = {name for name in result if query in name} - close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70) - lines = sorted(f"• {name}" for name in in_matches.union(close_matches)) - embed = Embed( - title="Query results", - colour=Colour.blue() - ) - - if lines: - await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) - else: - embed.description = "Nothing found." - await ctx.send(embed=embed) - - -def setup(bot: Bot) -> None: - """Load the OffTopicNames cog.""" - bot.add_cog(OffTopicNames(bot)) diff --git a/bot/cogs/utils/__init__.py b/bot/cogs/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/utils/bot.py b/bot/cogs/utils/bot.py deleted file mode 100644 index 71ed54f60..000000000 --- a/bot/cogs/utils/bot.py +++ /dev/null @@ -1,385 +0,0 @@ -import ast -import logging -import re -import time -from typing import Optional, Tuple - -from discord import Embed, Message, RawMessageUpdateEvent, TextChannel -from discord.ext.commands import Cog, Context, command, group - -from bot.bot import Bot -from bot.cogs.filters.token_remover import TokenRemover -from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs -from bot.decorators import with_role -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -RE_MARKDOWN = re.compile(r'([*_~`|>])') - - -class BotCog(Cog, name="Bot"): - """Bot information commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - # Stores allowed channels plus epoch time since last call. - self.channel_cooldowns = { - Channels.python_discussion: 0, - } - - # These channels will also work, but will not be subject to cooldown - self.channel_whitelist = ( - Channels.bot_commands, - ) - - # Stores improperly formatted Python codeblock message ids and the corresponding bot message - self.codeblock_message_ids = {} - - @group(invoke_without_command=True, name="bot", hidden=True) - @with_role(Roles.verified) - async def botinfo_group(self, ctx: Context) -> None: - """Bot informational commands.""" - await ctx.send_help(ctx.command) - - @botinfo_group.command(name='about', aliases=('info',), hidden=True) - @with_role(Roles.verified) - async def about_command(self, ctx: Context) -> None: - """Get information about the bot.""" - embed = Embed( - description="A utility bot designed just for the Python server! Try `!help` for more info.", - url="https://github.com/python-discord/bot" - ) - - embed.add_field(name="Total Users", value=str(len(self.bot.get_guild(Guild.id).members))) - embed.set_author( - name="Python Bot", - url="https://github.com/python-discord/bot", - icon_url=URLs.bot_avatar - ) - - await ctx.send(embed=embed) - - @command(name='echo', aliases=('print',)) - @with_role(*MODERATION_ROLES) - async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: - """Repeat the given message in either a specified channel or the current channel.""" - if channel is None: - await ctx.send(text) - else: - await channel.send(text) - - @command(name='embed') - @with_role(*MODERATION_ROLES) - async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: - """Send the input within an embed to either a specified channel or the current channel.""" - embed = Embed(description=text) - - if channel is None: - await ctx.send(embed=embed) - else: - await channel.send(embed=embed) - - def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: - """ - Strip msg in order to find Python code. - - Tries to strip out Python code out of msg and returns the stripped block or - None if the block is a valid Python codeblock. - """ - if msg.count("\n") >= 3: - # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. - if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: - log.trace( - "Someone wrote a message that was already a " - "valid Python syntax highlighted code block. No action taken." - ) - return None - - else: - # Stripping backticks from every line of the message. - log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") - content = "" - for line in msg.splitlines(keepends=True): - content += line.strip("`") - - content = content.strip() - - # Remove "Python" or "Py" from start of the message if it exists. - log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") - pycode = False - if content.lower().startswith("python"): - content = content[6:] - pycode = True - elif content.lower().startswith("py"): - content = content[2:] - pycode = True - - if pycode: - content = content.splitlines(keepends=True) - - # Check if there might be code in the first line, and preserve it. - first_line = content[0] - if " " in content[0]: - first_space = first_line.index(" ") - content[0] = first_line[first_space:] - content = "".join(content) - - # If there's no code we can just get rid of the first line. - else: - content = "".join(content[1:]) - - # Strip it again to remove any leading whitespace. This is neccessary - # if the first line of the message looked like ```python - old = content.strip() - - # Strips REPL code out of the message if there is any. - content, repl_code = self.repl_stripping(old) - if old != content: - return (content, old), repl_code - - # Try to apply indentation fixes to the code. - content = self.fix_indentation(content) - - # Check if the code contains backticks, if it does ignore the message. - if "`" in content: - log.trace("Detected ` inside the code, won't reply") - return None - else: - log.trace(f"Returning message.\n\n{content}\n\n") - return (content,), repl_code - - def fix_indentation(self, msg: str) -> str: - """Attempts to fix badly indented code.""" - def unindent(code: str, skip_spaces: int = 0) -> str: - """Unindents all code down to the number of spaces given in skip_spaces.""" - final = "" - current = code[0] - leading_spaces = 0 - - # Get numbers of spaces before code in the first line. - while current == " ": - current = code[leading_spaces + 1] - leading_spaces += 1 - leading_spaces -= skip_spaces - - # If there are any, remove that number of spaces from every line. - if leading_spaces > 0: - for line in code.splitlines(keepends=True): - line = line[leading_spaces:] - final += line - return final - else: - return code - - # Apply fix for "all lines are overindented" case. - msg = unindent(msg) - - # If the first line does not end with a colon, we can be - # certain the next line will be on the same indentation level. - # - # If it does end with a colon, we will need to indent all successive - # lines one additional level. - first_line = msg.splitlines()[0] - code = "".join(msg.splitlines(keepends=True)[1:]) - if not first_line.endswith(":"): - msg = f"{first_line}\n{unindent(code)}" - else: - msg = f"{first_line}\n{unindent(code, 4)}" - return msg - - def repl_stripping(self, msg: str) -> Tuple[str, bool]: - """ - Strip msg in order to extract Python code out of REPL output. - - Tries to strip out REPL Python code out of msg and returns the stripped msg. - - Returns True for the boolean if REPL code was found in the input msg. - """ - final = "" - for line in msg.splitlines(keepends=True): - if line.startswith(">>>") or line.startswith("..."): - final += line[4:] - log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") - if not final: - log.trace(f"Found no REPL code in \n\n{msg}\n\n") - return msg, False - else: - log.trace(f"Found REPL code in \n\n{msg}\n\n") - return final.rstrip(), True - - def has_bad_ticks(self, msg: Message) -> bool: - """Check to see if msg contains ticks that aren't '`'.""" - not_backticks = [ - "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", - "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", - "\u3003\u3003\u3003" - ] - - return msg.content[:3] in not_backticks - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """ - Detect poorly formatted Python code in new messages. - - If poorly formatted code is detected, send the user a helpful message explaining how to do - properly formatted Python syntax highlighting codeblocks. - """ - is_help_channel = ( - getattr(msg.channel, "category", None) - and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) - ) - parse_codeblock = ( - ( - is_help_channel - or msg.channel.id in self.channel_cooldowns - or msg.channel.id in self.channel_whitelist - ) - and not msg.author.bot - and len(msg.content.splitlines()) > 3 - and not TokenRemover.find_token_in_message(msg) - ) - - if parse_codeblock: # no token in the msg - on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 - if not on_cooldown or DEBUG_MODE: - try: - if self.has_bad_ticks(msg): - ticks = msg.content[:3] - content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) - if content is None: - return - - content, repl_code = content - - if len(content) == 2: - content = content[1] - else: - content = content[0] - - space_left = 204 - if len(content) >= space_left: - current_length = 0 - lines_walked = 0 - for line in content.splitlines(keepends=True): - if current_length + len(line) > space_left or lines_walked == 10: - break - current_length += len(line) - lines_walked += 1 - content = content[:current_length] + "#..." - content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) - howto = ( - "It looks like you are trying to paste code into this channel.\n\n" - "You seem to be using the wrong symbols to indicate where the codeblock should start. " - f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" - "**Here is an example of how it should look:**\n" - f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" - "**This will result in the following:**\n" - f"```python\n{content}\n```" - ) - - else: - howto = "" - content = self.codeblock_stripping(msg.content, False) - if content is None: - return - - content, repl_code = content - # Attempts to parse the message into an AST node. - # Invalid Python code will raise a SyntaxError. - tree = ast.parse(content[0]) - - # Multiple lines of single words could be interpreted as expressions. - # This check is to avoid all nodes being parsed as expressions. - # (e.g. words over multiple lines) - if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: - # Shorten the code to 10 lines and/or 204 characters. - space_left = 204 - if content and repl_code: - content = content[1] - else: - content = content[0] - - if len(content) >= space_left: - current_length = 0 - lines_walked = 0 - for line in content.splitlines(keepends=True): - if current_length + len(line) > space_left or lines_walked == 10: - break - current_length += len(line) - lines_walked += 1 - content = content[:current_length] + "#..." - - content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) - howto += ( - "It looks like you're trying to paste code into this channel.\n\n" - "Discord has support for Markdown, which allows you to post code with full " - "syntax highlighting. Please use these whenever you paste code, as this " - "helps improve the legibility and makes it easier for us to help you.\n\n" - f"**To do this, use the following method:**\n" - f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" - "**This will result in the following:**\n" - f"```python\n{content}\n```" - ) - - log.debug(f"{msg.author} posted something that needed to be put inside python code " - "blocks. Sending the user some instructions.") - else: - log.trace("The code consists only of expressions, not sending instructions") - - if howto != "": - # Increase amount of codeblock correction in stats - self.bot.stats.incr("codeblock_corrections") - howto_embed = Embed(description=howto) - bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) - self.codeblock_message_ids[msg.id] = bot_message.id - - self.bot.loop.create_task( - wait_for_deletion(bot_message, user_ids=(msg.author.id,), client=self.bot) - ) - else: - return - - if msg.channel.id not in self.channel_whitelist: - self.channel_cooldowns[msg.channel.id] = time.time() - - except SyntaxError: - log.trace( - f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " - "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " - f"The message that was posted was:\n\n{msg.content}\n\n" - ) - - @Cog.listener() - async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: - """Check to see if an edited message (previously called out) still contains poorly formatted code.""" - if ( - # Checks to see if the message was called out by the bot - payload.message_id not in self.codeblock_message_ids - # Makes sure that there is content in the message - or payload.data.get("content") is None - # Makes sure there's a channel id in the message payload - or payload.data.get("channel_id") is None - ): - return - - # Retrieve channel and message objects for use later - channel = self.bot.get_channel(int(payload.data.get("channel_id"))) - user_message = await channel.fetch_message(payload.message_id) - - # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None - has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) - - # If the message is fixed, delete the bot message and the entry from the id dictionary - if has_fixed_codeblock is None: - bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) - await bot_message.delete() - del self.codeblock_message_ids[payload.message_id] - log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") - - -def setup(bot: Bot) -> None: - """Load the Bot cog.""" - bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/utils/clean.py b/bot/cogs/utils/clean.py deleted file mode 100644 index c156ff02e..000000000 --- a/bot/cogs/utils/clean.py +++ /dev/null @@ -1,272 +0,0 @@ -import logging -import random -import re -from typing import Iterable, Optional - -from discord import Colour, Embed, Message, TextChannel, User -from discord.ext import commands -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import ( - Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES -) -from bot.decorators import with_role - -log = logging.getLogger(__name__) - - -class Clean(Cog): - """ - A cog that allows messages to be deleted in bulk, while applying various filters. - - You can delete messages sent by a specific user, messages sent by bots, all messages, or messages that match a - specific regular expression. - - The deleted messages are saved and uploaded to the database via an API endpoint, and a URL is returned which can be - used to view the messages in the Discord dark theme style. - """ - - def __init__(self, bot: Bot): - self.bot = bot - self.cleaning = False - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def _clean_messages( - self, - amount: int, - ctx: Context, - channels: Iterable[TextChannel], - bots_only: bool = False, - user: User = None, - regex: Optional[str] = None, - until_message: Optional[Message] = None, - ) -> None: - """A helper function that does the actual message cleaning.""" - def predicate_bots_only(message: Message) -> bool: - """Return True if the message was sent by a bot.""" - return message.author.bot - - def predicate_specific_user(message: Message) -> bool: - """Return True if the message was sent by the user provided in the _clean_messages call.""" - return message.author == user - - def predicate_regex(message: Message) -> bool: - """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" - content = [message.content] - - # Add the content for all embed attributes - for embed in message.embeds: - content.append(embed.title) - content.append(embed.description) - content.append(embed.footer.text) - content.append(embed.author.name) - for field in embed.fields: - content.append(field.name) - content.append(field.value) - - # Get rid of empty attributes and turn it into a string - content = [attr for attr in content if attr] - content = "\n".join(content) - - # Now let's see if there's a regex match - if not content: - return False - else: - return bool(re.search(regex.lower(), content.lower())) - - # Is this an acceptable amount of messages to clean? - if amount > CleanMessages.message_limit: - embed = Embed( - color=Colour(Colours.soft_red), - title=random.choice(NEGATIVE_REPLIES), - description=f"You cannot clean more than {CleanMessages.message_limit} messages." - ) - await ctx.send(embed=embed) - return - - # Are we already performing a clean? - if self.cleaning: - embed = Embed( - color=Colour(Colours.soft_red), - title=random.choice(NEGATIVE_REPLIES), - description="Please wait for the currently ongoing clean operation to complete." - ) - await ctx.send(embed=embed) - return - - # Set up the correct predicate - if bots_only: - predicate = predicate_bots_only # Delete messages from bots - elif user: - predicate = predicate_specific_user # Delete messages from specific user - elif regex: - predicate = predicate_regex # Delete messages that match regex - else: - predicate = None # Delete all messages - - # Default to using the invoking context's channel - if not channels: - channels = [ctx.channel] - - # Delete the invocation first - self.mod_log.ignore(Event.message_delete, ctx.message.id) - await ctx.message.delete() - - messages = [] - message_ids = [] - self.cleaning = True - - # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. - for channel in channels: - async for message in channel.history(limit=amount): - - # If at any point the cancel command is invoked, we should stop. - if not self.cleaning: - return - - # If we are looking for specific message. - if until_message: - - # we could use ID's here however in case if the message we are looking for gets deleted, - # we won't have a way to figure that out thus checking for datetime should be more reliable - if message.created_at < until_message.created_at: - # means we have found the message until which we were supposed to be deleting. - break - - # Since we will be using `delete_messages` method of a TextChannel and we need message objects to - # use it as well as to send logs we will start appending messages here instead adding them from - # purge. - messages.append(message) - - # If the message passes predicate, let's save it. - if predicate is None or predicate(message): - message_ids.append(message.id) - - self.cleaning = False - - # Now let's delete the actual messages with purge. - self.mod_log.ignore(Event.message_delete, *message_ids) - for channel in channels: - if until_message: - for i in range(0, len(messages), 100): - # while purge automatically handles the amount of messages - # delete_messages only allows for up to 100 messages at once - # thus we need to paginate the amount to always be <= 100 - await channel.delete_messages(messages[i:i + 100]) - else: - messages += await channel.purge(limit=amount, check=predicate) - - # Reverse the list to restore chronological order - if messages: - messages = reversed(messages) - log_url = await self.mod_log.upload_log(messages, ctx.author.id) - else: - # Can't build an embed, nothing to clean! - embed = Embed( - color=Colour(Colours.soft_red), - description="No matching messages could be found." - ) - await ctx.send(embed=embed, delete_after=10) - return - - # Build the embed and send it - target_channels = ", ".join(channel.mention for channel in channels) - - message = ( - f"**{len(message_ids)}** messages deleted in {target_channels} by **{ctx.author.name}**\n\n" - f"A log of the deleted messages can be found [here]({log_url})." - ) - - await self.mod_log.send_log_message( - icon_url=Icons.message_bulk_delete, - colour=Colour(Colours.soft_red), - title="Bulk message delete", - text=message, - channel_id=Channels.mod_log, - ) - - @group(invoke_without_command=True, name="clean", aliases=["purge"]) - @with_role(*MODERATION_ROLES) - async def clean_group(self, ctx: Context) -> None: - """Commands for cleaning messages in channels.""" - await ctx.send_help(ctx.command) - - @clean_group.command(name="user", aliases=["users"]) - @with_role(*MODERATION_ROLES) - async def clean_user( - self, - ctx: Context, - user: User, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, user=user, channels=channels) - - @clean_group.command(name="all", aliases=["everything"]) - @with_role(*MODERATION_ROLES) - async def clean_all( - self, - ctx: Context, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, channels=channels) - - @clean_group.command(name="bots", aliases=["bot"]) - @with_role(*MODERATION_ROLES) - async def clean_bots( - self, - ctx: Context, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, bots_only=True, channels=channels) - - @clean_group.command(name="regex", aliases=["word", "expression"]) - @with_role(*MODERATION_ROLES) - async def clean_regex( - self, - ctx: Context, - regex: str, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, regex=regex, channels=channels) - - @clean_group.command(name="message", aliases=["messages"]) - @with_role(*MODERATION_ROLES) - async def clean_message(self, ctx: Context, message: Message) -> None: - """Delete all messages until certain message, stop cleaning after hitting the `message`.""" - await self._clean_messages( - CleanMessages.message_limit, - ctx, - channels=[message.channel], - until_message=message - ) - - @clean_group.command(name="stop", aliases=["cancel", "abort"]) - @with_role(*MODERATION_ROLES) - async def clean_cancel(self, ctx: Context) -> None: - """If there is an ongoing cleaning process, attempt to immediately cancel it.""" - self.cleaning = False - - embed = Embed( - color=Colour.blurple(), - description="Clean interrupted." - ) - await ctx.send(embed=embed, delete_after=10) - - -def setup(bot: Bot) -> None: - """Load the Clean cog.""" - bot.add_cog(Clean(bot)) diff --git a/bot/cogs/utils/eval.py b/bot/cogs/utils/eval.py deleted file mode 100644 index eb8bfb1cf..000000000 --- a/bot/cogs/utils/eval.py +++ /dev/null @@ -1,202 +0,0 @@ -import contextlib -import inspect -import logging -import pprint -import re -import textwrap -import traceback -from io import StringIO -from typing import Any, Optional, Tuple - -import discord -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import Roles -from bot.decorators import with_role -from bot.interpreter import Interpreter - -log = logging.getLogger(__name__) - - -class CodeEval(Cog): - """Owner and admin feature that evaluates code and returns the result to the channel.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.env = {} - self.ln = 0 - self.stdout = StringIO() - - self.interpreter = Interpreter(bot) - - def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: - """Format the eval output into a string & attempt to format it into an Embed.""" - self._ = out - - res = "" - - # Erase temp input we made - if inp.startswith("_ = "): - inp = inp[4:] - - # Get all non-empty lines - lines = [line for line in inp.split("\n") if line.strip()] - if len(lines) != 1: - lines += [""] - - # Create the input dialog - for i, line in enumerate(lines): - if i == 0: - # Start dialog - start = f"In [{self.ln}]: " - - else: - # Indent the 3 dots correctly; - # Normally, it's something like - # In [X]: - # ...: - # - # But if it's - # In [XX]: - # ...: - # - # You can see it doesn't look right. - # This code simply indents the dots - # far enough to align them. - # we first `str()` the line number - # then we get the length - # and use `str.rjust()` - # to indent it. - start = "...: ".rjust(len(str(self.ln)) + 7) - - if i == len(lines) - 2: - if line.startswith("return"): - line = line[6:].strip() - - # Combine everything - res += (start + line + "\n") - - self.stdout.seek(0) - text = self.stdout.read() - self.stdout.close() - self.stdout = StringIO() - - if text: - res += (text + "\n") - - if out is None: - # No output, return the input statement - return (res, None) - - res += f"Out[{self.ln}]: " - - if isinstance(out, discord.Embed): - # We made an embed? Send that as embed - res += "" - res = (res, out) - - else: - if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): - # Leave out the traceback message - out = "\n" + "\n".join(out.split("\n")[1:]) - - if isinstance(out, str): - pretty = out - else: - pretty = pprint.pformat(out, compact=True, width=60) - - if pretty != str(out): - # We're using the pretty version, start on the next line - res += "\n" - - if pretty.count("\n") > 20: - # Text too long, shorten - li = pretty.split("\n") - - pretty = ("\n".join(li[:3]) # First 3 lines - + "\n ...\n" # Ellipsis to indicate removed lines - + "\n".join(li[-3:])) # last 3 lines - - # Add the output - res += pretty - res = (res, None) - - return res # Return (text, embed) - - async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: - """Eval the input code string & send an embed to the invoking context.""" - self.ln += 1 - - if code.startswith("exit"): - self.ln = 0 - self.env = {} - return await ctx.send("```Reset history!```") - - env = { - "message": ctx.message, - "author": ctx.message.author, - "channel": ctx.channel, - "guild": ctx.guild, - "ctx": ctx, - "self": self, - "bot": self.bot, - "inspect": inspect, - "discord": discord, - "contextlib": contextlib - } - - self.env.update(env) - - # Ignore this code, it works - code_ = """ -async def func(): # (None,) -> Any - try: - with contextlib.redirect_stdout(self.stdout): -{0} - if '_' in locals(): - if inspect.isawaitable(_): - _ = await _ - return _ - finally: - self.env.update(locals()) -""".format(textwrap.indent(code, ' ')) - - try: - exec(code_, self.env) # noqa: B102,S102 - func = self.env['func'] - res = await func() - - except Exception: - res = traceback.format_exc() - - out, embed = self._format(code, res) - await ctx.send(f"```py\n{out}```", embed=embed) - - @group(name='internal', aliases=('int',)) - @with_role(Roles.owners, Roles.admins) - async def internal_group(self, ctx: Context) -> None: - """Internal commands. Top secret!""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @internal_group.command(name='eval', aliases=('e',)) - @with_role(Roles.admins, Roles.owners) - async def eval(self, ctx: Context, *, code: str) -> None: - """Run eval in a REPL-like format.""" - code = code.strip("`") - if re.match('py(thon)?\n', code): - code = "\n".join(code.split("\n")[1:]) - - if not re.search( # Check if it's an expression - r"^(return|import|for|while|def|class|" - r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( - code.split("\n")) == 1: - code = "_ = " + code - - await self._eval(ctx, code) - - -def setup(bot: Bot) -> None: - """Load the CodeEval cog.""" - bot.add_cog(CodeEval(bot)) diff --git a/bot/cogs/utils/extensions.py b/bot/cogs/utils/extensions.py deleted file mode 100644 index 2cde07035..000000000 --- a/bot/cogs/utils/extensions.py +++ /dev/null @@ -1,289 +0,0 @@ -import functools -import importlib -import inspect -import logging -import pkgutil -import typing as t -from enum import Enum - -from discord import Colour, Embed -from discord.ext import commands -from discord.ext.commands import Context, group - -from bot import cogs -from bot.bot import Bot -from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - - -def walk_extensions() -> t.Iterator[str]: - """Yield extension names from the bot.cogs subpackage.""" - - def on_error(name: str) -> t.NoReturn: - raise ImportError(name=name) # pragma: no cover - - for module in pkgutil.walk_packages(cogs.__path__, f"{cogs.__name__}.", onerror=on_error): - if module.name.rsplit(".", maxsplit=1)[-1].startswith("_"): - # Ignore module/package names starting with an underscore. - continue - - if module.ispkg: - imported = importlib.import_module(module.name) - if not inspect.isfunction(getattr(imported, "setup", None)): - # If it lacks a setup function, it's not an extension. - continue - - yield module.name - - -UNLOAD_BLACKLIST = {f"{cogs.__name__}.utils.extensions", f"{cogs.__name__}.moderation.modlog"} -EXTENSIONS = frozenset(walk_extensions()) -COG_PATH_LEN = len(cogs.__name__.split(".")) - - -class Action(Enum): - """Represents an action to perform on an extension.""" - - # Need to be partial otherwise they are considered to be function definitions. - LOAD = functools.partial(Bot.load_extension) - UNLOAD = functools.partial(Bot.unload_extension) - RELOAD = functools.partial(Bot.reload_extension) - - -class Extension(commands.Converter): - """ - Fully qualify the name of an extension and ensure it exists. - - The * and ** values bypass this when used with the reload command. - """ - - async def convert(self, ctx: Context, argument: str) -> str: - """Fully qualify the name of an extension and ensure it exists.""" - # Special values to reload all extensions - if argument == "*" or argument == "**": - return argument - - argument = argument.lower() - - if argument in EXTENSIONS: - return argument - elif (qualified_arg := f"{cogs.__name__}.{argument}") in EXTENSIONS: - return qualified_arg - - matches = [] - for ext in EXTENSIONS: - name = ext.rsplit(".", maxsplit=1)[-1] - if argument == name: - matches.append(ext) - - if len(matches) > 1: - matches.sort() - names = "\n".join(matches) - raise commands.BadArgument( - f":x: `{argument}` is an ambiguous extension name. " - f"Please use one of the following fully-qualified names.```\n{names}```" - ) - elif matches: - return matches[0] - else: - raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") - - -class Extensions(commands.Cog): - """Extension management commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) - async def extensions_group(self, ctx: Context) -> None: - """Load, unload, reload, and list loaded extensions.""" - await ctx.send_help(ctx.command) - - @extensions_group.command(name="load", aliases=("l",)) - async def load_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Load extensions given their fully qualified or unqualified names. - - If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - if "*" in extensions or "**" in extensions: - extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) - - msg = self.batch_manage(Action.LOAD, *extensions) - await ctx.send(msg) - - @extensions_group.command(name="unload", aliases=("ul",)) - async def unload_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Unload currently loaded extensions given their fully qualified or unqualified names. - - If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) - - if blacklisted: - msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" - else: - if "*" in extensions or "**" in extensions: - extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST - - msg = self.batch_manage(Action.UNLOAD, *extensions) - - await ctx.send(msg) - - @extensions_group.command(name="reload", aliases=("r",)) - async def reload_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Reload extensions given their fully qualified or unqualified names. - - If an extension fails to be reloaded, it will be rolled-back to the prior working state. - - If '\*' is given as the name, all currently loaded extensions will be reloaded. - If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - if "**" in extensions: - extensions = EXTENSIONS - elif "*" in extensions: - extensions = set(self.bot.extensions.keys()) | set(extensions) - extensions.remove("*") - - msg = self.batch_manage(Action.RELOAD, *extensions) - - await ctx.send(msg) - - @extensions_group.command(name="list", aliases=("all",)) - async def list_command(self, ctx: Context) -> None: - """ - Get a list of all extensions, including their loaded status. - - Grey indicates that the extension is unloaded. - Green indicates that the extension is currently loaded. - """ - embed = Embed(colour=Colour.blurple()) - embed.set_author( - name="Extensions List", - url=URLs.github_bot_repo, - icon_url=URLs.bot_avatar - ) - - lines = [] - categories = self.group_extension_statuses() - for category, extensions in sorted(categories.items()): - # Treat each category as a single line by concatenating everything. - # This ensures the paginator will not cut off a page in the middle of a category. - category = category.replace("_", " ").title() - extensions = "\n".join(sorted(extensions)) - lines.append(f"**{category}**\n{extensions}\n") - - log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") - await LinePaginator.paginate(lines, ctx, embed, scale_to_size=700, empty=False) - - def group_extension_statuses(self) -> t.Mapping[str, str]: - """Return a mapping of extension names and statuses to their categories.""" - categories = {} - - for ext in EXTENSIONS: - if ext in self.bot.extensions: - status = Emojis.status_online - else: - status = Emojis.status_offline - - path = ext.split(".") - if len(path) > COG_PATH_LEN + 1: - category = " - ".join(path[COG_PATH_LEN:-1]) - else: - category = "uncategorised" - - categories.setdefault(category, []).append(f"{status} {path[-1]}") - - return categories - - def batch_manage(self, action: Action, *extensions: str) -> str: - """ - Apply an action to multiple extensions and return a message with the results. - - If only one extension is given, it is deferred to `manage()`. - """ - if len(extensions) == 1: - msg, _ = self.manage(action, extensions[0]) - return msg - - verb = action.name.lower() - failures = {} - - for extension in extensions: - _, error = self.manage(action, extension) - if error: - failures[extension] = error - - emoji = ":x:" if failures else ":ok_hand:" - msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." - - if failures: - failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) - msg += f"\nFailures:```{failures}```" - - log.debug(f"Batch {verb}ed extensions.") - - return msg - - def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: - """Apply an action to an extension and return the status message and any error message.""" - verb = action.name.lower() - error_msg = None - - try: - action.value(self.bot, ext) - except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): - if action is Action.RELOAD: - # When reloading, just load the extension if it was not loaded. - return self.manage(Action.LOAD, ext) - - msg = f":x: Extension `{ext}` is already {verb}ed." - log.debug(msg[4:]) - except Exception as e: - if hasattr(e, "original"): - e = e.original - - log.exception(f"Extension '{ext}' failed to {verb}.") - - error_msg = f"{e.__class__.__name__}: {e}" - msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" - else: - msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." - log.debug(msg[10:]) - - return msg, error_msg - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators and core developers to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Handle BadArgument errors locally to prevent the help command from showing.""" - if isinstance(error, commands.BadArgument): - await ctx.send(str(error)) - error.handled = True - - -def setup(bot: Bot) -> None: - """Load the Extensions cog.""" - bot.add_cog(Extensions(bot)) diff --git a/bot/cogs/utils/jams.py b/bot/cogs/utils/jams.py deleted file mode 100644 index b3102db2f..000000000 --- a/bot/cogs/utils/jams.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -import typing as t - -from discord import CategoryChannel, Guild, Member, PermissionOverwrite, Role -from discord.ext import commands -from more_itertools import unique_everseen - -from bot.bot import Bot -from bot.constants import Roles -from bot.decorators import with_role - -log = logging.getLogger(__name__) - -MAX_CHANNELS = 50 -CATEGORY_NAME = "Code Jam" - - -class CodeJams(commands.Cog): - """Manages the code-jam related parts of our server.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @commands.command() - @with_role(Roles.admins) - async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: - """ - Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. - - The first user passed will always be the team leader. - """ - # Ignore duplicate members - members = list(unique_everseen(members)) - - # We had a little issue during Code Jam 4 here, the greedy converter did it's job - # and ignored anything which wasn't a valid argument which left us with teams of - # two members or at some times even 1 member. This fixes that by checking that there - # are always 3 members in the members list. - if len(members) < 3: - await ctx.send( - ":no_entry_sign: One of your arguments was invalid\n" - f"There must be a minimum of 3 valid members in your team. Found: {len(members)}" - " members" - ) - return - - team_channel = await self.create_channels(ctx.guild, team_name, members) - await self.add_roles(ctx.guild, members) - - await ctx.send( - f":ok_hand: Team created: {team_channel}\n" - f"**Team Leader:** {members[0].mention}\n" - f"**Team Members:** {' '.join(member.mention for member in members[1:])}" - ) - - async def get_category(self, guild: Guild) -> CategoryChannel: - """ - Return a code jam category. - - If all categories are full or none exist, create a new category. - """ - for category in guild.categories: - # Need 2 available spaces: one for the text channel and one for voice. - if category.name == CATEGORY_NAME and MAX_CHANNELS - len(category.channels) >= 2: - return category - - return await self.create_category(guild) - - @staticmethod - async def create_category(guild: Guild) -> CategoryChannel: - """Create a new code jam category and return it.""" - log.info("Creating a new code jam category.") - - category_overwrites = { - guild.default_role: PermissionOverwrite(read_messages=False), - guild.me: PermissionOverwrite(read_messages=True) - } - - return await guild.create_category_channel( - CATEGORY_NAME, - overwrites=category_overwrites, - reason="It's code jam time!" - ) - - @staticmethod - def get_overwrites(members: t.List[Member], guild: Guild) -> t.Dict[t.Union[Member, Role], PermissionOverwrite]: - """Get code jam team channels permission overwrites.""" - # First member is always the team leader - team_channel_overwrites = { - members[0]: PermissionOverwrite( - manage_messages=True, - read_messages=True, - manage_webhooks=True, - connect=True - ), - guild.default_role: PermissionOverwrite(read_messages=False, connect=False), - guild.get_role(Roles.verified): PermissionOverwrite( - read_messages=False, - connect=False - ) - } - - # Rest of members should just have read_messages - for member in members[1:]: - team_channel_overwrites[member] = PermissionOverwrite( - read_messages=True, - connect=True - ) - - return team_channel_overwrites - - async def create_channels(self, guild: Guild, team_name: str, members: t.List[Member]) -> str: - """Create team text and voice channels. Return the mention for the text channel.""" - # Get permission overwrites and category - team_channel_overwrites = self.get_overwrites(members, guild) - code_jam_category = await self.get_category(guild) - - # Create a text channel for the team - team_channel = await guild.create_text_channel( - team_name, - overwrites=team_channel_overwrites, - category=code_jam_category - ) - - # Create a voice channel for the team - team_voice_name = " ".join(team_name.split("-")).title() - - await guild.create_voice_channel( - team_voice_name, - overwrites=team_channel_overwrites, - category=code_jam_category - ) - - return team_channel.mention - - @staticmethod - async def add_roles(guild: Guild, members: t.List[Member]) -> None: - """Assign team leader and jammer roles.""" - # Assign team leader role - await members[0].add_roles(guild.get_role(Roles.team_leaders)) - - # Assign rest of roles - jammer_role = guild.get_role(Roles.jammers) - for member in members: - await member.add_roles(jammer_role) - - -def setup(bot: Bot) -> None: - """Load the CodeJams cog.""" - bot.add_cog(CodeJams(bot)) diff --git a/bot/cogs/utils/reminders.py b/bot/cogs/utils/reminders.py deleted file mode 100644 index 670493bcf..000000000 --- a/bot/cogs/utils/reminders.py +++ /dev/null @@ -1,427 +0,0 @@ -import asyncio -import logging -import random -import textwrap -import typing as t -from datetime import datetime, timedelta -from operator import itemgetter - -import discord -from dateutil.parser import isoparse -from dateutil.relativedelta import relativedelta -from discord.ext.commands import Cog, Context, Greedy, group - -from bot.bot import Bot -from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES -from bot.converters import Duration -from bot.pagination import LinePaginator -from bot.utils.checks import without_role_check -from bot.utils.messages import send_denial -from bot.utils.scheduling import Scheduler -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -WHITELISTED_CHANNELS = Guild.reminder_whitelist -MAXIMUM_REMINDERS = 5 - -Mentionable = t.Union[discord.Member, discord.Role] - - -class Reminders(Cog): - """Provide in-channel reminder functionality.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_reminders()) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - async def reschedule_reminders(self) -> None: - """Get all current reminders from the API and reschedule them.""" - await self.bot.wait_until_guild_available() - response = await self.bot.api_client.get( - 'bot/reminders', - params={'active': 'true'} - ) - - now = datetime.utcnow() - - for reminder in response: - is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) - if not is_valid: - continue - - remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) - - # If the reminder is already overdue ... - if remind_at < now: - late = relativedelta(now, remind_at) - await self.send_reminder(reminder, late) - else: - self.schedule_reminder(reminder) - - def ensure_valid_reminder( - self, - reminder: dict, - cancel_task: bool = True - ) -> t.Tuple[bool, discord.User, discord.TextChannel]: - """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" - user = self.bot.get_user(reminder['author']) - channel = self.bot.get_channel(reminder['channel_id']) - is_valid = True - if not user or not channel: - is_valid = False - log.info( - f"Reminder {reminder['id']} invalid: " - f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." - ) - asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) - - return is_valid, user, channel - - @staticmethod - async def _send_confirmation( - ctx: Context, - on_success: str, - reminder_id: str, - delivery_dt: t.Optional[datetime], - ) -> None: - """Send an embed confirming the reminder change was made successfully.""" - embed = discord.Embed() - embed.colour = discord.Colour.green() - embed.title = random.choice(POSITIVE_REPLIES) - embed.description = on_success - - footer_str = f"ID: {reminder_id}" - if delivery_dt: - # Reminder deletion will have a `None` `delivery_dt` - footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" - - embed.set_footer(text=footer_str) - - await ctx.send(embed=embed) - - @staticmethod - async def _check_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> t.Tuple[bool, str]: - """ - Returns whether or not the list of mentions is allowed. - - Conditions: - - Role reminders are Mods+ - - Reminders for other users are Helpers+ - - If mentions aren't allowed, also return the type of mention(s) disallowed. - """ - if without_role_check(ctx, *STAFF_ROLES): - return False, "members/roles" - elif without_role_check(ctx, *MODERATION_ROLES): - return all(isinstance(mention, discord.Member) for mention in mentions), "roles" - else: - return True, "" - - @staticmethod - async def validate_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> bool: - """ - Filter mentions to see if the user can mention, and sends a denial if not allowed. - - Returns whether or not the validation is successful. - """ - mentions_allowed, disallowed_mentions = await Reminders._check_mentions(ctx, mentions) - - if not mentions or mentions_allowed: - return True - else: - await send_denial(ctx, f"You can't mention other {disallowed_mentions} in your reminder!") - return False - - def get_mentionables(self, mention_ids: t.List[int]) -> t.Iterator[Mentionable]: - """Converts Role and Member ids to their corresponding objects if possible.""" - guild = self.bot.get_guild(Guild.id) - for mention_id in mention_ids: - if (mentionable := (guild.get_member(mention_id) or guild.get_role(mention_id))): - yield mentionable - - def schedule_reminder(self, reminder: dict) -> None: - """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" - reminder_id = reminder["id"] - reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) - - async def _remind() -> None: - await self.send_reminder(reminder) - - log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") - await self._delete_reminder(reminder_id) - - self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) - - async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: - """Delete a reminder from the database, given its ID, and cancel the running task.""" - await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) - - if cancel_task: - # Now we can remove it from the schedule list - self.scheduler.cancel(reminder_id) - - async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: - """ - Edits a reminder in the database given the ID and payload. - - Returns the edited reminder. - """ - # Send the request to update the reminder in the database - reminder = await self.bot.api_client.patch( - 'bot/reminders/' + str(reminder_id), - json=payload - ) - return reminder - - async def _reschedule_reminder(self, reminder: dict) -> None: - """Reschedule a reminder object.""" - log.trace(f"Cancelling old task #{reminder['id']}") - self.scheduler.cancel(reminder["id"]) - - log.trace(f"Scheduling new task #{reminder['id']}") - self.schedule_reminder(reminder) - - async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: - """Send the reminder.""" - is_valid, user, channel = self.ensure_valid_reminder(reminder) - if not is_valid: - return - - embed = discord.Embed() - embed.colour = discord.Colour.blurple() - embed.set_author( - icon_url=Icons.remind_blurple, - name="It has arrived!" - ) - - embed.description = f"Here's your reminder: `{reminder['content']}`." - - if reminder.get("jump_url"): # keep backward compatibility - embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" - - if late: - embed.colour = discord.Colour.red() - embed.set_author( - icon_url=Icons.remind_red, - name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" - ) - - additional_mentions = ' '.join( - mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) - ) - - await channel.send( - content=f"{user.mention} {additional_mentions}", - embed=embed - ) - await self._delete_reminder(reminder["id"]) - - @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) - async def remind_group( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str - ) -> None: - """Commands for managing your reminders.""" - await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content) - - @remind_group.command(name="new", aliases=("add", "create")) - async def new_reminder( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str - ) -> None: - """ - Set yourself a simple reminder. - - Expiration is parsed per: http://strftime.org/ - """ - # If the user is not staff, we need to verify whether or not to make a reminder at all. - if without_role_check(ctx, *STAFF_ROLES): - - # If they don't have permission to set a reminder in this channel - if ctx.channel.id not in WHITELISTED_CHANNELS: - await send_denial(ctx, "Sorry, you can't do that here!") - return - - # Get their current active reminders - active_reminders = await self.bot.api_client.get( - 'bot/reminders', - params={ - 'author__id': str(ctx.author.id) - } - ) - - # Let's limit this, so we don't get 10 000 - # reminders from kip or something like that :P - if len(active_reminders) > MAXIMUM_REMINDERS: - await send_denial(ctx, "You have too many active reminders!") - return - - # Remove duplicate mentions - mentions = set(mentions) - mentions.discard(ctx.author) - - # Filter mentions to see if the user can mention members/roles - if not await self.validate_mentions(ctx, mentions): - return - - mention_ids = [mention.id for mention in mentions] - - # Now we can attempt to actually set the reminder. - reminder = await self.bot.api_client.post( - 'bot/reminders', - json={ - 'author': ctx.author.id, - 'channel_id': ctx.message.channel.id, - 'jump_url': ctx.message.jump_url, - 'content': content, - 'expiration': expiration.isoformat(), - 'mentions': mention_ids, - } - ) - - now = datetime.utcnow() - timedelta(seconds=1) - humanized_delta = humanize_delta(relativedelta(expiration, now)) - mention_string = ( - f"Your reminder will arrive in {humanized_delta} " - f"and will mention {len(mentions)} other(s)!" - ) - - # Confirm to the user that it worked. - await self._send_confirmation( - ctx, - on_success=mention_string, - reminder_id=reminder["id"], - delivery_dt=expiration, - ) - - self.schedule_reminder(reminder) - - @remind_group.command(name="list") - async def list_reminders(self, ctx: Context) -> None: - """View a paginated embed of all reminders for your user.""" - # Get all the user's reminders from the database. - data = await self.bot.api_client.get( - 'bot/reminders', - params={'author__id': str(ctx.author.id)} - ) - - now = datetime.utcnow() - - # Make a list of tuples so it can be sorted by time. - reminders = sorted( - ( - (rem['content'], rem['expiration'], rem['id'], rem['mentions']) - for rem in data - ), - key=itemgetter(1) - ) - - lines = [] - - for content, remind_at, id_, mentions in reminders: - # Parse and humanize the time, make it pretty :D - remind_datetime = isoparse(remind_at).replace(tzinfo=None) - time = humanize_delta(relativedelta(remind_datetime, now)) - - mentions = ", ".join( - # Both Role and User objects have the `name` attribute - mention.name for mention in self.get_mentionables(mentions) - ) - mention_string = f"\n**Mentions:** {mentions}" if mentions else "" - - text = textwrap.dedent(f""" - **Reminder #{id_}:** *expires in {time}* (ID: {id_}){mention_string} - {content} - """).strip() - - lines.append(text) - - embed = discord.Embed() - embed.colour = discord.Colour.blurple() - embed.title = f"Reminders for {ctx.author}" - - # Remind the user that they have no reminders :^) - if not lines: - embed.description = "No active reminders could be found." - await ctx.send(embed=embed) - return - - # Construct the embed and paginate it. - embed.colour = discord.Colour.blurple() - - await LinePaginator.paginate( - lines, - ctx, embed, - max_lines=3, - empty=True - ) - - @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) - async def edit_reminder_group(self, ctx: Context) -> None: - """Commands for modifying your current reminders.""" - await ctx.send_help(ctx.command) - - @edit_reminder_group.command(name="duration", aliases=("time",)) - async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: - """ - Edit one of your reminder's expiration. - - Expiration is parsed per: http://strftime.org/ - """ - await self.edit_reminder(ctx, id_, {'expiration': expiration.isoformat()}) - - @edit_reminder_group.command(name="content", aliases=("reason",)) - async def edit_reminder_content(self, ctx: Context, id_: int, *, content: str) -> None: - """Edit one of your reminder's content.""" - await self.edit_reminder(ctx, id_, {"content": content}) - - @edit_reminder_group.command(name="mentions", aliases=("pings",)) - async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: - """Edit one of your reminder's mentions.""" - # Remove duplicate mentions - mentions = set(mentions) - mentions.discard(ctx.author) - - # Filter mentions to see if the user can mention members/roles - if not await self.validate_mentions(ctx, mentions): - return - - mention_ids = [mention.id for mention in mentions] - await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) - - async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: - """Edits a reminder with the given payload, then sends a confirmation message.""" - reminder = await self._edit_reminder(id_, payload) - - # Parse the reminder expiration back into a datetime - expiration = isoparse(reminder["expiration"]).replace(tzinfo=None) - - # Send a confirmation message to the channel - await self._send_confirmation( - ctx, - on_success="That reminder has been edited successfully!", - reminder_id=id_, - delivery_dt=expiration, - ) - await self._reschedule_reminder(reminder) - - @remind_group.command("delete", aliases=("remove", "cancel")) - async def delete_reminder(self, ctx: Context, id_: int) -> None: - """Delete one of your active reminders.""" - await self._delete_reminder(id_) - await self._send_confirmation( - ctx, - on_success="That reminder has been deleted successfully!", - reminder_id=id_, - delivery_dt=None, - ) - - -def setup(bot: Bot) -> None: - """Load the Reminders cog.""" - bot.add_cog(Reminders(bot)) diff --git a/bot/cogs/utils/snekbox.py b/bot/cogs/utils/snekbox.py deleted file mode 100644 index 52c8b6f88..000000000 --- a/bot/cogs/utils/snekbox.py +++ /dev/null @@ -1,349 +0,0 @@ -import asyncio -import contextlib -import datetime -import logging -import re -import textwrap -from functools import partial -from signal import Signals -from typing import Optional, Tuple - -from discord import HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only - -from bot.bot import Bot -from bot.constants import Categories, Channels, Roles, URLs -from bot.decorators import in_whitelist -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") -FORMATTED_CODE_REGEX = re.compile( - r"^\s*" # any leading whitespace from the beginning of the string - r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block - r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) - r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all code inside the markup - r"\s*" # any more whitespace before the end of the code markup - r"(?P=delim)" # match the exact same delimiter from the start again - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive -) -RAW_CODE_REGEX = re.compile( - r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all the rest as code - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL # "." also matches newlines -) - -MAX_PASTE_LEN = 1000 - -# `!eval` command whitelists -EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) -EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) -EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) - -SIGKILL = 9 - -REEVAL_EMOJI = '\U0001f501' # :repeat: -REEVAL_TIMEOUT = 30 - - -class Snekbox(Cog): - """Safe evaluation of Python code using Snekbox.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.jobs = {} - - async def post_eval(self, code: str) -> dict: - """Send a POST request to the Snekbox API to evaluate code and return the results.""" - url = URLs.snekbox_eval_api - data = {"input": code} - async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return await resp.json() - - async def upload_output(self, output: str) -> Optional[str]: - """Upload the eval output to a paste service and return a URL to it if successful.""" - log.trace("Uploading full output to paste service...") - - if len(output) > MAX_PASTE_LEN: - log.info("Full output is too long to upload") - return "too long to upload" - - url = URLs.paste_service.format(key="documents") - try: - async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: - data = await resp.json() - - if "key" in data: - return URLs.paste_service.format(key=data["key"]) - except Exception: - # 400 (Bad Request) means there are too many characters - log.exception("Failed to upload full output to paste service!") - - @staticmethod - def prepare_input(code: str) -> str: - """Extract code from the Markdown, format it, and insert it into the code template.""" - match = FORMATTED_CODE_REGEX.fullmatch(code) - if match: - code, block, lang, delim = match.group("code", "block", "lang", "delim") - code = textwrap.dedent(code) - if block: - info = (f"'{lang}' highlighted" if lang else "plain") + " code block" - else: - info = f"{delim}-enclosed inline code" - log.trace(f"Extracted {info} for evaluation:\n{code}") - else: - code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) - log.trace( - f"Eval message contains unformatted or badly formatted code, " - f"stripping whitespace only:\n{code}" - ) - - return code - - @staticmethod - def get_results_message(results: dict) -> Tuple[str, str]: - """Return a user-friendly message and error corresponding to the process's return code.""" - stdout, returncode = results["stdout"], results["returncode"] - msg = f"Your eval job has completed with return code {returncode}" - error = "" - - if returncode is None: - msg = "Your eval job has failed" - error = stdout.strip() - elif returncode == 128 + SIGKILL: - msg = "Your eval job timed out or ran out of memory" - elif returncode == 255: - msg = "Your eval job has failed" - error = "A fatal NsJail error occurred" - else: - # Try to append signal's name if one exists - try: - name = Signals(returncode - 128).name - msg = f"{msg} ({name})" - except ValueError: - pass - - return msg, error - - @staticmethod - def get_status_emoji(results: dict) -> str: - """Return an emoji corresponding to the status code or lack of output in result.""" - if not results["stdout"].strip(): # No output - return ":warning:" - elif results["returncode"] == 0: # No error - return ":white_check_mark:" - else: # Exception - return ":x:" - - async def format_output(self, output: str) -> Tuple[str, Optional[str]]: - """ - Format the output and return a tuple of the formatted output and a URL to the full output. - - Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters - and upload the full output to a paste service. - """ - log.trace("Formatting output...") - - output = output.rstrip("\n") - original_output = output # To be uploaded to a pasting service if needed - paste_link = None - - if "<@" in output: - output = output.replace("<@", "<@\u200B") # Zero-width space - - if " 0: - output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] - output = output[:11] # Limiting to only 11 lines - output = "\n".join(output) - - if lines > 10: - truncated = True - if len(output) >= 1000: - output = f"{output[:1000]}\n... (truncated - too long, too many lines)" - else: - output = f"{output}\n... (truncated - too many lines)" - elif len(output) >= 1000: - truncated = True - output = f"{output[:1000]}\n... (truncated - too long)" - - if truncated: - paste_link = await self.upload_output(original_output) - - output = output or "[No output]" - - return output, paste_link - - async def send_eval(self, ctx: Context, code: str) -> Message: - """ - Evaluate code, format it, and send the output to the corresponding channel. - - Return the bot response. - """ - async with ctx.typing(): - results = await self.post_eval(code) - msg, error = self.get_results_message(results) - - if error: - output, paste_link = error, None - else: - output, paste_link = await self.format_output(results["stdout"]) - - icon = self.get_status_emoji(results) - msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" - if paste_link: - msg = f"{msg}\nFull output: {paste_link}" - - # Collect stats of eval fails + successes - if icon == ":x:": - self.bot.stats.incr("snekbox.python.fail") - else: - self.bot.stats.incr("snekbox.python.success") - - filter_cog = self.bot.get_cog("Filtering") - filter_triggered = False - if filter_cog: - filter_triggered = await filter_cog.filter_eval(msg, ctx.message) - if filter_triggered: - response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") - else: - response = await ctx.send(msg) - self.bot.loop.create_task( - wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) - ) - - log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") - return response - - async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: - """ - Check if the eval session should continue. - - Return the new code to evaluate or None if the eval session should be terminated. - """ - _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) - _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) - - with contextlib.suppress(NotFound): - try: - _, new_message = await self.bot.wait_for( - 'message_edit', - check=_predicate_eval_message_edit, - timeout=REEVAL_TIMEOUT - ) - await ctx.message.add_reaction(REEVAL_EMOJI) - await self.bot.wait_for( - 'reaction_add', - check=_predicate_emoji_reaction, - timeout=10 - ) - - code = await self.get_code(new_message) - await ctx.message.clear_reactions() - with contextlib.suppress(HTTPException): - await response.delete() - - except asyncio.TimeoutError: - await ctx.message.clear_reactions() - return None - - return code - - async def get_code(self, message: Message) -> Optional[str]: - """ - Return the code from `message` to be evaluated. - - If the message is an invocation of the eval command, return the first argument or None if it - doesn't exist. Otherwise, return the full content of the message. - """ - log.trace(f"Getting context for message {message.id}.") - new_ctx = await self.bot.get_context(message) - - if new_ctx.command is self.eval_command: - log.trace(f"Message {message.id} invokes eval command.") - split = message.content.split(maxsplit=1) - code = split[1] if len(split) > 1 else None - else: - log.trace(f"Message {message.id} does not invoke eval command.") - code = message.content - - return code - - @command(name="eval", aliases=("e",)) - @guild_only() - @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) - async def eval_command(self, ctx: Context, *, code: str = None) -> None: - """ - Run Python code and get the results. - - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! - """ - if ctx.author.id in self.jobs: - await ctx.send( - f"{ctx.author.mention} You've already got a job running - " - "please wait for it to finish!" - ) - return - - if not code: # None or empty string - await ctx.send_help(ctx.command) - return - - if Roles.helpers in (role.id for role in ctx.author.roles): - self.bot.stats.incr("snekbox_usages.roles.helpers") - else: - self.bot.stats.incr("snekbox_usages.roles.developers") - - if ctx.channel.category_id == Categories.help_in_use: - self.bot.stats.incr("snekbox_usages.channels.help") - elif ctx.channel.id == Channels.bot_commands: - self.bot.stats.incr("snekbox_usages.channels.bot_commands") - else: - self.bot.stats.incr("snekbox_usages.channels.topical") - - log.info(f"Received code from {ctx.author} for evaluation:\n{code}") - - while True: - self.jobs[ctx.author.id] = datetime.datetime.now() - code = self.prepare_input(code) - try: - response = await self.send_eval(ctx, code) - finally: - del self.jobs[ctx.author.id] - - code = await self.continue_eval(ctx, response) - if not code: - break - log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") - - -def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: - """Return True if the edited message is the context message and the content was indeed modified.""" - return new_msg.id == ctx.message.id and old_msg.content != new_msg.content - - -def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: - """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI - - -def setup(bot: Bot) -> None: - """Load the Snekbox cog.""" - bot.add_cog(Snekbox(bot)) diff --git a/bot/cogs/utils/utils.py b/bot/cogs/utils/utils.py deleted file mode 100644 index d96abbd5a..000000000 --- a/bot/cogs/utils/utils.py +++ /dev/null @@ -1,265 +0,0 @@ -import difflib -import logging -import re -import unicodedata -from email.parser import HeaderParser -from io import StringIO -from typing import Tuple, Union - -from discord import Colour, Embed, utils -from discord.ext.commands import BadArgument, Cog, Context, clean_content, command - -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES -from bot.decorators import in_whitelist, with_role -from bot.pagination import LinePaginator -from bot.utils import messages - -log = logging.getLogger(__name__) - -ZEN_OF_PYTHON = """\ -Beautiful is better than ugly. -Explicit is better than implicit. -Simple is better than complex. -Complex is better than complicated. -Flat is better than nested. -Sparse is better than dense. -Readability counts. -Special cases aren't special enough to break the rules. -Although practicality beats purity. -Errors should never pass silently. -Unless explicitly silenced. -In the face of ambiguity, refuse the temptation to guess. -There should be one-- and preferably only one --obvious way to do it. -Although that way may not be obvious at first unless you're Dutch. -Now is better than never. -Although never is often better than *right* now. -If the implementation is hard to explain, it's a bad idea. -If the implementation is easy to explain, it may be a good idea. -Namespaces are one honking great idea -- let's do more of those! -""" - -ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" - - -class Utils(Cog): - """A selection of utilities which don't have a clear category.""" - - def __init__(self, bot: Bot): - self.bot = bot - - self.base_pep_url = "http://www.python.org/dev/peps/pep-" - self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" - - @command(name='pep', aliases=('get_pep', 'p')) - async def pep_command(self, ctx: Context, pep_number: str) -> None: - """Fetches information about a PEP and sends it to the channel.""" - if pep_number.isdigit(): - pep_number = int(pep_number) - else: - await ctx.send_help(ctx.command) - return - - # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. - if pep_number == 0: - return await self.send_pep_zero(ctx) - - possible_extensions = ['.txt', '.rst'] - found_pep = False - for extension in possible_extensions: - # Attempt to fetch the PEP - pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" - log.trace(f"Requesting PEP {pep_number} with {pep_url}") - response = await self.bot.http_session.get(pep_url) - - if response.status == 200: - log.trace("PEP found") - found_pep = True - - pep_content = await response.text() - - # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 - pep_header = HeaderParser().parse(StringIO(pep_content)) - - # Assemble the embed - pep_embed = Embed( - title=f"**PEP {pep_number} - {pep_header['Title']}**", - description=f"[Link]({self.base_pep_url}{pep_number:04})", - ) - - pep_embed.set_thumbnail(url=ICON_URL) - - # Add the interesting information - fields_to_check = ("Status", "Python-Version", "Created", "Type") - for field in fields_to_check: - # Check for a PEP metadata field that is present but has an empty value - # embed field values can't contain an empty string - if pep_header.get(field, ""): - pep_embed.add_field(name=field, value=pep_header[field]) - - elif response.status != 404: - # any response except 200 and 404 is expected - found_pep = True # actually not, but it's easier to display this way - log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " - f"{response.status}.\n{response.text}") - - error_message = "Unexpected HTTP error during PEP search. Please let us know." - pep_embed = Embed(title="Unexpected error", description=error_message) - pep_embed.colour = Colour.red() - break - - if not found_pep: - log.trace("PEP was not found") - not_found = f"PEP {pep_number} does not exist." - pep_embed = Embed(title="PEP not found", description=not_found) - pep_embed.colour = Colour.red() - - await ctx.message.channel.send(embed=pep_embed) - - @command() - @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) - async def charinfo(self, ctx: Context, *, characters: str) -> None: - """Shows you information on up to 50 unicode characters.""" - match = re.match(r"<(a?):(\w+):(\d+)>", characters) - if match: - return await messages.send_denial( - ctx, - "**Non-Character Detected**\n" - "Only unicode characters can be processed, but a custom Discord emoji " - "was found. Please remove it and try again." - ) - - if len(characters) > 50: - return await messages.send_denial(ctx, f"Too many characters ({len(characters)}/50)") - - def get_info(char: str) -> Tuple[str, str]: - digit = f"{ord(char):x}" - if len(digit) <= 4: - u_code = f"\\u{digit:>04}" - else: - u_code = f"\\U{digit:>08}" - url = f"https://www.compart.com/en/unicode/U+{digit:>04}" - name = f"[{unicodedata.name(char, '')}]({url})" - info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}" - return info, u_code - - char_list, raw_list = zip(*(get_info(c) for c in characters)) - embed = Embed().set_author(name="Character Info") - - if len(characters) > 1: - # Maximum length possible is 502 out of 1024, so there's no need to truncate. - embed.add_field(name='Full Raw Text', value=f"`{''.join(raw_list)}`", inline=False) - - await LinePaginator.paginate(char_list, ctx, embed, max_lines=10, max_size=2000, empty=False) - - @command() - async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: - """ - Show the Zen of Python. - - Without any arguments, the full Zen will be produced. - If an integer is provided, the line with that index will be produced. - If a string is provided, the line which matches best will be produced. - """ - embed = Embed( - colour=Colour.blurple(), - title="The Zen of Python", - description=ZEN_OF_PYTHON - ) - - if search_value is None: - embed.title += ", by Tim Peters" - await ctx.send(embed=embed) - return - - zen_lines = ZEN_OF_PYTHON.splitlines() - - # handle if it's an index int - if isinstance(search_value, int): - upper_bound = len(zen_lines) - 1 - lower_bound = -1 * upper_bound - if not (lower_bound <= search_value <= upper_bound): - raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") - - embed.title += f" (line {search_value % len(zen_lines)}):" - embed.description = zen_lines[search_value] - await ctx.send(embed=embed) - return - - # Try to handle first exact word due difflib.SequenceMatched may use some other similar word instead - # exact word. - for i, line in enumerate(zen_lines): - for word in line.split(): - if word.lower() == search_value.lower(): - embed.title += f" (line {i}):" - embed.description = line - await ctx.send(embed=embed) - return - - # handle if it's a search string and not exact word - matcher = difflib.SequenceMatcher(None, search_value.lower()) - - best_match = "" - match_index = 0 - best_ratio = 0 - - for index, line in enumerate(zen_lines): - matcher.set_seq2(line.lower()) - - # the match ratio needs to be adjusted because, naturally, - # longer lines will have worse ratios than shorter lines when - # fuzzy searching for keywords. this seems to work okay. - adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() - - if adjusted_ratio > best_ratio: - best_ratio = adjusted_ratio - best_match = line - match_index = index - - if not best_match: - raise BadArgument("I didn't get a match! Please try again with a different search term.") - - embed.title += f" (line {match_index}):" - embed.description = best_match - await ctx.send(embed=embed) - - @command(aliases=("poll",)) - @with_role(*MODERATION_ROLES) - async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None: - """ - Build a quick voting poll with matching reactions with the provided options. - - A maximum of 20 options can be provided, as Discord supports a max of 20 - reactions on a single message. - """ - if len(title) > 256: - raise BadArgument("The title cannot be longer than 256 characters.") - if len(options) < 2: - raise BadArgument("Please provide at least 2 options.") - if len(options) > 20: - raise BadArgument("I can only handle 20 options!") - - codepoint_start = 127462 # represents "regional_indicator_a" unicode value - options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} - embed = Embed(title=title, description="\n".join(options.values())) - message = await ctx.send(embed=embed) - for reaction in options: - await message.add_reaction(reaction) - - async def send_pep_zero(self, ctx: Context) -> None: - """Send information about PEP 0.""" - pep_embed = Embed( - title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**", - description="[Link](https://www.python.org/dev/peps/)" - ) - pep_embed.set_thumbnail(url=ICON_URL) - pep_embed.add_field(name="Status", value="Active") - pep_embed.add_field(name="Created", value="13-Jul-2000") - pep_embed.add_field(name="Type", value="Informational") - - await ctx.send(embed=pep_embed) - - -def setup(bot: Bot) -> None: - """Load the Utils cog.""" - bot.add_cog(Utils(bot)) diff --git a/bot/exts/__init__.py b/bot/exts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/alias.py b/bot/exts/alias.py new file mode 100644 index 000000000..77867b933 --- /dev/null +++ b/bot/exts/alias.py @@ -0,0 +1,153 @@ +import inspect +import logging + +from discord import Colour, Embed +from discord.ext.commands import ( + Cog, Command, Context, Greedy, + clean_content, command, group, +) + +from bot.bot import Bot +from bot.converters import FetchedMember, TagNameConverter +from bot.exts.utils.extensions import Extension +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + + +class Alias (Cog): + """Aliases for commonly used commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: + """Invokes a command with args and kwargs.""" + log.debug(f"{cmd_name} was invoked through an alias") + cmd = self.bot.get_command(cmd_name) + if not cmd: + return log.info(f'Did not find command "{cmd_name}" to invoke.') + elif not await cmd.can_run(ctx): + return log.info( + f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' + ) + + await ctx.invoke(cmd, *args, **kwargs) + + @command(name='aliases') + async def aliases_command(self, ctx: Context) -> None: + """Show configured aliases on the bot.""" + embed = Embed( + title='Configured aliases', + colour=Colour.blue() + ) + await LinePaginator.paginate( + ( + f"• `{ctx.prefix}{value.name}` " + f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" + for name, value in inspect.getmembers(self) + if isinstance(value, Command) and name.endswith('_alias') + ), + ctx, embed, empty=False, max_lines=20 + ) + + @command(name="resources", aliases=("resource",), hidden=True) + async def site_resources_alias(self, ctx: Context) -> None: + """Alias for invoking site resources.""" + await self.invoke(ctx, "site resources") + + @command(name="tools", hidden=True) + async def site_tools_alias(self, ctx: Context) -> None: + """Alias for invoking site tools.""" + await self.invoke(ctx, "site tools") + + @command(name="watch", hidden=True) + async def bigbrother_watch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking bigbrother watch [user] [reason].""" + await self.invoke(ctx, "bigbrother watch", user, reason=reason) + + @command(name="unwatch", hidden=True) + async def bigbrother_unwatch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking bigbrother unwatch [user] [reason].""" + await self.invoke(ctx, "bigbrother unwatch", user, reason=reason) + + @command(name="home", hidden=True) + async def site_home_alias(self, ctx: Context) -> None: + """Alias for invoking site home.""" + await self.invoke(ctx, "site home") + + @command(name="faq", hidden=True) + async def site_faq_alias(self, ctx: Context) -> None: + """Alias for invoking site faq.""" + await self.invoke(ctx, "site faq") + + @command(name="rules", aliases=("rule",), hidden=True) + async def site_rules_alias(self, ctx: Context, rules: Greedy[int], *_: str) -> None: + """Alias for invoking site rules.""" + await self.invoke(ctx, "site rules", *rules) + + @command(name="reload", hidden=True) + async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: + """Alias for invoking extensions reload [extensions...].""" + await self.invoke(ctx, "extensions reload", *extensions) + + @command(name="defon", hidden=True) + async def defcon_enable_alias(self, ctx: Context) -> None: + """Alias for invoking defcon enable.""" + await self.invoke(ctx, "defcon enable") + + @command(name="defoff", hidden=True) + async def defcon_disable_alias(self, ctx: Context) -> None: + """Alias for invoking defcon disable.""" + await self.invoke(ctx, "defcon disable") + + @command(name="exception", hidden=True) + async def tags_get_traceback_alias(self, ctx: Context) -> None: + """Alias for invoking tags get traceback.""" + await self.invoke(ctx, "tags get", tag_name="traceback") + + @group(name="get", + aliases=("show", "g"), + hidden=True, + invoke_without_command=True) + async def get_group_alias(self, ctx: Context) -> None: + """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" + pass + + @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) + async def tags_get_alias( + self, ctx: Context, *, tag_name: TagNameConverter = None + ) -> None: + """ + Alias for invoking tags get [tag_name]. + + tag_name: str - tag to be viewed. + """ + await self.invoke(ctx, "tags get", tag_name=tag_name) + + @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) + async def docs_get_alias( + self, ctx: Context, symbol: clean_content = None + ) -> None: + """Alias for invoking docs get [symbol].""" + await self.invoke(ctx, "docs get", symbol) + + @command(name="nominate", hidden=True) + async def nomination_add_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking talentpool add [user] [reason].""" + await self.invoke(ctx, "talentpool add", user, reason=reason) + + @command(name="unnominate", hidden=True) + async def nomination_end_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking nomination end [user] [reason].""" + await self.invoke(ctx, "nomination end", user, reason=reason) + + @command(name="nominees", hidden=True) + async def nominees_alias(self, ctx: Context) -> None: + """Alias for invoking tp watched.""" + await self.invoke(ctx, "talentpool watched") + + +def setup(bot: Bot) -> None: + """Load the Alias cog.""" + bot.add_cog(Alias(bot)) diff --git a/bot/exts/backend/__init__.py b/bot/exts/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/backend/config_verifier.py b/bot/exts/backend/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/exts/backend/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): + """Verify config on startup.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + + async def verify_channels(self) -> None: + """ + Verify channels. + + If any channels in config aren't present in server, log them in a warning. + """ + await self.bot.wait_until_guild_available() + server = self.bot.get_guild(constants.Guild.id) + + server_channel_ids = {channel.id for channel in server.channels} + invalid_channels = [ + channel_name for channel_name, channel_id in constants.Channels + if channel_id not in server_channel_ids + ] + + if invalid_channels: + log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: + """Load the ConfigVerifier cog.""" + bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py new file mode 100644 index 000000000..f9d4de638 --- /dev/null +++ b/bot/exts/backend/error_handler.py @@ -0,0 +1,287 @@ +import contextlib +import logging +import typing as t + +from discord import Embed +from discord.ext.commands import Cog, Context, errors +from sentry_sdk import push_scope + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Colours +from bot.converters import TagNameConverter +from bot.utils.checks import InWhitelistCheckFailure + +log = logging.getLogger(__name__) + + +class ErrorHandler(Cog): + """Handles errors emitted from commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_error_embed(self, title: str, body: str) -> Embed: + """Return an embed that contains the exception.""" + return Embed( + title=title, + colour=Colours.soft_red, + description=body + ) + + @Cog.listener() + async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: + """ + Provide generic command error handling. + + Error handling is deferred to any local error handler, if present. This is done by + checking for the presence of a `handled` attribute on the error. + + Error handling emits a single error message in the invoking context `ctx` and a log message, + prioritised as follows: + + 1. If the name fails to match a command: + * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. + Otherwise if it matches a tag, the tag is invoked + * If CommandNotFound is raised when invoking the tag (determined by the presence of the + `invoked_from_error_handler` attribute), this error is treated as being unexpected + and therefore sends an error message + * Commands in the verification channel are ignored + 2. UserInputError: see `handle_user_input_error` + 3. CheckFailure: see `handle_check_failure` + 4. CommandOnCooldown: send an error message in the invoking context + 5. ResponseCodeError: see `handle_api_error` + 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` + """ + command = ctx.command + + if hasattr(e, "handled"): + log.trace(f"Command {command} had its error already handled locally; ignoring.") + return + + if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + if await self.try_silence(ctx): + return + if ctx.channel.id != Channels.verification: + # Try to look for a tag with the command's name + await self.try_get_tag(ctx) + return # Exit early to avoid logging. + elif isinstance(e, errors.UserInputError): + await self.handle_user_input_error(ctx, e) + elif isinstance(e, errors.CheckFailure): + await self.handle_check_failure(ctx, e) + elif isinstance(e, errors.CommandOnCooldown): + await ctx.send(e) + elif isinstance(e, errors.CommandInvokeError): + if isinstance(e.original, ResponseCodeError): + await self.handle_api_error(ctx, e.original) + else: + await self.handle_unexpected_error(ctx, e.original) + return # Exit early to avoid logging. + elif not isinstance(e, errors.DisabledCommand): + # ConversionError, MaxConcurrencyReached, ExtensionError + await self.handle_unexpected_error(ctx, e) + return # Exit early to avoid logging. + + log.debug( + f"Command {command} invoked by {ctx.message.author} with error " + f"{e.__class__.__name__}: {e}" + ) + + @staticmethod + def get_help_command(ctx: Context) -> t.Coroutine: + """Return a prepared `help` command invocation coroutine.""" + if ctx.command: + return ctx.send_help(ctx.command) + + return ctx.send_help() + + async def try_silence(self, ctx: Context) -> bool: + """ + Attempt to invoke the silence or unsilence command if invoke with matches a pattern. + + Respecting the checks if: + * invoked with `shh+` silence channel for amount of h's*2 with max of 15. + * invoked with `unshh+` unsilence channel + Return bool depending on success of command. + """ + command = ctx.invoked_with.lower() + silence_command = self.bot.get_command("silence") + ctx.invoked_from_error_handler = True + try: + if not await silence_command.can_run(ctx): + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + except errors.CommandError: + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + if command.startswith("shh"): + await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) + return True + elif command.startswith("unshh"): + await ctx.invoke(self.bot.get_command("unsilence")) + return True + return False + + async def try_get_tag(self, ctx: Context) -> None: + """ + Attempt to display a tag by interpreting the command name as a tag name. + + The invocation of tags get respects its checks. Any CommandErrors raised will be handled + by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to + the context to prevent infinite recursion in the case of a CommandNotFound exception. + """ + tags_get_command = self.bot.get_command("tags get") + ctx.invoked_from_error_handler = True + + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except errors.CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) + return + + try: + tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) + except errors.BadArgument: + log.debug( + f"{ctx.author} tried to use an invalid command " + f"and the fallback tag failed validation in TagNameConverter." + ) + else: + with contextlib.suppress(ResponseCodeError): + await ctx.invoke(tags_get_command, tag_name=tag_name) + # Return to not raise the exception + return + + async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: + """ + Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. + + * MissingRequiredArgument: send an error message with arg name and the help command + * TooManyArguments: send an error message and the help command + * BadArgument: send an error message and the help command + * BadUnionArgument: send an error message including the error produced by the last converter + * ArgumentParsingError: send an error message + * Other: send an error message and the help command + """ + prepared_help_command = self.get_help_command(ctx) + + if isinstance(e, errors.MissingRequiredArgument): + embed = self._get_error_embed("Missing required argument", e.param.name) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.missing_required_argument") + elif isinstance(e, errors.TooManyArguments): + embed = self._get_error_embed("Too many arguments", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.too_many_arguments") + elif isinstance(e, errors.BadArgument): + embed = self._get_error_embed("Bad argument", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.bad_argument") + elif isinstance(e, errors.BadUnionArgument): + embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") + await ctx.send(embed=embed) + self.bot.stats.incr("errors.bad_union_argument") + elif isinstance(e, errors.ArgumentParsingError): + embed = self._get_error_embed("Argument parsing error", str(e)) + await ctx.send(embed=embed) + self.bot.stats.incr("errors.argument_parsing_error") + else: + embed = self._get_error_embed( + "Input error", + "Something about your input seems off. Check the arguments and try again." + ) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.other_user_input_error") + + @staticmethod + async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: + """ + Send an error message in `ctx` for certain types of CheckFailure. + + The following types are handled: + + * BotMissingPermissions + * BotMissingRole + * BotMissingAnyRole + * NoPrivateMessage + * InWhitelistCheckFailure + """ + bot_missing_errors = ( + errors.BotMissingPermissions, + errors.BotMissingRole, + errors.BotMissingAnyRole + ) + + if isinstance(e, bot_missing_errors): + ctx.bot.stats.incr("errors.bot_permission_error") + await ctx.send( + "Sorry, it looks like I don't have the permissions or roles I need to do that." + ) + elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): + ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") + await ctx.send(e) + + @staticmethod + async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: + """Send an error message in `ctx` for ResponseCodeError and log it.""" + if e.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + log.debug(f"API responded with 404 for command {ctx.command}") + ctx.bot.stats.incr("errors.api_error_404") + elif e.status == 400: + content = await e.response.json() + log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + await ctx.send("According to the API, your request is malformed.") + ctx.bot.stats.incr("errors.api_error_400") + elif 500 <= e.status < 600: + await ctx.send("Sorry, there seems to be an internal issue with the API.") + log.warning(f"API responded with {e.status} for command {ctx.command}") + ctx.bot.stats.incr("errors.api_internal_server_error") + else: + await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") + log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") + ctx.bot.stats.incr(f"errors.api_error_{e.status}") + + @staticmethod + async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: + """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" + await ctx.send( + f"Sorry, an unexpected error occurred. Please let us know!\n\n" + f"```{e.__class__.__name__}: {e}```" + ) + + ctx.bot.stats.incr("errors.unexpected") + + with push_scope() as scope: + scope.user = { + "id": ctx.author.id, + "username": str(ctx.author) + } + + scope.set_tag("command", ctx.command.qualified_name) + scope.set_tag("message_id", ctx.message.id) + scope.set_tag("channel_id", ctx.channel.id) + + scope.set_extra("full_message", ctx.message.content) + + if ctx.guild is not None: + scope.set_extra( + "jump_to", + f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" + ) + + log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) + + +def setup(bot: Bot) -> None: + """Load the ErrorHandler cog.""" + bot.add_cog(ErrorHandler(bot)) diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py new file mode 100644 index 000000000..94fa2b139 --- /dev/null +++ b/bot/exts/backend/logging.py @@ -0,0 +1,42 @@ +import logging + +from discord import Embed +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, DEBUG_MODE + + +log = logging.getLogger(__name__) + + +class Logging(Cog): + """Debug logging module.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.bot.loop.create_task(self.startup_greeting()) + + async def startup_greeting(self) -> None: + """Announce our presence to the configured devlog channel.""" + await self.bot.wait_until_guild_available() + log.info("Bot connected!") + + embed = Embed(description="Connected!") + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=( + "https://raw.githubusercontent.com/" + "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" + ) + ) + + if not DEBUG_MODE: + await self.bot.get_channel(Channels.dev_log).send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the Logging cog.""" + bot.add_cog(Logging(bot)) diff --git a/bot/exts/backend/sync/__init__.py b/bot/exts/backend/sync/__init__.py new file mode 100644 index 000000000..2541beaa8 --- /dev/null +++ b/bot/exts/backend/sync/__init__.py @@ -0,0 +1,7 @@ +from bot.bot import Bot + + +def setup(bot: Bot) -> None: + """Load the Sync cog.""" + from ._cog import Sync + bot.add_cog(Sync(bot)) diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py new file mode 100644 index 000000000..b6068f328 --- /dev/null +++ b/bot/exts/backend/sync/_cog.py @@ -0,0 +1,180 @@ +import logging +from typing import Any, Dict + +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from . import _syncers + +log = logging.getLogger(__name__) + + +class Sync(Cog): + """Captures relevant events and sends them to the site.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.role_syncer = _syncers.RoleSyncer(self.bot) + self.user_syncer = _syncers.UserSyncer(self.bot) + + self.bot.loop.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Syncs the roles/users of the guild with the database.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: + """Send a PATCH request to partially update a user in the database.""" + try: + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) + except ResponseCodeError as e: + if e.response.status != 404: + raise + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") + + @Cog.listener() + async def on_guild_role_create(self, role: Role) -> None: + """Adds newly create role to the database table over the API.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.post( + 'bot/roles', + json={ + 'colour': role.colour.value, + 'id': role.id, + 'name': role.name, + 'permissions': role.permissions.value, + 'position': role.position, + } + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: Role) -> None: + """Deletes role from the database when it's deleted from the guild.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + @Cog.listener() + async def on_guild_role_update(self, before: Role, after: Role) -> None: + """Syncs role with the database if any of the stored attributes were updated.""" + if after.guild.id != constants.Guild.id: + return + + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: + await self.bot.api_client.put( + f'bot/roles/{after.id}', + json={ + 'colour': after.colour.value, + 'id': after.id, + 'name': after.name, + 'permissions': after.permissions.value, + 'position': after.position, + } + ) + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """ + Adds a new user or updates existing user to the database when a member joins the guild. + + If the joining member is a user that is already known to the database (i.e., a user that + previously left), it will update the user's information. If the user is not yet known by + the database, the user is added. + """ + if member.guild.id != constants.Guild.id: + return + + packed = { + 'discriminator': int(member.discriminator), + 'id': member.id, + 'in_guild': True, + 'name': member.name, + 'roles': sorted(role.id for role in member.roles) + } + + got_error = False + + try: + # First try an update of the user to set the `in_guild` field and other + # fields that may have changed since the last time we've seen them. + await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) + + except ResponseCodeError as e: + # If we didn't get 404, something else broke - propagate it up. + if e.response.status != 404: + raise + + got_error = True # yikes + + if got_error: + # If we got `404`, the user is new. Create them. + await self.bot.api_client.post('bot/users', json=packed) + + @Cog.listener() + async def on_member_remove(self, member: Member) -> None: + """Set the in_guild field to False when a member leaves the guild.""" + if member.guild.id != constants.Guild.id: + return + + await self.patch_user(member.id, json={"in_guild": False}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Update the roles of the member in the database if a change is detected.""" + if after.guild.id != constants.Guild.id: + return + + if before.roles != after.roles: + updated_information = {"roles": sorted(role.id for role in after.roles)} + await self.patch_user(after.id, json=updated_information) + + @Cog.listener() + async def on_user_update(self, before: User, after: User) -> None: + """Update the user information in the database if a relevant change is detected.""" + attrs = ("name", "discriminator") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): + updated_information = { + "name": after.name, + "discriminator": int(after.discriminator), + } + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) + + @commands.group(name='sync') + @commands.has_permissions(administrator=True) + async def sync_group(self, ctx: Context) -> None: + """Run synchronizations between the bot and site manually.""" + + @sync_group.command(name='roles') + @commands.has_permissions(administrator=True) + async def sync_roles_command(self, ctx: Context) -> None: + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) + + @sync_group.command(name='users') + @commands.has_permissions(administrator=True) + async def sync_users_command(self, ctx: Context) -> None: + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py new file mode 100644 index 000000000..f7ba811bc --- /dev/null +++ b/bot/exts/backend/sync/_syncers.py @@ -0,0 +1,347 @@ +import abc +import asyncio +import logging +import typing as t +from collections import namedtuple +from functools import partial + +import discord +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot + +log = logging.getLogger(__name__) + +# These objects are declared as namedtuples because tuples are hashable, +# something that we make use of when diffing site roles against guild roles. +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) + + +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' + ) + + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + allowed_roles = [discord.Object(constants.Roles.core_developers)] + message = await channel.send( + f"{self._CORE_DEV_MENTION}{msg_content}", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS + ) + + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except asyncio.TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.info(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/exts/dm_relay.py b/bot/exts/dm_relay.py new file mode 100644 index 000000000..0d8f340b4 --- /dev/null +++ b/bot/exts/dm_relay.py @@ -0,0 +1,124 @@ +import logging +from typing import Optional + +import discord +from discord import Color +from discord.ext import commands +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.converters import UserMentionOrID +from bot.utils import RedisCache +from bot.utils.checks import in_whitelist_check, with_role_check +from bot.utils.messages import send_attachments +from bot.utils.webhooks import send_webhook + +log = logging.getLogger(__name__) + + +class DMRelay(Cog): + """Relay direct messages to and from the bot.""" + + # RedisCache[str, t.Union[discord.User.id, discord.Member.id]] + dm_cache = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.dm_log + self.webhook = None + self.bot.loop.create_task(self.fetch_webhook()) + + @commands.command(aliases=("reply",)) + async def send_dm(self, ctx: commands.Context, member: Optional[UserMentionOrID], *, message: str) -> None: + """ + Allows you to send a DM to a user from the bot. + + If `member` is not provided, it will send to the last user who DM'd the bot. + + This feature should be used extremely sparingly. Use ModMail if you need to have a serious + conversation with a user. This is just for responding to extraordinary DMs, having a little + fun with users, and telling people they are DMing the wrong bot. + + NOTE: This feature will be removed if it is overused. + """ + if not member: + user_id = await self.dm_cache.get("last_user") + member = ctx.guild.get_member(user_id) if user_id else None + + # If we still don't have a Member at this point, give up + if not member: + log.debug("This bot has never gotten a DM, or the RedisCache has been cleared.") + await ctx.message.add_reaction("❌") + return + + try: + await member.send(message) + except discord.errors.Forbidden: + log.debug("User has disabled DMs.") + await ctx.message.add_reaction("❌") + else: + await ctx.message.add_reaction("✅") + self.bot.stats.incr("dm_relay.dm_sent") + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_guild_available() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Relays the message's content and attachments to the dm_log channel.""" + # Only relay DMs from humans + if message.author.bot or message.guild or self.webhook is None: + return + + if message.clean_content: + await send_webhook( + webhook=self.webhook, + content=message.clean_content, + username=f"{message.author.display_name} ({message.author.id})", + avatar_url=message.author.avatar_url + ) + await self.dm_cache.set("last_user", message.author.id) + self.bot.stats.incr("dm_relay.dm_received") + + # Handle any attachments + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (discord.errors.Forbidden, discord.errors.NotFound): + e = discord.Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await send_webhook( + webhook=self.webhook, + embed=e, + username=f"{message.author.display_name} ({message.author.id})", + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception("Failed to send an attachment to the webhook") + + def cog_check(self, ctx: commands.Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_whitelist_check( + ctx, + channels=[constants.Channels.dm_log], + redirect=None, + fail_silently=True, + ) + ] + return all(checks) + + +def setup(bot: Bot) -> None: + """Load the DMRelay cog.""" + bot.add_cog(DMRelay(bot)) diff --git a/bot/exts/duck_pond.py b/bot/exts/duck_pond.py new file mode 100644 index 000000000..7021069fa --- /dev/null +++ b/bot/exts/duck_pond.py @@ -0,0 +1,166 @@ +import logging +from typing import Union + +import discord +from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.utils.messages import send_attachments +from bot.utils.webhooks import send_webhook + +log = logging.getLogger(__name__) + + +class DuckPond(Cog): + """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.duck_pond + self.webhook = None + self.bot.loop.create_task(self.fetch_webhook()) + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_guild_available() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @staticmethod + def is_staff(member: Union[User, Member]) -> bool: + """Check if a specific member or user is staff.""" + if hasattr(member, "roles"): + for role in member.roles: + if role.id in constants.STAFF_ROLES: + return True + return False + + async def has_green_checkmark(self, message: Message) -> bool: + """Check if the message has a green checkmark reaction.""" + for reaction in message.reactions: + if reaction.emoji == "✅": + async for user in reaction.users(): + if user == self.bot.user: + return True + return False + + async def count_ducks(self, message: Message) -> int: + """ + Count the number of ducks in the reactions of a specific message. + + Only counts ducks added by staff members. + """ + duck_count = 0 + duck_reactors = [] + + for reaction in message.reactions: + async for user in reaction.users(): + + # Is the user a staff member and not already counted as reactor? + if not self.is_staff(user) or user.id in duck_reactors: + continue + + # Is the emoji a duck? + if hasattr(reaction.emoji, "id"): + if reaction.emoji.id in constants.DuckPond.custom_emojis: + duck_count += 1 + duck_reactors.append(user.id) + elif isinstance(reaction.emoji, str): + if reaction.emoji == "🦆": + duck_count += 1 + duck_reactors.append(user.id) + return duck_count + + async def relay_message(self, message: Message) -> None: + """Relays the message's content and attachments to the duck pond channel.""" + if message.clean_content: + await send_webhook( + webhook=self.webhook, + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await send_webhook( + webhook=self.webhook, + embed=e, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception("Failed to send an attachment to the webhook") + + await message.add_reaction("✅") + + @staticmethod + def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: + """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" + if payload.emoji.is_custom_emoji(): + if payload.emoji.id in constants.DuckPond.custom_emojis: + return True + elif payload.emoji.name == "🦆": + return True + + return False + + @Cog.listener() + async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: + """ + Determine if a message should be sent to the duck pond. + + This will count the number of duck reactions on the message, and if this amount meets the + amount of ducks specified in the config under duck_pond/threshold, it will + send the message off to the duck pond. + """ + # Is the emoji in the reaction a duck? + if not self._payload_has_duckpond_emoji(payload): + return + + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + message = await channel.fetch_message(payload.message_id) + member = discord.utils.get(message.guild.members, id=payload.user_id) + + # Is the member a human and a staff member? + if not self.is_staff(member) or member.bot: + return + + # Does the message already have a green checkmark? + if await self.has_green_checkmark(message): + return + + # Time to count our ducks! + duck_count = await self.count_ducks(message) + + # If we've got more than the required amount of ducks, send the message to the duck_pond. + if duck_count >= constants.DuckPond.threshold: + await self.relay_message(message) + + @Cog.listener() + async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: + """Ensure that people don't remove the green checkmark from duck ponded messages.""" + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + + # Prevent the green checkmark from being removed + if payload.emoji.name == "✅": + message = await channel.fetch_message(payload.message_id) + duck_count = await self.count_ducks(message) + if duck_count >= constants.DuckPond.threshold: + await message.add_reaction("✅") + + +def setup(bot: Bot) -> None: + """Load the DuckPond cog.""" + bot.add_cog(DuckPond(bot)) diff --git a/bot/exts/filters/__init__.py b/bot/exts/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py new file mode 100644 index 000000000..c76bd2c60 --- /dev/null +++ b/bot/exts/filters/antimalware.py @@ -0,0 +1,98 @@ +import logging +import typing as t +from os.path import splitext + +from discord import Embed, Message, NotFound +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, STAFF_ROLES, URLs + +log = logging.getLogger(__name__) + +PY_EMBED_DESCRIPTION = ( + "It looks like you tried to attach a Python file - " + f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" +) + +TXT_EMBED_DESCRIPTION = ( + "**Uh-oh!** It looks like your message got zapped by our spam filter. " + "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" + "• If you attempted to send a message longer than 2000 characters, try shortening your message " + "to fit within the character limit or use a pasting service (see below) \n\n" + "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " + "{cmd_channel_mention} for more information) or use a pasting service like: " + f"\n\n{URLs.site_schema}{URLs.site_paste}" +) + +DISALLOWED_EMBED_DESCRIPTION = ( + "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " + "We currently allow the following file types: **{joined_whitelist}**.\n\n" + "Feel free to ask in {meta_channel_mention} if you think this is a mistake." +) + + +class AntiMalware(Cog): + """Delete messages which contain attachments with non-whitelisted file extensions.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_whitelisted_file_formats(self) -> list: + """Get the file formats currently on the whitelist.""" + return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() + + def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: + """Get an iterable containing all the disallowed extensions of attachments.""" + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) + return extensions_blocked + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Identify messages with prohibited attachments.""" + # Return when message don't have attachment and don't moderate DMs + if not message.attachments or not message.guild: + return + + # Check if user is staff, if is, return + # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance + if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): + return + + embed = Embed() + extensions_blocked = self._get_disallowed_extensions(message) + blocked_extensions_str = ', '.join(extensions_blocked) + if ".py" in extensions_blocked: + # Short-circuit on *.py files to provide a pastebin link + embed.description = PY_EMBED_DESCRIPTION + elif ".txt" in extensions_blocked: + # Work around Discord AutoConversion of messages longer than 2000 chars to .txt + cmd_channel = self.bot.get_channel(Channels.bot_commands) + embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) + elif extensions_blocked: + meta_channel = self.bot.get_channel(Channels.meta) + embed.description = DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=', '.join(self._get_whitelisted_file_formats()), + blocked_extensions_str=blocked_extensions_str, + meta_channel_mention=meta_channel.mention, + ) + + if embed.description: + log.info( + f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", + extra={"attachment_list": [attachment.filename for attachment in message.attachments]} + ) + + await message.channel.send(f"Hey {message.author.mention}!", embed=embed) + + # Delete the offending message: + try: + await message.delete() + except NotFound: + log.info(f"Tried to delete message `{message.id}`, but message could not be found.") + + +def setup(bot: Bot) -> None: + """Load the AntiMalware cog.""" + bot.add_cog(AntiMalware(bot)) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py new file mode 100644 index 000000000..3c5f13ebf --- /dev/null +++ b/bot/exts/filters/antispam.py @@ -0,0 +1,288 @@ +import asyncio +import logging +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from operator import itemgetter +from typing import Dict, Iterable, List, Set + +from discord import Colour, Member, Message, NotFound, Object, TextChannel +from discord.ext.commands import Cog + +from bot import rules +from bot.bot import Bot +from bot.constants import ( + AntiSpam as AntiSpamConfig, Channels, + Colours, DEBUG_MODE, Event, Filter, + Guild as GuildConfig, Icons, + STAFF_ROLES, +) +from bot.converters import Duration +from bot.exts.moderation.modlog import ModLog +from bot.utils.messages import send_attachments + + +log = logging.getLogger(__name__) + +RULE_FUNCTION_MAPPING = { + 'attachments': rules.apply_attachments, + 'burst': rules.apply_burst, + 'burst_shared': rules.apply_burst_shared, + 'chars': rules.apply_chars, + 'discord_emojis': rules.apply_discord_emojis, + 'duplicates': rules.apply_duplicates, + 'links': rules.apply_links, + 'mentions': rules.apply_mentions, + 'newlines': rules.apply_newlines, + 'role_mentions': rules.apply_role_mentions +} + + +@dataclass +class DeletionContext: + """Represents a Deletion Context for a single spam event.""" + + channel: TextChannel + members: Dict[int, Member] = field(default_factory=dict) + rules: Set[str] = field(default_factory=set) + messages: Dict[int, Message] = field(default_factory=dict) + attachments: List[List[str]] = field(default_factory=list) + + async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: + """Adds new rule violation events to the deletion context.""" + self.rules.add(rule_name) + + for member in members: + if member.id not in self.members: + self.members[member.id] = member + + for message in messages: + if message.id not in self.messages: + self.messages[message.id] = message + + # Re-upload attachments + destination = message.guild.get_channel(Channels.attachment_log) + urls = await send_attachments(message, destination, link_large=False) + self.attachments.append(urls) + + async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: + """Method that takes care of uploading the queue and posting modlog alert.""" + triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values()) + + mod_alert_message = ( + f"**Triggered by:** {triggered_by_users}\n" + f"**Channel:** {self.channel.mention}\n" + f"**Rules:** {', '.join(rule for rule in self.rules)}\n" + ) + + # For multiple messages or those with excessive newlines, use the logs API + if len(self.messages) > 1 or 'newlines' in self.rules: + url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) + mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" + else: + mod_alert_message += "Message:\n" + [message] = self.messages.values() + content = message.clean_content + remaining_chars = 2040 - len(mod_alert_message) + + if len(content) > remaining_chars: + content = content[:remaining_chars] + "..." + + mod_alert_message += f"{content}" + + *_, last_message = self.messages.values() + await modlog.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title="Spam detected!", + text=mod_alert_message, + thumbnail=last_message.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=AntiSpamConfig.ping_everyone + ) + + +class AntiSpam(Cog): + """Cog that controls our anti-spam measures.""" + + def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: + self.bot = bot + self.validation_errors = validation_errors + role_id = AntiSpamConfig.punishment['role_id'] + self.muted_role = Object(role_id) + self.expiration_date_converter = Duration() + + self.message_deletion_queue = dict() + + self.bot.loop.create_task(self.alert_on_validation_error()) + + @property + def mod_log(self) -> ModLog: + """Allows for easy access of the ModLog cog.""" + return self.bot.get_cog("ModLog") + + async def alert_on_validation_error(self) -> None: + """Unloads the cog and alerts admins if configuration validation failed.""" + await self.bot.wait_until_guild_available() + if self.validation_errors: + body = "**The following errors were encountered:**\n" + body += "\n".join(f"- {error}" for error in self.validation_errors.values()) + body += "\n\n**The cog has been unloaded.**" + + await self.mod_log.send_log_message( + title="Error: AntiSpam configuration validation failed!", + text=body, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Colour.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Applies the antispam rules to each received message.""" + if ( + not message.guild + or message.guild.id != GuildConfig.id + or message.author.bot + or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) + or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) + ): + return + + # Fetch the rule configuration with the highest rule interval. + max_interval_config = max( + AntiSpamConfig.rules.values(), + key=itemgetter('interval') + ) + max_interval = max_interval_config['interval'] + + # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. + earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) + relevant_messages = [ + msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) + if not msg.author.bot + ] + + for rule_name in AntiSpamConfig.rules: + rule_config = AntiSpamConfig.rules[rule_name] + rule_function = RULE_FUNCTION_MAPPING[rule_name] + + # Create a list of messages that were sent in the interval that the rule cares about. + latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) + messages_for_rule = [ + msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp + ] + result = await rule_function(message, messages_for_rule, rule_config) + + # If the rule returns `None`, that means the message didn't violate it. + # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` + # which contains the reason for why the message violated the rule and + # an iterable of all members that violated the rule. + if result is not None: + self.bot.stats.incr(f"mod_alerts.{rule_name}") + reason, members, relevant_messages = result + full_reason = f"`{rule_name}` rule: {reason}" + + # If there's no spam event going on for this channel, start a new Message Deletion Context + channel = message.channel + if channel.id not in self.message_deletion_queue: + log.trace(f"Creating queue for channel `{channel.id}`") + self.message_deletion_queue[message.channel.id] = DeletionContext(channel) + self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) + + # Add the relevant of this trigger to the Deletion Context + await self.message_deletion_queue[message.channel.id].add( + rule_name=rule_name, + members=members, + messages=relevant_messages + ) + + for member in members: + + # Fire it off as a background task to ensure + # that the sleep doesn't block further tasks + self.bot.loop.create_task( + self.punish(message, member, full_reason) + ) + + await self.maybe_delete_messages(channel, relevant_messages) + break + + async def punish(self, msg: Message, member: Member, reason: str) -> None: + """Punishes the given member for triggering an antispam rule.""" + if not any(role.id == self.muted_role.id for role in member.roles): + remove_role_after = AntiSpamConfig.punishment['remove_after'] + + # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes + context = await self.bot.get_context(msg) + context.author = self.bot.user + context.message.author = self.bot.user + + # Since we're going to invoke the tempmute command directly, we need to manually call the converter. + dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S") + await context.invoke( + self.bot.get_command('tempmute'), + member, + dt_remove_role_after, + reason=reason + ) + + async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: + """Cleans the messages if cleaning is configured.""" + if AntiSpamConfig.clean_offending: + # If we have more than one message, we can use bulk delete. + if len(messages) > 1: + message_ids = [message.id for message in messages] + self.mod_log.ignore(Event.message_delete, *message_ids) + await channel.delete_messages(messages) + + # Otherwise, the bulk delete endpoint will throw up. + # Delete the message directly instead. + else: + self.mod_log.ignore(Event.message_delete, messages[0].id) + try: + await messages[0].delete() + except NotFound: + log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") + + async def _process_deletion_context(self, context_id: int) -> None: + """Processes the Deletion Context queue.""" + log.trace("Sleeping before processing message deletion queue.") + await asyncio.sleep(10) + + if context_id not in self.message_deletion_queue: + log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") + return + + deletion_context = self.message_deletion_queue.pop(context_id) + await deletion_context.upload_messages(self.bot.user.id, self.mod_log) + + +def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: + """Validates the antispam configs.""" + validation_errors = {} + for name, config in rules_.items(): + if name not in RULE_FUNCTION_MAPPING: + log.error( + f"Unrecognized antispam rule `{name}`. " + f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" + ) + validation_errors[name] = f"`{name}` is not recognized as an antispam rule." + continue + for required_key in ('interval', 'max'): + if required_key not in config: + log.error( + f"`{required_key}` is required but was not " + f"set in rule `{name}`'s configuration." + ) + validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" + return validation_errors + + +def setup(bot: Bot) -> None: + """Validate the AntiSpam configs and load the AntiSpam cog.""" + validation_errors = validate_config() + bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py new file mode 100644 index 000000000..c15adc461 --- /dev/null +++ b/bot/exts/filters/filter_lists.py @@ -0,0 +1,273 @@ +import logging +from typing import Optional + +from discord import Colour, Embed +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.converters import ValidDiscordServerInvite, ValidFilterListType +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + + +class FilterLists(Cog): + """Commands for blacklisting and whitelisting things.""" + + methods_with_filterlist_types = [ + "allow_add", + "allow_delete", + "allow_get", + "deny_add", + "deny_delete", + "deny_get", + ] + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.bot.loop.create_task(self._amend_docstrings()) + + async def _amend_docstrings(self) -> None: + """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" + await self.bot.wait_until_guild_available() + + # Add valid filterlist types to the docstrings + valid_types = await ValidFilterListType.get_valid_types(self.bot) + valid_types = [f"`{type_.lower()}`" for type_ in valid_types] + + for method_name in self.methods_with_filterlist_types: + command = getattr(self, method_name) + command.help = ( + f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." + ) + + async def _add_data( + self, + ctx: Context, + allowed: bool, + list_type: ValidFilterListType, + content: str, + comment: Optional[str] = None, + ) -> None: + """Add an item to a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we gotta validate it. + if list_type == "GUILD_INVITE": + guild_data = await self._validate_guild_invite(ctx, content) + content = guild_data.get("id") + + # Unless the user has specified another comment, let's + # use the server name as the comment so that the list + # of guild IDs will be more easily readable when we + # display it. + if not comment: + comment = guild_data.get("name") + + # If it's a file format, let's make sure it has a leading dot. + elif list_type == "FILE_FORMAT" and not content.startswith("."): + content = f".{content}" + + # Try to add the item to the database + log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") + payload = { + "allowed": allowed, + "type": list_type, + "content": content, + "comment": comment, + } + + try: + item = await self.bot.api_client.post( + "bot/filter-lists", + json=payload + ) + except ResponseCodeError as e: + if e.status == 400: + await ctx.message.add_reaction("❌") + log.debug( + f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " + "probably because the request violated the UniqueConstraint." + ) + raise BadArgument( + f"Unable to add the item to the {allow_type}. " + "The item probably already exists. Keep in mind that a " + "blacklist and a whitelist for the same item cannot co-exist, " + "and we do not permit any duplicates." + ) + raise + + # Insert the item into the cache + self.bot.insert_item_into_filter_list_cache(item) + await ctx.message.add_reaction("✅") + + async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we need to convert it. + if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): + guild_data = await self._validate_guild_invite(ctx, content) + content = guild_data.get("id") + + # If it's a file format, let's make sure it has a leading dot. + elif list_type == "FILE_FORMAT" and not content.startswith("."): + content = f".{content}" + + # Find the content and delete it. + log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") + item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) + + if item is not None: + try: + await self.bot.api_client.delete( + f"bot/filter-lists/{item['id']}" + ) + del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] + await ctx.message.add_reaction("✅") + except ResponseCodeError as e: + log.debug( + f"{ctx.author} tried to delete an item with the id {item['id']}, but " + f"the API raised an unexpected error: {e}" + ) + await ctx.message.add_reaction("❌") + else: + await ctx.message.add_reaction("❌") + + async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: + """Paginate and display all items in a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] + + # Build a list of lines we want to show in the paginator + lines = [] + for content, metadata in result.items(): + line = f"• `{content}`" + + if comment := metadata.get("comment"): + line += f" - {comment}" + + lines.append(line) + lines = sorted(lines) + + # Build the embed + list_type_plural = list_type.lower().replace("_", " ").title() + "s" + embed = Embed( + title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", + colour=Colour.blue() + ) + log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") + + if result: + await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + await ctx.message.add_reaction("❌") + + async def _sync_data(self, ctx: Context) -> None: + """Syncs the filterlists with the API.""" + try: + log.trace("Attempting to sync FilterList cache with data from the API.") + await self.bot.cache_filter_list_data() + await ctx.message.add_reaction("✅") + except ResponseCodeError as e: + log.debug( + f"{ctx.author} tried to sync FilterList cache data but " + f"the API raised an unexpected error: {e}" + ) + await ctx.message.add_reaction("❌") + + @staticmethod + async def _validate_guild_invite(ctx: Context, invite: str) -> dict: + """ + Validates a guild invite, and returns the guild info as a dict. + + Will raise a BadArgument if the guild invite is invalid. + """ + log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") + validator = ValidDiscordServerInvite() + guild_data = await validator.convert(ctx, invite) + + # If we make it this far without raising a BadArgument, the invite is + # valid. Let's return a dict of guild information. + log.trace(f"{invite} validated as server invite. Converting to ID.") + return guild_data + + @group(aliases=("allowlist", "allow", "al", "wl")) + async def whitelist(self, ctx: Context) -> None: + """Group for whitelisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @group(aliases=("denylist", "deny", "bl", "dl")) + async def blacklist(self, ctx: Context) -> None: + """Group for blacklisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @whitelist.command(name="add", aliases=("a", "set")) + async def allow_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified allowlist.""" + await self._add_data(ctx, True, list_type, content, comment) + + @blacklist.command(name="add", aliases=("a", "set")) + async def deny_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified denylist.""" + await self._add_data(ctx, False, list_type, content, comment) + + @whitelist.command(name="remove", aliases=("delete", "rm",)) + async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified allowlist.""" + await self._delete_data(ctx, True, list_type, content) + + @blacklist.command(name="remove", aliases=("delete", "rm",)) + async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified denylist.""" + await self._delete_data(ctx, False, list_type, content) + + @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified allowlist.""" + await self._list_all_data(ctx, True, list_type) + + @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified denylist.""" + await self._list_all_data(ctx, False, list_type) + + @whitelist.command(name="sync", aliases=("s",)) + async def allow_sync(self, ctx: Context) -> None: + """Syncs both allowlists and denylists with the API.""" + await self._sync_data(ctx) + + @blacklist.command(name="sync", aliases=("s",)) + async def deny_sync(self, ctx: Context) -> None: + """Syncs both allowlists and denylists with the API.""" + await self._sync_data(ctx) + + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the FilterLists cog.""" + bot.add_cog(FilterLists(bot)) diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py new file mode 100644 index 000000000..2ae476d8a --- /dev/null +++ b/bot/exts/filters/filtering.py @@ -0,0 +1,575 @@ +import asyncio +import logging +import re +from datetime import datetime, timedelta +from typing import List, Mapping, Optional, Tuple, Union + +import dateutil +import discord.errors +from dateutil.relativedelta import relativedelta +from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel +from discord.ext.commands import Cog +from discord.utils import escape_markdown + +from bot.bot import Bot +from bot.constants import ( + Channels, Colours, + Filter, Icons, URLs +) +from bot.exts.moderation.modlog import ModLog +from bot.utils.redis_cache import RedisCache +from bot.utils.regex import INVITE_RE +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + +# Regular expressions +SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) +URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) +ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") + +# Other constants. +DAYS_BETWEEN_ALERTS = 3 +OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) + + +class Filtering(Cog): + """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" + + # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent + name_alerts = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + self.name_lock = asyncio.Lock() + + staff_mistake_str = "If you believe this was a mistake, please let staff know!" + self.filters = { + "filter_zalgo": { + "enabled": Filter.filter_zalgo, + "function": self._has_zalgo, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_zalgo, + "notification_msg": ( + "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " + f"{staff_mistake_str}" + ), + "schedule_deletion": False + }, + "filter_invites": { + "enabled": Filter.filter_invites, + "function": self._has_invites, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_invites, + "notification_msg": ( + f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" + r"Our server rules can be found here: " + ), + "schedule_deletion": False + }, + "filter_domains": { + "enabled": Filter.filter_domains, + "function": self._has_urls, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_domains, + "notification_msg": ( + f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" + ), + "schedule_deletion": False + }, + "watch_regex": { + "enabled": Filter.watch_regex, + "function": self._has_watch_regex_match, + "type": "watchlist", + "content_only": True, + "schedule_deletion": True + }, + "watch_rich_embeds": { + "enabled": Filter.watch_rich_embeds, + "function": self._has_rich_embed, + "type": "watchlist", + "content_only": False, + "schedule_deletion": False + } + } + + self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: + """Fetch items from the filter_list_cache.""" + return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() + + @staticmethod + def _expand_spoilers(text: str) -> str: + """Return a string containing all interpretations of a spoilered message.""" + split_text = SPOILER_RE.split(text) + return ''.join( + split_text[0::2] + split_text[1::2] + split_text + ) + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Invoke message filter for new messages.""" + await self._filter_message(msg) + + # Ignore webhook messages. + if msg.webhook_id is None: + await self.check_bad_words_in_name(msg.author) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Invoke message filter for message edits. + + If there have been multiple edits, calculate the time delta from the previous edit. + """ + if not before.edited_at: + delta = relativedelta(after.edited_at, before.created_at).microseconds + else: + delta = relativedelta(after.edited_at, before.edited_at).microseconds + await self._filter_message(after, delta) + + def get_name_matches(self, name: str) -> List[re.Match]: + """Check bad words from passed string (name). Return list of matches.""" + matches = [] + watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) + for pattern in watchlist_patterns: + if match := re.search(pattern, name, flags=re.IGNORECASE): + matches.append(match) + return matches + + async def check_send_alert(self, member: Member) -> bool: + """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" + if last_alert := await self.name_alerts.get(member.id): + last_alert = datetime.utcfromtimestamp(last_alert) + if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: + log.trace(f"Last alert was too recent for {member}'s nickname.") + return False + + return True + + async def check_bad_words_in_name(self, member: Member) -> None: + """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" + # Use lock to avoid race conditions + async with self.name_lock: + # Check whether the users display name contains any words in our blacklist + matches = self.get_name_matches(member.display_name) + + if not matches or not await self.check_send_alert(member): + return + + log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") + + log_string = ( + f"**User:** {member.mention} (`{member.id}`)\n" + f"**Display Name:** {member.display_name}\n" + f"**Bad Matches:** {', '.join(match.group() for match in matches)}" + ) + + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colours.soft_red, + title="Username filtering alert", + text=log_string, + channel_id=Channels.mod_alerts, + thumbnail=member.avatar_url + ) + + # Update time when alert sent + await self.name_alerts.set(member.id, datetime.utcnow().timestamp()) + + async def filter_eval(self, result: str, msg: Message) -> bool: + """ + Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. + + Also requires the original message, to check whether to filter and for mod logs. + Returns whether a filter was triggered or not. + """ + filter_triggered = False + # Should we filter this message? + if self._check_filter(msg): + for filter_name, _filter in self.filters.items(): + # Is this specific filter enabled in the config? + # We also do not need to worry about filters that take the full message, + # since all we have is an arbitrary string. + if _filter["enabled"] and _filter["content_only"]: + match = await _filter["function"](result) + + if match: + # If this is a filter (not a watchlist), we set the variable so we know + # that it has been triggered + if _filter["type"] == "filter": + filter_triggered = True + + # We do not have to check against DM channels since !eval cannot be used there. + channel_str = f"in {msg.channel.mention}" + + message_content, additional_embeds, additional_embeds_msg = self._add_stats( + filter_name, match, result + ) + + message = ( + f"The {filter_name} {_filter['type']} was triggered " + f"by **{msg.author}** " + f"(`{msg.author.id}`) {channel_str} using !eval with " + f"[the following message]({msg.jump_url}):\n\n" + f"{message_content}" + ) + + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title=f"{_filter['type'].title()} triggered!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=Filter.ping_everyone, + additional_embeds=additional_embeds, + additional_embeds_msg=additional_embeds_msg + ) + + break # We don't want multiple filters to trigger + + return filter_triggered + + async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: + """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" + # Should we filter this message? + if self._check_filter(msg): + for filter_name, _filter in self.filters.items(): + # Is this specific filter enabled in the config? + if _filter["enabled"]: + # Double trigger check for the embeds filter + if filter_name == "watch_rich_embeds": + # If the edit delta is less than 0.001 seconds, then we're probably dealing + # with a double filter trigger. + if delta is not None and delta < 100: + continue + + # Does the filter only need the message content or the full message? + if _filter["content_only"]: + match = await _filter["function"](msg.content) + else: + match = await _filter["function"](msg) + + if match: + is_private = msg.channel.type is discord.ChannelType.private + + # If this is a filter (not a watchlist) and not in a DM, delete the message. + if _filter["type"] == "filter" and not is_private: + try: + # Embeds (can?) trigger both the `on_message` and `on_message_edit` + # event handlers, triggering filtering twice for the same message. + # + # If `on_message`-triggered filtering already deleted the message + # then `on_message_edit`-triggered filtering will raise exception + # since the message no longer exists. + # + # In addition, to avoid sending two notifications to the user, the + # logs, and mod_alert, we return if the message no longer exists. + await msg.delete() + except discord.errors.NotFound: + return + + # Notify the user if the filter specifies + if _filter["user_notification"]: + await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) + + # If the message is classed as offensive, we store it in the site db and + # it will be deleted it after one week. + if _filter["schedule_deletion"] and not is_private: + delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() + data = { + 'id': msg.id, + 'channel_id': msg.channel.id, + 'delete_date': delete_date + } + + await self.bot.api_client.post('bot/offensive-messages', json=data) + self.schedule_msg_delete(data) + log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") + + if is_private: + channel_str = "via DM" + else: + channel_str = f"in {msg.channel.mention}" + + message_content, additional_embeds, additional_embeds_msg = self._add_stats( + filter_name, match, msg.content + ) + + message = ( + f"The {filter_name} {_filter['type']} was triggered " + f"by **{msg.author}** " + f"(`{msg.author.id}`) {channel_str} with [the " + f"following message]({msg.jump_url}):\n\n" + f"{message_content}" + ) + + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title=f"{_filter['type'].title()} triggered!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=Filter.ping_everyone if not is_private else False, + additional_embeds=additional_embeds, + additional_embeds_msg=additional_embeds_msg + ) + + break # We don't want multiple filters to trigger + + def _add_stats(self, name: str, match: Union[re.Match, dict, bool, List[discord.Embed]], content: str) -> Tuple[ + str, Optional[List[discord.Embed]], Optional[str] + ]: + """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" + # Word and match stats for watch_regex + if name == "watch_regex": + surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] + message_content = ( + f"**Match:** '{match[0]}'\n" + f"**Location:** '...{escape_markdown(surroundings)}...'\n" + f"\n**Original Message:**\n{escape_markdown(content)}" + ) + else: # Use original content + message_content = content + + additional_embeds = None + additional_embeds_msg = None + + self.bot.stats.incr(f"filters.{name}") + + # The function returns True for invalid invites. + # They have no data so additional embeds can't be created for them. + if name == "filter_invites" and match is not True: + additional_embeds = [] + for _, data in match.items(): + embed = discord.Embed(description=( + f"**Members:**\n{data['members']}\n" + f"**Active:**\n{data['active']}" + )) + embed.set_author(name=data["name"]) + embed.set_thumbnail(url=data["icon"]) + embed.set_footer(text=f"Guild ID: {data['id']}") + additional_embeds.append(embed) + additional_embeds_msg = "For the following guild(s):" + + elif name == "watch_rich_embeds": + additional_embeds = match + additional_embeds_msg = "With the following embed(s):" + + return message_content, additional_embeds, additional_embeds_msg + + @staticmethod + def _check_filter(msg: Message) -> bool: + """Check whitelists to see if we should filter this message.""" + role_whitelisted = False + + if type(msg.author) is Member: # Only Member has roles, not User. + for role in msg.author.roles: + if role.id in Filter.role_whitelist: + role_whitelisted = True + + return ( + msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist + and not role_whitelisted # Role not in whitelist + and not msg.author.bot # Author not a bot + ) + + async def _has_watch_regex_match(self, text: str) -> Union[bool, re.Match]: + """ + Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. + + `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is + matched as-is. Spoilers are expanded, if any, and URLs are ignored. + """ + if SPOILER_RE.search(text): + text = self._expand_spoilers(text) + + # Make sure it's not a URL + if URL_RE.search(text): + return False + + watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) + for pattern in watchlist_patterns: + match = re.search(pattern, text, flags=re.IGNORECASE) + if match: + return match + + async def _has_urls(self, text: str) -> bool: + """Returns True if the text contains one of the blacklisted URLs from the config file.""" + if not URL_RE.search(text): + return False + + text = text.lower() + domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) + + for url in domain_blacklist: + if url.lower() in text: + return True + + return False + + @staticmethod + async def _has_zalgo(text: str) -> bool: + """ + Returns True if the text contains zalgo characters. + + Zalgo range is \u0300 – \u036F and \u0489. + """ + return bool(ZALGO_RE.search(text)) + + async def _has_invites(self, text: str) -> Union[dict, bool]: + """ + Checks if there's any invites in the text content that aren't in the guild whitelist. + + If any are detected, a dictionary of invite data is returned, with a key per invite. + If none are detected, False is returned. + + Attempts to catch some of common ways to try to cheat the system. + """ + # Remove backslashes to prevent escape character aroundfuckery like + # discord\.gg/gdudes-pony-farm + text = text.replace("\\", "") + + invites = INVITE_RE.findall(text) + invite_data = dict() + for invite in invites: + if invite in invite_data: + continue + + response = await self.bot.http_session.get( + f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} + ) + response = await response.json() + guild = response.get("guild") + if guild is None: + # Lack of a "guild" key in the JSON response indicates either an group DM invite, an + # expired invite, or an invalid invite. The API does not currently differentiate + # between invalid and expired invites + return True + + guild_id = guild.get("id") + guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) + guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) + + # Is this invite allowed? + guild_partnered_or_verified = ( + 'PARTNERED' in guild.get("features", []) + or 'VERIFIED' in guild.get("features", []) + ) + invite_not_allowed = ( + guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted. + or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted. + and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered. + ) + + if invite_not_allowed: + guild_icon_hash = guild["icon"] + guild_icon = ( + "https://cdn.discordapp.com/icons/" + f"{guild_id}/{guild_icon_hash}.png?size=512" + ) + + invite_data[invite] = { + "name": guild["name"], + "id": guild['id'], + "icon": guild_icon, + "members": response["approximate_member_count"], + "active": response["approximate_presence_count"] + } + + return invite_data if invite_data else False + + @staticmethod + async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: + """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" + if msg.embeds: + for embed in msg.embeds: + if embed.type == "rich": + urls = URL_RE.findall(msg.content) + if not embed.url or embed.url not in urls: + # If `embed.url` does not exist or if `embed.url` is not part of the content + # of the message, it's unlikely to be an auto-generated embed by Discord. + return msg.embeds + else: + log.trace( + "Found a rich embed sent by a regular user account, " + "but it was likely just an automatic URL embed." + ) + return False + return False + + async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: + """ + Notify filtered_member about a moderation action with the reason str. + + First attempts to DM the user, fall back to in-channel notification if user has DMs disabled + """ + try: + await filtered_member.send(reason) + except discord.errors.Forbidden: + await channel.send(f"{filtered_member.mention} {reason}") + + def schedule_msg_delete(self, msg: dict) -> None: + """Delete an offensive message once its deletion date is reached.""" + delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) + self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) + + async def reschedule_offensive_msg_deletion(self) -> None: + """Get all the pending message deletion from the API and reschedule them.""" + await self.bot.wait_until_ready() + response = await self.bot.api_client.get('bot/offensive-messages',) + + now = datetime.utcnow() + + for msg in response: + delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) + + if delete_at < now: + await self.delete_offensive_msg(msg) + else: + self.schedule_msg_delete(msg) + + async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: + """Delete an offensive message, and then delete it from the db.""" + try: + channel = self.bot.get_channel(msg['channel_id']) + if channel: + msg_obj = await channel.fetch_message(msg['id']) + await msg_obj.delete() + except NotFound: + log.info( + f"Tried to delete message {msg['id']}, but the message can't be found " + f"(it has been probably already deleted)." + ) + except HTTPException as e: + log.warning(f"Failed to delete message {msg['id']}: status {e.status}") + + await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') + log.info(f"Deleted the offensive message with id {msg['id']}.") + + +def setup(bot: Bot) -> None: + """Load the Filtering cog.""" + bot.add_cog(Filtering(bot)) diff --git a/bot/exts/filters/security.py b/bot/exts/filters/security.py new file mode 100644 index 000000000..c680c5e27 --- /dev/null +++ b/bot/exts/filters/security.py @@ -0,0 +1,31 @@ +import logging + +from discord.ext.commands import Cog, Context, NoPrivateMessage + +from bot.bot import Bot + +log = logging.getLogger(__name__) + + +class Security(Cog): + """Security-related helpers.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.bot.check(self.check_not_bot) # Global commands check - no bots can run any commands at all + self.bot.check(self.check_on_guild) # Global commands check - commands can't be run in a DM + + def check_not_bot(self, ctx: Context) -> bool: + """Check if the context is a bot user.""" + return not ctx.author.bot + + def check_on_guild(self, ctx: Context) -> bool: + """Check if the context is in a guild.""" + if ctx.guild is None: + raise NoPrivateMessage("This command cannot be used in private messages.") + return True + + +def setup(bot: Bot) -> None: + """Load the Security cog.""" + bot.add_cog(Security(bot)) diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filters/token_remover.py new file mode 100644 index 000000000..0eda3dc6a --- /dev/null +++ b/bot/exts/filters/token_remover.py @@ -0,0 +1,182 @@ +import base64 +import binascii +import logging +import re +import typing as t + +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog + +from bot import utils +from bot.bot import Bot +from bot.constants import Channels, Colours, Event, Icons +from bot.exts.moderation.modlog import ModLog + +log = logging.getLogger(__name__) + +LOG_MESSAGE = ( + "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, " + "token was `{user_id}.{timestamp}.{hmac}`" +) +DELETION_MESSAGE_TEMPLATE = ( + "Hey {mention}! I noticed you posted a seemingly valid Discord API " + "token in your message and have removed your message. " + "This means that your token has been **compromised**. " + "Please change your token **immediately** at: " + "\n\n" + "Feel free to re-post it with the token removed. " + "If you believe this was a mistake, please let us know!" +) +DISCORD_EPOCH = 1_420_070_400 +TOKEN_EPOCH = 1_293_840_000 + +# Three parts delimited by dots: user ID, creation timestamp, HMAC. +# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string. +# Each part only matches base64 URL-safe characters. +# Padding has never been observed, but the padding character '=' is matched just in case. +TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII) + + +class Token(t.NamedTuple): + """A Discord Bot token.""" + + user_id: str + timestamp: str + hmac: str + + +class TokenRemover(Cog): + """Scans messages for potential discord.py bot tokens and removes them.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """ + Check each message for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return + + found_token = self.find_token_in_message(msg) + if found_token: + await self.take_action(msg, found_token) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + await self.on_message(after) + + async def take_action(self, msg: Message, found_token: Token) -> None: + """Remove the `msg` containing the `found_token` and send a mod log message.""" + self.mod_log.ignore(Event.message_delete, msg.id) + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") + return + + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + log_message = self.format_log_message(msg, found_token) + log.debug(log_message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=log_message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) + + self.bot.stats.incr("tokens.removed_tokens") + + @staticmethod + def format_log_message(msg: Message, token: Token) -> str: + """Return the log message to send for `token` being censored in `msg`.""" + return LOG_MESSAGE.format( + author=msg.author, + author_id=msg.author.id, + channel=msg.channel.mention, + user_id=token.user_id, + timestamp=token.timestamp, + hmac='x' * len(token.hmac), + ) + + @classmethod + def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: + """Return a seemingly valid token found in `msg` or `None` if no token is found.""" + # Use finditer rather than search to guard against method calls prematurely returning the + # token check (e.g. `message.channel.send` also matches our token pattern) + for match in TOKEN_RE.finditer(msg.content): + token = Token(*match.groups()) + if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp): + # Short-circuit on first match + return token + + # No matching substring + return + + @staticmethod + def is_valid_user_id(b64_content: str) -> bool: + """ + Check potential token to see if it contains a valid Discord user ID. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + b64_content = utils.pad_base64(b64_content) + + try: + decoded_bytes = base64.urlsafe_b64decode(b64_content) + string = decoded_bytes.decode('utf-8') + + # isdigit on its own would match a lot of other Unicode characters, hence the isascii. + return string.isascii() and string.isdigit() + except (binascii.Error, ValueError): + return False + + @staticmethod + def is_valid_timestamp(b64_content: str) -> bool: + """ + Return True if `b64_content` decodes to a valid timestamp. + + If the timestamp is greater than the Discord epoch, it's probably valid. + See: https://i.imgur.com/7WdehGn.png + """ + b64_content = utils.pad_base64(b64_content) + + try: + decoded_bytes = base64.urlsafe_b64decode(b64_content) + timestamp = int.from_bytes(decoded_bytes, byteorder="big") + except (binascii.Error, ValueError) as e: + log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") + return False + + # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound + # is not checked. + if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH: + return True + else: + log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch") + return False + + +def setup(bot: Bot) -> None: + """Load the TokenRemover cog.""" + bot.add_cog(TokenRemover(bot)) diff --git a/bot/exts/filters/webhook_remover.py b/bot/exts/filters/webhook_remover.py new file mode 100644 index 000000000..ca126ebf5 --- /dev/null +++ b/bot/exts/filters/webhook_remover.py @@ -0,0 +1,84 @@ +import logging +import re + +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, Colours, Event, Icons +from bot.exts.moderation.modlog import ModLog + +WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) + +ALERT_MESSAGE_TEMPLATE = ( + "{user}, looks like you posted a Discord webhook URL. Therefore, your " + "message has been removed. Your webhook may have been **compromised** so " + "please re-create the webhook **immediately**. If you believe this was " + "mistake, please let us know." +) + +log = logging.getLogger(__name__) + + +class WebhookRemover(Cog): + """Scan messages to detect Discord webhooks links.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get current instance of `ModLog`.""" + return self.bot.get_cog("ModLog") + + async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: + """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" + # Don't log this, due internal delete, not by user. Will make different entry. + self.mod_log.ignore(Event.message_delete, msg.id) + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") + return + + await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) + + message = ( + f"{msg.author} (`{msg.author.id}`) posted a Discord webhook URL " + f"to #{msg.channel}. Webhook URL was `{redacted_url}`" + ) + log.debug(message) + + # Send entry to moderation alerts. + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Discord webhook URL removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts + ) + + self.bot.stats.incr("tokens.removed_webhooks") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Check if a Discord webhook URL is in `message`.""" + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return + + matches = WEBHOOK_URL_RE.search(msg.content) + if matches: + await self.delete_and_respond(msg, matches[1] + "xxx") + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """Check if a Discord webhook URL is in the edited message `after`.""" + await self.on_message(after) + + +def setup(bot: Bot) -> None: + """Load `WebhookRemover` cog.""" + bot.add_cog(WebhookRemover(bot)) diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py new file mode 100644 index 000000000..57094751e --- /dev/null +++ b/bot/exts/help_channels.py @@ -0,0 +1,944 @@ +import asyncio +import json +import logging +import random +import typing as t +from collections import deque +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import discord +import discord.abc +from discord.ext import commands + +from bot import constants +from bot.bot import Bot +from bot.utils import RedisCache +from bot.utils.checks import with_role_check +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + +ASKING_GUIDE_URL = "https://pythondiscord.com/pages/asking-good-questions/" +MAX_CHANNELS_PER_CATEGORY = 50 +EXCLUDED_CHANNELS = (constants.Channels.how_to_get_help, constants.Channels.cooldown) + +HELP_CHANNEL_TOPIC = """ +This is a Python help channel. You can claim your own help channel in the Python Help: Available category. +""" + +AVAILABLE_MSG = f""" +This help channel is now **available**, which means that you can claim it by simply typing your \ +question into it. Once claimed, the channel will move into the **Python Help: Occupied** category, \ +and will be yours until it has been inactive for {constants.HelpChannels.idle_minutes} minutes or \ +is closed manually with `!close`. When that happens, it will be set to **dormant** and moved into \ +the **Help: Dormant** category. + +Try to write the best question you can by providing a detailed description and telling us what \ +you've tried already. For more information on asking a good question, \ +check out our guide on [asking good questions]({ASKING_GUIDE_URL}). +""" + +DORMANT_MSG = f""" +This help channel has been marked as **dormant**, and has been moved into the **Help: Dormant** \ +category at the bottom of the channel list. It is no longer possible to send messages in this \ +channel until it becomes available again. + +If your question wasn't answered yet, you can claim a new help channel from the \ +**Help: Available** category by simply asking your question again. Consider rephrasing the \ +question to maximize your chance of getting a good answer. If you're not sure how, have a look \ +through our guide for [asking a good question]({ASKING_GUIDE_URL}). +""" + +CoroutineFunc = t.Callable[..., t.Coroutine] + + +class HelpChannels(commands.Cog): + """ + Manage the help channel system of the guild. + + The system is based on a 3-category system: + + Available Category + + * Contains channels which are ready to be occupied by someone who needs help + * Will always contain `constants.HelpChannels.max_available` channels; refilled automatically + from the pool of dormant channels + * Prioritise using the channels which have been dormant for the longest amount of time + * If there are no more dormant channels, the bot will automatically create a new one + * If there are no dormant channels to move, helpers will be notified (see `notify()`) + * When a channel becomes available, the dormant embed will be edited to show `AVAILABLE_MSG` + * User can only claim a channel at an interval `constants.HelpChannels.claim_minutes` + * To keep track of cooldowns, user which claimed a channel will have a temporary role + + In Use Category + + * Contains all channels which are occupied by someone needing help + * Channel moves to dormant category after `constants.HelpChannels.idle_minutes` of being idle + * Command can prematurely mark a channel as dormant + * Channel claimant is allowed to use the command + * Allowed roles for the command are configurable with `constants.HelpChannels.cmd_whitelist` + * When a channel becomes dormant, an embed with `DORMANT_MSG` will be sent + + Dormant Category + + * Contains channels which aren't in use + * Channels are used to refill the Available category + + Help channels are named after the chemical elements in `bot/resources/elements.json`. + """ + + # This cache tracks which channels are claimed by which members. + # RedisCache[discord.TextChannel.id, t.Union[discord.User.id, discord.Member.id]] + help_channel_claimants = RedisCache() + + # This cache maps a help channel to whether it has had any + # activity other than the original claimant. True being no other + # activity and False being other activity. + # RedisCache[discord.TextChannel.id, bool] + unanswered = RedisCache() + + # This dictionary maps a help channel to the time it was claimed + # RedisCache[discord.TextChannel.id, UtcPosixTimestamp] + claim_times = RedisCache() + + # This cache maps a help channel to original question message in same channel. + # RedisCache[discord.TextChannel.id, discord.Message.id] + question_messages = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + # Categories + self.available_category: discord.CategoryChannel = None + self.in_use_category: discord.CategoryChannel = None + self.dormant_category: discord.CategoryChannel = None + + # Queues + self.channel_queue: asyncio.Queue[discord.TextChannel] = None + self.name_queue: t.Deque[str] = None + + self.name_positions = self.get_names() + self.last_notification: t.Optional[datetime] = None + + # Asyncio stuff + self.queue_tasks: t.List[asyncio.Task] = [] + self.ready = asyncio.Event() + self.on_message_lock = asyncio.Lock() + self.init_task = self.bot.loop.create_task(self.init_cog()) + + def cog_unload(self) -> None: + """Cancel the init task and scheduled tasks when the cog unloads.""" + log.trace("Cog unload: cancelling the init_cog task") + self.init_task.cancel() + + log.trace("Cog unload: cancelling the channel queue tasks") + for task in self.queue_tasks: + task.cancel() + + self.scheduler.cancel_all() + + def create_channel_queue(self) -> asyncio.Queue: + """ + Return a queue of dormant channels to use for getting the next available channel. + + The channels are added to the queue in a random order. + """ + log.trace("Creating the channel queue.") + + channels = list(self.get_category_channels(self.dormant_category)) + random.shuffle(channels) + + log.trace("Populating the channel queue with channels.") + queue = asyncio.Queue() + for channel in channels: + queue.put_nowait(channel) + + return queue + + async def create_dormant(self) -> t.Optional[discord.TextChannel]: + """ + Create and return a new channel in the Dormant category. + + The new channel will sync its permission overwrites with the category. + + Return None if no more channel names are available. + """ + log.trace("Getting a name for a new dormant channel.") + + try: + name = self.name_queue.popleft() + except IndexError: + log.debug("No more names available for new dormant channels.") + return None + + log.debug(f"Creating a new dormant channel named {name}.") + return await self.dormant_category.create_text_channel(name, topic=HELP_CHANNEL_TOPIC) + + def create_name_queue(self) -> deque: + """Return a queue of element names to use for creating new channels.""" + log.trace("Creating the chemical element name queue.") + + used_names = self.get_used_names() + + log.trace("Determining the available names.") + available_names = (name for name in self.name_positions if name not in used_names) + + log.trace("Populating the name queue with names.") + return deque(available_names) + + async def dormant_check(self, ctx: commands.Context) -> bool: + """Return True if the user is the help channel claimant or passes the role check.""" + if await self.help_channel_claimants.get(ctx.channel.id) == ctx.author.id: + log.trace(f"{ctx.author} is the help channel claimant, passing the check for dormant.") + self.bot.stats.incr("help.dormant_invoke.claimant") + return True + + log.trace(f"{ctx.author} is not the help channel claimant, checking roles.") + role_check = with_role_check(ctx, *constants.HelpChannels.cmd_whitelist) + + if role_check: + self.bot.stats.incr("help.dormant_invoke.staff") + + return role_check + + @commands.command(name="close", aliases=["dormant", "solved"], enabled=False) + async def close_command(self, ctx: commands.Context) -> None: + """ + Make the current in-use help channel dormant. + + Make the channel dormant if the user passes the `dormant_check`, + delete the message that invoked this, + and reset the send permissions cooldown for the user who started the session. + """ + log.trace("close command invoked; checking if the channel is in-use.") + if ctx.channel.category == self.in_use_category: + if await self.dormant_check(ctx): + await self.remove_cooldown_role(ctx.author) + + # Ignore missing task when cooldown has passed but the channel still isn't dormant. + if ctx.author.id in self.scheduler: + self.scheduler.cancel(ctx.author.id) + + await self.move_to_dormant(ctx.channel, "command") + self.scheduler.cancel(ctx.channel.id) + else: + log.debug(f"{ctx.author} invoked command 'dormant' outside an in-use help channel") + + async def get_available_candidate(self) -> discord.TextChannel: + """ + Return a dormant channel to turn into an available channel. + + If no channel is available, wait indefinitely until one becomes available. + """ + log.trace("Getting an available channel candidate.") + + try: + channel = self.channel_queue.get_nowait() + except asyncio.QueueEmpty: + log.info("No candidate channels in the queue; creating a new channel.") + channel = await self.create_dormant() + + if not channel: + log.info("Couldn't create a candidate channel; waiting to get one from the queue.") + await self.notify() + channel = await self.wait_for_dormant_channel() + + return channel + + @staticmethod + def get_clean_channel_name(channel: discord.TextChannel) -> str: + """Return a clean channel name without status emojis prefix.""" + prefix = constants.HelpChannels.name_prefix + try: + # Try to remove the status prefix using the index of the channel prefix + name = channel.name[channel.name.index(prefix):] + log.trace(f"The clean name for `{channel}` is `{name}`") + except ValueError: + # If, for some reason, the channel name does not contain "help-" fall back gracefully + log.info(f"Can't get clean name because `{channel}` isn't prefixed by `{prefix}`.") + name = channel.name + + return name + + @staticmethod + def is_excluded_channel(channel: discord.abc.GuildChannel) -> bool: + """Check if a channel should be excluded from the help channel system.""" + return not isinstance(channel, discord.TextChannel) or channel.id in EXCLUDED_CHANNELS + + def get_category_channels(self, category: discord.CategoryChannel) -> t.Iterable[discord.TextChannel]: + """Yield the text channels of the `category` in an unsorted manner.""" + log.trace(f"Getting text channels in the category '{category}' ({category.id}).") + + # This is faster than using category.channels because the latter sorts them. + for channel in self.bot.get_guild(constants.Guild.id).channels: + if channel.category_id == category.id and not self.is_excluded_channel(channel): + yield channel + + async def get_in_use_time(self, channel_id: int) -> t.Optional[timedelta]: + """Return the duration `channel_id` has been in use. Return None if it's not in use.""" + log.trace(f"Calculating in use time for channel {channel_id}.") + + claimed_timestamp = await self.claim_times.get(channel_id) + if claimed_timestamp: + claimed = datetime.utcfromtimestamp(claimed_timestamp) + return datetime.utcnow() - claimed + + @staticmethod + def get_names() -> t.List[str]: + """ + Return a truncated list of prefixed element names. + + The amount of names is configured with `HelpChannels.max_total_channels`. + The prefix is configured with `HelpChannels.name_prefix`. + """ + count = constants.HelpChannels.max_total_channels + prefix = constants.HelpChannels.name_prefix + + log.trace(f"Getting the first {count} element names from JSON.") + + with Path("bot/resources/elements.json").open(encoding="utf-8") as elements_file: + all_names = json.load(elements_file) + + if prefix: + return [prefix + name for name in all_names[:count]] + else: + return all_names[:count] + + def get_used_names(self) -> t.Set[str]: + """Return channel names which are already being used.""" + log.trace("Getting channel names which are already being used.") + + names = set() + for cat in (self.available_category, self.in_use_category, self.dormant_category): + for channel in self.get_category_channels(cat): + names.add(self.get_clean_channel_name(channel)) + + if len(names) > MAX_CHANNELS_PER_CATEGORY: + log.warning( + f"Too many help channels ({len(names)}) already exist! " + f"Discord only supports {MAX_CHANNELS_PER_CATEGORY} in a category." + ) + + log.trace(f"Got {len(names)} used names: {names}") + return names + + @classmethod + async def get_idle_time(cls, channel: discord.TextChannel) -> t.Optional[int]: + """ + Return the time elapsed, in seconds, since the last message sent in the `channel`. + + Return None if the channel has no messages. + """ + log.trace(f"Getting the idle time for #{channel} ({channel.id}).") + + msg = await cls.get_last_message(channel) + if not msg: + log.debug(f"No idle time available; #{channel} ({channel.id}) has no messages.") + return None + + idle_time = (datetime.utcnow() - msg.created_at).seconds + + log.trace(f"#{channel} ({channel.id}) has been idle for {idle_time} seconds.") + return idle_time + + @staticmethod + async def get_last_message(channel: discord.TextChannel) -> t.Optional[discord.Message]: + """Return the last message sent in the channel or None if no messages exist.""" + log.trace(f"Getting the last message in #{channel} ({channel.id}).") + + try: + return await channel.history(limit=1).next() # noqa: B305 + except discord.NoMoreItems: + log.debug(f"No last message available; #{channel} ({channel.id}) has no messages.") + return None + + async def init_available(self) -> None: + """Initialise the Available category with channels.""" + log.trace("Initialising the Available category with channels.") + + channels = list(self.get_category_channels(self.available_category)) + missing = constants.HelpChannels.max_available - len(channels) + + # If we've got less than `max_available` channel available, we should add some. + if missing > 0: + log.trace(f"Moving {missing} missing channels to the Available category.") + for _ in range(missing): + await self.move_to_available() + + # If for some reason we have more than `max_available` channels available, + # we should move the superfluous ones over to dormant. + elif missing < 0: + log.trace(f"Moving {abs(missing)} superfluous available channels over to the Dormant category.") + for channel in channels[:abs(missing)]: + await self.move_to_dormant(channel, "auto") + + async def init_categories(self) -> None: + """Get the help category objects. Remove the cog if retrieval fails.""" + log.trace("Getting the CategoryChannel objects for the help categories.") + + try: + self.available_category = await self.try_get_channel( + constants.Categories.help_available + ) + self.in_use_category = await self.try_get_channel(constants.Categories.help_in_use) + self.dormant_category = await self.try_get_channel(constants.Categories.help_dormant) + except discord.HTTPException: + log.exception("Failed to get a category; cog will be removed") + self.bot.remove_cog(self.qualified_name) + + async def init_cog(self) -> None: + """Initialise the help channel system.""" + log.trace("Waiting for the guild to be available before initialisation.") + await self.bot.wait_until_guild_available() + + log.trace("Initialising the cog.") + await self.init_categories() + await self.check_cooldowns() + + self.channel_queue = self.create_channel_queue() + self.name_queue = self.create_name_queue() + + log.trace("Moving or rescheduling in-use channels.") + for channel in self.get_category_channels(self.in_use_category): + await self.move_idle_channel(channel, has_task=False) + + # Prevent the command from being used until ready. + # The ready event wasn't used because channels could change categories between the time + # the command is invoked and the cog is ready (e.g. if move_idle_channel wasn't called yet). + # This may confuse users. So would potentially long delays for the cog to become ready. + self.close_command.enabled = True + + await self.init_available() + + log.info("Cog is ready!") + self.ready.set() + + self.report_stats() + + def report_stats(self) -> None: + """Report the channel count stats.""" + total_in_use = sum(1 for _ in self.get_category_channels(self.in_use_category)) + total_available = sum(1 for _ in self.get_category_channels(self.available_category)) + total_dormant = sum(1 for _ in self.get_category_channels(self.dormant_category)) + + self.bot.stats.gauge("help.total.in_use", total_in_use) + self.bot.stats.gauge("help.total.available", total_available) + self.bot.stats.gauge("help.total.dormant", total_dormant) + + @staticmethod + def is_claimant(member: discord.Member) -> bool: + """Return True if `member` has the 'Help Cooldown' role.""" + return any(constants.Roles.help_cooldown == role.id for role in member.roles) + + def match_bot_embed(self, message: t.Optional[discord.Message], description: str) -> bool: + """Return `True` if the bot's `message`'s embed description matches `description`.""" + if not message or not message.embeds: + return False + + bot_msg_desc = message.embeds[0].description + if bot_msg_desc is discord.Embed.Empty: + log.trace("Last message was a bot embed but it was empty.") + return False + return message.author == self.bot.user and bot_msg_desc.strip() == description.strip() + + @staticmethod + def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: + """Return True if `channel` is within a category with `category_id`.""" + actual_category = getattr(channel, "category", None) + return actual_category is not None and actual_category.id == category_id + + async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None: + """ + Make the `channel` dormant if idle or schedule the move if still active. + + If `has_task` is True and rescheduling is required, the extant task to make the channel + dormant will first be cancelled. + """ + log.trace(f"Handling in-use channel #{channel} ({channel.id}).") + + if not await self.is_empty(channel): + idle_seconds = constants.HelpChannels.idle_minutes * 60 + else: + idle_seconds = constants.HelpChannels.deleted_idle_minutes * 60 + + time_elapsed = await self.get_idle_time(channel) + + if time_elapsed is None or time_elapsed >= idle_seconds: + log.info( + f"#{channel} ({channel.id}) is idle longer than {idle_seconds} seconds " + f"and will be made dormant." + ) + + await self.move_to_dormant(channel, "auto") + else: + # Cancel the existing task, if any. + if has_task: + self.scheduler.cancel(channel.id) + + delay = idle_seconds - time_elapsed + log.info( + f"#{channel} ({channel.id}) is still active; " + f"scheduling it to be moved after {delay} seconds." + ) + + self.scheduler.schedule_later(delay, channel.id, self.move_idle_channel(channel)) + + async def move_to_bottom_position(self, channel: discord.TextChannel, category_id: int, **options) -> None: + """ + Move the `channel` to the bottom position of `category` and edit channel attributes. + + To ensure "stable sorting", we use the `bulk_channel_update` endpoint and provide the current + positions of the other channels in the category as-is. This should make sure that the channel + really ends up at the bottom of the category. + + If `options` are provided, the channel will be edited after the move is completed. This is the + same order of operations that `discord.TextChannel.edit` uses. For information on available + options, see the documention on `discord.TextChannel.edit`. While possible, position-related + options should be avoided, as it may interfere with the category move we perform. + """ + # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. + category = await self.try_get_channel(category_id) + + payload = [{"id": c.id, "position": c.position} for c in category.channels] + + # Calculate the bottom position based on the current highest position in the category. If the + # category is currently empty, we simply use the current position of the channel to avoid making + # unnecessary changes to positions in the guild. + bottom_position = payload[-1]["position"] + 1 if payload else channel.position + + payload.append( + { + "id": channel.id, + "position": bottom_position, + "parent_id": category.id, + "lock_permissions": True, + } + ) + + # We use d.py's method to ensure our request is processed by d.py's rate limit manager + await self.bot.http.bulk_channel_update(category.guild.id, payload) + + # Now that the channel is moved, we can edit the other attributes + if options: + await channel.edit(**options) + + async def move_to_available(self) -> None: + """Make a channel available.""" + log.trace("Making a channel available.") + + channel = await self.get_available_candidate() + log.info(f"Making #{channel} ({channel.id}) available.") + + await self.send_available_message(channel) + + log.trace(f"Moving #{channel} ({channel.id}) to the Available category.") + + await self.move_to_bottom_position( + channel=channel, + category_id=constants.Categories.help_available, + ) + + self.report_stats() + + async def move_to_dormant(self, channel: discord.TextChannel, caller: str) -> None: + """ + Make the `channel` dormant. + + A caller argument is provided for metrics. + """ + log.info(f"Moving #{channel} ({channel.id}) to the Dormant category.") + + await self.help_channel_claimants.delete(channel.id) + await self.move_to_bottom_position( + channel=channel, + category_id=constants.Categories.help_dormant, + ) + + self.bot.stats.incr(f"help.dormant_calls.{caller}") + + in_use_time = await self.get_in_use_time(channel.id) + if in_use_time: + self.bot.stats.timing("help.in_use_time", in_use_time) + + unanswered = await self.unanswered.get(channel.id) + if unanswered: + self.bot.stats.incr("help.sessions.unanswered") + elif unanswered is not None: + self.bot.stats.incr("help.sessions.answered") + + log.trace(f"Position of #{channel} ({channel.id}) is actually {channel.position}.") + log.trace(f"Sending dormant message for #{channel} ({channel.id}).") + embed = discord.Embed(description=DORMANT_MSG) + await channel.send(embed=embed) + + await self.unpin(channel) + + log.trace(f"Pushing #{channel} ({channel.id}) into the channel queue.") + self.channel_queue.put_nowait(channel) + self.report_stats() + + async def move_to_in_use(self, channel: discord.TextChannel) -> None: + """Make a channel in-use and schedule it to be made dormant.""" + log.info(f"Moving #{channel} ({channel.id}) to the In Use category.") + + await self.move_to_bottom_position( + channel=channel, + category_id=constants.Categories.help_in_use, + ) + + timeout = constants.HelpChannels.idle_minutes * 60 + + log.trace(f"Scheduling #{channel} ({channel.id}) to become dormant in {timeout} sec.") + self.scheduler.schedule_later(timeout, channel.id, self.move_idle_channel(channel)) + self.report_stats() + + async def notify(self) -> None: + """ + Send a message notifying about a lack of available help channels. + + Configuration: + + * `HelpChannels.notify` - toggle notifications + * `HelpChannels.notify_channel` - destination channel for notifications + * `HelpChannels.notify_minutes` - minimum interval between notifications + * `HelpChannels.notify_roles` - roles mentioned in notifications + """ + if not constants.HelpChannels.notify: + return + + log.trace("Notifying about lack of channels.") + + if self.last_notification: + elapsed = (datetime.utcnow() - self.last_notification).seconds + minimum_interval = constants.HelpChannels.notify_minutes * 60 + should_send = elapsed >= minimum_interval + else: + should_send = True + + if not should_send: + log.trace("Notification not sent because it's too recent since the previous one.") + return + + try: + log.trace("Sending notification message.") + + channel = self.bot.get_channel(constants.HelpChannels.notify_channel) + mentions = " ".join(f"<@&{role}>" for role in constants.HelpChannels.notify_roles) + allowed_roles = [discord.Object(id_) for id_ in constants.HelpChannels.notify_roles] + + message = await channel.send( + f"{mentions} A new available help channel is needed but there " + f"are no more dormant ones. Consider freeing up some in-use channels manually by " + f"using the `{constants.Bot.prefix}dormant` command within the channels.", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + + self.bot.stats.incr("help.out_of_channel_alerts") + + self.last_notification = message.created_at + except Exception: + # Handle it here cause this feature isn't critical for the functionality of the system. + log.exception("Failed to send notification about lack of dormant channels!") + + async def check_for_answer(self, message: discord.Message) -> None: + """Checks for whether new content in a help channel comes from non-claimants.""" + channel = message.channel + + # Confirm the channel is an in use help channel + if self.is_in_category(channel, constants.Categories.help_in_use): + log.trace(f"Checking if #{channel} ({channel.id}) has been answered.") + + # Check if there is an entry in unanswered + if await self.unanswered.contains(channel.id): + claimant_id = await self.help_channel_claimants.get(channel.id) + if not claimant_id: + # The mapping for this channel doesn't exist, we can't do anything. + return + + # Check the message did not come from the claimant + if claimant_id != message.author.id: + # Mark the channel as answered + await self.unanswered.set(channel.id, False) + + @commands.Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Move an available channel to the In Use category and replace it with a dormant one.""" + if message.author.bot: + return # Ignore messages sent by bots. + + channel = message.channel + + await self.check_for_answer(message) + + if not self.is_in_category(channel, constants.Categories.help_available) or self.is_excluded_channel(channel): + return # Ignore messages outside the Available category or in excluded channels. + + log.trace("Waiting for the cog to be ready before processing messages.") + await self.ready.wait() + + log.trace("Acquiring lock to prevent a channel from being processed twice...") + async with self.on_message_lock: + log.trace(f"on_message lock acquired for {message.id}.") + + if not self.is_in_category(channel, constants.Categories.help_available): + log.debug( + f"Message {message.id} will not make #{channel} ({channel.id}) in-use " + f"because another message in the channel already triggered that." + ) + return + + log.info(f"Channel #{channel} was claimed by `{message.author.id}`.") + await self.move_to_in_use(channel) + await self.revoke_send_permissions(message.author) + + await self.pin(message) + + # Add user with channel for dormant check. + await self.help_channel_claimants.set(channel.id, message.author.id) + + self.bot.stats.incr("help.claimed") + + # Must use a timezone-aware datetime to ensure a correct POSIX timestamp. + timestamp = datetime.now(timezone.utc).timestamp() + await self.claim_times.set(channel.id, timestamp) + + await self.unanswered.set(channel.id, True) + + log.trace(f"Releasing on_message lock for {message.id}.") + + # Move a dormant channel to the Available category to fill in the gap. + # This is done last and outside the lock because it may wait indefinitely for a channel to + # be put in the queue. + await self.move_to_available() + + @commands.Cog.listener() + async def on_message_delete(self, msg: discord.Message) -> None: + """ + Reschedule an in-use channel to become dormant sooner if the channel is empty. + + The new time for the dormant task is configured with `HelpChannels.deleted_idle_minutes`. + """ + if not self.is_in_category(msg.channel, constants.Categories.help_in_use): + return + + if not await self.is_empty(msg.channel): + return + + log.info(f"Claimant of #{msg.channel} ({msg.author}) deleted message, channel is empty now. Rescheduling task.") + + # Cancel existing dormant task before scheduling new. + self.scheduler.cancel(msg.channel.id) + + delay = constants.HelpChannels.deleted_idle_minutes * 60 + self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel)) + + async def is_empty(self, channel: discord.TextChannel) -> bool: + """Return True if there's an AVAILABLE_MSG and the messages leading up are bot messages.""" + log.trace(f"Checking if #{channel} ({channel.id}) is empty.") + + # A limit of 100 results in a single API call. + # If AVAILABLE_MSG isn't found within 100 messages, then assume the channel is not empty. + # Not gonna do an extensive search for it cause it's too expensive. + async for msg in channel.history(limit=100): + if not msg.author.bot: + log.trace(f"#{channel} ({channel.id}) has a non-bot message.") + return False + + if self.match_bot_embed(msg, AVAILABLE_MSG): + log.trace(f"#{channel} ({channel.id}) has the available message embed.") + return True + + return False + + async def check_cooldowns(self) -> None: + """Remove expired cooldowns and re-schedule active ones.""" + log.trace("Checking all cooldowns to remove or re-schedule them.") + guild = self.bot.get_guild(constants.Guild.id) + cooldown = constants.HelpChannels.claim_minutes * 60 + + for channel_id, member_id in await self.help_channel_claimants.items(): + member = guild.get_member(member_id) + if not member: + continue # Member probably left the guild. + + in_use_time = await self.get_in_use_time(channel_id) + + if not in_use_time or in_use_time.seconds > cooldown: + # Remove the role if no claim time could be retrieved or if the cooldown expired. + # Since the channel is in the claimants cache, it is definitely strange for a time + # to not exist. However, it isn't a reason to keep the user stuck with a cooldown. + await self.remove_cooldown_role(member) + else: + # The member is still on a cooldown; re-schedule it for the remaining time. + delay = cooldown - in_use_time.seconds + self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) + + async def add_cooldown_role(self, member: discord.Member) -> None: + """Add the help cooldown role to `member`.""" + log.trace(f"Adding cooldown role for {member} ({member.id}).") + await self._change_cooldown_role(member, member.add_roles) + + async def remove_cooldown_role(self, member: discord.Member) -> None: + """Remove the help cooldown role from `member`.""" + log.trace(f"Removing cooldown role for {member} ({member.id}).") + await self._change_cooldown_role(member, member.remove_roles) + + async def _change_cooldown_role(self, member: discord.Member, coro_func: CoroutineFunc) -> None: + """ + Change `member`'s cooldown role via awaiting `coro_func` and handle errors. + + `coro_func` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. + """ + guild = self.bot.get_guild(constants.Guild.id) + role = guild.get_role(constants.Roles.help_cooldown) + if role is None: + log.warning(f"Help cooldown role ({constants.Roles.help_cooldown}) could not be found!") + return + + try: + await coro_func(role) + except discord.NotFound: + log.debug(f"Failed to change role for {member} ({member.id}): member not found") + except discord.Forbidden: + log.debug( + f"Forbidden to change role for {member} ({member.id}); " + f"possibly due to role hierarchy" + ) + except discord.HTTPException as e: + log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") + + async def revoke_send_permissions(self, member: discord.Member) -> None: + """ + Disallow `member` to send messages in the Available category for a certain time. + + The time until permissions are reinstated can be configured with + `HelpChannels.claim_minutes`. + """ + log.trace( + f"Revoking {member}'s ({member.id}) send message permissions in the Available category." + ) + + await self.add_cooldown_role(member) + + # Cancel the existing task, if any. + # Would mean the user somehow bypassed the lack of permissions (e.g. user is guild owner). + if member.id in self.scheduler: + self.scheduler.cancel(member.id) + + delay = constants.HelpChannels.claim_minutes * 60 + self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) + + async def send_available_message(self, channel: discord.TextChannel) -> None: + """Send the available message by editing a dormant message or sending a new message.""" + channel_info = f"#{channel} ({channel.id})" + log.trace(f"Sending available message in {channel_info}.") + + embed = discord.Embed(description=AVAILABLE_MSG) + + msg = await self.get_last_message(channel) + if self.match_bot_embed(msg, DORMANT_MSG): + log.trace(f"Found dormant message {msg.id} in {channel_info}; editing it.") + await msg.edit(embed=embed) + else: + log.trace(f"Dormant message not found in {channel_info}; sending a new message.") + await channel.send(embed=embed) + + async def try_get_channel(self, channel_id: int) -> discord.abc.GuildChannel: + """Attempt to get or fetch a channel and return it.""" + log.trace(f"Getting the channel {channel_id}.") + + channel = self.bot.get_channel(channel_id) + if not channel: + log.debug(f"Channel {channel_id} is not in cache; fetching from API.") + channel = await self.bot.fetch_channel(channel_id) + + log.trace(f"Channel #{channel} ({channel_id}) retrieved.") + return channel + + async def pin_wrapper(self, msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool: + """ + Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. + + Return True if successful and False otherwise. + """ + channel_str = f"#{channel} ({channel.id})" + if pin: + func = self.bot.http.pin_message + verb = "pin" + else: + func = self.bot.http.unpin_message + verb = "unpin" + + try: + await func(channel.id, msg_id) + except discord.HTTPException as e: + if e.code == 10008: + log.debug(f"Message {msg_id} in {channel_str} doesn't exist; can't {verb}.") + else: + log.exception( + f"Error {verb}ning message {msg_id} in {channel_str}: {e.status} ({e.code})" + ) + return False + else: + log.trace(f"{verb.capitalize()}ned message {msg_id} in {channel_str}.") + return True + + async def pin(self, message: discord.Message) -> None: + """Pin an initial question `message` and store it in a cache.""" + if await self.pin_wrapper(message.id, message.channel, pin=True): + await self.question_messages.set(message.channel.id, message.id) + + async def unpin(self, channel: discord.TextChannel) -> None: + """Unpin the initial question message sent in `channel`.""" + msg_id = await self.question_messages.pop(channel.id) + if msg_id is None: + log.debug(f"#{channel} ({channel.id}) doesn't have a message pinned.") + else: + await self.pin_wrapper(msg_id, channel, pin=False) + + async def wait_for_dormant_channel(self) -> discord.TextChannel: + """Wait for a dormant channel to become available in the queue and return it.""" + log.trace("Waiting for a dormant channel.") + + task = asyncio.create_task(self.channel_queue.get()) + self.queue_tasks.append(task) + channel = await task + + log.trace(f"Channel #{channel} ({channel.id}) finally retrieved from the queue.") + self.queue_tasks.remove(task) + + return channel + + +def validate_config() -> None: + """Raise a ValueError if the cog's config is invalid.""" + log.trace("Validating config.") + total = constants.HelpChannels.max_total_channels + available = constants.HelpChannels.max_available + + if total == 0 or available == 0: + raise ValueError("max_total_channels and max_available and must be greater than 0.") + + if total < available: + raise ValueError( + f"max_total_channels ({total}) must be greater than or equal to max_available " + f"({available})." + ) + + if total > MAX_CHANNELS_PER_CATEGORY: + raise ValueError( + f"max_total_channels ({total}) must be less than or equal to " + f"{MAX_CHANNELS_PER_CATEGORY} due to Discord's limit on channels per category." + ) + + +def setup(bot: Bot) -> None: + """Load the HelpChannels cog.""" + try: + validate_config() + except ValueError as e: + log.error(f"HelpChannels cog will not be loaded due to misconfiguration: {e}") + else: + bot.add_cog(HelpChannels(bot)) diff --git a/bot/exts/info/__init__.py b/bot/exts/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py new file mode 100644 index 000000000..204cffb37 --- /dev/null +++ b/bot/exts/info/doc.py @@ -0,0 +1,511 @@ +import asyncio +import functools +import logging +import re +import textwrap +from collections import OrderedDict +from contextlib import suppress +from types import SimpleNamespace +from typing import Any, Callable, Optional, Tuple + +import discord +from bs4 import BeautifulSoup +from bs4.element import PageElement, Tag +from discord.errors import NotFound +from discord.ext import commands +from markdownify import MarkdownConverter +from requests import ConnectTimeout, ConnectionError, HTTPError +from sphinx.ext import intersphinx +from urllib3.exceptions import ProtocolError + +from bot.bot import Bot +from bot.constants import MODERATION_ROLES, RedirectOutput +from bot.converters import ValidPythonIdentifier, ValidURL +from bot.decorators import with_role +from bot.pagination import LinePaginator + + +log = logging.getLogger(__name__) +logging.getLogger('urllib3').setLevel(logging.WARNING) + +# Since Intersphinx is intended to be used with Sphinx, +# we need to mock its configuration. +SPHINX_MOCK_APP = SimpleNamespace( + config=SimpleNamespace( + intersphinx_timeout=3, + tls_verify=True, + user_agent="python3:python-discord/bot:1.0.0" + ) +) + +NO_OVERRIDE_GROUPS = ( + "2to3fixer", + "token", + "label", + "pdbcommand", + "term", +) +NO_OVERRIDE_PACKAGES = ( + "python", +) + +SEARCH_END_TAG_ATTRS = ( + "data", + "function", + "class", + "exception", + "seealso", + "section", + "rubric", + "sphinxsidebar", +) +UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") +WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") + +FAILED_REQUEST_RETRY_AMOUNT = 3 +NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay + + +def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: + """ + LRU cache implementation for coroutines. + + Once the cache exceeds the maximum size, keys are deleted in FIFO order. + + An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. + """ + # Assign the cache to the function itself so we can clear it from outside. + async_cache.cache = OrderedDict() + + def decorator(function: Callable) -> Callable: + """Define the async_cache decorator.""" + @functools.wraps(function) + async def wrapper(*args) -> Any: + """Decorator wrapper for the caching logic.""" + key = ':'.join(args[arg_offset:]) + + value = async_cache.cache.get(key) + if value is None: + if len(async_cache.cache) > max_size: + async_cache.cache.popitem(last=False) + + async_cache.cache[key] = await function(*args) + return async_cache.cache[key] + return wrapper + return decorator + + +class DocMarkdownConverter(MarkdownConverter): + """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" + + def convert_code(self, el: PageElement, text: str) -> str: + """Undo `markdownify`s underscore escaping.""" + return f"`{text}`".replace('\\', '') + + def convert_pre(self, el: PageElement, text: str) -> str: + """Wrap any codeblocks in `py` for syntax highlighting.""" + code = ''.join(el.strings) + return f"```py\n{code}```" + + +def markdownify(html: str) -> DocMarkdownConverter: + """Create a DocMarkdownConverter object from the input html.""" + return DocMarkdownConverter(bullets='•').convert(html) + + +class InventoryURL(commands.Converter): + """ + Represents an Intersphinx inventory URL. + + This converter checks whether intersphinx accepts the given inventory URL, and raises + `BadArgument` if that is not the case. + + Otherwise, it simply passes through the given URL. + """ + + @staticmethod + async def convert(ctx: commands.Context, url: str) -> str: + """Convert url to Intersphinx inventory URL.""" + try: + intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) + except AttributeError: + raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") + except ConnectionError: + if url.startswith('https'): + raise commands.BadArgument( + f"Cannot establish a connection to `{url}`. Does it support HTTPS?" + ) + raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") + except ValueError: + raise commands.BadArgument( + f"Failed to read Intersphinx inventory from URL `{url}`. " + "Are you sure that it's a valid inventory file?" + ) + return url + + +class Doc(commands.Cog): + """A set of commands for querying & displaying documentation.""" + + def __init__(self, bot: Bot): + self.base_urls = {} + self.bot = bot + self.inventories = {} + self.renamed_symbols = set() + + self.bot.loop.create_task(self.init_refresh_inventory()) + + async def init_refresh_inventory(self) -> None: + """Refresh documentation inventory on cog initialization.""" + await self.bot.wait_until_guild_available() + await self.refresh_inventory() + + async def update_single( + self, package_name: str, base_url: str, inventory_url: str + ) -> None: + """ + Rebuild the inventory for a single package. + + Where: + * `package_name` is the package name to use, appears in the log + * `base_url` is the root documentation URL for the specified package, used to build + absolute paths that link to specific symbols + * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running + `intersphinx.fetch_inventory` in an executor on the bot's event loop + """ + self.base_urls[package_name] = base_url + + package = await self._fetch_inventory(inventory_url) + if not package: + return None + + for group, value in package.items(): + for symbol, (package_name, _version, relative_doc_url, _) in value.items(): + absolute_doc_url = base_url + relative_doc_url + + if symbol in self.inventories: + group_name = group.split(":")[1] + symbol_base_url = self.inventories[symbol].split("/", 3)[2] + if ( + group_name in NO_OVERRIDE_GROUPS + or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) + ): + + symbol = f"{group_name}.{symbol}" + # If renamed `symbol` already exists, add library name in front to differentiate between them. + if symbol in self.renamed_symbols: + # Split `package_name` because of packages like Pillow that have spaces in them. + symbol = f"{package_name.split()[0]}.{symbol}" + + self.inventories[symbol] = absolute_doc_url + self.renamed_symbols.add(symbol) + continue + + self.inventories[symbol] = absolute_doc_url + + log.trace(f"Fetched inventory for {package_name}.") + + async def refresh_inventory(self) -> None: + """Refresh internal documentation inventory.""" + log.debug("Refreshing documentation inventory...") + + # Clear the old base URLS and inventories to ensure + # that we start from a fresh local dataset. + # Also, reset the cache used for fetching documentation. + self.base_urls.clear() + self.inventories.clear() + self.renamed_symbols.clear() + async_cache.cache = OrderedDict() + + # Run all coroutines concurrently - since each of them performs a HTTP + # request, this speeds up fetching the inventory data heavily. + coros = [ + self.update_single( + package["package"], package["base_url"], package["inventory_url"] + ) for package in await self.bot.api_client.get('bot/documentation-links') + ] + await asyncio.gather(*coros) + + async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: + """ + Given a Python symbol, return its signature and description. + + The first tuple element is the signature of the given symbol as a markup-free string, and + the second tuple element is the description of the given symbol with HTML markup included. + + If the given symbol is a module, returns a tuple `(None, str)` + else if the symbol could not be found, returns `None`. + """ + url = self.inventories.get(symbol) + if url is None: + return None + + async with self.bot.http_session.get(url) as response: + html = await response.text(encoding='utf-8') + + # Find the signature header and parse the relevant parts. + symbol_id = url.split('#')[-1] + soup = BeautifulSoup(html, 'lxml') + symbol_heading = soup.find(id=symbol_id) + search_html = str(soup) + + if symbol_heading is None: + return None + + if symbol_id == f"module-{symbol}": + # Get page content from the module headerlink to the + # first tag that has its class in `SEARCH_END_TAG_ATTRS` + start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) + if start_tag is None: + return [], "" + + end_tag = start_tag.find_next(self._match_end_tag) + if end_tag is None: + return [], "" + + description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) + description_end_index = search_html.find(str(end_tag)) + description = search_html[description_start_index:description_end_index] + signatures = None + + else: + signatures = [] + description = str(symbol_heading.find_next_sibling("dd")) + description_pos = search_html.find(description) + # Get text of up to 3 signatures, remove unwanted symbols + for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): + signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) + if signature and search_html.find(str(element)) < description_pos: + signatures.append(signature) + + return signatures, description.replace('¶', '') + + @async_cache(arg_offset=1) + async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: + """ + Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. + + If the symbol is known, an Embed with documentation about it is returned. + """ + scraped_html = await self.get_symbol_html(symbol) + if scraped_html is None: + return None + + signatures = scraped_html[0] + permalink = self.inventories[symbol] + description = markdownify(scraped_html[1]) + + # Truncate the description of the embed to the last occurrence + # of a double newline (interpreted as a paragraph) before index 1000. + if len(description) > 1000: + shortened = description[:1000] + description_cutoff = shortened.rfind('\n\n', 100) + if description_cutoff == -1: + # Search the shortened version for cutoff points in decreasing desirability, + # cutoff at 1000 if none are found. + for string in (". ", ", ", ",", " "): + description_cutoff = shortened.rfind(string) + if description_cutoff != -1: + break + else: + description_cutoff = 1000 + description = description[:description_cutoff] + + # If there is an incomplete code block, cut it out + if description.count("```") % 2: + codeblock_start = description.rfind('```py') + description = description[:codeblock_start].rstrip() + description += f"... [read more]({permalink})" + + description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) + if signatures is None: + # If symbol is a module, don't show signature. + embed_description = description + + elif not signatures: + # It's some "meta-page", for example: + # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views + embed_description = "This appears to be a generic page not tied to a specific symbol." + + else: + embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) + embed_description += f"\n{description}" + + embed = discord.Embed( + title=f'`{symbol}`', + url=permalink, + description=embed_description + ) + # Show all symbols with the same name that were renamed in the footer. + embed.set_footer( + text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) + ) + return embed + + @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) + async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: + """Lookup documentation for Python symbols.""" + await ctx.invoke(self.get_command, symbol) + + @docs_group.command(name='get', aliases=('g',)) + async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: + """ + Return a documentation embed for a given symbol. + + If no symbol is given, return a list of all available inventories. + + Examples: + !docs + !docs aiohttp + !docs aiohttp.ClientSession + !docs get aiohttp.ClientSession + """ + if symbol is None: + inventory_embed = discord.Embed( + title=f"All inventories (`{len(self.base_urls)}` total)", + colour=discord.Colour.blue() + ) + + lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) + if self.base_urls: + await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) + + else: + inventory_embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=inventory_embed) + + else: + # Fetching documentation for a symbol (at least for the first time, since + # caching is used) takes quite some time, so let's send typing to indicate + # that we got the command, but are still working on it. + async with ctx.typing(): + doc_embed = await self.get_symbol_embed(symbol) + + if doc_embed is None: + error_embed = discord.Embed( + description=f"Sorry, I could not find any documentation for `{symbol}`.", + colour=discord.Colour.red() + ) + error_message = await ctx.send(embed=error_embed) + with suppress(NotFound): + await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) + await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) + else: + await ctx.send(embed=doc_embed) + + @docs_group.command(name='set', aliases=('s',)) + @with_role(*MODERATION_ROLES) + async def set_command( + self, ctx: commands.Context, package_name: ValidPythonIdentifier, + base_url: ValidURL, inventory_url: InventoryURL + ) -> None: + """ + Adds a new documentation metadata object to the site's database. + + The database will update the object, should an existing item with the specified `package_name` already exist. + + Example: + !docs set \ + python \ + https://docs.python.org/3/ \ + https://docs.python.org/3/objects.inv + """ + body = { + 'package': package_name, + 'base_url': base_url, + 'inventory_url': inventory_url + } + await self.bot.api_client.post('bot/documentation-links', json=body) + + log.info( + f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" + f"Package name: {package_name}\n" + f"Base url: {base_url}\n" + f"Inventory URL: {inventory_url}" + ) + + # Rebuilding the inventory can take some time, so lets send out a + # typing event to show that the Bot is still working. + async with ctx.typing(): + await self.refresh_inventory() + await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") + + @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) + @with_role(*MODERATION_ROLES) + async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: + """ + Removes the specified package from the database. + + Examples: + !docs delete aiohttp + """ + await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') + + async with ctx.typing(): + # Rebuild the inventory to ensure that everything + # that was from this package is properly deleted. + await self.refresh_inventory() + await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") + + @docs_group.command(name="refresh", aliases=("rfsh", "r")) + @with_role(*MODERATION_ROLES) + async def refresh_command(self, ctx: commands.Context) -> None: + """Refresh inventories and send differences to channel.""" + old_inventories = set(self.base_urls) + with ctx.typing(): + await self.refresh_inventory() + # Get differences of added and removed inventories + added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) + if added: + added = f"+ {added}" + + removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) + if removed: + removed = f"- {removed}" + + embed = discord.Embed( + title="Inventories refreshed", + description=f"```diff\n{added}\n{removed}```" if added or removed else "" + ) + await ctx.send(embed=embed) + + async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: + """Get and return inventory from `inventory_url`. If fetching fails, return None.""" + fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) + for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): + try: + package = await self.bot.loop.run_in_executor(None, fetch_func) + except ConnectTimeout: + log.error( + f"Fetching of inventory {inventory_url} timed out," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except ProtocolError: + log.error( + f"Connection lost while fetching inventory {inventory_url}," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except HTTPError as e: + log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") + return None + except ConnectionError: + log.error(f"Couldn't establish connection to inventory {inventory_url}.") + return None + else: + return package + log.error(f"Fetching of inventory {inventory_url} failed.") + return None + + @staticmethod + def _match_end_tag(tag: Tag) -> bool: + """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" + for attr in SEARCH_END_TAG_ATTRS: + if attr in tag.get("class", ()): + return True + + return tag.name == "table" + + +def setup(bot: Bot) -> None: + """Load the Doc cog.""" + bot.add_cog(Doc(bot)) diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py new file mode 100644 index 000000000..3d1d6fd10 --- /dev/null +++ b/bot/exts/info/help.py @@ -0,0 +1,375 @@ +import itertools +import logging +from asyncio import TimeoutError +from collections import namedtuple +from contextlib import suppress +from typing import List, Union + +from discord import Colour, Embed, Member, Message, NotFound, Reaction, User +from discord.ext.commands import Bot, Cog, Command, Context, Group, HelpCommand +from fuzzywuzzy import fuzz, process +from fuzzywuzzy.utils import full_process + +from bot import constants +from bot.constants import Channels, Emojis, STAFF_ROLES +from bot.decorators import redirect_output +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + +COMMANDS_PER_PAGE = 8 +DELETE_EMOJI = Emojis.trashcan +PREFIX = constants.Bot.prefix + +Category = namedtuple("Category", ["name", "description", "cogs"]) + + +async def help_cleanup(bot: Bot, author: Member, message: Message) -> None: + """ + Runs the cleanup for the help command. + + Adds the :trashcan: reaction that, when clicked, will delete the help message. + After a 300 second timeout, the reaction will be removed. + """ + def check(reaction: Reaction, user: User) -> bool: + """Checks the reaction is :trashcan:, the author is original author and messages are the same.""" + return str(reaction) == DELETE_EMOJI and user.id == author.id and reaction.message.id == message.id + + await message.add_reaction(DELETE_EMOJI) + + with suppress(NotFound): + try: + await bot.wait_for("reaction_add", check=check, timeout=300) + await message.delete() + except TimeoutError: + await message.remove_reaction(DELETE_EMOJI, bot.user) + + +class HelpQueryNotFound(ValueError): + """ + Raised when a HelpSession Query doesn't match a command or cog. + + Contains the custom attribute of ``possible_matches``. + + Instances of this object contain a dictionary of any command(s) that were close to matching the + query, where keys are the possible matched command names and values are the likeness match scores. + """ + + def __init__(self, arg: str, possible_matches: dict = None): + super().__init__(arg) + self.possible_matches = possible_matches + + +class CustomHelpCommand(HelpCommand): + """ + An interactive instance for the bot help command. + + Cogs can be grouped into custom categories. All cogs with the same category will be displayed + under a single category name in the help output. Custom categories are defined inside the cogs + as a class attribute named `category`. A description can also be specified with the attribute + `category_description`. If a description is not found in at least one cog, the default will be + the regular description (class docstring) of the first cog found in the category. + """ + + def __init__(self): + super().__init__(command_attrs={"help": "Shows help for bot commands"}) + + @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) + async def command_callback(self, ctx: Context, *, command: str = None) -> None: + """Attempts to match the provided query with a valid command or cog.""" + # the only reason we need to tamper with this is because d.py does not support "categories", + # so we need to deal with them ourselves. + + bot = ctx.bot + + if command is None: + # quick and easy, send bot help if command is none + mapping = self.get_bot_mapping() + await self.send_bot_help(mapping) + return + + cog_matches = [] + description = None + for cog in bot.cogs.values(): + if hasattr(cog, "category") and cog.category == command: + cog_matches.append(cog) + if hasattr(cog, "category_description"): + description = cog.category_description + + if cog_matches: + category = Category(name=command, description=description, cogs=cog_matches) + await self.send_category_help(category) + return + + # it's either a cog, group, command or subcommand; let the parent class deal with it + await super().command_callback(ctx, command=command) + + async def get_all_help_choices(self) -> set: + """ + Get all the possible options for getting help in the bot. + + This will only display commands the author has permission to run. + + These include: + - Category names + - Cog names + - Group command names (and aliases) + - Command names (and aliases) + - Subcommand names (with parent group and aliases for subcommand, but not including aliases for group) + + Options and choices are case sensitive. + """ + # first get all commands including subcommands and full command name aliases + choices = set() + for command in await self.filter_commands(self.context.bot.walk_commands()): + # the the command or group name + choices.add(str(command)) + + if isinstance(command, Command): + # all aliases if it's just a command + choices.update(command.aliases) + else: + # otherwise we need to add the parent name in + choices.update(f"{command.full_parent_name} {alias}" for alias in command.aliases) + + # all cog names + choices.update(self.context.bot.cogs) + + # all category names + choices.update(cog.category for cog in self.context.bot.cogs.values() if hasattr(cog, "category")) + return choices + + async def command_not_found(self, string: str) -> "HelpQueryNotFound": + """ + Handles when a query does not match a valid command, group, cog or category. + + Will return an instance of the `HelpQueryNotFound` exception with the error message and possible matches. + """ + choices = await self.get_all_help_choices() + + # Run fuzzywuzzy's processor beforehand, and avoid matching if processed string is empty + # This avoids fuzzywuzzy from raising a warning on inputs with only non-alphanumeric characters + if (processed := full_process(string)): + result = process.extractBests(processed, choices, scorer=fuzz.ratio, score_cutoff=60, processor=None) + else: + result = [] + + return HelpQueryNotFound(f'Query "{string}" not found.', dict(result)) + + async def subcommand_not_found(self, command: Command, string: str) -> "HelpQueryNotFound": + """ + Redirects the error to `command_not_found`. + + `command_not_found` deals with searching and getting best choices for both commands and subcommands. + """ + return await self.command_not_found(f"{command.qualified_name} {string}") + + async def send_error_message(self, error: HelpQueryNotFound) -> None: + """Send the error message to the channel.""" + embed = Embed(colour=Colour.red(), title=str(error)) + + if getattr(error, "possible_matches", None): + matches = "\n".join(f"`{match}`" for match in error.possible_matches) + embed.description = f"**Did you mean:**\n{matches}" + + await self.context.send(embed=embed) + + async def command_formatting(self, command: Command) -> Embed: + """ + Takes a command and turns it into an embed. + + It will add an author, command signature + help, aliases and a note if the user can't run the command. + """ + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + parent = command.full_parent_name + + name = str(command) if not parent else f"{parent} {command.name}" + command_details = f"**```{PREFIX}{name} {command.signature}```**\n" + + # show command aliases + aliases = ", ".join(f"`{alias}`" if not parent else f"`{parent} {alias}`" for alias in command.aliases) + if aliases: + command_details += f"**Can also use:** {aliases}\n\n" + + # check if the user is allowed to run this command + if not await command.can_run(self.context): + command_details += "***You cannot run this command.***\n\n" + + command_details += f"*{command.help or 'No details provided.'}*\n" + embed.description = command_details + + return embed + + async def send_command_help(self, command: Command) -> None: + """Send help for a single command.""" + embed = await self.command_formatting(command) + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + @staticmethod + def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: + """ + Formats the prefix, command name and signature, and short doc for an iterable of commands. + + return_as_list is helpful for passing these command details into the paginator as a list of command details. + """ + details = [] + for command in commands_: + signature = f" {command.signature}" if command.signature else "" + details.append( + f"\n**`{PREFIX}{command.qualified_name}{signature}`**\n*{command.short_doc or 'No details provided'}*" + ) + if return_as_list: + return details + else: + return "".join(details) + + async def send_group_help(self, group: Group) -> None: + """Sends help for a group command.""" + subcommands = group.commands + + if len(subcommands) == 0: + # no subcommands, just treat it like a regular command + await self.send_command_help(group) + return + + # remove commands that the user can't run and are hidden, and sort by name + commands_ = await self.filter_commands(subcommands, sort=True) + + embed = await self.command_formatting(group) + + command_details = self.get_commands_brief_details(commands_) + if command_details: + embed.description += f"\n**Subcommands:**\n{command_details}" + + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + async def send_cog_help(self, cog: Cog) -> None: + """Send help for a cog.""" + # sort commands by name, and remove any the user cant run or are hidden. + commands_ = await self.filter_commands(cog.get_commands(), sort=True) + + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + embed.description = f"**{cog.qualified_name}**\n*{cog.description}*" + + command_details = self.get_commands_brief_details(commands_) + if command_details: + embed.description += f"\n\n**Commands:**\n{command_details}" + + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + @staticmethod + def _category_key(command: Command) -> str: + """ + Returns a cog name of a given command for use as a key for `sorted` and `groupby`. + + A zero width space is used as a prefix for results with no cogs to force them last in ordering. + """ + if command.cog: + with suppress(AttributeError): + if command.cog.category: + return f"**{command.cog.category}**" + return f"**{command.cog_name}**" + else: + return "**\u200bNo Category:**" + + async def send_category_help(self, category: Category) -> None: + """ + Sends help for a bot category. + + This sends a brief help for all commands in all cogs registered to the category. + """ + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + all_commands = [] + for cog in category.cogs: + all_commands.extend(cog.get_commands()) + + filtered_commands = await self.filter_commands(all_commands, sort=True) + + command_detail_lines = self.get_commands_brief_details(filtered_commands, return_as_list=True) + description = f"**{category.name}**\n*{category.description}*" + + if command_detail_lines: + description += "\n\n**Commands:**" + + await LinePaginator.paginate( + command_detail_lines, + self.context, + embed, + prefix=description, + max_lines=COMMANDS_PER_PAGE, + max_size=2000, + ) + + async def send_bot_help(self, mapping: dict) -> None: + """Sends help for all bot commands and cogs.""" + bot = self.context.bot + + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + filter_commands = await self.filter_commands(bot.commands, sort=True, key=self._category_key) + + cog_or_category_pages = [] + + for cog_or_category, _commands in itertools.groupby(filter_commands, key=self._category_key): + sorted_commands = sorted(_commands, key=lambda c: c.name) + + if len(sorted_commands) == 0: + continue + + command_detail_lines = self.get_commands_brief_details(sorted_commands, return_as_list=True) + + # Split cogs or categories which have too many commands to fit in one page. + # The length of commands is included for later use when aggregating into pages for the paginator. + for index in range(0, len(sorted_commands), COMMANDS_PER_PAGE): + truncated_lines = command_detail_lines[index:index + COMMANDS_PER_PAGE] + joined_lines = "".join(truncated_lines) + cog_or_category_pages.append((f"**{cog_or_category}**{joined_lines}", len(truncated_lines))) + + pages = [] + counter = 0 + page = "" + for page_details, length in cog_or_category_pages: + counter += length + if counter > COMMANDS_PER_PAGE: + # force a new page on paginator even if it falls short of the max pages + # since we still want to group categories/cogs. + counter = length + pages.append(page) + page = f"{page_details}\n\n" + else: + page += f"{page_details}\n\n" + + if page: + # add any remaining command help that didn't get added in the last iteration above. + pages.append(page) + + await LinePaginator.paginate(pages, self.context, embed=embed, max_lines=1, max_size=2000) + + +class Help(Cog): + """Custom Embed Pagination Help feature.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.old_help_command = bot.help_command + bot.help_command = CustomHelpCommand() + bot.help_command.cog = self + + def cog_unload(self) -> None: + """Reset the help command when the cog is unloaded.""" + self.bot.help_command = self.old_help_command + + +def setup(bot: Bot) -> None: + """Load the Help cog.""" + bot.add_cog(Help(bot)) + log.info("Cog loaded: Help") diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py new file mode 100644 index 000000000..8982196d1 --- /dev/null +++ b/bot/exts/info/information.py @@ -0,0 +1,422 @@ +import colorsys +import logging +import pprint +import textwrap +from collections import Counter, defaultdict +from string import Template +from typing import Any, Mapping, Optional, Union + +from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils +from discord.abc import GuildChannel +from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group +from discord.utils import escape_markdown + +from bot import constants +from bot.bot import Bot +from bot.decorators import in_whitelist, with_role +from bot.pagination import LinePaginator +from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + + +class Information(Cog): + """A cog with commands for generating embeds with server info, such as server stats and user info.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @staticmethod + def role_can_read(channel: GuildChannel, role: Role) -> bool: + """Return True if `role` can read messages in `channel`.""" + overwrites = channel.overwrites_for(role) + return overwrites.read_messages is True + + def get_staff_channel_count(self, guild: Guild) -> int: + """ + Get the number of channels that are staff-only. + + We need to know two things about a channel: + - Does the @everyone role have explicit read deny permissions? + - Do staff roles have explicit read allow permissions? + + If the answer to both of these questions is yes, it's a staff channel. + """ + channel_ids = set() + for channel in guild.channels: + if channel.type is ChannelType.category: + continue + + everyone_can_read = self.role_can_read(channel, guild.default_role) + + for role in constants.STAFF_ROLES: + role_can_read = self.role_can_read(channel, guild.get_role(role)) + if role_can_read and not everyone_can_read: + channel_ids.add(channel.id) + break + + return len(channel_ids) + + @staticmethod + def get_channel_type_counts(guild: Guild) -> str: + """Return the total amounts of the various types of channels in `guild`.""" + channel_counter = Counter(c.type for c in guild.channels) + channel_type_list = [] + for channel, count in channel_counter.items(): + channel_type = str(channel).title() + channel_type_list.append(f"{channel_type} channels: {count}") + + channel_type_list = sorted(channel_type_list) + return "\n".join(channel_type_list) + + @with_role(*constants.MODERATION_ROLES) + @command(name="roles") + async def roles_info(self, ctx: Context) -> None: + """Returns a list of all roles and their corresponding IDs.""" + # Sort the roles alphabetically and remove the @everyone role + roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) + + # Build a list + role_list = [] + for role in roles: + role_list.append(f"`{role.id}` - {role.mention}") + + # Build an embed + embed = Embed( + title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", + colour=Colour.blurple() + ) + + await LinePaginator.paginate(role_list, ctx, embed, empty=False) + + @with_role(*constants.MODERATION_ROLES) + @command(name="role") + async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: + """ + Return information on a role or list of roles. + + To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. + """ + parsed_roles = [] + failed_roles = [] + + for role_name in roles: + if isinstance(role_name, Role): + # Role conversion has already succeeded + parsed_roles.append(role_name) + continue + + role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) + + if not role: + failed_roles.append(role_name) + continue + + parsed_roles.append(role) + + if failed_roles: + await ctx.send(f":x: Could not retrieve the following roles: {', '.join(failed_roles)}") + + for role in parsed_roles: + h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) + + embed = Embed( + title=f"{role.name} info", + colour=role.colour, + ) + embed.add_field(name="ID", value=role.id, inline=True) + embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) + embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) + embed.add_field(name="Member count", value=len(role.members), inline=True) + embed.add_field(name="Position", value=role.position) + embed.add_field(name="Permission code", value=role.permissions.value, inline=True) + + await ctx.send(embed=embed) + + @command(name="server", aliases=["server_info", "guild", "guild_info"]) + async def server_info(self, ctx: Context) -> None: + """Returns an embed full of server information.""" + created = time_since(ctx.guild.created_at, precision="days") + features = ", ".join(ctx.guild.features) + region = ctx.guild.region + + roles = len(ctx.guild.roles) + member_count = ctx.guild.member_count + channel_counts = self.get_channel_type_counts(ctx.guild) + + # How many of each user status? + statuses = Counter(member.status for member in ctx.guild.members) + embed = Embed(colour=Colour.blurple()) + + # How many staff members and staff channels do we have? + staff_member_count = len(ctx.guild.get_role(constants.Roles.helpers).members) + staff_channel_count = self.get_staff_channel_count(ctx.guild) + + # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the + # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting + # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts + # after the dedent is made. + embed.description = Template( + textwrap.dedent(f""" + **Server information** + Created: {created} + Voice region: {region} + Features: {features} + + **Channel counts** + $channel_counts + Staff channels: {staff_channel_count} + + **Member counts** + Members: {member_count:,} + Staff members: {staff_member_count} + Roles: {roles} + + **Member statuses** + {constants.Emojis.status_online} {statuses[Status.online]:,} + {constants.Emojis.status_idle} {statuses[Status.idle]:,} + {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} + {constants.Emojis.status_offline} {statuses[Status.offline]:,} + """) + ).substitute({"channel_counts": channel_counts}) + embed.set_thumbnail(url=ctx.guild.icon_url) + + await ctx.send(embed=embed) + + @command(name="user", aliases=["user_info", "member", "member_info"]) + async def user_info(self, ctx: Context, user: Member = None) -> None: + """Returns info about a user.""" + if user is None: + user = ctx.author + + # Do a role check if this is being executed on someone other than the caller + elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): + await ctx.send("You may not use this command on users other than yourself.") + return + + # Non-staff may only do this in #bot-commands + if not with_role_check(ctx, *constants.STAFF_ROLES): + if not ctx.channel.id == constants.Channels.bot_commands: + raise InWhitelistCheckFailure(constants.Channels.bot_commands) + + embed = await self.create_user_embed(ctx, user) + + await ctx.send(embed=embed) + + async def create_user_embed(self, ctx: Context, user: Member) -> Embed: + """Creates an embed containing information on the `user`.""" + created = time_since(user.created_at, max_units=3) + + # Custom status + custom_status = '' + for activity in user.activities: + # Check activity.state for None value if user has a custom status set + # This guards against a custom status with an emoji but no text, which will cause + # escape_markdown to raise an exception + # This can be reworked after a move to d.py 1.3.0+, which adds a CustomActivity class + if activity.name == 'Custom Status' and activity.state: + state = escape_markdown(activity.state) + custom_status = f'Status: {state}\n' + + name = str(user) + if user.nick: + name = f"{user.nick} ({name})" + + joined = time_since(user.joined_at, max_units=3) + roles = ", ".join(role.mention for role in user.roles[1:]) + + description = [ + textwrap.dedent(f""" + **User Information** + Created: {created} + Profile: {user.mention} + ID: {user.id} + {custom_status} + **Member Information** + Joined: {joined} + Roles: {roles or None} + """).strip() + ] + + # Show more verbose output in moderation channels for infractions and nominations + if ctx.channel.id in constants.MODERATION_CHANNELS: + description.append(await self.expanded_user_infraction_counts(user)) + description.append(await self.user_nomination_counts(user)) + else: + description.append(await self.basic_user_infraction_counts(user)) + + # Let's build the embed now + embed = Embed( + title=name, + description="\n\n".join(description) + ) + + embed.set_thumbnail(url=user.avatar_url_as(static_format="png")) + embed.colour = user.top_role.colour if roles else Colour.blurple() + + return embed + + async def basic_user_infraction_counts(self, member: Member) -> str: + """Gets the total and active infraction counts for the given `member`.""" + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'hidden': 'False', + 'user__id': str(member.id) + } + ) + + total_infractions = len(infractions) + active_infractions = sum(infraction['active'] for infraction in infractions) + + infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}" + + return infraction_output + + async def expanded_user_infraction_counts(self, member: Member) -> str: + """ + Gets expanded infraction counts for the given `member`. + + The counts will be split by infraction type and the number of active infractions for each type will indicated + in the output as well. + """ + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'user__id': str(member.id) + } + ) + + infraction_output = ["**Infractions**"] + if not infractions: + infraction_output.append("This user has never received an infraction.") + else: + # Count infractions split by `type` and `active` status for this user + infraction_types = set() + infraction_counter = defaultdict(int) + for infraction in infractions: + infraction_type = infraction["type"] + infraction_active = 'active' if infraction["active"] else 'inactive' + + infraction_types.add(infraction_type) + infraction_counter[f"{infraction_active} {infraction_type}"] += 1 + + # Format the output of the infraction counts + for infraction_type in sorted(infraction_types): + active_count = infraction_counter[f"active {infraction_type}"] + total_count = active_count + infraction_counter[f"inactive {infraction_type}"] + + line = f"{infraction_type.capitalize()}s: {total_count}" + if active_count: + line += f" ({active_count} active)" + + infraction_output.append(line) + + return "\n".join(infraction_output) + + async def user_nomination_counts(self, member: Member) -> str: + """Gets the active and historical nomination counts for the given `member`.""" + nominations = await self.bot.api_client.get( + 'bot/nominations', + params={ + 'user__id': str(member.id) + } + ) + + output = ["**Nominations**"] + + if not nominations: + output.append("This user has never been nominated.") + else: + count = len(nominations) + is_currently_nominated = any(nomination["active"] for nomination in nominations) + nomination_noun = "nomination" if count == 1 else "nominations" + + if is_currently_nominated: + output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).") + else: + output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.") + + return "\n".join(output) + + def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str: + """Format a mapping to be readable to a human.""" + # sorting is technically superfluous but nice if you want to look for a specific field + fields = sorted(mapping.items(), key=lambda item: item[0]) + + if field_width is None: + field_width = len(max(mapping.keys(), key=len)) + + out = '' + + for key, val in fields: + if isinstance(val, dict): + # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries + inner_width = int(field_width * 1.6) + val = '\n' + self.format_fields(val, field_width=inner_width) + + elif isinstance(val, str): + # split up text since it might be long + text = textwrap.fill(val, width=100, replace_whitespace=False) + + # indent it, I guess you could do this with `wrap` and `join` but this is nicer + val = textwrap.indent(text, ' ' * (field_width + len(': '))) + + # the first line is already indented so we `str.lstrip` it + val = val.lstrip() + + if key == 'color': + # makes the base 10 representation of a hex number readable to humans + val = hex(val) + + out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width) + + # remove trailing whitespace + return out.rstrip() + + @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) + @group(invoke_without_command=True) + @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) + async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: + """Shows information about the raw API response.""" + # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling + # doing this extra request is also much easier than trying to convert everything back into a dictionary again + raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) + + paginator = Paginator() + + def add_content(title: str, content: str) -> None: + paginator.add_line(f'== {title} ==\n') + # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution. + # we hope it's not close to 2000 + paginator.add_line(content.replace('```', '`` `')) + paginator.close_page() + + if message.content: + add_content('Raw message', message.content) + + transformer = pprint.pformat if json else self.format_fields + for field_name in ('embeds', 'attachments'): + data = raw_data[field_name] + + if not data: + continue + + total = len(data) + for current, item in enumerate(data, start=1): + title = f'Raw {field_name} ({current}/{total})' + add_content(title, transformer(item)) + + for page in paginator.pages: + await ctx.send(page) + + @raw.command() + async def json(self, ctx: Context, message: Message) -> None: + """Shows information about the raw API response in a copy-pasteable Python format.""" + await ctx.invoke(self.raw, message=message, json=True) + + +def setup(bot: Bot) -> None: + """Load the Information cog.""" + bot.add_cog(Information(bot)) diff --git a/bot/exts/info/python_news.py b/bot/exts/info/python_news.py new file mode 100644 index 000000000..0ab5738a4 --- /dev/null +++ b/bot/exts/info/python_news.py @@ -0,0 +1,232 @@ +import logging +import typing as t +from datetime import date, datetime + +import discord +import feedparser +from bs4 import BeautifulSoup +from discord.ext.commands import Cog +from discord.ext.tasks import loop + +from bot import constants +from bot.bot import Bot +from bot.utils.webhooks import send_webhook + +PEPS_RSS_URL = "https://www.python.org/dev/peps/peps.rss/" + +RECENT_THREADS_TEMPLATE = "https://mail.python.org/archives/list/{name}@python.org/recent-threads" +THREAD_TEMPLATE_URL = "https://mail.python.org/archives/api/list/{name}@python.org/thread/{id}/" +MAILMAN_PROFILE_URL = "https://mail.python.org/archives/users/{id}/" +THREAD_URL = "https://mail.python.org/archives/list/{list}@python.org/thread/{id}/" + +AVATAR_URL = "https://www.python.org/static/opengraph-icon-200x200.png" + +log = logging.getLogger(__name__) + + +class PythonNews(Cog): + """Post new PEPs and Python News to `#python-news`.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_names = {} + self.webhook: t.Optional[discord.Webhook] = None + + self.bot.loop.create_task(self.get_webhook_names()) + self.bot.loop.create_task(self.get_webhook_and_channel()) + + async def start_tasks(self) -> None: + """Start the tasks for fetching new PEPs and mailing list messages.""" + self.fetch_new_media.start() + + @loop(minutes=20) + async def fetch_new_media(self) -> None: + """Fetch new mailing list messages and then new PEPs.""" + await self.post_maillist_news() + await self.post_pep_news() + + async def sync_maillists(self) -> None: + """Sync currently in-use maillists with API.""" + # Wait until guild is available to avoid running before everything is ready + await self.bot.wait_until_guild_available() + + response = await self.bot.api_client.get("bot/bot-settings/news") + for mail in constants.PythonNews.mail_lists: + if mail not in response["data"]: + response["data"][mail] = [] + + # Because we are handling PEPs differently, we don't include it to mail lists + if "pep" not in response["data"]: + response["data"]["pep"] = [] + + await self.bot.api_client.put("bot/bot-settings/news", json=response) + + async def get_webhook_names(self) -> None: + """Get webhook author names from maillist API.""" + await self.bot.wait_until_guild_available() + + async with self.bot.http_session.get("https://mail.python.org/archives/api/lists") as resp: + lists = await resp.json() + + for mail in lists: + if mail["name"].split("@")[0] in constants.PythonNews.mail_lists: + self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] + + async def post_pep_news(self) -> None: + """Fetch new PEPs and when they don't have announcement in #python-news, create it.""" + # Wait until everything is ready and http_session available + await self.bot.wait_until_guild_available() + await self.sync_maillists() + + async with self.bot.http_session.get(PEPS_RSS_URL) as resp: + data = feedparser.parse(await resp.text("utf-8")) + + news_listing = await self.bot.api_client.get("bot/bot-settings/news") + payload = news_listing.copy() + pep_numbers = news_listing["data"]["pep"] + + # Reverse entries to send oldest first + data["entries"].reverse() + for new in data["entries"]: + try: + new_datetime = datetime.strptime(new["published"], "%a, %d %b %Y %X %Z") + except ValueError: + log.warning(f"Wrong datetime format passed in PEP new: {new['published']}") + continue + pep_nr = new["title"].split(":")[0].split()[1] + if ( + pep_nr in pep_numbers + or new_datetime.date() < date.today() + ): + continue + + # Build an embed and send a webhook + embed = discord.Embed( + title=new["title"], + description=new["summary"], + timestamp=new_datetime, + url=new["link"], + colour=constants.Colours.soft_green + ) + embed.set_footer(text=data["feed"]["title"], icon_url=AVATAR_URL) + msg = await send_webhook( + webhook=self.webhook, + username=data["feed"]["title"], + embed=embed, + avatar_url=AVATAR_URL, + wait=True, + ) + payload["data"]["pep"].append(pep_nr) + + # Increase overall PEP new stat + self.bot.stats.incr("python_news.posted.pep") + + if msg.channel.is_news(): + log.trace("Publishing PEP annnouncement because it was in a news channel") + await msg.publish() + + # Apply new sent news to DB to avoid duplicate sending + await self.bot.api_client.put("bot/bot-settings/news", json=payload) + + async def post_maillist_news(self) -> None: + """Send new maillist threads to #python-news that is listed in configuration.""" + await self.bot.wait_until_guild_available() + await self.sync_maillists() + existing_news = await self.bot.api_client.get("bot/bot-settings/news") + payload = existing_news.copy() + + for maillist in constants.PythonNews.mail_lists: + async with self.bot.http_session.get(RECENT_THREADS_TEMPLATE.format(name=maillist)) as resp: + recents = BeautifulSoup(await resp.text(), features="lxml") + + # When a

element is present in the response then the mailing list + # has not had any activity during the current month, so therefore it + # can be ignored. + if recents.p: + continue + + for thread in recents.html.body.div.find_all("a", href=True): + # We want only these threads that have identifiers + if "latest" in thread["href"]: + continue + + thread_information, email_information = await self.get_thread_and_first_mail( + maillist, thread["href"].split("/")[-2] + ) + + try: + new_date = datetime.strptime(email_information["date"], "%Y-%m-%dT%X%z") + except ValueError: + log.warning(f"Invalid datetime from Thread email: {email_information['date']}") + continue + + if ( + thread_information["thread_id"] in existing_news["data"][maillist] + or 'Re: ' in thread_information["subject"] + or new_date.date() < date.today() + ): + continue + + content = email_information["content"] + link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) + + # Build an embed and send a message to the webhook + embed = discord.Embed( + title=thread_information["subject"], + description=content[:500] + f"... [continue reading]({link})" if len(content) > 500 else content, + timestamp=new_date, + url=link, + colour=constants.Colours.soft_green + ) + embed.set_author( + name=f"{email_information['sender_name']} ({email_information['sender']['address']})", + url=MAILMAN_PROFILE_URL.format(id=email_information["sender"]["mailman_id"]), + ) + embed.set_footer( + text=f"Posted to {self.webhook_names[maillist]}", + icon_url=AVATAR_URL, + ) + msg = await send_webhook( + webhook=self.webhook, + username=self.webhook_names[maillist], + embed=embed, + avatar_url=AVATAR_URL, + wait=True, + ) + payload["data"][maillist].append(thread_information["thread_id"]) + + # Increase this specific maillist counter in stats + self.bot.stats.incr(f"python_news.posted.{maillist.replace('-', '_')}") + + if msg.channel.is_news(): + log.trace("Publishing mailing list message because it was in a news channel") + await msg.publish() + + await self.bot.api_client.put("bot/bot-settings/news", json=payload) + + async def get_thread_and_first_mail(self, maillist: str, thread_identifier: str) -> t.Tuple[t.Any, t.Any]: + """Get mail thread and first mail from mail.python.org based on `maillist` and `thread_identifier`.""" + async with self.bot.http_session.get( + THREAD_TEMPLATE_URL.format(name=maillist, id=thread_identifier) + ) as resp: + thread_information = await resp.json() + + async with self.bot.http_session.get(thread_information["starting_email"]) as resp: + email_information = await resp.json() + return thread_information, email_information + + async def get_webhook_and_channel(self) -> None: + """Storage #python-news channel Webhook and `TextChannel` to `News.webhook` and `channel`.""" + await self.bot.wait_until_guild_available() + self.webhook = await self.bot.fetch_webhook(constants.PythonNews.webhook) + + await self.start_tasks() + + def cog_unload(self) -> None: + """Stop news posting tasks on cog unload.""" + self.fetch_new_media.cancel() + + +def setup(bot: Bot) -> None: + """Add `News` cog.""" + bot.add_cog(PythonNews(bot)) diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py new file mode 100644 index 000000000..d853ab2ea --- /dev/null +++ b/bot/exts/info/reddit.py @@ -0,0 +1,304 @@ +import asyncio +import logging +import random +import textwrap +from collections import namedtuple +from datetime import datetime, timedelta +from typing import List + +from aiohttp import BasicAuth, ClientError +from discord import Colour, Embed, TextChannel +from discord.ext.commands import Cog, Context, group +from discord.ext.tasks import loop + +from bot.bot import Bot +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks +from bot.converters import Subreddit +from bot.decorators import with_role +from bot.pagination import LinePaginator +from bot.utils.messages import sub_clyde + +log = logging.getLogger(__name__) + +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + + +class Reddit(Cog): + """Track subreddit posts and show detailed statistics about them.""" + + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} + URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 + + def __init__(self, bot: Bot): + self.bot = bot + + self.webhook = None + self.access_token = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + bot.loop.create_task(self.init_reddit_ready()) + self.auto_poster_loop.start() + + def cog_unload(self) -> None: + """Stop the loop task and revoke the access token when the cog is unloaded.""" + self.auto_poster_loop.cancel() + if self.access_token and self.access_token.expires_at > datetime.utcnow(): + asyncio.create_task(self.revoke_access_token()) + + async def init_reddit_ready(self) -> None: + """Sets the reddit webhook when the cog is loaded.""" + await self.bot.wait_until_guild_available() + if not self.webhook: + self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) + + @property + def channel(self) -> TextChannel: + """Get the #reddit channel object from the bot's cache.""" + return self.bot.get_channel(Channels.reddit) + + async def get_access_token(self) -> None: + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded and a ClientError raised if retrieval was still unsuccessful. + """ + for i in range(1, self.MAX_RETRIES + 1): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(seconds=expiration) + ) + + log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") + return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) + + await asyncio.sleep(3) + + self.bot.remove_cog(self.qualified_name) + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + + async def revoke_access_token(self) -> None: + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } + ) + + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + else: + log.warning(f"Unable to revoke access token: status {response.status}.") + + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: + """A helper method to fetch a certain amount of Reddit posts at a given route.""" + # Reddit's JSON responses only provide 25 posts at most. + if not 25 >= amount > 0: + raise ValueError("Invalid amount of subreddit posts requested.") + + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): + response = await self.bot.http_session.get( + url=url, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, + params=params + ) + if response.status == 200 and response.content_type == 'application/json': + # Got appropriate response - process and return. + content = await response.json() + posts = content["data"]["children"] + return posts[:amount] + + await asyncio.sleep(3) + + log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") + return list() # Failed to get appropriate response within allowed number of retries. + + async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: + """ + Get the top amount of posts for a given subreddit within a specified timeframe. + + A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top + weekly posts. + + The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. + """ + embed = Embed(description="") + + posts = await self.fetch_posts( + route=f"{subreddit}/top", + amount=amount, + params={"t": time} + ) + + if not posts: + embed.title = random.choice(ERROR_REPLIES) + embed.colour = Colour.red() + embed.description = ( + "Sorry! We couldn't find any posts from that subreddit. " + "If this problem persists, please let us know." + ) + + return embed + + for post in posts: + data = post["data"] + + text = data["selftext"] + if text: + text = textwrap.shorten(text, width=128, placeholder="...") + text += "\n" # Add newline to separate embed info + + ups = data["ups"] + comments = data["num_comments"] + author = data["author"] + + title = textwrap.shorten(data["title"], width=64, placeholder="...") + link = self.URL + data["permalink"] + + embed.description += ( + f"**[{title}]({link})**\n" + f"{text}" + f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" + ) + + embed.colour = Colour.blurple() + return embed + + @loop() + async def auto_poster_loop(self) -> None: + """Post the top 5 posts daily, and the top 5 posts weekly.""" + # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter + now = datetime.utcnow() + tomorrow = now + timedelta(days=1) + midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) + seconds_until = (midnight_tomorrow - now).total_seconds() + + await asyncio.sleep(seconds_until) + + await self.bot.wait_until_guild_available() + if not self.webhook: + await self.bot.fetch_webhook(Webhooks.reddit) + + if datetime.utcnow().weekday() == 0: + await self.top_weekly_posts() + # if it's a monday send the top weekly posts + + for subreddit in RedditConfig.subreddits: + top_posts = await self.get_top_posts(subreddit=subreddit, time="day") + username = sub_clyde(f"{subreddit} Top Daily Posts") + message = await self.webhook.send(username=username, embed=top_posts, wait=True) + + if message.channel.is_news(): + await message.publish() + + async def top_weekly_posts(self) -> None: + """Post a summary of the top posts.""" + for subreddit in RedditConfig.subreddits: + # Send and pin the new weekly posts. + top_posts = await self.get_top_posts(subreddit=subreddit, time="week") + username = sub_clyde(f"{subreddit} Top Weekly Posts") + message = await self.webhook.send(wait=True, username=username, embed=top_posts) + + if subreddit.lower() == "r/python": + if not self.channel: + log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") + return + + # Remove the oldest pins so that only 12 remain at most. + pins = await self.channel.pins() + + while len(pins) >= 12: + await pins[-1].unpin() + del pins[-1] + + await message.pin() + + if message.channel.is_news(): + await message.publish() + + @group(name="reddit", invoke_without_command=True) + async def reddit_group(self, ctx: Context) -> None: + """View the top posts from various subreddits.""" + await ctx.send_help(ctx.command) + + @reddit_group.command(name="top") + async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of all time from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="all") + + await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) + + @reddit_group.command(name="daily") + async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of today from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="day") + + await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) + + @reddit_group.command(name="weekly") + async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of this week from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="week") + + await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) + + @with_role(*STAFF_ROLES) + @reddit_group.command(name="subreddits", aliases=("subs",)) + async def subreddits_command(self, ctx: Context) -> None: + """Send a paginated embed of all the subreddits we're relaying.""" + embed = Embed() + embed.title = "Relayed subreddits." + embed.colour = Colour.blurple() + + await LinePaginator.paginate( + RedditConfig.subreddits, + ctx, embed, + footer_text="Use the reddit commands along with these to view their posts.", + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Reddit cog.""" + if not RedditConfig.secret or not RedditConfig.client_id: + log.error("Credentials not provided, cog not loaded.") + return + bot.add_cog(Reddit(bot)) diff --git a/bot/exts/info/site.py b/bot/exts/info/site.py new file mode 100644 index 000000000..ac29daa1d --- /dev/null +++ b/bot/exts/info/site.py @@ -0,0 +1,146 @@ +import logging + +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import URLs +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + +PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" + + +class Site(Cog): + """Commands for linking to different parts of the site.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="site", aliases=("s",), invoke_without_command=True) + async def site_group(self, ctx: Context) -> None: + """Commands for getting info about our website.""" + await ctx.send_help(ctx.command) + + @site_group.command(name="home", aliases=("about",)) + async def site_main(self, ctx: Context) -> None: + """Info about the website itself.""" + url = f"{URLs.site_schema}{URLs.site}/" + + embed = Embed(title="Python Discord website") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + f"[Our official website]({url}) is an open-source community project " + "created with Python and Django. It contains information about the server " + "itself, lets you sign up for upcoming events, has its own wiki, contains " + "a list of valuable learning resources, and much more." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="resources") + async def site_resources(self, ctx: Context) -> None: + """Info about the site's Resources page.""" + learning_url = f"{PAGES_URL}/resources" + + embed = Embed(title="Resources") + embed.set_footer(text=f"{learning_url}") + embed.colour = Colour.blurple() + embed.description = ( + f"The [Resources page]({learning_url}) on our website contains a " + "list of hand-selected learning resources that we regularly recommend " + f"to both beginners and experts." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="tools") + async def site_tools(self, ctx: Context) -> None: + """Info about the site's Tools page.""" + tools_url = f"{PAGES_URL}/resources/tools" + + embed = Embed(title="Tools") + embed.set_footer(text=f"{tools_url}") + embed.colour = Colour.blurple() + embed.description = ( + f"The [Tools page]({tools_url}) on our website contains a " + f"couple of the most popular tools for programming in Python." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="help") + async def site_help(self, ctx: Context) -> None: + """Info about the site's Getting Help page.""" + url = f"{PAGES_URL}/resources/guides/asking-good-questions" + + embed = Embed(title="Asking Good Questions") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + "Asking the right question about something that's new to you can sometimes be tricky. " + f"To help with this, we've created a [guide to asking good questions]({url}) on our website. " + "It contains everything you need to get the very best help from our community." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="faq") + async def site_faq(self, ctx: Context) -> None: + """Info about the site's FAQ page.""" + url = f"{PAGES_URL}/frequently-asked-questions" + + embed = Embed(title="FAQ") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + "As the largest Python community on Discord, we get hundreds of questions every day. " + "Many of these questions have been asked before. We've compiled a list of the most " + "frequently asked questions along with their answers, which can be found on " + f"our [FAQ page]({url})." + ) + + await ctx.send(embed=embed) + + @site_group.command(aliases=['r', 'rule'], name='rules') + async def site_rules(self, ctx: Context, *rules: int) -> None: + """Provides a link to all rules or, if specified, displays specific rule(s).""" + rules_embed = Embed(title='Rules', color=Colour.blurple()) + rules_embed.url = f"{PAGES_URL}/rules" + + if not rules: + # Rules were not submitted. Return the default description. + rules_embed.description = ( + "The rules and guidelines that apply to this community can be found on" + f" our [rules page]({PAGES_URL}/rules). We expect" + " all members of the community to have read and understood these." + ) + + await ctx.send(embed=rules_embed) + return + + full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) + invalid_indices = tuple( + pick + for pick in rules + if pick < 1 or pick > len(full_rules) + ) + + if invalid_indices: + indices = ', '.join(map(str, invalid_indices)) + await ctx.send(f":x: Invalid rule indices: {indices}") + return + + for rule in rules: + self.bot.stats.incr(f"rule_uses.{rule}") + + final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) + + await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) + + +def setup(bot: Bot) -> None: + """Load the Site cog.""" + bot.add_cog(Site(bot)) diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py new file mode 100644 index 000000000..205e0ba81 --- /dev/null +++ b/bot/exts/info/source.py @@ -0,0 +1,141 @@ +import inspect +from pathlib import Path +from typing import Optional, Tuple, Union + +from discord import Embed +from discord.ext import commands + +from bot.bot import Bot +from bot.constants import URLs + +SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] + + +class SourceConverter(commands.Converter): + """Convert an argument into a help command, tag, command, or cog.""" + + async def convert(self, ctx: commands.Context, argument: str) -> SourceType: + """Convert argument into source object.""" + if argument.lower().startswith("help"): + return ctx.bot.help_command + + cog = ctx.bot.get_cog(argument) + if cog: + return cog + + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd + + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + elif argument.lower() in tags_cog._cache: + return argument.lower() + + raise commands.BadArgument( + f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog." + ) + + +class BotSource(commands.Cog): + """Displays information about the bot's source code.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @commands.command(name="source", aliases=("src",)) + async def source_command(self, ctx: commands.Context, *, source_item: SourceConverter = None) -> None: + """Display information and a GitHub link to the source code of a command, tag, or cog.""" + if not source_item: + embed = Embed(title="Bot's GitHub Repository") + embed.add_field(name="Repository", value=f"[Go to GitHub]({URLs.github_bot_repo})") + embed.set_thumbnail(url="https://avatars1.githubusercontent.com/u/9919") + await ctx.send(embed=embed) + return + + embed = await self.build_embed(source_item) + await ctx.send(embed=embed) + + def get_source_link(self, source_item: SourceType) -> Tuple[str, str, Optional[int]]: + """ + Build GitHub link of source item, return this link, file location and first line number. + + Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). + """ + if isinstance(source_item, commands.Command): + if source_item.cog_name == "Alias": + cmd_name = source_item.callback.__name__.replace("_alias", "") + cmd = self.bot.get_command(cmd_name.replace("_", " ")) + src = cmd.callback.__code__ + filename = src.co_filename + else: + src = source_item.callback.__code__ + filename = src.co_filename + elif isinstance(source_item, str): + tags_cog = self.bot.get_cog("Tags") + filename = tags_cog._cache[source_item]["location"] + else: + src = type(source_item) + try: + filename = inspect.getsourcefile(src) + except TypeError: + raise commands.BadArgument("Cannot get source for a dynamically-created object.") + + if not isinstance(source_item, str): + try: + lines, first_line_no = inspect.getsourcelines(src) + except OSError: + raise commands.BadArgument("Cannot get source for a dynamically-created object.") + + lines_extension = f"#L{first_line_no}-L{first_line_no+len(lines)-1}" + else: + first_line_no = None + lines_extension = "" + + # Handle tag file location differently than others to avoid errors in some cases + if not first_line_no: + file_location = Path(filename).relative_to("/bot/") + else: + file_location = Path(filename).relative_to(Path.cwd()).as_posix() + + url = f"{URLs.github_bot_repo}/blob/master/{file_location}{lines_extension}" + + return url, file_location, first_line_no or None + + async def build_embed(self, source_object: SourceType) -> Optional[Embed]: + """Build embed based on source object.""" + url, location, first_line = self.get_source_link(source_object) + + if isinstance(source_object, commands.HelpCommand): + title = "Help Command" + description = source_object.__doc__.splitlines()[1] + elif isinstance(source_object, commands.Command): + if source_object.cog_name == "Alias": + cmd_name = source_object.callback.__name__.replace("_alias", "") + cmd = self.bot.get_command(cmd_name.replace("_", " ")) + description = cmd.short_doc + else: + description = source_object.short_doc + + title = f"Command: {source_object.qualified_name}" + elif isinstance(source_object, str): + title = f"Tag: {source_object}" + description = "" + else: + title = f"Cog: {source_object.qualified_name}" + description = source_object.description.splitlines()[0] + + embed = Embed(title=title, description=description) + embed.add_field(name="Source Code", value=f"[Go to GitHub]({url})") + line_text = f":{first_line}" if first_line else "" + embed.set_footer(text=f"{location}{line_text}") + + return embed + + +def setup(bot: Bot) -> None: + """Load the BotSource cog.""" + bot.add_cog(BotSource(bot)) diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py new file mode 100644 index 000000000..d42f55466 --- /dev/null +++ b/bot/exts/info/stats.py @@ -0,0 +1,129 @@ +import string +from datetime import datetime + +from discord import Member, Message, Status +from discord.ext.commands import Cog, Context +from discord.ext.tasks import loop + +from bot.bot import Bot +from bot.constants import Categories, Channels, Guild, Stats as StatConf + + +CHANNEL_NAME_OVERRIDES = { + Channels.off_topic_0: "off_topic_0", + Channels.off_topic_1: "off_topic_1", + Channels.off_topic_2: "off_topic_2", + Channels.staff_lounge: "staff_lounge" +} + +ALLOWED_CHARS = string.ascii_letters + string.digits + "_" + + +class Stats(Cog): + """A cog which provides a way to hook onto Discord events and forward to stats.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.last_presence_update = None + self.update_guild_boost.start() + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Report message events in the server to statsd.""" + if message.guild is None: + return + + if message.guild.id != Guild.id: + return + + cat = getattr(message.channel, "category", None) + if cat is not None and cat.id == Categories.modmail: + if message.channel.id != Channels.incidents: + # Do not report modmail channels to stats, there are too many + # of them for interesting statistics to be drawn out of this. + return + + reformatted_name = message.channel.name.replace('-', '_') + + if CHANNEL_NAME_OVERRIDES.get(message.channel.id): + reformatted_name = CHANNEL_NAME_OVERRIDES.get(message.channel.id) + + reformatted_name = "".join(char for char in reformatted_name if char in ALLOWED_CHARS) + + stat_name = f"channels.{reformatted_name}" + self.bot.stats.incr(stat_name) + + # Increment the total message count + self.bot.stats.incr("messages") + + @Cog.listener() + async def on_command_completion(self, ctx: Context) -> None: + """Report completed commands to statsd.""" + command_name = ctx.command.qualified_name.replace(" ", "_") + + self.bot.stats.incr(f"commands.{command_name}") + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Update member count stat on member join.""" + if member.guild.id != Guild.id: + return + + self.bot.stats.gauge("guild.total_members", len(member.guild.members)) + + @Cog.listener() + async def on_member_leave(self, member: Member) -> None: + """Update member count stat on member leave.""" + if member.guild.id != Guild.id: + return + + self.bot.stats.gauge("guild.total_members", len(member.guild.members)) + + @Cog.listener() + async def on_member_update(self, _before: Member, after: Member) -> None: + """Update presence estimates on member update.""" + if after.guild.id != Guild.id: + return + + if self.last_presence_update: + if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: + return + + self.last_presence_update = datetime.now() + + online = 0 + idle = 0 + dnd = 0 + offline = 0 + + for member in after.guild.members: + if member.status is Status.online: + online += 1 + elif member.status is Status.dnd: + dnd += 1 + elif member.status is Status.idle: + idle += 1 + elif member.status is Status.offline: + offline += 1 + + self.bot.stats.gauge("guild.status.online", online) + self.bot.stats.gauge("guild.status.idle", idle) + self.bot.stats.gauge("guild.status.do_not_disturb", dnd) + self.bot.stats.gauge("guild.status.offline", offline) + + @loop(hours=1) + async def update_guild_boost(self) -> None: + """Post the server boost level and tier every hour.""" + await self.bot.wait_until_guild_available() + g = self.bot.get_guild(Guild.id) + self.bot.stats.gauge("boost.amount", g.premium_subscription_count) + self.bot.stats.gauge("boost.tier", g.premium_tier) + + def cog_unload(self) -> None: + """Stop the boost statistic task on unload of the Cog.""" + self.update_guild_boost.stop() + + +def setup(bot: Bot) -> None: + """Load the stats cog.""" + bot.add_cog(Stats(bot)) diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py new file mode 100644 index 000000000..3d76c5c08 --- /dev/null +++ b/bot/exts/info/tags.py @@ -0,0 +1,277 @@ +import logging +import re +import time +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot import constants +from bot.bot import Bot +from bot.converters import TagNameConverter +from bot.pagination import LinePaginator +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +TEST_CHANNELS = ( + constants.Channels.bot_commands, + constants.Channels.helpers +) + +REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) +FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." + + +class Tags(Cog): + """Save new tags and fetch existing tags.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.tag_cooldowns = {} + self._cache = self.get_tags() + + @staticmethod + def get_tags() -> dict: + """Get all tags.""" + cache = {} + + base_path = Path("bot", "resources", "tags") + for file in base_path.glob("**/*"): + if file.is_file(): + tag_title = file.stem + tag = { + "title": tag_title, + "embed": { + "description": file.read_text(encoding="utf8"), + }, + "restricted_to": "developers", + "location": f"/bot/{file}" + } + + # Convert to a list to allow negative indexing. + parents = list(file.relative_to(base_path).parents) + if len(parents) > 1: + # -1 would be '.' hence -2 is used as the index. + tag["restricted_to"] = parents[-2].name + + cache[tag_title] = tag + + return cache + + @staticmethod + def check_accessibility(user: Member, tag: dict) -> bool: + """Check if user can access a tag.""" + return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] + + @staticmethod + def _fuzzy_search(search: str, target: str) -> float: + """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" + current, index = 0, 0 + _search = REGEX_NON_ALPHABET.sub('', search.lower()) + _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) + _target = next(_targets) + try: + while True: + while index < len(_target) and _search[current] == _target[index]: + current += 1 + index += 1 + index, _target = 0, next(_targets) + except (StopIteration, IndexError): + pass + return current / len(_search) * 100 + + def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: + """Return a list of suggested tags.""" + scores: Dict[str, int] = { + tag_title: Tags._fuzzy_search(tag_name, tag['title']) + for tag_title, tag in self._cache.items() + } + + thresholds = thresholds or [100, 90, 80, 70, 60] + + for threshold in thresholds: + suggestions = [ + self._cache[tag_title] + for tag_title, matching_score in scores.items() + if matching_score >= threshold + ] + if suggestions: + return suggestions + + return [] + + def _get_tag(self, tag_name: str) -> list: + """Get a specific tag.""" + found = [self._cache.get(tag_name.lower(), None)] + if not found[0]: + return self._get_suggestions(tag_name) + return found + + def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: + """ + Search for tags via contents. + + `predicate` will be the built-in any, all, or a custom callable. Must return a bool. + """ + keywords_processed: List[str] = [] + for keyword in keywords.split(','): + keyword_sanitized = keyword.strip().casefold() + if not keyword_sanitized: + # this happens when there are leading / trailing / consecutive comma. + continue + keywords_processed.append(keyword_sanitized) + + if not keywords_processed: + # after sanitizing, we can end up with an empty list, for example when keywords is ',' + # in that case, we simply want to search for such keywords directly instead. + keywords_processed = [keywords] + + matching_tags = [] + for tag in self._cache.values(): + matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) + if self.check_accessibility(user, tag) and check(matches): + matching_tags.append(tag) + + return matching_tags + + async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: + """Send the result of matching tags to user.""" + if not matching_tags: + pass + elif len(matching_tags) == 1: + await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) + else: + is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 + embed = Embed( + title=f"Here are the tags containing the given keyword{'s' * is_plural}:", + description='\n'.join(tag['title'] for tag in matching_tags[:10]) + ) + await LinePaginator.paginate( + sorted(f"**»** {tag['title']}" for tag in matching_tags), + ctx, + embed, + footer_text=FOOTER_TEXT, + empty=False, + max_lines=15 + ) + + @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) + async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + """Show all known tags, a single tag, or run a subcommand.""" + await ctx.invoke(self.get_command, tag_name=tag_name) + + @tags_group.group(name='search', invoke_without_command=True) + async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Only search for tags that has ALL the keywords. + """ + matching_tags = self._get_tags_via_content(all, keywords, ctx.author) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @search_tag_content.command(name='any') + async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Search for tags that has ANY of the keywords. + """ + matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @tags_group.command(name='get', aliases=('show', 'g')) + async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + """Get a specified tag, or a list of all tags if no tag is specified.""" + + def _command_on_cooldown(tag_name: str) -> bool: + """ + Check if the command is currently on cooldown, on a per-tag, per-channel basis. + + The cooldown duration is set in constants.py. + """ + now = time.time() + + cooldown_conditions = ( + tag_name + and tag_name in self.tag_cooldowns + and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags + and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id + ) + + if cooldown_conditions: + return True + return False + + if _command_on_cooldown(tag_name): + time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] + time_left = constants.Cooldowns.tags - time_elapsed + log.info( + f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " + f"Cooldown ends in {time_left:.1f} seconds." + ) + return + + if tag_name is not None: + temp_founds = self._get_tag(tag_name) + + founds = [] + + for found_tag in temp_founds: + if self.check_accessibility(ctx.author, found_tag): + founds.append(found_tag) + + if len(founds) == 1: + tag = founds[0] + if ctx.channel.id not in TEST_CHANNELS: + self.tag_cooldowns[tag_name] = { + "time": time.time(), + "channel": ctx.channel.id + } + + self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") + + await wait_for_deletion( + await ctx.send(embed=Embed.from_dict(tag['embed'])), + [ctx.author.id], + client=self.bot + ) + elif founds and len(tag_name) >= 3: + await wait_for_deletion( + await ctx.send( + embed=Embed( + title='Did you mean ...', + description='\n'.join(tag['title'] for tag in founds[:10]) + ) + ), + [ctx.author.id], + client=self.bot + ) + + else: + tags = self._cache.values() + if not tags: + await ctx.send(embed=Embed( + description="**There are no tags in the database!**", + colour=Colour.red() + )) + else: + embed: Embed = Embed(title="**Current tags**") + await LinePaginator.paginate( + sorted( + f"**»** {tag['title']}" for tag in tags + if self.check_accessibility(ctx.author, tag) + ), + ctx, + embed, + footer_text=FOOTER_TEXT, + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Tags cog.""" + bot.add_cog(Tags(bot)) diff --git a/bot/exts/info/wolfram.py b/bot/exts/info/wolfram.py new file mode 100644 index 000000000..e6cae3bb8 --- /dev/null +++ b/bot/exts/info/wolfram.py @@ -0,0 +1,280 @@ +import logging +from io import BytesIO +from typing import Callable, List, Optional, Tuple +from urllib import parse + +import discord +from dateutil.relativedelta import relativedelta +from discord import Embed +from discord.ext import commands +from discord.ext.commands import BucketType, Cog, Context, check, group + +from bot.bot import Bot +from bot.constants import Colours, STAFF_ROLES, Wolfram +from bot.pagination import ImagePaginator +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +APPID = Wolfram.key +DEFAULT_OUTPUT_FORMAT = "JSON" +QUERY = "http://api.wolframalpha.com/v2/{request}?{data}" +WOLF_IMAGE = "https://www.symbols.com/gi.php?type=1&id=2886&i=1" + +MAX_PODS = 20 + +# Allows for 10 wolfram calls pr user pr day +usercd = commands.CooldownMapping.from_cooldown(Wolfram.user_limit_day, 60*60*24, BucketType.user) + +# Allows for max api requests / days in month per day for the entire guild (Temporary) +guildcd = commands.CooldownMapping.from_cooldown(Wolfram.guild_limit_day, 60*60*24, BucketType.guild) + + +async def send_embed( + ctx: Context, + message_txt: str, + colour: int = Colours.soft_red, + footer: str = None, + img_url: str = None, + f: discord.File = None +) -> None: + """Generate & send a response embed with Wolfram as the author.""" + embed = Embed(colour=colour) + embed.description = message_txt + embed.set_author(name="Wolfram Alpha", + icon_url=WOLF_IMAGE, + url="https://www.wolframalpha.com/") + if footer: + embed.set_footer(text=footer) + + if img_url: + embed.set_image(url=img_url) + + await ctx.send(embed=embed, file=f) + + +def custom_cooldown(*ignore: List[int]) -> Callable: + """ + Implement per-user and per-guild cooldowns for requests to the Wolfram API. + + A list of roles may be provided to ignore the per-user cooldown + """ + async def predicate(ctx: Context) -> bool: + if ctx.invoked_with == 'help': + # if the invoked command is help we don't want to increase the ratelimits since it's not actually + # invoking the command/making a request, so instead just check if the user/guild are on cooldown. + guild_cooldown = not guildcd.get_bucket(ctx.message).get_tokens() == 0 # if guild is on cooldown + if not any(r.id in ignore for r in ctx.author.roles): # check user bucket if user is not ignored + return guild_cooldown and not usercd.get_bucket(ctx.message).get_tokens() == 0 + return guild_cooldown + + user_bucket = usercd.get_bucket(ctx.message) + + if all(role.id not in ignore for role in ctx.author.roles): + user_rate = user_bucket.update_rate_limit() + + if user_rate: + # Can't use api; cause: member limit + delta = relativedelta(seconds=int(user_rate)) + cooldown = humanize_delta(delta) + message = ( + "You've used up your limit for Wolfram|Alpha requests.\n" + f"Cooldown: {cooldown}" + ) + await send_embed(ctx, message) + return False + + guild_bucket = guildcd.get_bucket(ctx.message) + guild_rate = guild_bucket.update_rate_limit() + + # Repr has a token attribute to read requests left + log.debug(guild_bucket) + + if guild_rate: + # Can't use api; cause: guild limit + message = ( + "The max limit of requests for the server has been reached for today.\n" + f"Cooldown: {int(guild_rate)}" + ) + await send_embed(ctx, message) + return False + + return True + return check(predicate) + + +async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tuple]]: + """Get the Wolfram API pod pages for the provided query.""" + async with ctx.channel.typing(): + url_str = parse.urlencode({ + "input": query, + "appid": APPID, + "output": DEFAULT_OUTPUT_FORMAT, + "format": "image,plaintext" + }) + request_url = QUERY.format(request="query", data=url_str) + + async with bot.http_session.get(request_url) as response: + json = await response.json(content_type='text/plain') + + result = json["queryresult"] + + if result["error"]: + # API key not set up correctly + if result["error"]["msg"] == "Invalid appid": + message = "Wolfram API key is invalid or missing." + log.warning( + "API key seems to be missing, or invalid when " + f"processing a wolfram request: {url_str}, Response: {json}" + ) + await send_embed(ctx, message) + return + + message = "Something went wrong internally with your request, please notify staff!" + log.warning(f"Something went wrong getting a response from wolfram: {url_str}, Response: {json}") + await send_embed(ctx, message) + return + + if not result["success"]: + message = f"I couldn't find anything for {query}." + await send_embed(ctx, message) + return + + if not result["numpods"]: + message = "Could not find any results." + await send_embed(ctx, message) + return + + pods = result["pods"] + pages = [] + for pod in pods[:MAX_PODS]: + subs = pod.get("subpods") + + for sub in subs: + title = sub.get("title") or sub.get("plaintext") or sub.get("id", "") + img = sub["img"]["src"] + pages.append((title, img)) + return pages + + +class Wolfram(Cog): + """Commands for interacting with the Wolfram|Alpha API.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_command(self, ctx: Context, *, query: str) -> None: + """Requests all answers on a single image, sends an image of all related pods.""" + url_str = parse.urlencode({ + "i": query, + "appid": APPID, + }) + query = QUERY.format(request="simple", data=url_str) + + # Give feedback that the bot is working. + async with ctx.channel.typing(): + async with self.bot.http_session.get(query) as response: + status = response.status + image_bytes = await response.read() + + f = discord.File(BytesIO(image_bytes), filename="image.png") + image_url = "attachment://image.png" + + if status == 501: + message = "Failed to get response" + footer = "" + color = Colours.soft_red + elif status == 400: + message = "No input found" + footer = "" + color = Colours.soft_red + elif status == 403: + message = "Wolfram API key is invalid or missing." + footer = "" + color = Colours.soft_red + else: + message = "" + footer = "View original for a bigger picture." + color = Colours.soft_orange + + # Sends a "blank" embed if no request is received, unsure how to fix + await send_embed(ctx, message, color, footer=footer, img_url=image_url, f=f) + + @wolfram_command.command(name="page", aliases=("pa", "p")) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_page_command(self, ctx: Context, *, query: str) -> None: + """ + Requests a drawn image of given query. + + Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. + """ + pages = await get_pod_pages(ctx, self.bot, query) + + if not pages: + return + + embed = Embed() + embed.set_author(name="Wolfram Alpha", + icon_url=WOLF_IMAGE, + url="https://www.wolframalpha.com/") + embed.colour = Colours.soft_orange + + await ImagePaginator.paginate(pages, ctx, embed) + + @wolfram_command.command(name="cut", aliases=("c",)) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_cut_command(self, ctx: Context, *, query: str) -> None: + """ + Requests a drawn image of given query. + + Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. + """ + pages = await get_pod_pages(ctx, self.bot, query) + + if not pages: + return + + if len(pages) >= 2: + page = pages[1] + else: + page = pages[0] + + await send_embed(ctx, page[0], colour=Colours.soft_orange, img_url=page[1]) + + @wolfram_command.command(name="short", aliases=("sh", "s")) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_short_command(self, ctx: Context, *, query: str) -> None: + """Requests an answer to a simple question.""" + url_str = parse.urlencode({ + "i": query, + "appid": APPID, + }) + query = QUERY.format(request="result", data=url_str) + + # Give feedback that the bot is working. + async with ctx.channel.typing(): + async with self.bot.http_session.get(query) as response: + status = response.status + response_text = await response.text() + + if status == 501: + message = "Failed to get response" + color = Colours.soft_red + elif status == 400: + message = "No input found" + color = Colours.soft_red + elif response_text == "Error 1: Invalid appid": + message = "Wolfram API key is invalid or missing." + color = Colours.soft_red + else: + message = response_text + color = Colours.soft_orange + + await send_embed(ctx, message, color) + + +def setup(bot: Bot) -> None: + """Load the Wolfram cog.""" + bot.add_cog(Wolfram(bot)) diff --git a/bot/exts/moderation/__init__.py b/bot/exts/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py new file mode 100644 index 000000000..b75a4dcfe --- /dev/null +++ b/bot/exts/moderation/defcon.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import logging +from collections import namedtuple +from datetime import datetime, timedelta +from enum import Enum + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles +from bot.decorators import with_role +from bot.exts.moderation.modlog import ModLog + +log = logging.getLogger(__name__) + +REJECTION_MESSAGE = """ +Hi, {user} - Thanks for your interest in our server! + +Due to a current (or detected) cyberattack on our community, we've limited access to the server for new accounts. Since +your account is relatively new, we're unable to provide access to the server at this time. + +Even so, thanks for joining! We're very excited at the possibility of having you here, and we hope that this situation +will be resolved soon. In the meantime, please feel free to peruse the resources on our site at +, and have a nice day! +""" + +BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism" + + +class Action(Enum): + """Defcon Action.""" + + ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template']) + + ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n") + DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "") + UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n") + + +class Defcon(Cog): + """Time-sensitive server defense mechanisms.""" + + days = None # type: timedelta + enabled = False # type: bool + + def __init__(self, bot: Bot): + self.bot = bot + self.channel = None + self.days = timedelta(days=0) + + self.bot.loop.create_task(self.sync_settings()) + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def sync_settings(self) -> None: + """On cog load, try to synchronize DEFCON settings to the API.""" + await self.bot.wait_until_guild_available() + self.channel = await self.bot.fetch_channel(Channels.defcon) + + try: + response = await self.bot.api_client.get('bot/bot-settings/defcon') + data = response['data'] + + except Exception: # Yikes! + log.exception("Unable to get DEFCON settings!") + await self.bot.get_channel(Channels.dev_log).send( + f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" + ) + + else: + if data["enabled"]: + self.enabled = True + self.days = timedelta(days=data["days"]) + log.info(f"DEFCON enabled: {self.days.days} days") + + else: + self.enabled = False + self.days = timedelta(days=0) + log.info("DEFCON disabled") + + await self.update_channel_topic() + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """If DEFCON is enabled, check newly joining users to see if they meet the account age threshold.""" + if self.enabled and self.days.days > 0: + now = datetime.utcnow() + + if now - member.created_at < self.days: + log.info(f"Rejecting user {member}: Account is too new and DEFCON is enabled") + + message_sent = False + + try: + await member.send(REJECTION_MESSAGE.format(user=member.mention)) + + message_sent = True + except Exception: + log.exception(f"Unable to send rejection message to user: {member}") + + await member.kick(reason="DEFCON active, user is too new") + self.bot.stats.incr("defcon.leaves") + + message = ( + f"{member} (`{member.id}`) was denied entry because their account is too new." + ) + + if not message_sent: + message = f"{message}\n\nUnable to send rejection message via DM; they probably have DMs disabled." + + await self.mod_log.send_log_message( + Icons.defcon_denied, Colours.soft_red, "Entry denied", + message, member.avatar_url_as(static_format="png") + ) + + @group(name='defcon', aliases=('dc',), invoke_without_command=True) + @with_role(Roles.admins, Roles.owners) + async def defcon_group(self, ctx: Context) -> None: + """Check the DEFCON status or run a subcommand.""" + await ctx.send_help(ctx.command) + + async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: + """Providing a structured way to do an defcon action.""" + try: + response = await self.bot.api_client.get('bot/bot-settings/defcon') + data = response['data'] + + if "enable_date" in data and action is Action.DISABLED: + enabled = datetime.fromisoformat(data["enable_date"]) + + delta = datetime.now() - enabled + + self.bot.stats.timing("defcon.enabled", delta) + except Exception: + pass + + error = None + try: + await self.bot.api_client.put( + 'bot/bot-settings/defcon', + json={ + 'name': 'defcon', + 'data': { + # TODO: retrieve old days count + 'days': days, + 'enabled': action is not Action.DISABLED, + 'enable_date': datetime.now().isoformat() + } + } + ) + except Exception as err: + log.exception("Unable to update DEFCON settings.") + error = err + finally: + await ctx.send(self.build_defcon_msg(action, error)) + await self.send_defcon_log(action, ctx.author, error) + + self.bot.stats.gauge("defcon.threshold", days) + + @defcon_group.command(name='enable', aliases=('on', 'e')) + @with_role(Roles.admins, Roles.owners) + async def enable_command(self, ctx: Context) -> None: + """ + Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! + + Currently, this just adds an account age requirement. Use !defcon days to set how old an account must be, + in days. + """ + self.enabled = True + await self._defcon_action(ctx, days=0, action=Action.ENABLED) + await self.update_channel_topic() + + @defcon_group.command(name='disable', aliases=('off', 'd')) + @with_role(Roles.admins, Roles.owners) + async def disable_command(self, ctx: Context) -> None: + """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" + self.enabled = False + await self._defcon_action(ctx, days=0, action=Action.DISABLED) + await self.update_channel_topic() + + @defcon_group.command(name='status', aliases=('s',)) + @with_role(Roles.admins, Roles.owners) + async def status_command(self, ctx: Context) -> None: + """Check the current status of DEFCON mode.""" + embed = Embed( + colour=Colour.blurple(), title="DEFCON Status", + description=f"**Enabled:** {self.enabled}\n" + f"**Days:** {self.days.days}" + ) + + await ctx.send(embed=embed) + + @defcon_group.command(name='days') + @with_role(Roles.admins, Roles.owners) + async def days_command(self, ctx: Context, days: int) -> None: + """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" + self.days = timedelta(days=days) + self.enabled = True + await self._defcon_action(ctx, days=days, action=Action.UPDATED) + await self.update_channel_topic() + + async def update_channel_topic(self) -> None: + """Update the #defcon channel topic with the current DEFCON status.""" + if self.enabled: + day_str = "days" if self.days.days > 1 else "day" + new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Enabled, Threshold: {self.days.days} {day_str})" + else: + new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Disabled)" + + self.mod_log.ignore(Event.guild_channel_update, Channels.defcon) + await self.channel.edit(topic=new_topic) + + def build_defcon_msg(self, action: Action, e: Exception = None) -> str: + """Build in-channel response string for DEFCON action.""" + if action is Action.ENABLED: + msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n" + elif action is Action.DISABLED: + msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n" + elif action is Action.UPDATED: + msg = ( + f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} " + f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n" + ) + + if e: + msg += ( + "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" + f"```py\n{e}\n```" + ) + + return msg + + async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None: + """Send log message for DEFCON action.""" + info = action.value + log_msg: str = ( + f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n" + f"{info.template.format(days=self.days.days)}" + ) + status_msg = f"DEFCON {action.name.lower()}" + + if e: + log_msg += ( + "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" + f"```py\n{e}\n```" + ) + + await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg) + + +def setup(bot: Bot) -> None: + """Load the Defcon cog.""" + bot.add_cog(Defcon(bot)) diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py new file mode 100644 index 000000000..e49913552 --- /dev/null +++ b/bot/exts/moderation/incidents.py @@ -0,0 +1,412 @@ +import asyncio +import logging +import typing as t +from datetime import datetime +from enum import Enum + +import discord +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, Colours, Emojis, Guild, Webhooks +from bot.utils.messages import sub_clyde + +log = logging.getLogger(__name__) + +# Amount of messages for `crawl_task` to process at most on start-up - limited to 50 +# as in practice, there should never be this many messages, and if there are, +# something has likely gone very wrong +CRAWL_LIMIT = 50 + +# Seconds for `crawl_task` to sleep after adding reactions to a message +CRAWL_SLEEP = 2 + + +class Signal(Enum): + """ + Recognized incident status signals. + + This binds emoji to actions. The bot will only react to emoji linked here. + All other signals are seen as invalid. + """ + + ACTIONED = Emojis.incident_actioned + NOT_ACTIONED = Emojis.incident_unactioned + INVESTIGATING = Emojis.incident_investigating + + +# Reactions from non-mod roles will be removed +ALLOWED_ROLES: t.Set[int] = set(Guild.moderation_roles) + +# Message must have all of these emoji to pass the `has_signals` check +ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} + +# An embed coupled with an optional file to be dispatched +# If the file is not None, the embed attempts to show it in its body +FileEmbed = t.Tuple[discord.Embed, t.Optional[discord.File]] + + +async def download_file(attachment: discord.Attachment) -> t.Optional[discord.File]: + """ + Download & return `attachment` file. + + If the download fails, the reason is logged and None will be returned. + 404 and 403 errors are only logged at debug level. + """ + log.debug(f"Attempting to download attachment: {attachment.filename}") + try: + return await attachment.to_file() + except (discord.NotFound, discord.Forbidden) as exc: + log.debug(f"Failed to download attachment: {exc}") + except Exception: + log.exception("Failed to download attachment") + + +async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> FileEmbed: + """ + Create an embed representation of `incident` for the #incidents-archive channel. + + The name & discriminator of `actioned_by` and `outcome` will be presented in the + embed footer. Additionally, the embed is coloured based on `outcome`. + + The author of `incident` is not shown in the embed. It is assumed that this piece + of information will be relayed in other ways, e.g. webhook username. + + As mentions in embeds do not ping, we do not need to use `incident.clean_content`. + + If `incident` contains attachments, the first attachment will be downloaded and + returned alongside the embed. The embed attempts to display the attachment. + Should the download fail, we fallback on linking the `proxy_url`, which should + remain functional for some time after the original message is deleted. + """ + log.trace(f"Creating embed for {incident.id=}") + + if outcome is Signal.ACTIONED: + colour = Colours.soft_green + footer = f"Actioned by {actioned_by}" + else: + colour = Colours.soft_red + footer = f"Rejected by {actioned_by}" + + embed = discord.Embed( + description=incident.content, + timestamp=datetime.utcnow(), + colour=colour, + ) + embed.set_footer(text=footer, icon_url=actioned_by.avatar_url) + + if incident.attachments: + attachment = incident.attachments[0] # User-sent messages can only contain one attachment + file = await download_file(attachment) + + if file is not None: + embed.set_image(url=f"attachment://{attachment.filename}") # Embed displays the attached file + else: + embed.set_author(name="[Failed to relay attachment]", url=attachment.proxy_url) # Embed links the file + else: + file = None + + return embed, file + + +def is_incident(message: discord.Message) -> bool: + """True if `message` qualifies as an incident, False otherwise.""" + conditions = ( + message.channel.id == Channels.incidents, # Message sent in #incidents + not message.author.bot, # Not by a bot + not message.content.startswith("#"), # Doesn't start with a hash + not message.pinned, # And isn't header + ) + return all(conditions) + + +def own_reactions(message: discord.Message) -> t.Set[str]: + """Get the set of reactions placed on `message` by the bot itself.""" + return {str(reaction.emoji) for reaction in message.reactions if reaction.me} + + +def has_signals(message: discord.Message) -> bool: + """True if `message` already has all `Signal` reactions, False otherwise.""" + return ALL_SIGNALS.issubset(own_reactions(message)) + + +async def add_signals(incident: discord.Message) -> None: + """ + Add `Signal` member emoji to `incident` as reactions. + + If the emoji has already been placed on `incident` by the bot, it will be skipped. + """ + existing_reacts = own_reactions(incident) + + for signal_emoji in Signal: + if signal_emoji.value in existing_reacts: # This would not raise, but it is a superfluous API call + log.trace(f"Skipping emoji as it's already been placed: {signal_emoji}") + else: + log.trace(f"Adding reaction: {signal_emoji}") + await incident.add_reaction(signal_emoji.value) + + +class Incidents(Cog): + """ + Automation for the #incidents channel. + + This cog does not provide a command API, it only reacts to the following events. + + On start-up: + * Crawl #incidents and add missing `Signal` emoji where appropriate + * This is to retro-actively add the available options for messages which + were sent while the bot wasn't listening + * Pinned messages and message starting with # do not qualify as incidents + * See: `crawl_incidents` + + On message: + * Add `Signal` member emoji if message qualifies as an incident + * Ignore messages starting with # + * Use this if verbal communication is necessary + * Each such message must be deleted manually once appropriate + * See: `on_message` + + On reaction: + * Remove reaction if not permitted + * User does not have any of the roles in `ALLOWED_ROLES` + * Used emoji is not a `Signal` member + * If `Signal.ACTIONED` or `Signal.NOT_ACTIONED` were chosen, attempt to + relay the incident message to #incidents-archive + * If relay successful, delete original message + * See: `on_raw_reaction_add` + + Please refer to function docstrings for implementation details. + """ + + def __init__(self, bot: Bot) -> None: + """Prepare `event_lock` and schedule `crawl_task` on start-up.""" + self.bot = bot + + self.event_lock = asyncio.Lock() + self.crawl_task = self.bot.loop.create_task(self.crawl_incidents()) + + async def crawl_incidents(self) -> None: + """ + Crawl #incidents and add missing emoji where necessary. + + This is to catch-up should an incident be reported while the bot wasn't listening. + After adding each reaction, we take a short break to avoid drowning in ratelimits. + + Once this task is scheduled, listeners that change messages should await it. + The crawl assumes that the channel history doesn't change as we go over it. + + Behaviour is configured by: `CRAWL_LIMIT`, `CRAWL_SLEEP`. + """ + await self.bot.wait_until_guild_available() + incidents: discord.TextChannel = self.bot.get_channel(Channels.incidents) + + log.debug(f"Crawling messages in #incidents: {CRAWL_LIMIT=}, {CRAWL_SLEEP=}") + async for message in incidents.history(limit=CRAWL_LIMIT): + + if not is_incident(message): + log.trace(f"Skipping message {message.id}: not an incident") + continue + + if has_signals(message): + log.trace(f"Skipping message {message.id}: already has all signals") + continue + + await add_signals(message) + await asyncio.sleep(CRAWL_SLEEP) + + log.debug("Crawl task finished!") + + async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: + """ + Relay an embed representation of `incident` to the #incidents-archive channel. + + The following pieces of information are relayed: + * Incident message content (as embed description) + * Incident attachment (if image, shown in archive embed) + * Incident author name (as webhook author) + * Incident author avatar (as webhook avatar) + * Resolution signal `outcome` (as embed colour & footer) + * Moderator `actioned_by` (name & discriminator shown in footer) + + If `incident` contains an attachment, we try to add it to the archive embed. There is + no handing of extensions / file types - we simply dispatch the attachment file with the + webhook, and try to display it in the embed. Testing indicates that if the attachment + cannot be displayed (e.g. a text file), it's invisible in the embed, with no error. + + Return True if the relay finishes successfully. If anything goes wrong, meaning + not all information was relayed, return False. This signals that the original + message is not safe to be deleted, as we will lose some information. + """ + log.debug(f"Archiving incident: {incident.id} (outcome: {outcome}, actioned by: {actioned_by})") + embed, attachment_file = await make_embed(incident, outcome, actioned_by) + + try: + webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive) + await webhook.send( + embed=embed, + username=sub_clyde(incident.author.name), + avatar_url=incident.author.avatar_url, + file=attachment_file, + ) + except Exception: + log.exception(f"Failed to archive incident {incident.id} to #incidents-archive") + return False + else: + log.trace("Message archived successfully!") + return True + + def make_confirmation_task(self, incident: discord.Message, timeout: int = 5) -> asyncio.Task: + """ + Create a task to wait `timeout` seconds for `incident` to be deleted. + + If `timeout` passes, this will raise `asyncio.TimeoutError`, signaling that we haven't + been able to confirm that the message was deleted. + """ + log.trace(f"Confirmation task will wait {timeout=} seconds for {incident.id=} to be deleted") + + def check(payload: discord.RawReactionActionEvent) -> bool: + return payload.message_id == incident.id + + coroutine = self.bot.wait_for(event="raw_message_delete", check=check, timeout=timeout) + return self.bot.loop.create_task(coroutine) + + async def process_event(self, reaction: str, incident: discord.Message, member: discord.Member) -> None: + """ + Process a `reaction_add` event in #incidents. + + First, we check that the reaction is a recognized `Signal` member, and that it was sent by + a permitted user (at least one role in `ALLOWED_ROLES`). If not, the reaction is removed. + + If the reaction was either `Signal.ACTIONED` or `Signal.NOT_ACTIONED`, we attempt to relay + the report to #incidents-archive. If successful, the original message is deleted. + + We do not release `event_lock` until we receive the corresponding `message_delete` event. + This ensures that if there is a racing event awaiting the lock, it will fail to find the + message, and will abort. There is a `timeout` to ensure that this doesn't hold the lock + forever should something go wrong. + """ + members_roles: t.Set[int] = {role.id for role in member.roles} + if not members_roles & ALLOWED_ROLES: # Intersection is truthy on at least 1 common element + log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals") + await incident.remove_reaction(reaction, member) + return + + try: + signal = Signal(reaction) + except ValueError: + log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") + await incident.remove_reaction(reaction, member) + return + + log.trace(f"Received signal: {signal}") + + if signal not in (Signal.ACTIONED, Signal.NOT_ACTIONED): + log.debug("Reaction was valid, but no action is currently defined for it") + return + + relay_successful = await self.archive(incident, signal, actioned_by=member) + if not relay_successful: + log.trace("Original message will not be deleted as we failed to relay it to the archive") + return + + timeout = 5 # Seconds + confirmation_task = self.make_confirmation_task(incident, timeout) + + log.trace("Deleting original message") + await incident.delete() + + log.trace(f"Awaiting deletion confirmation: {timeout=} seconds") + try: + await confirmation_task + except asyncio.TimeoutError: + log.warning(f"Did not receive incident deletion confirmation within {timeout} seconds!") + else: + log.trace("Deletion was confirmed") + + async def resolve_message(self, message_id: int) -> t.Optional[discord.Message]: + """ + Get `discord.Message` for `message_id` from cache, or API. + + We first look into the local cache to see if the message is present. + + If not, we try to fetch the message from the API. This is necessary for messages + which were sent before the bot's current session. + + In an edge-case, it is also possible that the message was already deleted, and + the API will respond with a 404. In such a case, None will be returned. + This signals that the event for `message_id` should be ignored. + """ + await self.bot.wait_until_guild_available() # First make sure that the cache is ready + log.trace(f"Resolving message for: {message_id=}") + message: t.Optional[discord.Message] = self.bot._connection._get_message(message_id) + + if message is not None: + log.trace("Message was found in cache") + return message + + log.trace("Message not found, attempting to fetch") + try: + message = await self.bot.get_channel(Channels.incidents).fetch_message(message_id) + except discord.NotFound: + log.trace("Message doesn't exist, it was likely already relayed") + except Exception: + log.exception(f"Failed to fetch message {message_id}!") + else: + log.trace("Message fetched successfully!") + return message + + @Cog.listener() + async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None: + """ + Pre-process `payload` and pass it to `process_event` if appropriate. + + We abort instantly if `payload` doesn't relate to a message sent in #incidents, + or if it was sent by a bot. + + If `payload` relates to a message in #incidents, we first ensure that `crawl_task` has + finished, to make sure we don't mutate channel state as we're crawling it. + + Next, we acquire `event_lock` - to prevent racing, events are processed one at a time. + + Once we have the lock, the `discord.Message` object for this event must be resolved. + If the lock was previously held by an event which successfully relayed the incident, + this will fail and we abort the current event. + + Finally, with both the lock and the `discord.Message` instance in our hands, we delegate + to `process_event` to handle the event. + + The justification for using a raw listener is the need to receive events for messages + which were not cached in the current session. As a result, a certain amount of + complexity is introduced, but at the moment this doesn't appear to be avoidable. + """ + if payload.channel_id != Channels.incidents or payload.member.bot: + return + + log.trace(f"Received reaction add event in #incidents, waiting for crawler: {self.crawl_task.done()=}") + await self.crawl_task + + log.trace(f"Acquiring event lock: {self.event_lock.locked()=}") + async with self.event_lock: + message = await self.resolve_message(payload.message_id) + + if message is None: + log.debug("Listener will abort as related message does not exist!") + return + + if not is_incident(message): + log.debug("Ignoring event for a non-incident message") + return + + await self.process_event(str(payload.emoji), message, payload.member) + log.trace("Releasing event lock") + + @Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Pass `message` to `add_signals` if and only if it satisfies `is_incident`.""" + if is_incident(message): + await add_signals(message) + + +def setup(bot: Bot) -> None: + """Load the Incidents cog.""" + bot.add_cog(Incidents(bot)) diff --git a/bot/exts/moderation/infraction/__init__.py b/bot/exts/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py new file mode 100644 index 000000000..1310fd3d9 --- /dev/null +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -0,0 +1,463 @@ +import logging +import textwrap +import typing as t +from abc import abstractmethod +from datetime import datetime +from gettext import ngettext + +import dateutil.parser +import discord +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Colours, STAFF_CHANNELS +from bot.exts.moderation.modlog import ModLog +from bot.utils import time +from bot.utils.scheduling import Scheduler +from . import _utils +from ._utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class InfractionScheduler: + """Handles the application, pardoning, and expiration of infractions.""" + + def __init__(self, bot: Bot, supported_infractions: t.Container[str]): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + @property + def mod_log(self) -> ModLog: + """Get the currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: + """Schedule expiration for previous infractions.""" + await self.bot.wait_until_guild_available() + + log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") + + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={'active': 'true'} + ) + for infraction in infractions: + if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: + self.schedule_expiration(infraction) + + async def reapply_infraction( + self, + infraction: _utils.Infraction, + apply_coro: t.Optional[t.Awaitable] + ) -> None: + """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" + # Calculate the time remaining, in seconds, for the mute. + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + delta = (expiry - datetime.utcnow()).total_seconds() + + # Mark as inactive if less than a minute remains. + if delta < 60: + log.info( + "Infraction will be deactivated instead of re-applied " + "because less than 1 minute remains." + ) + await self.deactivate_infraction(infraction) + return + + # Allowing mod log since this is a passive action that should be logged. + await apply_coro + log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") + + async def apply_infraction( + self, + ctx: Context, + infraction: _utils.Infraction, + user: UserSnowflake, + action_coro: t.Optional[t.Awaitable] = None + ) -> None: + """Apply an infraction to the user, log the infraction, and optionally notify the user.""" + infr_type = infraction["type"] + icon = _utils.INFRACTION_ICONS[infr_type][0] + reason = infraction["reason"] + expiry = time.format_infraction_with_duration(infraction["expires_at"]) + id_ = infraction['id'] + + log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") + + # Default values for the confirmation message and mod log. + confirm_msg = ":ok_hand: applied" + + # Specifying an expiry for a note or warning makes no sense. + if infr_type in ("note", "warning"): + expiry_msg = "" + else: + expiry_msg = f" until {expiry}" if expiry else " permanently" + + dm_result = "" + dm_log_text = "" + expiry_log_text = f"\nExpires: {expiry}" if expiry else "" + log_title = "applied" + log_content = None + failed = False + + # DM the user about the infraction if it's not a shadow/hidden infraction. + # This needs to happen before we apply the infraction, as the bot cannot + # send DMs to user that it doesn't share a guild with. If we were to + # apply kick/ban infractions first, this would mean that we'd make it + # impossible for us to deliver a DM. See python-discord/bot#982. + if not infraction["hidden"]: + dm_result = f"{constants.Emojis.failmail} " + dm_log_text = "\nDM: **Failed**" + + # Sometimes user is a discord.Object; make it a proper user. + try: + if not isinstance(user, (discord.Member, discord.User)): + user = await self.bot.fetch_user(user.id) + except discord.HTTPException as e: + log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + else: + # Accordingly display whether the user was successfully notified via DM. + if await _utils.notify_infraction(user, infr_type, expiry, reason, icon): + dm_result = ":incoming_envelope: " + dm_log_text = "\nDM: Sent" + + end_msg = "" + if infraction["actor"] == self.bot.user.id: + log.trace( + f"Infraction #{id_} actor is bot; including the reason in the confirmation message." + ) + if reason: + end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" + elif ctx.channel.id not in STAFF_CHANNELS: + log.trace( + f"Infraction #{id_} context is not in a staff channel; omitting infraction count." + ) + else: + log.trace(f"Fetching total infraction count for {user}.") + + infractions = await self.bot.api_client.get( + "bot/infractions", + params={"user__id": str(user.id)} + ) + total = len(infractions) + end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" + + # Execute the necessary actions to apply the infraction on Discord. + if action_coro: + log.trace(f"Awaiting the infraction #{id_} application action coroutine.") + try: + await action_coro + if expiry: + # Schedule the expiration of the infraction. + self.schedule_expiration(infraction) + except discord.HTTPException as e: + # Accordingly display that applying the infraction failed. + confirm_msg = ":x: failed to apply" + expiry_msg = "" + log_content = ctx.author.mention + log_title = "failed to apply" + + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) + failed = True + + if failed: + log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") + try: + await self.bot.api_client.delete(f"bot/infractions/{id_}") + except ResponseCodeError as e: + confirm_msg += " and failed to delete" + log_title += " and failed to delete" + log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") + infr_message = "" + else: + infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" + + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{id_} confirmation message.") + await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") + + # Send a log message to the mod log. + log.trace(f"Sending apply mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=icon, + colour=Colours.soft_red, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {user.mention} (`{user.id}`) + Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} + Reason: {reason} + """), + content=log_content, + footer=f"ID {infraction['id']}" + ) + + log.info(f"Applied {infr_type} infraction #{id_} to {user}.") + + async def pardon_infraction( + self, + ctx: Context, + infr_type: str, + user: UserSnowflake, + send_msg: bool = True + ) -> None: + """ + Prematurely end an infraction for a user and log the action in the mod log. + + If `send_msg` is True, then a pardoning confirmation message will be sent to + the context channel. Otherwise, no such message will be sent. + """ + log.trace(f"Pardoning {infr_type} infraction for {user}.") + + # Check the current active infraction + log.trace(f"Fetching active {infr_type} infractions for {user}.") + response = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': user.id + } + ) + + if not response: + log.debug(f"No active {infr_type} infraction found for {user}.") + await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") + return + + # Deactivate the infraction and cancel its scheduled expiration task. + log_text = await self.deactivate_infraction(response[0], send_log=False) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["Actor"] = str(ctx.message.author) + log_content = None + id_ = response[0]['id'] + footer = f"ID: {id_}" + + # If multiple active infractions were found, mark them as inactive in the database + # and cancel their expiration tasks. + if len(response) > 1: + log.info( + f"Found more than one active {infr_type} infraction for user {user.id}; " + "deactivating the extra active infractions too." + ) + + footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" + + log_note = f"Found multiple **active** {infr_type} infractions in the database." + if "Note" in log_text: + log_text["Note"] = f" {log_note}" + else: + log_text["Note"] = log_note + + # deactivate_infraction() is not called again because: + # 1. Discord cannot store multiple active bans or assign multiples of the same role + # 2. It would send a pardon DM for each active infraction, which is redundant + for infraction in response[1:]: + id_ = infraction['id'] + try: + # Mark infraction as inactive in the database. + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError: + log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") + # This is simpler and cleaner than trying to concatenate all the errors. + log_text["Failure"] = "See bot's logs for details." + + # Cancel pending expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Accordingly display whether the user was successfully notified via DM. + dm_emoji = "" + if log_text.get("DM") == "Sent": + dm_emoji = ":incoming_envelope: " + elif "DM" in log_text: + dm_emoji = f"{constants.Emojis.failmail} " + + # Accordingly display whether the pardon failed. + if "Failure" in log_text: + confirm_msg = ":x: failed to pardon" + log_title = "pardon failed" + log_content = ctx.author.mention + + log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") + else: + confirm_msg = ":ok_hand: pardoned" + log_title = "pardoned" + + log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") + + # Send a confirmation message to the invoking context. + if send_msg: + log.trace(f"Sending infraction #{id_} pardon confirmation message.") + await ctx.send( + f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " + f"{log_text.get('Failure', '')}" + ) + + # Move reason to end of entry to avoid cutting out some keys + log_text["Reason"] = log_text.pop("Reason") + + # Send a log message to the mod log. + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS[infr_type][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=footer, + content=log_content, + ) + + async def deactivate_infraction( + self, + infraction: _utils.Infraction, + send_log: bool = True + ) -> t.Dict[str, str]: + """ + Deactivate an active infraction and return a dictionary of lines to send in a mod log. + + The infraction is removed from Discord, marked as inactive in the database, and has its + expiration task cancelled. If `send_log` is True, a mod log is sent for the + deactivation of the infraction. + + Infractions of unsupported types will raise a ValueError. + """ + guild = self.bot.get_guild(constants.Guild.id) + mod_role = guild.get_role(constants.Roles.moderators) + user_id = infraction["user"] + actor = infraction["actor"] + type_ = infraction["type"] + id_ = infraction["id"] + inserted_at = infraction["inserted_at"] + expiry = infraction["expires_at"] + + log.info(f"Marking infraction #{id_} as inactive (expired).") + + expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None + created = time.format_infraction_with_duration(inserted_at, expiry) + + log_content = None + log_text = { + "Member": f"<@{user_id}>", + "Actor": str(self.bot.get_user(actor) or actor), + "Reason": infraction["reason"], + "Created": created, + } + + try: + log.trace("Awaiting the pardon action coroutine.") + returned_log = await self._pardon_action(infraction) + + if returned_log is not None: + log_text = {**log_text, **returned_log} # Merge the logs together + else: + raise ValueError( + f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" + ) + except discord.Forbidden: + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") + log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" + log_content = mod_role.mention + except discord.HTTPException as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." + log_content = mod_role.mention + + # Check if the user is currently being watched by Big Brother. + try: + log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") + + active_watch = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "watch", + "user__id": user_id + } + ) + + log_text["Watching"] = "Yes" if active_watch else "No" + except ResponseCodeError: + log.exception(f"Failed to fetch watch status for user {user_id}") + log_text["Watching"] = "Unknown - failed to fetch watch status." + + try: + # Mark infraction as inactive in the database. + log.trace(f"Marking infraction #{id_} as inactive in the database.") + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_line = f"API request failed with code {e.status}." + log_content = mod_role.mention + + # Append to an existing failure message if possible + if "Failure" in log_text: + log_text["Failure"] += f" {log_line}" + else: + log_text["Failure"] = log_line + + # Cancel the expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Send a log message to the mod log. + if send_log: + log_title = "expiration failed" if "Failure" in log_text else "expired" + + user = self.bot.get_user(user_id) + avatar = user.avatar_url_as(static_format="png") if user else None + + # Move reason to end so when reason is too long, this is not gonna cut out required items. + log_text["Reason"] = log_text.pop("Reason") + + log.trace(f"Sending deactivation mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS[type_][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {type_}", + thumbnail=avatar, + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=f"ID: {id_}", + content=log_content, + ) + + return log_text + + @abstractmethod + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + raise NotImplementedError + + def schedule_expiration(self, infraction: _utils.Infraction) -> None: + """ + Marks an infraction expired after the delay from time of scheduling to time of expiration. + + At the time of expiration, the infraction is marked as inactive on the website and the + expiration task is cancelled. + """ + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py new file mode 100644 index 000000000..fb55287b6 --- /dev/null +++ b/bot/exts/moderation/infraction/_utils.py @@ -0,0 +1,201 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext.commands import Context + +from bot.api import ResponseCodeError +from bot.constants import Colours, Icons + +log = logging.getLogger(__name__) + +# apply icon, pardon icon +INFRACTION_ICONS = { + "ban": (Icons.user_ban, Icons.user_unban), + "kick": (Icons.sign_out, None), + "mute": (Icons.user_mute, Icons.user_unmute), + "note": (Icons.user_warn, None), + "superstar": (Icons.superstarify, Icons.unsuperstarify), + "warning": (Icons.user_warn, None), +} +RULES_URL = "https://pythondiscord.com/pages/rules" +APPEALABLE_INFRACTIONS = ("ban", "mute") + +# Type aliases +UserObject = t.Union[discord.Member, discord.User] +UserSnowflake = t.Union[UserObject, discord.Object] +Infraction = t.Dict[str, t.Union[str, int, bool]] + + +async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: + """ + Create a new user in the database. + + Used when an infraction needs to be applied on a user absent in the guild. + """ + log.trace(f"Attempting to add user {user.id} to the database.") + + if not isinstance(user, (discord.Member, discord.User)): + log.debug("The user being added to the DB is not a Member or User object.") + + payload = { + 'discriminator': int(getattr(user, 'discriminator', 0)), + 'id': user.id, + 'in_guild': False, + 'name': getattr(user, 'name', 'Name unknown'), + 'roles': [] + } + + try: + response = await ctx.bot.api_client.post('bot/users', json=payload) + log.info(f"User {user.id} added to the DB.") + return response + except ResponseCodeError as e: + log.error(f"Failed to add user {user.id} to the DB. {e}") + await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") + + +async def post_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + reason: str, + expires_at: datetime = None, + hidden: bool = False, + active: bool = True +) -> t.Optional[dict]: + """Posts an infraction to the API.""" + log.trace(f"Posting {infr_type} infraction for {user} to the API.") + + payload = { + "actor": ctx.message.author.id, + "hidden": hidden, + "reason": reason, + "type": infr_type, + "user": user.id, + "active": active + } + if expires_at: + payload['expires_at'] = expires_at.isoformat() + + # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. + for should_post_user in (True, False): + try: + response = await ctx.bot.api_client.post('bot/infractions', json=payload) + return response + except ResponseCodeError as e: + if e.status == 400 and 'user' in e.response_json: + # Only one attempt to add the user to the database, not two: + if not should_post_user or await post_user(ctx, user) is None: + return + else: + log.exception(f"Unexpected error while adding an infraction for {user}:") + await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") + return + + +async def get_active_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + send_msg: bool = True +) -> t.Optional[dict]: + """ + Retrieves an active infraction of the given type for the user. + + If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, + then a message for the moderator will be sent to the context channel letting them know. + Otherwise, no message will be sent. + """ + log.trace(f"Checking if {user} has active infractions of type {infr_type}.") + + active_infractions = await ctx.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': str(user.id) + } + ) + if active_infractions: + # Checks to see if the moderator should be told there is an active infraction + if send_msg: + log.trace(f"{user} has active infractions of type {infr_type}.") + await ctx.send( + f":x: According to my records, this user already has a {infr_type} infraction. " + f"See infraction **#{active_infractions[0]['id']}**." + ) + return active_infractions[0] + else: + log.trace(f"{user} does not have active infractions of type {infr_type}.") + + +async def notify_infraction( + user: UserObject, + infr_type: str, + expires_at: t.Optional[str] = None, + reason: t.Optional[str] = None, + icon_url: str = Icons.token_removed +) -> bool: + """DM a user about their new infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + + text = textwrap.dedent(f""" + **Type:** {infr_type.capitalize()} + **Expires:** {expires_at or "N/A"} + **Reason:** {reason or "No reason provided."} + """) + + embed = discord.Embed( + description=textwrap.shorten(text, width=2048, placeholder="..."), + colour=Colours.soft_red + ) + + embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) + embed.title = f"Please review our rules over at {RULES_URL}" + embed.url = RULES_URL + + if infr_type in APPEALABLE_INFRACTIONS: + embed.set_footer( + text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" + ) + + return await send_private_embed(user, embed) + + +async def notify_pardon( + user: UserObject, + title: str, + content: str, + icon_url: str = Icons.user_verified +) -> bool: + """DM a user about their pardoned infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their pardoned infraction.") + + embed = discord.Embed( + description=content, + colour=Colours.soft_green + ) + + embed.set_author(name=title, icon_url=icon_url) + + return await send_private_embed(user, embed) + + +async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: + """ + A helper method for sending an embed to a user's DMs. + + Returns a boolean indicator of DM success. + """ + try: + await user.send(embed=embed) + return True + except (discord.HTTPException, discord.Forbidden, discord.NotFound): + log.debug( + f"Infraction-related information could not be sent to user {user} ({user.id}). " + "The user either could not be retrieved or probably disabled their DMs." + ) + return False diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py new file mode 100644 index 000000000..cb459b447 --- /dev/null +++ b/bot/exts/moderation/infraction/infractions.py @@ -0,0 +1,375 @@ +import logging +import textwrap +import typing as t + +import discord +from discord import Member +from discord.ext import commands +from discord.ext.commands import Context, command + +from bot import constants +from bot.bot import Bot +from bot.constants import Event +from bot.converters import Expiry, FetchedMember +from bot.decorators import respect_role_hierarchy +from bot.utils.checks import with_role_check +from . import _utils +from ._scheduler import InfractionScheduler +from ._utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class Infractions(InfractionScheduler, commands.Cog): + """Apply and pardon infractions on users for moderation purposes.""" + + category = "Moderation" + category_description = "Server moderation tools." + + def __init__(self, bot: Bot): + super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) + + self.category = "Moderation" + self._muted_role = discord.Object(constants.Roles.muted) + + @commands.Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Reapply active mute infractions for returning members.""" + active_mutes = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "mute", + "user__id": member.id + } + ) + + if active_mutes: + reason = f"Re-applying active mute: {active_mutes[0]['id']}" + action = member.add_roles(self._muted_role, reason=reason) + + await self.reapply_infraction(active_mutes[0], action) + + # region: Permanent infractions + + @command() + async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Warn a user for the given reason.""" + infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False) + if infraction is None: + return + + await self.apply_infraction(ctx, infraction, user) + + @command() + async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Kick a user for the given reason.""" + await self.apply_kick(ctx, user, reason) + + @command() + async def ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Permanently ban a user for the given reason and stop watching them with Big Brother.""" + await self.apply_ban(ctx, user, reason) + + # endregion + # region: Temporary infractions + + @command(aliases=["mute"]) + async def tempmute(self, ctx: Context, user: Member, duration: Expiry, *, reason: t.Optional[str] = None) -> None: + """ + Temporarily mute a user for the given reason and duration. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_mute(ctx, user, reason, expires_at=duration) + + @command() + async def tempban( + self, + ctx: Context, + user: FetchedMember, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily ban a user for the given reason and duration. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_ban(ctx, user, reason, expires_at=duration) + + # endregion + # region: Permanent shadow infractions + + @command(hidden=True) + async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Create a private note for a user with the given reason without notifying the user.""" + infraction = await _utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) + if infraction is None: + return + + await self.apply_infraction(ctx, infraction, user) + + @command(hidden=True, aliases=['shadowkick', 'skick']) + async def shadow_kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Kick a user for the given reason without notifying the user.""" + await self.apply_kick(ctx, user, reason, hidden=True) + + @command(hidden=True, aliases=['shadowban', 'sban']) + async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Permanently ban a user for the given reason without notifying the user.""" + await self.apply_ban(ctx, user, reason, hidden=True) + + # endregion + # region: Temporary shadow infractions + + @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) + async def shadow_tempmute( + self, ctx: Context, + user: Member, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily mute a user for the given reason and duration without notifying the user. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True) + + @command(hidden=True, aliases=["shadowtempban, stempban"]) + async def shadow_tempban( + self, + ctx: Context, + user: FetchedMember, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily ban a user for the given reason and duration without notifying the user. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) + + # endregion + # region: Remove infractions (un- commands) + + @command() + async def unmute(self, ctx: Context, user: FetchedMember) -> None: + """Prematurely end the active mute infraction for the user.""" + await self.pardon_infraction(ctx, "mute", user) + + @command() + async def unban(self, ctx: Context, user: FetchedMember) -> None: + """Prematurely end the active ban infraction for the user.""" + await self.pardon_infraction(ctx, "ban", user) + + # endregion + # region: Base apply functions + + async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: + """Apply a mute infraction with kwargs passed to `post_infraction`.""" + if await _utils.get_active_infraction(ctx, user, "mute"): + return + + infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_update, user.id) + + async def action() -> None: + await user.add_roles(self._muted_role, reason=reason) + + log.trace(f"Attempting to kick {user} from voice because they've been muted.") + await user.move_to(None, reason=reason) + + await self.apply_infraction(ctx, infraction, user, action()) + + @respect_role_hierarchy() + async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: + """Apply a kick infraction with kwargs passed to `post_infraction`.""" + infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_remove, user.id) + + if reason: + reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = user.kick(reason=reason) + await self.apply_infraction(ctx, infraction, user, action) + + @respect_role_hierarchy() + async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: + """ + Apply a ban infraction with kwargs passed to `post_infraction`. + + Will also remove the banned user from the Big Brother watch list if applicable. + """ + # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active + is_temporary = kwargs.get("expires_at") is not None + active_infraction = await _utils.get_active_infraction(ctx, user, "ban", is_temporary) + + if active_infraction: + if is_temporary: + log.trace("Tempban ignored as it cannot overwrite an active ban.") + return + + if active_infraction.get('expires_at') is None: + log.trace("Permaban already exists, notify.") + await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).") + return + + log.trace("Old tempban is being replaced by new permaban.") + await self.pardon_infraction(ctx, "ban", user, is_temporary) + + infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_remove, user.id) + + if reason: + reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = ctx.guild.ban(user, reason=reason, delete_message_days=0) + await self.apply_infraction(ctx, infraction, user, action) + + if infraction.get('expires_at') is not None: + log.trace(f"Ban isn't permanent; user {user} won't be unwatched by Big Brother.") + return + + bb_cog = self.bot.get_cog("Big Brother") + if not bb_cog: + log.error(f"Big Brother cog not loaded; perma-banned user {user} won't be unwatched.") + return + + log.trace(f"Big Brother cog loaded; attempting to unwatch perma-banned user {user}.") + + bb_reason = "User has been permanently banned from the server. Automatically removed." + await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) + + # endregion + # region: Base pardon functions + + async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + """Remove a user's muted role, DM them a notification, and return a log dict.""" + user = guild.get_member(user_id) + log_text = {} + + if user: + # Remove the muted role. + self.mod_log.ignore(Event.member_update, user.id) + await user.remove_roles(self._muted_role, reason=reason) + + # DM the user about the expiration. + notified = await _utils.notify_pardon( + user=user, + title="You have been unmuted", + content="You may now send messages in the server.", + icon_url=_utils.INFRACTION_ICONS["mute"][1] + ) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["DM"] = "Sent" if notified else "**Failed**" + else: + log.info(f"Failed to unmute user {user_id}: user not found") + log_text["Failure"] = "User was not found in the guild." + + return log_text + + async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + """Remove a user's ban on the Discord guild and return a log dict.""" + user = discord.Object(user_id) + log_text = {} + + self.mod_log.ignore(Event.member_unban, user_id) + + try: + await guild.unban(user, reason=reason) + except discord.NotFound: + log.info(f"Failed to unban user {user_id}: no active ban found on Discord") + log_text["Note"] = "No active ban found on Discord." + + return log_text + + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + guild = self.bot.get_guild(constants.Guild.id) + user_id = infraction["user"] + reason = f"Infraction #{infraction['id']} expired or was pardoned." + + if infraction["type"] == "mute": + return await self.pardon_mute(user_id, guild, reason) + elif infraction["type"] == "ban": + return await self.pardon_ban(user_id, guild, reason) + + # endregion + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Send a notification to the invoking context on a Union failure.""" + if isinstance(error, commands.BadUnionArgument): + if discord.User in error.converters or discord.Member in error.converters: + await ctx.send(str(error.errors[0])) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Infractions cog.""" + bot.add_cog(Infractions(bot)) diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py new file mode 100644 index 000000000..eea6ac9ea --- /dev/null +++ b/bot/exts/moderation/infraction/management.py @@ -0,0 +1,310 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext import commands +from discord.ext.commands import Context + +from bot import constants +from bot.bot import Bot +from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user +from bot.exts.moderation.modlog import ModLog +from bot.pagination import LinePaginator +from bot.utils import time +from bot.utils.checks import in_whitelist_check, with_role_check +from . import _utils +from .infractions import Infractions + +log = logging.getLogger(__name__) + + +class ModManagement(commands.Cog): + """Management of infractions.""" + + category = "Moderation" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @property + def infractions_cog(self) -> Infractions: + """Get currently loaded Infractions cog instance.""" + return self.bot.get_cog("Infractions") + + # region: Edit infraction commands + + @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) + async def infraction_group(self, ctx: Context) -> None: + """Infraction manipulation commands.""" + await ctx.send_help(ctx.command) + + @infraction_group.command(name='edit') + async def infraction_edit( + self, + ctx: Context, + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821 + duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821 + *, + reason: str = None + ) -> None: + """ + Edit the duration and/or the reason of an infraction. + + Durations are relative to the time of updating and should be appended with a unit of time. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. + + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. + """ + if duration is None and reason is None: + # Unlike UserInputError, the error handler will show a specified message for BadArgument + raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") + + # Retrieve the previous infraction for its information. + if isinstance(infraction_id, str): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get("bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + ":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") + + request_data = {} + confirm_messages = [] + log_text = "" + + if duration is not None and not old_infraction['active']: + if reason is None: + await ctx.send(":x: Cannot edit the expiration of an expired infraction.") + return + confirm_messages.append("expiry unchanged (infraction already expired)") + elif isinstance(duration, str): + request_data['expires_at'] = None + confirm_messages.append("marked as permanent") + elif duration is not None: + request_data['expires_at'] = duration.isoformat() + expiry = time.format_infraction_with_duration(request_data['expires_at']) + confirm_messages.append(f"set to expire on {expiry}") + else: + confirm_messages.append("expiry unchanged") + + if reason: + request_data['reason'] = reason + confirm_messages.append("set a new reason") + log_text += f""" + Previous reason: {old_infraction['reason']} + New reason: {reason} + """.rstrip() + else: + confirm_messages.append("reason unchanged") + + # Update the infraction + new_infraction = await self.bot.api_client.patch( + f'bot/infractions/{infraction_id}', + json=request_data, + ) + + # Re-schedule infraction if the expiration has been updated + if 'expires_at' in request_data: + # A scheduled task should only exist if the old infraction wasn't permanent + if old_infraction['expires_at']: + self.infractions_cog.scheduler.cancel(new_infraction['id']) + + # If the infraction was not marked as permanent, schedule a new expiration task + if request_data['expires_at']: + self.infractions_cog.schedule_expiration(new_infraction) + + log_text += f""" + Previous expiry: {old_infraction['expires_at'] or "Permanent"} + New expiry: {new_infraction['expires_at'] or "Permanent"} + """.rstrip() + + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") + + # Get information about the infraction's user + user_id = new_infraction['user'] + user = ctx.guild.get_member(user_id) + + if user: + user_text = f"{user.mention} (`{user.id}`)" + thumbnail = user.avatar_url_as(static_format="png") + else: + user_text = f"`{user_id}`" + thumbnail = None + + # The infraction's actor + actor_id = new_infraction['actor'] + actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`" + + await self.mod_log.send_log_message( + icon_url=constants.Icons.pencil, + colour=discord.Colour.blurple(), + title="Infraction edited", + thumbnail=thumbnail, + text=textwrap.dedent(f""" + Member: {user_text} + Actor: {actor} + Edited by: {ctx.message.author}{log_text} + """) + ) + + # endregion + # region: Search infractions + + @infraction_group.group(name="search", invoke_without_command=True) + async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None: + """Searches for infractions in the database.""" + if isinstance(query, discord.User): + await ctx.invoke(self.search_user, query) + else: + await ctx.invoke(self.search_reason, query) + + @infraction_search_group.command(name="user", aliases=("member", "id")) + async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: + """Search for infractions by member.""" + infraction_list = await self.bot.api_client.get( + 'bot/infractions', + params={'user__id': str(user.id)} + ) + embed = discord.Embed( + title=f"Infractions for {user} ({len(infraction_list)} total)", + colour=discord.Colour.orange() + ) + await self.send_infraction_list(ctx, embed, infraction_list) + + @infraction_search_group.command(name="reason", aliases=("match", "regex", "re")) + async def search_reason(self, ctx: Context, reason: str) -> None: + """Search for infractions by their reason. Use Re2 for matching.""" + infraction_list = await self.bot.api_client.get( + 'bot/infractions', + params={'search': reason} + ) + embed = discord.Embed( + title=f"Infractions matching `{reason}` ({len(infraction_list)} total)", + colour=discord.Colour.orange() + ) + await self.send_infraction_list(ctx, embed, infraction_list) + + # endregion + # region: Utility functions + + async def send_infraction_list( + self, + ctx: Context, + embed: discord.Embed, + infractions: t.Iterable[_utils.Infraction] + ) -> None: + """Send a paginated embed of infractions for the specified user.""" + if not infractions: + await ctx.send(":warning: No infractions could be found for that query.") + return + + lines = tuple( + self.infraction_to_string(infraction) + for infraction in infractions + ) + + await LinePaginator.paginate( + lines, + ctx=ctx, + embed=embed, + empty=True, + max_lines=3, + max_size=1000 + ) + + def infraction_to_string(self, infraction: _utils.Infraction) -> str: + """Convert the infraction object to a string representation.""" + actor_id = infraction["actor"] + guild = self.bot.get_guild(constants.Guild.id) + actor = guild.get_member(actor_id) + active = infraction["active"] + user_id = infraction["user"] + hidden = infraction["hidden"] + created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" + else: + remaining = "Inactive" + + if infraction["expires_at"] is None: + expires = "*Permanent*" + else: + date_from = datetime.strptime(created, time.INFRACTION_FORMAT) + expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) + + lines = textwrap.dedent(f""" + {"**===============**" if active else "==============="} + Status: {"__**Active**__" if active else "Inactive"} + User: {self.bot.get_user(user_id)} (`{user_id}`) + Type: **{infraction["type"]}** + Shadow: {hidden} + Created: {created} + Expires: {expires} + Remaining: {remaining} + Actor: {actor.mention if actor else actor_id} + ID: `{infraction["id"]}` + Reason: {infraction["reason"] or "*None*"} + {"**===============**" if active else "==============="} + """) + + return lines.strip() + + # endregion + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators inside moderator channels to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_whitelist_check( + ctx, + channels=constants.MODERATION_CHANNELS, + categories=[constants.Categories.modmail], + redirect=None, + fail_silently=True, + ) + ] + return all(checks) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Send a notification to the invoking context on a Union failure.""" + if isinstance(error, commands.BadUnionArgument): + if discord.User in error.converters: + await ctx.send(str(error.errors[0])) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the ModManagement cog.""" + bot.add_cog(ModManagement(bot)) diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py new file mode 100644 index 000000000..7dc5b4691 --- /dev/null +++ b/bot/exts/moderation/infraction/superstarify.py @@ -0,0 +1,244 @@ +import json +import logging +import random +import textwrap +import typing as t +from pathlib import Path + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, command + +from bot import constants +from bot.bot import Bot +from bot.converters import Expiry +from bot.utils.checks import with_role_check +from bot.utils.time import format_infraction +from . import _utils +from ._scheduler import InfractionScheduler + +log = logging.getLogger(__name__) +NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" + +with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file: + STAR_NAMES = json.load(stars_file) + + +class Superstarify(InfractionScheduler, Cog): + """A set of commands to moderate terrible nicknames.""" + + def __init__(self, bot: Bot): + super().__init__(bot, supported_infractions={"superstar"}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Revert nickname edits if the user has an active superstarify infraction.""" + if before.display_name == after.display_name: + return # User didn't change their nickname. Abort! + + log.trace( + f"{before} ({before.display_name}) is trying to change their nickname to " + f"{after.display_name}. Checking if the user is in superstar-prison..." + ) + + active_superstarifies = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "superstar", + "user__id": str(before.id) + } + ) + + if not active_superstarifies: + log.trace(f"{before} has no active superstar infractions.") + return + + infraction = active_superstarifies[0] + forced_nick = self.get_nick(infraction["id"], before.id) + if after.display_name == forced_nick: + return # Nick change was triggered by this event. Ignore. + + log.info( + f"{after.display_name} ({after.id}) tried to escape superstar prison. " + f"Changing the nick back to {before.display_name}." + ) + await after.edit( + nick=forced_nick, + reason=f"Superstarified member tried to escape the prison: {infraction['id']}" + ) + + notified = await _utils.notify_infraction( + user=after, + infr_type="Superstarify", + expires_at=format_infraction(infraction["expires_at"]), + reason=( + "You have tried to change your nickname on the **Python Discord** server " + f"from **{before.display_name}** to **{after.display_name}**, but as you " + "are currently in superstar-prison, you do not have permission to do so." + ), + icon_url=_utils.INFRACTION_ICONS["superstar"][0] + ) + + if not notified: + log.info("Failed to DM user about why they cannot change their nickname.") + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Reapply active superstar infractions for returning members.""" + active_superstarifies = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "superstar", + "user__id": member.id + } + ) + + if active_superstarifies: + infraction = active_superstarifies[0] + action = member.edit( + nick=self.get_nick(infraction["id"], member.id), + reason=f"Superstarified member tried to escape the prison: {infraction['id']}" + ) + + await self.reapply_infraction(infraction, action) + + @command(name="superstarify", aliases=("force_nick", "star")) + async def superstarify( + self, + ctx: Context, + member: Member, + duration: Expiry, + *, + reason: str = None, + ) -> None: + """ + Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + + An optional reason can be provided. If no reason is given, the original name will be shown + in a generated reason. + """ + if await _utils.get_active_infraction(ctx, member, "superstar"): + return + + # Post the infraction to the API + reason = reason or f"old nick: {member.display_name}" + infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) + id_ = infraction["id"] + + old_nick = member.display_name + forced_nick = self.get_nick(id_, member.id) + expiry_str = format_infraction(infraction["expires_at"]) + + # Apply the infraction and schedule the expiration task. + log.debug(f"Changing nickname of {member} to {forced_nick}.") + self.mod_log.ignore(constants.Event.member_update, member.id) + await member.edit(nick=forced_nick, reason=reason) + self.schedule_expiration(infraction) + + # Send a DM to the user to notify them of their new infraction. + await _utils.notify_infraction( + user=member, + infr_type="Superstarify", + expires_at=expiry_str, + icon_url=_utils.INFRACTION_ICONS["superstar"][0], + reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." + ) + + # Send an embed with the infraction information to the invoking context. + log.trace(f"Sending superstar #{id_} embed.") + embed = Embed( + title="Congratulations!", + colour=constants.Colours.soft_orange, + description=( + f"Your previous nickname, **{old_nick}**, " + f"was so bad that we have decided to change it. " + f"Your new nickname will be **{forced_nick}**.\n\n" + f"You will be unable to change your nickname until **{expiry_str}**.\n\n" + "If you're confused by this, please read our " + f"[official nickname policy]({NICKNAME_POLICY_URL})." + ) + ) + await ctx.send(embed=embed) + + # Log to the mod log channel. + log.trace(f"Sending apply mod log for superstar #{id_}.") + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS["superstar"][0], + colour=Colour.gold(), + title="Member achieved superstardom", + thumbnail=member.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {member.mention} (`{member.id}`) + Actor: {ctx.message.author} + Expires: {expiry_str} + Old nickname: `{old_nick}` + New nickname: `{forced_nick}` + Reason: {reason} + """), + footer=f"ID {id_}" + ) + + @command(name="unsuperstarify", aliases=("release_nick", "unstar")) + async def unsuperstarify(self, ctx: Context, member: Member) -> None: + """Remove the superstarify infraction and allow the user to change their nickname.""" + await self.pardon_infraction(ctx, "superstar", member) + + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """Pardon a superstar infraction and return a log dict.""" + if infraction["type"] != "superstar": + return + + guild = self.bot.get_guild(constants.Guild.id) + user = guild.get_member(infraction["user"]) + + # Don't bother sending a notification if the user left the guild. + if not user: + log.debug( + "User left the guild and therefore won't be notified about superstar " + f"{infraction['id']} pardon." + ) + return {} + + # DM the user about the expiration. + notified = await _utils.notify_pardon( + user=user, + title="You are no longer superstarified", + content="You may now change your nickname on the server.", + icon_url=_utils.INFRACTION_ICONS["superstar"][1] + ) + + return { + "Member": f"{user.mention}(`{user.id}`)", + "DM": "Sent" if notified else "**Failed**" + } + + @staticmethod + def get_nick(infraction_id: int, member_id: int) -> str: + """Randomly select a nickname from the Superstarify nickname list.""" + log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") + + rng = random.Random(str(infraction_id) + str(member_id)) + return rng.choice(STAR_NAMES) + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Superstarify cog.""" + bot.add_cog(Superstarify(bot)) diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py new file mode 100644 index 000000000..c86f04b9d --- /dev/null +++ b/bot/exts/moderation/modlog.py @@ -0,0 +1,837 @@ +import asyncio +import difflib +import itertools +import logging +import typing as t +from datetime import datetime +from itertools import zip_longest + +import discord +from dateutil.relativedelta import relativedelta +from deepdiff import DeepDiff +from discord import Colour +from discord.abc import GuildChannel +from discord.ext.commands import Cog, Context +from discord.utils import escape_markdown + +from bot.bot import Bot +from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +GUILD_CHANNEL = t.Union[discord.CategoryChannel, discord.TextChannel, discord.VoiceChannel] + +CHANNEL_CHANGES_UNSUPPORTED = ("permissions",) +CHANNEL_CHANGES_SUPPRESSED = ("_overwrites", "position") +ROLE_CHANGES_UNSUPPORTED = ("colour", "permissions") + +VOICE_STATE_ATTRIBUTES = { + "channel.name": "Channel", + "self_stream": "Streaming", + "self_video": "Broadcasting", +} + + +class ModLog(Cog, name="ModLog"): + """Logging for server events and staff actions.""" + + def __init__(self, bot: Bot): + self.bot = bot + self._ignored = {event: [] for event in Event} + + self._cached_deletes = [] + self._cached_edits = [] + + async def upload_log( + self, + messages: t.Iterable[discord.Message], + actor_id: int, + attachments: t.Iterable[t.List[str]] = None + ) -> str: + """Upload message logs to the database and return a URL to a page for viewing the logs.""" + if attachments is None: + attachments = [] + + response = await self.bot.api_client.post( + 'bot/deleted-messages', + json={ + 'actor': actor_id, + 'creation': datetime.utcnow().isoformat(), + 'deletedmessage_set': [ + { + 'id': message.id, + 'author': message.author.id, + 'channel_id': message.channel.id, + 'content': message.content, + 'embeds': [embed.to_dict() for embed in message.embeds], + 'attachments': attachment, + } + for message, attachment in zip_longest(messages, attachments, fillvalue=[]) + ] + } + ) + + return f"{URLs.site_logs_view}/{response['id']}" + + def ignore(self, event: Event, *items: int) -> None: + """Add event to ignored events to suppress log emission.""" + for item in items: + if item not in self._ignored[event]: + self._ignored[event].append(item) + + async def send_log_message( + self, + icon_url: t.Optional[str], + colour: t.Union[discord.Colour, int], + title: t.Optional[str], + text: str, + thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, + channel_id: int = Channels.mod_log, + ping_everyone: bool = False, + files: t.Optional[t.List[discord.File]] = None, + content: t.Optional[str] = None, + additional_embeds: t.Optional[t.List[discord.Embed]] = None, + additional_embeds_msg: t.Optional[str] = None, + timestamp_override: t.Optional[datetime] = None, + footer: t.Optional[str] = None, + ) -> Context: + """Generate log embed and send to logging channel.""" + # Truncate string directly here to avoid removing newlines + embed = discord.Embed( + description=text[:2045] + "..." if len(text) > 2048 else text + ) + + if title and icon_url: + embed.set_author(name=title, icon_url=icon_url) + + embed.colour = colour + embed.timestamp = timestamp_override or datetime.utcnow() + + if footer: + embed.set_footer(text=footer) + + if thumbnail: + embed.set_thumbnail(url=thumbnail) + + if ping_everyone: + if content: + content = f"@everyone\n{content}" + else: + content = "@everyone" + + channel = self.bot.get_channel(channel_id) + log_message = await channel.send( + content=content, + embed=embed, + files=files, + allowed_mentions=discord.AllowedMentions(everyone=True) + ) + + if additional_embeds: + if additional_embeds_msg: + await channel.send(additional_embeds_msg) + for additional_embed in additional_embeds: + await channel.send(embed=additional_embed) + + return await self.bot.get_context(log_message) # Optionally return for use with antispam + + @Cog.listener() + async def on_guild_channel_create(self, channel: GUILD_CHANNEL) -> None: + """Log channel create event to mod log.""" + if channel.guild.id != GuildConstant.id: + return + + if isinstance(channel, discord.CategoryChannel): + title = "Category created" + message = f"{channel.name} (`{channel.id}`)" + elif isinstance(channel, discord.VoiceChannel): + title = "Voice channel created" + + if channel.category: + message = f"{channel.category}/{channel.name} (`{channel.id}`)" + else: + message = f"{channel.name} (`{channel.id}`)" + else: + title = "Text channel created" + + if channel.category: + message = f"{channel.category}/{channel.name} (`{channel.id}`)" + else: + message = f"{channel.name} (`{channel.id}`)" + + await self.send_log_message(Icons.hash_green, Colours.soft_green, title, message) + + @Cog.listener() + async def on_guild_channel_delete(self, channel: GUILD_CHANNEL) -> None: + """Log channel delete event to mod log.""" + if channel.guild.id != GuildConstant.id: + return + + if isinstance(channel, discord.CategoryChannel): + title = "Category deleted" + elif isinstance(channel, discord.VoiceChannel): + title = "Voice channel deleted" + else: + title = "Text channel deleted" + + if channel.category and not isinstance(channel, discord.CategoryChannel): + message = f"{channel.category}/{channel.name} (`{channel.id}`)" + else: + message = f"{channel.name} (`{channel.id}`)" + + await self.send_log_message( + Icons.hash_red, Colours.soft_red, + title, message + ) + + @Cog.listener() + async def on_guild_channel_update(self, before: GUILD_CHANNEL, after: GuildChannel) -> None: + """Log channel update event to mod log.""" + if before.guild.id != GuildConstant.id: + return + + if before.id in self._ignored[Event.guild_channel_update]: + self._ignored[Event.guild_channel_update].remove(before.id) + return + + # Two channel updates are sent for a single edit: 1 for topic and 1 for category change. + # TODO: remove once support is added for ignoring multiple occurrences for the same channel. + help_categories = (Categories.help_available, Categories.help_dormant, Categories.help_in_use) + if after.category and after.category.id in help_categories: + return + + diff = DeepDiff(before, after) + changes = [] + done = [] + + diff_values = diff.get("values_changed", {}) + diff_values.update(diff.get("type_changes", {})) + + for key, value in diff_values.items(): + if not key: # Not sure why, but it happens + continue + + key = key[5:] # Remove "root." prefix + + if "[" in key: + key = key.split("[", 1)[0] + + if "." in key: + key = key.split(".", 1)[0] + + if key in done or key in CHANNEL_CHANGES_SUPPRESSED: + continue + + if key in CHANNEL_CHANGES_UNSUPPORTED: + changes.append(f"**{key.title()}** updated") + else: + new = value["new_value"] + old = value["old_value"] + + # Discord does not treat consecutive backticks ("``") as an empty inline code block, so the markdown + # formatting is broken when `new` and/or `old` are empty values. "None" is used for these cases so + # formatting is preserved. + changes.append(f"**{key.title()}:** `{old or 'None'}` **→** `{new or 'None'}`") + + done.append(key) + + if not changes: + return + + message = "" + + for item in sorted(changes): + message += f"{Emojis.bullet} {item}\n" + + if after.category: + message = f"**{after.category}/#{after.name} (`{after.id}`)**\n{message}" + else: + message = f"**#{after.name}** (`{after.id}`)\n{message}" + + await self.send_log_message( + Icons.hash_blurple, Colour.blurple(), + "Channel updated", message + ) + + @Cog.listener() + async def on_guild_role_create(self, role: discord.Role) -> None: + """Log role create event to mod log.""" + if role.guild.id != GuildConstant.id: + return + + await self.send_log_message( + Icons.crown_green, Colours.soft_green, + "Role created", f"`{role.id}`" + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: discord.Role) -> None: + """Log role delete event to mod log.""" + if role.guild.id != GuildConstant.id: + return + + await self.send_log_message( + Icons.crown_red, Colours.soft_red, + "Role removed", f"{role.name} (`{role.id}`)" + ) + + @Cog.listener() + async def on_guild_role_update(self, before: discord.Role, after: discord.Role) -> None: + """Log role update event to mod log.""" + if before.guild.id != GuildConstant.id: + return + + diff = DeepDiff(before, after) + changes = [] + done = [] + + diff_values = diff.get("values_changed", {}) + diff_values.update(diff.get("type_changes", {})) + + for key, value in diff_values.items(): + if not key: # Not sure why, but it happens + continue + + key = key[5:] # Remove "root." prefix + + if "[" in key: + key = key.split("[", 1)[0] + + if "." in key: + key = key.split(".", 1)[0] + + if key in done or key == "color": + continue + + if key in ROLE_CHANGES_UNSUPPORTED: + changes.append(f"**{key.title()}** updated") + else: + new = value["new_value"] + old = value["old_value"] + + changes.append(f"**{key.title()}:** `{old}` **→** `{new}`") + + done.append(key) + + if not changes: + return + + message = "" + + for item in sorted(changes): + message += f"{Emojis.bullet} {item}\n" + + message = f"**{after.name}** (`{after.id}`)\n{message}" + + await self.send_log_message( + Icons.crown_blurple, Colour.blurple(), + "Role updated", message + ) + + @Cog.listener() + async def on_guild_update(self, before: discord.Guild, after: discord.Guild) -> None: + """Log guild update event to mod log.""" + if before.id != GuildConstant.id: + return + + diff = DeepDiff(before, after) + changes = [] + done = [] + + diff_values = diff.get("values_changed", {}) + diff_values.update(diff.get("type_changes", {})) + + for key, value in diff_values.items(): + if not key: # Not sure why, but it happens + continue + + key = key[5:] # Remove "root." prefix + + if "[" in key: + key = key.split("[", 1)[0] + + if "." in key: + key = key.split(".", 1)[0] + + if key in done: + continue + + new = value["new_value"] + old = value["old_value"] + + changes.append(f"**{key.title()}:** `{old}` **→** `{new}`") + + done.append(key) + + if not changes: + return + + message = "" + + for item in sorted(changes): + message += f"{Emojis.bullet} {item}\n" + + message = f"**{after.name}** (`{after.id}`)\n{message}" + + await self.send_log_message( + Icons.guild_update, Colour.blurple(), + "Guild updated", message, + thumbnail=after.icon_url_as(format="png") + ) + + @Cog.listener() + async def on_member_ban(self, guild: discord.Guild, member: discord.Member) -> None: + """Log ban event to user log.""" + if guild.id != GuildConstant.id: + return + + if member.id in self._ignored[Event.member_ban]: + self._ignored[Event.member_ban].remove(member.id) + return + + await self.send_log_message( + Icons.user_ban, Colours.soft_red, + "User banned", f"{member} (`{member.id}`)", + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.user_log + ) + + @Cog.listener() + async def on_member_join(self, member: discord.Member) -> None: + """Log member join event to user log.""" + if member.guild.id != GuildConstant.id: + return + + member_str = escape_markdown(str(member)) + message = f"{member_str} (`{member.id}`)" + now = datetime.utcnow() + difference = abs(relativedelta(now, member.created_at)) + + message += "\n\n**Account age:** " + humanize_delta(difference) + + if difference.days < 1 and difference.months < 1 and difference.years < 1: # New user account! + message = f"{Emojis.new} {message}" + + await self.send_log_message( + Icons.sign_in, Colours.soft_green, + "User joined", message, + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.user_log + ) + + @Cog.listener() + async def on_member_remove(self, member: discord.Member) -> None: + """Log member leave event to user log.""" + if member.guild.id != GuildConstant.id: + return + + if member.id in self._ignored[Event.member_remove]: + self._ignored[Event.member_remove].remove(member.id) + return + + member_str = escape_markdown(str(member)) + await self.send_log_message( + Icons.sign_out, Colours.soft_red, + "User left", f"{member_str} (`{member.id}`)", + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.user_log + ) + + @Cog.listener() + async def on_member_unban(self, guild: discord.Guild, member: discord.User) -> None: + """Log member unban event to mod log.""" + if guild.id != GuildConstant.id: + return + + if member.id in self._ignored[Event.member_unban]: + self._ignored[Event.member_unban].remove(member.id) + return + + member_str = escape_markdown(str(member)) + await self.send_log_message( + Icons.user_unban, Colour.blurple(), + "User unbanned", f"{member_str} (`{member.id}`)", + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.mod_log + ) + + @staticmethod + def get_role_diff(before: t.List[discord.Role], after: t.List[discord.Role]) -> t.List[str]: + """Return a list of strings describing the roles added and removed.""" + changes = [] + before_roles = set(before) + after_roles = set(after) + + for role in (before_roles - after_roles): + changes.append(f"**Role removed:** {role.name} (`{role.id}`)") + + for role in (after_roles - before_roles): + changes.append(f"**Role added:** {role.name} (`{role.id}`)") + + return changes + + @Cog.listener() + async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: + """Log member update event to user log.""" + if before.guild.id != GuildConstant.id: + return + + if before.id in self._ignored[Event.member_update]: + self._ignored[Event.member_update].remove(before.id) + return + + changes = self.get_role_diff(before.roles, after.roles) + + # The regex is a simple way to exclude all sequence and mapping types. + diff = DeepDiff(before, after, exclude_regex_paths=r".*\[.*") + + # A type change seems to always take precedent over a value change. Furthermore, it will + # include the value change along with the type change anyway. Therefore, it's OK to + # "overwrite" values_changed; in practice there will never even be anything to overwrite. + diff_values = {**diff.get("values_changed", {}), **diff.get("type_changes", {})} + + for attr, value in diff_values.items(): + if not attr: # Not sure why, but it happens. + continue + + attr = attr[5:] # Remove "root." prefix. + attr = attr.replace("_", " ").replace(".", " ").capitalize() + + new = value.get("new_value") + old = value.get("old_value") + + changes.append(f"**{attr}:** `{old}` **→** `{new}`") + + if not changes: + return + + message = "" + + for item in sorted(changes): + message += f"{Emojis.bullet} {item}\n" + + member_str = escape_markdown(str(after)) + message = f"**{member_str}** (`{after.id}`)\n{message}" + + await self.send_log_message( + icon_url=Icons.user_update, + colour=Colour.blurple(), + title="Member updated", + text=message, + thumbnail=after.avatar_url_as(static_format="png"), + channel_id=Channels.user_log + ) + + @Cog.listener() + async def on_message_delete(self, message: discord.Message) -> None: + """Log message delete event to message change log.""" + channel = message.channel + author = message.author + + # Ignore DMs. + if not message.guild: + return + + if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist: + return + + self._cached_deletes.append(message.id) + + if message.id in self._ignored[Event.message_delete]: + self._ignored[Event.message_delete].remove(message.id) + return + + if author.bot: + return + + author_str = escape_markdown(str(author)) + if channel.category: + response = ( + f"**Author:** {author_str} (`{author.id}`)\n" + f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + ) + else: + response = ( + f"**Author:** {author_str} (`{author.id}`)\n" + f"**Channel:** #{channel.name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + ) + + if message.attachments: + # Prepend the message metadata with the number of attachments + response = f"**Attachments:** {len(message.attachments)}\n" + response + + # Shorten the message content if necessary + content = message.clean_content + remaining_chars = 2040 - len(response) + + if len(content) > remaining_chars: + botlog_url = await self.upload_log(messages=[message], actor_id=message.author.id) + ending = f"\n\nMessage truncated, [full message here]({botlog_url})." + truncation_point = remaining_chars - len(ending) + content = f"{content[:truncation_point]}...{ending}" + + response += f"{content}" + + await self.send_log_message( + Icons.message_delete, Colours.soft_red, + "Message deleted", + response, + channel_id=Channels.message_log + ) + + @Cog.listener() + async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: + """Log raw message delete event to message change log.""" + if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.modlog_blacklist: + return + + await asyncio.sleep(1) # Wait here in case the normal event was fired + + if event.message_id in self._cached_deletes: + # It was in the cache and the normal event was fired, so we can just ignore it + self._cached_deletes.remove(event.message_id) + return + + if event.message_id in self._ignored[Event.message_delete]: + self._ignored[Event.message_delete].remove(event.message_id) + return + + channel = self.bot.get_channel(event.channel_id) + + if channel.category: + response = ( + f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" + f"**Message ID:** `{event.message_id}`\n" + "\n" + "This message was not cached, so the message content cannot be displayed." + ) + else: + response = ( + f"**Channel:** #{channel.name} (`{channel.id}`)\n" + f"**Message ID:** `{event.message_id}`\n" + "\n" + "This message was not cached, so the message content cannot be displayed." + ) + + await self.send_log_message( + Icons.message_delete, Colours.soft_red, + "Message deleted", + response, + channel_id=Channels.message_log + ) + + @Cog.listener() + async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: + """Log message edit event to message change log.""" + if ( + not msg_before.guild + or msg_before.guild.id != GuildConstant.id + or msg_before.channel.id in GuildConstant.modlog_blacklist + or msg_before.author.bot + ): + return + + self._cached_edits.append(msg_before.id) + + if msg_before.content == msg_after.content: + return + + author = msg_before.author + author_str = escape_markdown(str(author)) + + channel = msg_before.channel + channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" + + # Getting the difference per words and group them by type - add, remove, same + # Note that this is intended grouping without sorting + diff = difflib.ndiff(msg_before.clean_content.split(), msg_after.clean_content.split()) + diff_groups = tuple( + (diff_type, tuple(s[2:] for s in diff_words)) + for diff_type, diff_words in itertools.groupby(diff, key=lambda s: s[0]) + ) + + content_before: t.List[str] = [] + content_after: t.List[str] = [] + + for index, (diff_type, words) in enumerate(diff_groups): + sub = ' '.join(words) + if diff_type == '-': + content_before.append(f"[{sub}](http://o.hi)") + elif diff_type == '+': + content_after.append(f"[{sub}](http://o.hi)") + elif diff_type == ' ': + if len(words) > 2: + sub = ( + f"{words[0] if index > 0 else ''}" + " ... " + f"{words[-1] if index < len(diff_groups) - 1 else ''}" + ) + content_before.append(sub) + content_after.append(sub) + + response = ( + f"**Author:** {author_str} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{msg_before.id}`\n" + "\n" + f"**Before**:\n{' '.join(content_before)}\n" + f"**After**:\n{' '.join(content_after)}\n" + "\n" + f"[Jump to message]({msg_after.jump_url})" + ) + + if msg_before.edited_at: + # Message was previously edited, to assist with self-bot detection, use the edited_at + # datetime as the baseline and create a human-readable delta between this edit event + # and the last time the message was edited + timestamp = msg_before.edited_at + delta = humanize_delta(relativedelta(msg_after.edited_at, msg_before.edited_at)) + footer = f"Last edited {delta} ago" + else: + # Message was not previously edited, use the created_at datetime as the baseline, no + # delta calculation needed + timestamp = msg_before.created_at + footer = None + + await self.send_log_message( + Icons.message_edit, Colour.blurple(), "Message edited", response, + channel_id=Channels.message_log, timestamp_override=timestamp, footer=footer + ) + + @Cog.listener() + async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: + """Log raw message edit event to message change log.""" + try: + channel = self.bot.get_channel(int(event.data["channel_id"])) + message = await channel.fetch_message(event.message_id) + except discord.NotFound: # Was deleted before we got the event + return + + if ( + not message.guild + or message.guild.id != GuildConstant.id + or message.channel.id in GuildConstant.modlog_blacklist + or message.author.bot + ): + return + + await asyncio.sleep(1) # Wait here in case the normal event was fired + + if event.message_id in self._cached_edits: + # It was in the cache and the normal event was fired, so we can just ignore it + self._cached_edits.remove(event.message_id) + return + + author = message.author + channel = message.channel + channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" + + before_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + "This message was not cached, so the message content cannot be displayed." + ) + + after_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + f"{message.clean_content}" + ) + + await self.send_log_message( + Icons.message_edit, Colour.blurple(), "Message edited (Before)", + before_response, channel_id=Channels.message_log + ) + + await self.send_log_message( + Icons.message_edit, Colour.blurple(), "Message edited (After)", + after_response, channel_id=Channels.message_log + ) + + @Cog.listener() + async def on_voice_state_update( + self, + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState + ) -> None: + """Log member voice state changes to the voice log channel.""" + if ( + member.guild.id != GuildConstant.id + or (before.channel and before.channel.id in GuildConstant.modlog_blacklist) + ): + return + + if member.id in self._ignored[Event.voice_state_update]: + self._ignored[Event.voice_state_update].remove(member.id) + return + + # Exclude all channel attributes except the name. + diff = DeepDiff( + before, + after, + exclude_paths=("root.session_id", "root.afk"), + exclude_regex_paths=r"root\.channel\.(?!name)", + ) + + # A type change seems to always take precedent over a value change. Furthermore, it will + # include the value change along with the type change anyway. Therefore, it's OK to + # "overwrite" values_changed; in practice there will never even be anything to overwrite. + diff_values = {**diff.get("values_changed", {}), **diff.get("type_changes", {})} + + icon = Icons.voice_state_blue + colour = Colour.blurple() + changes = [] + + for attr, values in diff_values.items(): + if not attr: # Not sure why, but it happens. + continue + + old = values["old_value"] + new = values["new_value"] + + attr = attr[5:] # Remove "root." prefix. + attr = VOICE_STATE_ATTRIBUTES.get(attr, attr.replace("_", " ").capitalize()) + + changes.append(f"**{attr}:** `{old}` **→** `{new}`") + + # Set the embed icon and colour depending on which attribute changed. + if any(name in attr for name in ("Channel", "deaf", "mute")): + if new is None or new is True: + # Left a channel or was muted/deafened. + icon = Icons.voice_state_red + colour = Colours.soft_red + elif old is None or old is True: + # Joined a channel or was unmuted/undeafened. + icon = Icons.voice_state_green + colour = Colours.soft_green + + if not changes: + return + + member_str = escape_markdown(str(member)) + message = "\n".join(f"{Emojis.bullet} {item}" for item in sorted(changes)) + message = f"**{member_str}** (`{member.id}`)\n{message}" + + await self.send_log_message( + icon_url=icon, + colour=colour, + title="Voice state updated", + text=message, + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.voice_log + ) + + +def setup(bot: Bot) -> None: + """Load the ModLog cog.""" + bot.add_cog(ModLog(bot)) diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py new file mode 100644 index 000000000..4af87c724 --- /dev/null +++ b/bot/exts/moderation/silence.py @@ -0,0 +1,170 @@ +import asyncio +import logging +from contextlib import suppress +from typing import Optional + +from discord import TextChannel +from discord.ext import commands, tasks +from discord.ext.commands import Context + +from bot.bot import Bot +from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles +from bot.converters import HushDurationConverter +from bot.utils.checks import with_role_check +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + + +class SilenceNotifier(tasks.Loop): + """Loop notifier for posting notices to `alert_channel` containing added channels.""" + + def __init__(self, alert_channel: TextChannel): + super().__init__(self._notifier, seconds=1, minutes=0, hours=0, count=None, reconnect=True, loop=None) + self._silenced_channels = {} + self._alert_channel = alert_channel + + def add_channel(self, channel: TextChannel) -> None: + """Add channel to `_silenced_channels` and start loop if not launched.""" + if not self._silenced_channels: + self.start() + log.info("Starting notifier loop.") + self._silenced_channels[channel] = self._current_loop + + def remove_channel(self, channel: TextChannel) -> None: + """Remove channel from `_silenced_channels` and stop loop if no channels remain.""" + with suppress(KeyError): + del self._silenced_channels[channel] + if not self._silenced_channels: + self.stop() + log.info("Stopping notifier loop.") + + async def _notifier(self) -> None: + """Post notice of `_silenced_channels` with their silenced duration to `_alert_channel` periodically.""" + # Wait for 15 minutes between notices with pause at start of loop. + if self._current_loop and not self._current_loop/60 % 15: + log.debug( + f"Sending notice with channels: " + f"{', '.join(f'#{channel} ({channel.id})' for channel in self._silenced_channels)}." + ) + channels_text = ', '.join( + f"{channel.mention} for {(self._current_loop-start)//60} min" + for channel, start in self._silenced_channels.items() + ) + await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") + + +class Silence(commands.Cog): + """Commands for stopping channel messages for `verified` role in a channel.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + self.muted_channels = set() + + self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars()) + self._get_instance_vars_event = asyncio.Event() + + async def _get_instance_vars(self) -> None: + """Get instance variables after they're available to get from the guild.""" + await self.bot.wait_until_guild_available() + guild = self.bot.get_guild(Guild.id) + self._verified_role = guild.get_role(Roles.verified) + self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) + self._mod_log_channel = self.bot.get_channel(Channels.mod_log) + self.notifier = SilenceNotifier(self._mod_log_channel) + self._get_instance_vars_event.set() + + @commands.command(aliases=("hush",)) + async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None: + """ + Silence the current channel for `duration` minutes or `forever`. + + Duration is capped at 15 minutes, passing forever makes the silence indefinite. + Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start. + """ + await self._get_instance_vars_event.wait() + log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.") + if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration): + await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.") + return + if duration is None: + await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.") + return + + await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") + + self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) + + @commands.command(aliases=("unhush",)) + async def unsilence(self, ctx: Context) -> None: + """ + Unsilence the current channel. + + If the channel was silenced indefinitely, notifications for the channel will stop. + """ + await self._get_instance_vars_event.wait() + log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") + if not await self._unsilence(ctx.channel): + await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.") + else: + await ctx.send(f"{Emojis.check_mark} unsilenced current channel.") + + async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: + """ + Silence `channel` for `self._verified_role`. + + If `persistent` is `True` add `channel` to notifier. + `duration` is only used for logging; if None is passed `persistent` should be True to not log None. + Return `True` if channel permissions were changed, `False` otherwise. + """ + current_overwrite = channel.overwrites_for(self._verified_role) + if current_overwrite.send_messages is False: + log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.") + return False + await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False)) + self.muted_channels.add(channel) + if persistent: + log.info(f"Silenced #{channel} ({channel.id}) indefinitely.") + self.notifier.add_channel(channel) + return True + + log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).") + return True + + async def _unsilence(self, channel: TextChannel) -> bool: + """ + Unsilence `channel`. + + Check if `channel` is silenced through a `PermissionOverwrite`, + if it is unsilence it and remove it from the notifier. + Return `True` if channel permissions were changed, `False` otherwise. + """ + current_overwrite = channel.overwrites_for(self._verified_role) + if current_overwrite.send_messages is False: + await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) + log.info(f"Unsilenced channel #{channel} ({channel.id}).") + self.scheduler.cancel(channel.id) + self.notifier.remove_channel(channel) + self.muted_channels.discard(channel) + return True + log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") + return False + + def cog_unload(self) -> None: + """Send alert with silenced channels and cancel scheduled tasks on unload.""" + self.scheduler.cancel_all() + if self.muted_channels: + channels_string = ''.join(channel.mention for channel in self.muted_channels) + message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}" + asyncio.create_task(self._mod_alerts_channel.send(message)) + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Silence cog.""" + bot.add_cog(Silence(bot)) diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py new file mode 100644 index 000000000..1d055afac --- /dev/null +++ b/bot/exts/moderation/slowmode.py @@ -0,0 +1,97 @@ +import logging +from datetime import datetime +from typing import Optional + +from dateutil.relativedelta import relativedelta +from discord import TextChannel +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Emojis, MODERATION_ROLES +from bot.converters import DurationDelta +from bot.decorators import with_role_check +from bot.utils import time + +log = logging.getLogger(__name__) + +SLOWMODE_MAX_DELAY = 21600 # seconds + + +class Slowmode(Cog): + """Commands for getting and setting slowmode delays of text channels.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @group(name='slowmode', aliases=['sm'], invoke_without_command=True) + async def slowmode_group(self, ctx: Context) -> None: + """Get or set the slowmode delay for the text channel this was invoked in or a given text channel.""" + await ctx.send_help(ctx.command) + + @slowmode_group.command(name='get', aliases=['g']) + async def get_slowmode(self, ctx: Context, channel: Optional[TextChannel]) -> None: + """Get the slowmode delay for a text channel.""" + # Use the channel this command was invoked in if one was not given + if channel is None: + channel = ctx.channel + + delay = relativedelta(seconds=channel.slowmode_delay) + humanized_delay = time.humanize_delta(delay) + + await ctx.send(f'The slowmode delay for {channel.mention} is {humanized_delay}.') + + @slowmode_group.command(name='set', aliases=['s']) + async def set_slowmode(self, ctx: Context, channel: Optional[TextChannel], delay: DurationDelta) -> None: + """Set the slowmode delay for a text channel.""" + # Use the channel this command was invoked in if one was not given + if channel is None: + channel = ctx.channel + + # Convert `dateutil.relativedelta.relativedelta` to `datetime.timedelta` + # Must do this to get the delta in a particular unit of time + utcnow = datetime.utcnow() + slowmode_delay = (utcnow + delay - utcnow).total_seconds() + + humanized_delay = time.humanize_delta(delay) + + # Ensure the delay is within discord's limits + if slowmode_delay <= SLOWMODE_MAX_DELAY: + log.info(f'{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.') + + await channel.edit(slowmode_delay=slowmode_delay) + await ctx.send( + f'{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}.' + ) + + else: + log.info( + f'{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, ' + 'which is not between 0 and 6 hours.' + ) + + await ctx.send( + f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.' + ) + + @slowmode_group.command(name='reset', aliases=['r']) + async def reset_slowmode(self, ctx: Context, channel: Optional[TextChannel]) -> None: + """Reset the slowmode delay for a text channel to 0 seconds.""" + # Use the channel this command was invoked in if one was not given + if channel is None: + channel = ctx.channel + + log.info(f'{ctx.author} reset the slowmode delay for #{channel} to 0 seconds.') + + await channel.edit(slowmode_delay=0) + await ctx.send( + f'{Emojis.check_mark} The slowmode delay for {channel.mention} has been reset to 0 seconds.' + ) + + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Slowmode cog.""" + bot.add_cog(Slowmode(bot)) diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py new file mode 100644 index 000000000..0db3e800d --- /dev/null +++ b/bot/exts/moderation/verification.py @@ -0,0 +1,191 @@ +import logging +from contextlib import suppress + +from discord import Colour, Forbidden, Message, NotFound, Object +from discord.ext.commands import Cog, Context, command + +from bot import constants +from bot.bot import Bot +from bot.decorators import in_whitelist, without_role +from bot.exts.moderation.modlog import ModLog +from bot.utils.checks import InWhitelistCheckFailure, without_role_check + +log = logging.getLogger(__name__) + +WELCOME_MESSAGE = f""" +Hello! Welcome to the server, and thanks for verifying yourself! + +For your records, these are the documents you accepted: + +`1)` Our rules, here: +`2)` Our privacy policy, here: - you can find information on how to have \ +your information removed here as well. + +Feel free to review them at any point! + +Additionally, if you'd like to receive notifications for the announcements \ +we post in <#{constants.Channels.announcements}> +from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ +to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. + +If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ +<#{constants.Channels.bot_commands}>. +""" + +BOT_MESSAGE_DELETE_DELAY = 10 + + +class Verification(Cog): + """User verification and role self-management.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Check new message event for messages to the checkpoint channel & process.""" + if message.channel.id != constants.Channels.verification: + return # Only listen for #checkpoint messages + + if message.author.bot: + # They're a bot, delete their message after the delay. + await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) + return + + # if a user mentions a role or guild member + # alert the mods in mod-alerts channel + if message.mentions or message.role_mentions: + log.debug( + f"{message.author} mentioned one or more users " + f"and/or roles in {message.channel.name}" + ) + + embed_text = ( + f"{message.author.mention} sent a message in " + f"{message.channel.mention} that contained user and/or role mentions." + f"\n\n**Original message:**\n>>> {message.content}" + ) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=constants.Icons.filtering, + colour=Colour(constants.Colours.soft_red), + title=f"User/Role mentioned in {message.channel.name}", + text=embed_text, + thumbnail=message.author.avatar_url_as(static_format="png"), + channel_id=constants.Channels.mod_alerts, + ) + + ctx: Context = await self.bot.get_context(message) + if ctx.command is not None and ctx.command.name == "accept": + return + + if any(r.id == constants.Roles.verified for r in ctx.author.roles): + log.info( + f"{ctx.author} posted '{ctx.message.content}' " + "in the verification channel, but is already verified." + ) + return + + log.debug( + f"{ctx.author} posted '{ctx.message.content}' in the verification " + "channel. We are providing instructions how to verify." + ) + await ctx.send( + f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " + f"and gain access to the rest of the server.", + delete_after=20 + ) + + log.trace(f"Deleting the message posted by {ctx.author}") + with suppress(NotFound): + await ctx.message.delete() + + @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) + @without_role(constants.Roles.verified) + @in_whitelist(channels=(constants.Channels.verification,)) + async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Accept our rules and gain access to the rest of the server.""" + log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") + await ctx.author.add_roles(Object(constants.Roles.verified), reason="Accepted the rules") + try: + await ctx.author.send(WELCOME_MESSAGE) + except Forbidden: + log.info(f"Sending welcome message failed for {ctx.author}.") + finally: + log.trace(f"Deleting accept message by {ctx.author}.") + with suppress(NotFound): + self.mod_log.ignore(constants.Event.message_delete, ctx.message.id) + await ctx.message.delete() + + @command(name='subscribe') + @in_whitelist(channels=(constants.Channels.bot_commands,)) + async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Subscribe to announcement notifications by assigning yourself the role.""" + has_role = False + + for role in ctx.author.roles: + if role.id == constants.Roles.announcements: + has_role = True + break + + if has_role: + await ctx.send(f"{ctx.author.mention} You're already subscribed!") + return + + log.debug(f"{ctx.author} called !subscribe. Assigning the 'Announcements' role.") + await ctx.author.add_roles(Object(constants.Roles.announcements), reason="Subscribed to announcements") + + log.trace(f"Deleting the message posted by {ctx.author}.") + + await ctx.send( + f"{ctx.author.mention} Subscribed to <#{constants.Channels.announcements}> notifications.", + ) + + @command(name='unsubscribe') + @in_whitelist(channels=(constants.Channels.bot_commands,)) + async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Unsubscribe from announcement notifications by removing the role from yourself.""" + has_role = False + + for role in ctx.author.roles: + if role.id == constants.Roles.announcements: + has_role = True + break + + if not has_role: + await ctx.send(f"{ctx.author.mention} You're already unsubscribed!") + return + + log.debug(f"{ctx.author} called !unsubscribe. Removing the 'Announcements' role.") + await ctx.author.remove_roles(Object(constants.Roles.announcements), reason="Unsubscribed from announcements") + + log.trace(f"Deleting the message posted by {ctx.author}.") + + await ctx.send( + f"{ctx.author.mention} Unsubscribed from <#{constants.Channels.announcements}> notifications." + ) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Check for & ignore any InWhitelistCheckFailure.""" + if isinstance(error, InWhitelistCheckFailure): + error.handled = True + + @staticmethod + def bot_check(ctx: Context) -> bool: + """Block any command within the verification channel that is not !accept.""" + if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): + return ctx.command.name == "accept" + else: + return True + + +def setup(bot: Bot) -> None: + """Load the Verification cog.""" + bot.add_cog(Verification(bot)) diff --git a/bot/exts/moderation/watchchannels/__init__.py b/bot/exts/moderation/watchchannels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py new file mode 100644 index 000000000..013d3ee03 --- /dev/null +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -0,0 +1,348 @@ +import asyncio +import logging +import re +import textwrap +from abc import abstractmethod +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Optional + +import dateutil.parser +import discord +from discord import Color, DMChannel, Embed, HTTPException, Message, errors +from discord.ext.commands import Cog, Context + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons +from bot.exts.moderation.modlog import ModLog +from bot.pagination import LinePaginator +from bot.utils import CogABCMeta, messages +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + +URL_RE = re.compile(r"(https?://[^\s]+)") + + +@dataclass +class MessageHistory: + """Represents a watch channel's message history.""" + + last_author: Optional[int] = None + last_channel: Optional[int] = None + message_count: int = 0 + + +class WatchChannel(metaclass=CogABCMeta): + """ABC with functionality for relaying users' messages to a certain channel.""" + + @abstractmethod + def __init__( + self, + bot: Bot, + destination: int, + webhook_id: int, + api_endpoint: str, + api_default_params: dict, + logger: logging.Logger + ) -> None: + self.bot = bot + + self.destination = destination # E.g., Channels.big_brother_logs + self.webhook_id = webhook_id # E.g., Webhooks.big_brother + self.api_endpoint = api_endpoint # E.g., 'bot/infractions' + self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} + self.log = logger # Logger of the child cog for a correct name in the logs + + self._consume_task = None + self.watched_users = defaultdict(dict) + self.message_queue = defaultdict(lambda: defaultdict(deque)) + self.consumption_queue = {} + self.retries = 5 + self.retry_delay = 10 + self.channel = None + self.webhook = None + self.message_history = MessageHistory() + + self._start = self.bot.loop.create_task(self.start_watchchannel()) + + @property + def modlog(self) -> ModLog: + """Provides access to the ModLog cog for alert purposes.""" + return self.bot.get_cog("ModLog") + + @property + def consuming_messages(self) -> bool: + """Checks if a consumption task is currently running.""" + if self._consume_task is None: + return False + + if self._consume_task.done(): + exc = self._consume_task.exception() + if exc: + self.log.exception( + "The message queue consume task has failed with:", + exc_info=exc + ) + return False + + return True + + async def start_watchchannel(self) -> None: + """Starts the watch channel by getting the channel, webhook, and user cache ready.""" + await self.bot.wait_until_guild_available() + + try: + self.channel = await self.bot.fetch_channel(self.destination) + except HTTPException: + self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + if self.channel is None or self.webhook is None: + self.log.error("Failed to start the watch channel; unloading the cog.") + + message = textwrap.dedent( + f""" + An error occurred while loading the text channel or webhook. + + TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} + Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} + + The Cog has been unloaded. + """ + ) + + await self.modlog.send_log_message( + title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", + text=message, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + if not await self.fetch_user_cache(): + await self.modlog.send_log_message( + title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", + text="Could not retrieve the list of watched users from the API and messages will not be relayed.", + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + async def fetch_user_cache(self) -> bool: + """ + Fetches watched users from the API and updates the watched user cache accordingly. + + This function returns `True` if the update succeeded. + """ + try: + data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) + except ResponseCodeError as err: + self.log.exception("Failed to fetch the watched users from the API", exc_info=err) + return False + + self.watched_users = defaultdict(dict) + + for entry in data: + user_id = entry.pop('user') + self.watched_users[user_id] = entry + + return True + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Queues up messages sent by watched users.""" + if msg.author.id in self.watched_users: + if not self.consuming_messages: + self._consume_task = self.bot.loop.create_task(self.consume_messages()) + + self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") + self.message_queue[msg.author.id][msg.channel.id].append(msg) + + async def consume_messages(self, delay_consumption: bool = True) -> None: + """Consumes the message queues to log watched users' messages.""" + if delay_consumption: + self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") + await asyncio.sleep(BigBrotherConfig.log_delay) + + self.log.trace("Started consuming the message queue") + + # If the previous consumption Task failed, first consume the existing comsumption_queue + if not self.consumption_queue: + self.consumption_queue = self.message_queue.copy() + self.message_queue.clear() + + for user_channel_queues in self.consumption_queue.values(): + for channel_queue in user_channel_queues.values(): + while channel_queue: + msg = channel_queue.popleft() + + self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") + await self.relay_message(msg) + + self.consumption_queue.clear() + + if self.message_queue: + self.log.trace("Channel queue not empty: Continuing consuming queues") + self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) + else: + self.log.trace("Done consuming messages.") + + async def webhook_send( + self, + content: Optional[str] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + embed: Optional[Embed] = None, + ) -> None: + """Sends a message to the webhook with the specified kwargs.""" + username = messages.sub_clyde(username) + try: + await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send a message to the webhook", + exc_info=exc + ) + + async def relay_message(self, msg: Message) -> None: + """Relays the message to the relevant watch channel.""" + limit = BigBrotherConfig.header_message_limit + + if ( + msg.author.id != self.message_history.last_author + or msg.channel.id != self.message_history.last_channel + or self.message_history.message_count >= limit + ): + self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) + + await self.send_header(msg) + + cleaned_content = msg.clean_content + + if cleaned_content: + # Put all non-media URLs in a code block to prevent embeds + media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} + for url in URL_RE.findall(cleaned_content): + if url not in media_urls: + cleaned_content = cleaned_content.replace(url, f"`{url}`") + await self.webhook_send( + cleaned_content, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + + if msg.attachments: + try: + await messages.send_attachments(msg, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.webhook_send( + embed=e, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send an attachment to the webhook", + exc_info=exc + ) + + self.message_history.message_count += 1 + + async def send_header(self, msg: Message) -> None: + """Sends a header embed with information about the relayed messages to the watch channel.""" + user_id = msg.author.id + + guild = self.bot.get_guild(GuildConfig.id) + actor = guild.get_member(self.watched_users[user_id]['actor']) + actor = actor.display_name if actor else self.watched_users[user_id]['actor'] + + inserted_at = self.watched_users[user_id]['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + + reason = self.watched_users[user_id]['reason'] + + if isinstance(msg.channel, DMChannel): + # If a watched user DMs the bot there won't be a channel name or jump URL + # This could technically include a GroupChannel but bot's can't be in those + message_jump = "via DM" + else: + message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + + footer = f"Added {time_delta} by {actor} | Reason: {reason}" + embed = Embed(description=f"{msg.author.mention} {message_jump}") + embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) + + await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) + + async def list_watched_users( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Gives an overview of the watched user list for this channel. + + The optional kwarg `oldest_first` orders the list by oldest entry. + + The optional kwarg `update_cache` specifies whether the cache should + be refreshed by polling the API. + """ + if update_cache: + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") + update_cache = False + + lines = [] + for user_id, user_data in self.watched_users.items(): + inserted_at = user_data['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + lines.append(f"• <@{user_id}> (added {time_delta})") + + if oldest_first: + lines.reverse() + + lines = lines or ("There's nothing here yet.",) + + embed = Embed( + title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", + color=Color.blue() + ) + await LinePaginator.paginate(lines, ctx, embed, empty=False) + + @staticmethod + def _get_time_delta(time_string: str) -> str: + """Returns the time in human-readable time delta format.""" + date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) + time_delta = time_since(date_time, precision="minutes", max_units=1) + + return time_delta + + def _remove_user(self, user_id: int) -> None: + """Removes a user from a watch channel.""" + self.watched_users.pop(user_id, None) + self.message_queue.pop(user_id, None) + self.consumption_queue.pop(user_id, None) + + def cog_unload(self) -> None: + """Takes care of unloading the cog and canceling the consumption task.""" + self.log.trace("Unloading the cog") + if self._consume_task and not self._consume_task.done(): + self._consume_task.cancel() + try: + self._consume_task.result() + except asyncio.CancelledError as e: + self.log.exception( + "The consume task was canceled. Messages may be lost.", + exc_info=e + ) diff --git a/bot/exts/moderation/watchchannels/bigbrother.py b/bot/exts/moderation/watchchannels/bigbrother.py new file mode 100644 index 000000000..4ac916c9e --- /dev/null +++ b/bot/exts/moderation/watchchannels/bigbrother.py @@ -0,0 +1,170 @@ +import logging +import textwrap +from collections import ChainMap + +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES, Webhooks +from bot.converters import FetchedMember +from bot.decorators import with_role +from bot.exts.moderation.infraction._utils import post_infraction +from ._watchchannel import WatchChannel + +log = logging.getLogger(__name__) + + +class BigBrother(WatchChannel, Cog, name="Big Brother"): + """Monitors users by relaying their messages to a watch channel to assist with moderation.""" + + def __init__(self, bot: Bot) -> None: + super().__init__( + bot, + destination=Channels.big_brother_logs, + webhook_id=Webhooks.big_brother, + api_endpoint='bot/infractions', + api_default_params={'active': 'true', 'type': 'watch', 'ordering': '-inserted_at'}, + logger=log + ) + + @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def bigbrother_group(self, ctx: Context) -> None: + """Monitors users by relaying their messages to the Big Brother watch channel.""" + await ctx.send_help(ctx.command) + + @bigbrother_group.command(name='watched', aliases=('all', 'list')) + @with_role(*MODERATION_ROLES) + async def watched_command( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Shows the users that are currently being monitored by Big Brother. + + The optional kwarg `oldest_first` can be used to order the list by oldest watched. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) + + @bigbrother_group.command(name='oldest') + @with_role(*MODERATION_ROLES) + async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: + """ + Shows Big Brother monitored users ordered by oldest watched. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) + + @bigbrother_group.command(name='watch', aliases=('w',)) + @with_role(*MODERATION_ROLES) + async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Relay messages sent by the given `user` to the `#big-brother` channel. + + A `reason` for adding the user to Big Brother is required and will be displayed + in the header when relaying messages of this user to the watchchannel. + """ + await self.apply_watch(ctx, user, reason) + + @bigbrother_group.command(name='unwatch', aliases=('uw',)) + @with_role(*MODERATION_ROLES) + async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Stop relaying messages by the given `user`.""" + await self.apply_unwatch(ctx, user, reason) + + async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: + """ + Add `user` to watched users and apply a watch infraction with `reason`. + + A message indicating the result of the operation is sent to `ctx`. + The message will include `user`'s previous watch infraction history, if it exists. + """ + if user.bot: + await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") + return + + if not await self.fetch_user_cache(): + await ctx.send(f":x: Updating the user cache failed, can't watch user {user}") + return + + if user.id in self.watched_users: + await ctx.send(f":x: {user} is already being watched.") + return + + response = await post_infraction(ctx, user, 'watch', reason, hidden=True, active=True) + + if response is not None: + self.watched_users[user.id] = response + msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother." + + history = await self.bot.api_client.get( + self.api_endpoint, + params={ + "user__id": str(user.id), + "active": "false", + 'type': 'watch', + 'ordering': '-inserted_at' + } + ) + + if len(history) > 1: + total = f"({len(history) // 2} previous infractions in total)" + end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") + start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" + msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" + else: + msg = ":x: Failed to post the infraction: response was empty." + + await ctx.send(msg) + + async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: + """ + Remove `user` from watched users and mark their infraction as inactive with `reason`. + + If `send_message` is True, a message indicating the result of the operation is sent to + `ctx`. + """ + active_watches = await self.bot.api_client.get( + self.api_endpoint, + params=ChainMap( + self.api_default_params, + {"user__id": str(user.id)} + ) + ) + if active_watches: + log.trace("Active watches for user found. Attempting to remove.") + [infraction] = active_watches + + await self.bot.api_client.patch( + f"{self.api_endpoint}/{infraction['id']}", + json={'active': False} + ) + + await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False) + + self._remove_user(user.id) + + if not send_message: # Prevents a message being sent to the channel if part of a permanent ban + log.debug(f"Perma-banned user {user} was unwatched.") + return + log.trace("User is not banned. Sending message to channel") + message = f":white_check_mark: Messages sent by {user} will no longer be relayed." + + else: + log.trace("No active watches found for user.") + if not send_message: # Prevents a message being sent to the channel if part of a permanent ban + log.debug(f"{user} was not on the watch list; no removal necessary.") + return + log.trace("User is not perma banned. Send the error message.") + message = ":x: The specified user is currently not being watched." + + await ctx.send(message) + + +def setup(bot: Bot) -> None: + """Load the BigBrother cog.""" + bot.add_cog(BigBrother(bot)) diff --git a/bot/exts/moderation/watchchannels/talentpool.py b/bot/exts/moderation/watchchannels/talentpool.py new file mode 100644 index 000000000..2972f56e1 --- /dev/null +++ b/bot/exts/moderation/watchchannels/talentpool.py @@ -0,0 +1,269 @@ +import logging +import textwrap +from collections import ChainMap + +from discord import Color, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks +from bot.converters import FetchedMember +from bot.decorators import with_role +from bot.pagination import LinePaginator +from bot.utils import time +from ._watchchannel import WatchChannel + +log = logging.getLogger(__name__) + + +class TalentPool(WatchChannel, Cog, name="Talentpool"): + """Relays messages of helper candidates to a watch channel to observe them.""" + + def __init__(self, bot: Bot) -> None: + super().__init__( + bot, + destination=Channels.talent_pool, + webhook_id=Webhooks.talent_pool, + api_endpoint='bot/nominations', + api_default_params={'active': 'true', 'ordering': '-inserted_at'}, + logger=log, + ) + + @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def nomination_group(self, ctx: Context) -> None: + """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" + await ctx.send_help(ctx.command) + + @nomination_group.command(name='watched', aliases=('all', 'list')) + @with_role(*MODERATION_ROLES) + async def watched_command( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Shows the users that are currently being monitored in the talent pool. + + The optional kwarg `oldest_first` can be used to order the list by oldest nomination. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) + + @nomination_group.command(name='oldest') + @with_role(*MODERATION_ROLES) + async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: + """ + Shows talent pool monitored users ordered by oldest nomination. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) + + @nomination_group.command(name='watch', aliases=('w', 'add', 'a')) + @with_role(*STAFF_ROLES) + async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Relay messages sent by the given `user` to the `#talent-pool` channel. + + A `reason` for adding the user to the talent pool is required and will be displayed + in the header when relaying messages of this user to the channel. + """ + if user.bot: + await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") + return + + if isinstance(user, Member) and any(role.id in STAFF_ROLES for role in user.roles): + await ctx.send(":x: Nominating staff members, eh? Here's a cookie :cookie:") + return + + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update the user cache; can't add {user}") + return + + if user.id in self.watched_users: + await ctx.send(f":x: {user} is already being watched in the talent pool") + return + + # Manual request with `raise_for_status` as False because we want the actual response + session = self.bot.api_client.session + url = self.bot.api_client._url_for(self.api_endpoint) + kwargs = { + 'json': { + 'actor': ctx.author.id, + 'reason': reason, + 'user': user.id + }, + 'raise_for_status': False, + } + async with session.post(url, **kwargs) as resp: + response_data = await resp.json() + + if resp.status == 400 and response_data.get('user', False): + await ctx.send(":x: The specified user can't be found in the database tables") + return + else: + resp.raise_for_status() + + self.watched_users[user.id] = response_data + msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel" + + history = await self.bot.api_client.get( + self.api_endpoint, + params={ + "user__id": str(user.id), + "active": "false", + "ordering": "-inserted_at" + } + ) + + if history: + total = f"({len(history)} previous nominations in total)" + start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" + end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" + msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" + + await ctx.send(msg) + + @nomination_group.command(name='history', aliases=('info', 'search')) + @with_role(*MODERATION_ROLES) + async def history_command(self, ctx: Context, user: FetchedMember) -> None: + """Shows the specified user's nomination history.""" + result = await self.bot.api_client.get( + self.api_endpoint, + params={ + 'user__id': str(user.id), + 'ordering': "-active,-inserted_at" + } + ) + if not result: + await ctx.send(":warning: This user has never been nominated") + return + + embed = Embed( + title=f"Nominations for {user.display_name} `({user.id})`", + color=Color.blue() + ) + lines = [self._nomination_to_string(nomination) for nomination in result] + await LinePaginator.paginate( + lines, + ctx=ctx, + embed=embed, + empty=True, + max_lines=3, + max_size=1000 + ) + + @nomination_group.command(name='unwatch', aliases=('end', )) + @with_role(*MODERATION_ROLES) + async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Ends the active nomination of the specified user with the given reason. + + Providing a `reason` is required. + """ + active_nomination = await self.bot.api_client.get( + self.api_endpoint, + params=ChainMap( + self.api_default_params, + {"user__id": str(user.id)} + ) + ) + + if not active_nomination: + await ctx.send(":x: The specified user does not have an active nomination") + return + + [nomination] = active_nomination + await self.bot.api_client.patch( + f"{self.api_endpoint}/{nomination['id']}", + json={'end_reason': reason, 'active': False} + ) + await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") + self._remove_user(user.id) + + @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def nomination_edit_group(self, ctx: Context) -> None: + """Commands to edit nominations.""" + await ctx.send_help(ctx.command) + + @nomination_edit_group.command(name='reason') + @with_role(*MODERATION_ROLES) + async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None: + """ + Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. + + If the nomination is active, the reason for nominating the user will be edited; + If the nomination is no longer active, the reason for ending the nomination will be edited instead. + """ + try: + nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") + except ResponseCodeError as e: + if e.response.status == 404: + self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") + await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") + return + else: + raise + + field = "reason" if nomination["active"] else "end_reason" + + self.log.trace(f"Changing {field} for nomination with id {nomination_id} to {reason}") + + await self.bot.api_client.patch( + f"{self.api_endpoint}/{nomination_id}", + json={field: reason} + ) + + await ctx.send(f":white_check_mark: Updated the {field} of the nomination!") + + def _nomination_to_string(self, nomination_object: dict) -> str: + """Creates a string representation of a nomination.""" + guild = self.bot.get_guild(Guild.id) + + actor_id = nomination_object["actor"] + actor = guild.get_member(actor_id) + + active = nomination_object["active"] + log.debug(active) + log.debug(type(nomination_object["inserted_at"])) + + start_date = time.format_infraction(nomination_object["inserted_at"]) + if active: + lines = textwrap.dedent( + f""" + =============== + Status: **Active** + Date: {start_date} + Actor: {actor.mention if actor else actor_id} + Reason: {nomination_object["reason"]} + Nomination ID: `{nomination_object["id"]}` + =============== + """ + ) + else: + end_date = time.format_infraction(nomination_object["ended_at"]) + lines = textwrap.dedent( + f""" + =============== + Status: Inactive + Date: {start_date} + Actor: {actor.mention if actor else actor_id} + Reason: {nomination_object["reason"]} + + End date: {end_date} + Unwatch reason: {nomination_object["end_reason"]} + Nomination ID: `{nomination_object["id"]}` + =============== + """ + ) + + return lines.strip() + + +def setup(bot: Bot) -> None: + """Load the TalentPool cog.""" + bot.add_cog(TalentPool(bot)) diff --git a/bot/exts/off_topic_names.py b/bot/exts/off_topic_names.py new file mode 100644 index 000000000..ce95450e0 --- /dev/null +++ b/bot/exts/off_topic_names.py @@ -0,0 +1,162 @@ +import asyncio +import difflib +import logging +from datetime import datetime, timedelta + +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES +from bot.converters import OffTopicName +from bot.decorators import with_role +from bot.pagination import LinePaginator + +CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) +log = logging.getLogger(__name__) + + +async def update_names(bot: Bot) -> None: + """Background updater task that performs the daily channel name update.""" + while True: + # Since we truncate the compute timedelta to seconds, we add one second to ensure + # we go past midnight in the `seconds_to_sleep` set below. + today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0) + next_midnight = today_at_midnight + timedelta(days=1) + seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1 + await asyncio.sleep(seconds_to_sleep) + + try: + channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get( + 'bot/off-topic-channel-names', params={'random_items': 3} + ) + except ResponseCodeError as e: + log.error(f"Failed to get new off topic channel names: code {e.response.status}") + continue + channel_0, channel_1, channel_2 = (bot.get_channel(channel_id) for channel_id in CHANNELS) + + await channel_0.edit(name=f'ot0-{channel_0_name}') + await channel_1.edit(name=f'ot1-{channel_1_name}') + await channel_2.edit(name=f'ot2-{channel_2_name}') + log.debug( + "Updated off-topic channel names to" + f" {channel_0_name}, {channel_1_name} and {channel_2_name}" + ) + + +class OffTopicNames(Cog): + """Commands related to managing the off-topic category channel names.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.updater_task = None + + self.bot.loop.create_task(self.init_offtopic_updater()) + + def cog_unload(self) -> None: + """Cancel any running updater tasks on cog unload.""" + if self.updater_task is not None: + self.updater_task.cancel() + + async def init_offtopic_updater(self) -> None: + """Start off-topic channel updating event loop if it hasn't already started.""" + await self.bot.wait_until_guild_available() + if self.updater_task is None: + coro = update_names(self.bot) + self.updater_task = self.bot.loop.create_task(coro) + + @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def otname_group(self, ctx: Context) -> None: + """Add or list items from the off-topic channel name rotation.""" + await ctx.send_help(ctx.command) + + @otname_group.command(name='add', aliases=('a',)) + @with_role(*MODERATION_ROLES) + async def add_command(self, ctx: Context, *, name: OffTopicName) -> None: + """ + Adds a new off-topic name to the rotation. + + The name is not added if it is too similar to an existing name. + """ + existing_names = await self.bot.api_client.get('bot/off-topic-channel-names') + close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) + + if close_match: + match = close_match[0] + log.info( + f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'" + ) + await ctx.send( + f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. " + "Use `!otn forceadd` to override this check." + ) + else: + await self._add_name(ctx, name) + + @otname_group.command(name='forceadd', aliases=('fa',)) + @with_role(*MODERATION_ROLES) + async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None: + """Forcefully adds a new off-topic name to the rotation.""" + await self._add_name(ctx, name) + + async def _add_name(self, ctx: Context, name: str) -> None: + """Adds an off-topic channel name to the site storage.""" + await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name}) + + log.info(f"{ctx.author} added the off-topic channel name '{name}'") + await ctx.send(f":ok_hand: Added `{name}` to the names list.") + + @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) + @with_role(*MODERATION_ROLES) + async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None: + """Removes a off-topic name from the rotation.""" + await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') + + log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") + await ctx.send(f":ok_hand: Removed `{name}` from the names list.") + + @otname_group.command(name='list', aliases=('l',)) + @with_role(*MODERATION_ROLES) + async def list_command(self, ctx: Context) -> None: + """ + Lists all currently known off-topic channel names in a paginator. + + Restricted to Moderator and above to not spoil the surprise. + """ + result = await self.bot.api_client.get('bot/off-topic-channel-names') + lines = sorted(f"• {name}" for name in result) + embed = Embed( + title=f"Known off-topic names (`{len(result)}` total)", + colour=Colour.blue() + ) + if result: + await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + + @otname_group.command(name='search', aliases=('s',)) + @with_role(*MODERATION_ROLES) + async def search_command(self, ctx: Context, *, query: OffTopicName) -> None: + """Search for an off-topic name.""" + result = await self.bot.api_client.get('bot/off-topic-channel-names') + in_matches = {name for name in result if query in name} + close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70) + lines = sorted(f"• {name}" for name in in_matches.union(close_matches)) + embed = Embed( + title="Query results", + colour=Colour.blue() + ) + + if lines: + await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) + else: + embed.description = "Nothing found." + await ctx.send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the OffTopicNames cog.""" + bot.add_cog(OffTopicNames(bot)) diff --git a/bot/exts/utils/__init__.py b/bot/exts/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py new file mode 100644 index 000000000..866fd2b68 --- /dev/null +++ b/bot/exts/utils/bot.py @@ -0,0 +1,385 @@ +import ast +import logging +import re +import time +from typing import Optional, Tuple + +from discord import Embed, Message, RawMessageUpdateEvent, TextChannel +from discord.ext.commands import Cog, Context, command, group + +from bot.bot import Bot +from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs +from bot.decorators import with_role +from bot.exts.filters.token_remover import TokenRemover +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +RE_MARKDOWN = re.compile(r'([*_~`|>])') + + +class BotCog(Cog, name="Bot"): + """Bot information commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + # Stores allowed channels plus epoch time since last call. + self.channel_cooldowns = { + Channels.python_discussion: 0, + } + + # These channels will also work, but will not be subject to cooldown + self.channel_whitelist = ( + Channels.bot_commands, + ) + + # Stores improperly formatted Python codeblock message ids and the corresponding bot message + self.codeblock_message_ids = {} + + @group(invoke_without_command=True, name="bot", hidden=True) + @with_role(Roles.verified) + async def botinfo_group(self, ctx: Context) -> None: + """Bot informational commands.""" + await ctx.send_help(ctx.command) + + @botinfo_group.command(name='about', aliases=('info',), hidden=True) + @with_role(Roles.verified) + async def about_command(self, ctx: Context) -> None: + """Get information about the bot.""" + embed = Embed( + description="A utility bot designed just for the Python server! Try `!help` for more info.", + url="https://github.com/python-discord/bot" + ) + + embed.add_field(name="Total Users", value=str(len(self.bot.get_guild(Guild.id).members))) + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=URLs.bot_avatar + ) + + await ctx.send(embed=embed) + + @command(name='echo', aliases=('print',)) + @with_role(*MODERATION_ROLES) + async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Repeat the given message in either a specified channel or the current channel.""" + if channel is None: + await ctx.send(text) + else: + await channel.send(text) + + @command(name='embed') + @with_role(*MODERATION_ROLES) + async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Send the input within an embed to either a specified channel or the current channel.""" + embed = Embed(description=text) + + if channel is None: + await ctx.send(embed=embed) + else: + await channel.send(embed=embed) + + def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: + """ + Strip msg in order to find Python code. + + Tries to strip out Python code out of msg and returns the stripped block or + None if the block is a valid Python codeblock. + """ + if msg.count("\n") >= 3: + # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. + if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: + log.trace( + "Someone wrote a message that was already a " + "valid Python syntax highlighted code block. No action taken." + ) + return None + + else: + # Stripping backticks from every line of the message. + log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") + content = "" + for line in msg.splitlines(keepends=True): + content += line.strip("`") + + content = content.strip() + + # Remove "Python" or "Py" from start of the message if it exists. + log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") + pycode = False + if content.lower().startswith("python"): + content = content[6:] + pycode = True + elif content.lower().startswith("py"): + content = content[2:] + pycode = True + + if pycode: + content = content.splitlines(keepends=True) + + # Check if there might be code in the first line, and preserve it. + first_line = content[0] + if " " in content[0]: + first_space = first_line.index(" ") + content[0] = first_line[first_space:] + content = "".join(content) + + # If there's no code we can just get rid of the first line. + else: + content = "".join(content[1:]) + + # Strip it again to remove any leading whitespace. This is neccessary + # if the first line of the message looked like ```python + old = content.strip() + + # Strips REPL code out of the message if there is any. + content, repl_code = self.repl_stripping(old) + if old != content: + return (content, old), repl_code + + # Try to apply indentation fixes to the code. + content = self.fix_indentation(content) + + # Check if the code contains backticks, if it does ignore the message. + if "`" in content: + log.trace("Detected ` inside the code, won't reply") + return None + else: + log.trace(f"Returning message.\n\n{content}\n\n") + return (content,), repl_code + + def fix_indentation(self, msg: str) -> str: + """Attempts to fix badly indented code.""" + def unindent(code: str, skip_spaces: int = 0) -> str: + """Unindents all code down to the number of spaces given in skip_spaces.""" + final = "" + current = code[0] + leading_spaces = 0 + + # Get numbers of spaces before code in the first line. + while current == " ": + current = code[leading_spaces + 1] + leading_spaces += 1 + leading_spaces -= skip_spaces + + # If there are any, remove that number of spaces from every line. + if leading_spaces > 0: + for line in code.splitlines(keepends=True): + line = line[leading_spaces:] + final += line + return final + else: + return code + + # Apply fix for "all lines are overindented" case. + msg = unindent(msg) + + # If the first line does not end with a colon, we can be + # certain the next line will be on the same indentation level. + # + # If it does end with a colon, we will need to indent all successive + # lines one additional level. + first_line = msg.splitlines()[0] + code = "".join(msg.splitlines(keepends=True)[1:]) + if not first_line.endswith(":"): + msg = f"{first_line}\n{unindent(code)}" + else: + msg = f"{first_line}\n{unindent(code, 4)}" + return msg + + def repl_stripping(self, msg: str) -> Tuple[str, bool]: + """ + Strip msg in order to extract Python code out of REPL output. + + Tries to strip out REPL Python code out of msg and returns the stripped msg. + + Returns True for the boolean if REPL code was found in the input msg. + """ + final = "" + for line in msg.splitlines(keepends=True): + if line.startswith(">>>") or line.startswith("..."): + final += line[4:] + log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") + if not final: + log.trace(f"Found no REPL code in \n\n{msg}\n\n") + return msg, False + else: + log.trace(f"Found REPL code in \n\n{msg}\n\n") + return final.rstrip(), True + + def has_bad_ticks(self, msg: Message) -> bool: + """Check to see if msg contains ticks that aren't '`'.""" + not_backticks = [ + "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", + "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", + "\u3003\u3003\u3003" + ] + + return msg.content[:3] in not_backticks + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """ + Detect poorly formatted Python code in new messages. + + If poorly formatted code is detected, send the user a helpful message explaining how to do + properly formatted Python syntax highlighting codeblocks. + """ + is_help_channel = ( + getattr(msg.channel, "category", None) + and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) + ) + parse_codeblock = ( + ( + is_help_channel + or msg.channel.id in self.channel_cooldowns + or msg.channel.id in self.channel_whitelist + ) + and not msg.author.bot + and len(msg.content.splitlines()) > 3 + and not TokenRemover.find_token_in_message(msg) + ) + + if parse_codeblock: # no token in the msg + on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 + if not on_cooldown or DEBUG_MODE: + try: + if self.has_bad_ticks(msg): + ticks = msg.content[:3] + content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) + if content is None: + return + + content, repl_code = content + + if len(content) == 2: + content = content[1] + else: + content = content[0] + + space_left = 204 + if len(content) >= space_left: + current_length = 0 + lines_walked = 0 + for line in content.splitlines(keepends=True): + if current_length + len(line) > space_left or lines_walked == 10: + break + current_length += len(line) + lines_walked += 1 + content = content[:current_length] + "#..." + content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) + howto = ( + "It looks like you are trying to paste code into this channel.\n\n" + "You seem to be using the wrong symbols to indicate where the codeblock should start. " + f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" + "**Here is an example of how it should look:**\n" + f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" + "**This will result in the following:**\n" + f"```python\n{content}\n```" + ) + + else: + howto = "" + content = self.codeblock_stripping(msg.content, False) + if content is None: + return + + content, repl_code = content + # Attempts to parse the message into an AST node. + # Invalid Python code will raise a SyntaxError. + tree = ast.parse(content[0]) + + # Multiple lines of single words could be interpreted as expressions. + # This check is to avoid all nodes being parsed as expressions. + # (e.g. words over multiple lines) + if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: + # Shorten the code to 10 lines and/or 204 characters. + space_left = 204 + if content and repl_code: + content = content[1] + else: + content = content[0] + + if len(content) >= space_left: + current_length = 0 + lines_walked = 0 + for line in content.splitlines(keepends=True): + if current_length + len(line) > space_left or lines_walked == 10: + break + current_length += len(line) + lines_walked += 1 + content = content[:current_length] + "#..." + + content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) + howto += ( + "It looks like you're trying to paste code into this channel.\n\n" + "Discord has support for Markdown, which allows you to post code with full " + "syntax highlighting. Please use these whenever you paste code, as this " + "helps improve the legibility and makes it easier for us to help you.\n\n" + f"**To do this, use the following method:**\n" + f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" + "**This will result in the following:**\n" + f"```python\n{content}\n```" + ) + + log.debug(f"{msg.author} posted something that needed to be put inside python code " + "blocks. Sending the user some instructions.") + else: + log.trace("The code consists only of expressions, not sending instructions") + + if howto != "": + # Increase amount of codeblock correction in stats + self.bot.stats.incr("codeblock_corrections") + howto_embed = Embed(description=howto) + bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) + self.codeblock_message_ids[msg.id] = bot_message.id + + self.bot.loop.create_task( + wait_for_deletion(bot_message, user_ids=(msg.author.id,), client=self.bot) + ) + else: + return + + if msg.channel.id not in self.channel_whitelist: + self.channel_cooldowns[msg.channel.id] = time.time() + + except SyntaxError: + log.trace( + f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " + "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " + f"The message that was posted was:\n\n{msg.content}\n\n" + ) + + @Cog.listener() + async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: + """Check to see if an edited message (previously called out) still contains poorly formatted code.""" + if ( + # Checks to see if the message was called out by the bot + payload.message_id not in self.codeblock_message_ids + # Makes sure that there is content in the message + or payload.data.get("content") is None + # Makes sure there's a channel id in the message payload + or payload.data.get("channel_id") is None + ): + return + + # Retrieve channel and message objects for use later + channel = self.bot.get_channel(int(payload.data.get("channel_id"))) + user_message = await channel.fetch_message(payload.message_id) + + # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None + has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) + + # If the message is fixed, delete the bot message and the entry from the id dictionary + if has_fixed_codeblock is None: + bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) + await bot_message.delete() + del self.codeblock_message_ids[payload.message_id] + log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") + + +def setup(bot: Bot) -> None: + """Load the Bot cog.""" + bot.add_cog(BotCog(bot)) diff --git a/bot/exts/utils/clean.py b/bot/exts/utils/clean.py new file mode 100644 index 000000000..d9a7aafe1 --- /dev/null +++ b/bot/exts/utils/clean.py @@ -0,0 +1,272 @@ +import logging +import random +import re +from typing import Iterable, Optional + +from discord import Colour, Embed, Message, TextChannel, User +from discord.ext import commands +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import ( + Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES +) +from bot.decorators import with_role +from bot.exts.moderation.modlog import ModLog + +log = logging.getLogger(__name__) + + +class Clean(Cog): + """ + A cog that allows messages to be deleted in bulk, while applying various filters. + + You can delete messages sent by a specific user, messages sent by bots, all messages, or messages that match a + specific regular expression. + + The deleted messages are saved and uploaded to the database via an API endpoint, and a URL is returned which can be + used to view the messages in the Discord dark theme style. + """ + + def __init__(self, bot: Bot): + self.bot = bot + self.cleaning = False + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def _clean_messages( + self, + amount: int, + ctx: Context, + channels: Iterable[TextChannel], + bots_only: bool = False, + user: User = None, + regex: Optional[str] = None, + until_message: Optional[Message] = None, + ) -> None: + """A helper function that does the actual message cleaning.""" + def predicate_bots_only(message: Message) -> bool: + """Return True if the message was sent by a bot.""" + return message.author.bot + + def predicate_specific_user(message: Message) -> bool: + """Return True if the message was sent by the user provided in the _clean_messages call.""" + return message.author == user + + def predicate_regex(message: Message) -> bool: + """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" + content = [message.content] + + # Add the content for all embed attributes + for embed in message.embeds: + content.append(embed.title) + content.append(embed.description) + content.append(embed.footer.text) + content.append(embed.author.name) + for field in embed.fields: + content.append(field.name) + content.append(field.value) + + # Get rid of empty attributes and turn it into a string + content = [attr for attr in content if attr] + content = "\n".join(content) + + # Now let's see if there's a regex match + if not content: + return False + else: + return bool(re.search(regex.lower(), content.lower())) + + # Is this an acceptable amount of messages to clean? + if amount > CleanMessages.message_limit: + embed = Embed( + color=Colour(Colours.soft_red), + title=random.choice(NEGATIVE_REPLIES), + description=f"You cannot clean more than {CleanMessages.message_limit} messages." + ) + await ctx.send(embed=embed) + return + + # Are we already performing a clean? + if self.cleaning: + embed = Embed( + color=Colour(Colours.soft_red), + title=random.choice(NEGATIVE_REPLIES), + description="Please wait for the currently ongoing clean operation to complete." + ) + await ctx.send(embed=embed) + return + + # Set up the correct predicate + if bots_only: + predicate = predicate_bots_only # Delete messages from bots + elif user: + predicate = predicate_specific_user # Delete messages from specific user + elif regex: + predicate = predicate_regex # Delete messages that match regex + else: + predicate = None # Delete all messages + + # Default to using the invoking context's channel + if not channels: + channels = [ctx.channel] + + # Delete the invocation first + self.mod_log.ignore(Event.message_delete, ctx.message.id) + await ctx.message.delete() + + messages = [] + message_ids = [] + self.cleaning = True + + # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. + for channel in channels: + async for message in channel.history(limit=amount): + + # If at any point the cancel command is invoked, we should stop. + if not self.cleaning: + return + + # If we are looking for specific message. + if until_message: + + # we could use ID's here however in case if the message we are looking for gets deleted, + # we won't have a way to figure that out thus checking for datetime should be more reliable + if message.created_at < until_message.created_at: + # means we have found the message until which we were supposed to be deleting. + break + + # Since we will be using `delete_messages` method of a TextChannel and we need message objects to + # use it as well as to send logs we will start appending messages here instead adding them from + # purge. + messages.append(message) + + # If the message passes predicate, let's save it. + if predicate is None or predicate(message): + message_ids.append(message.id) + + self.cleaning = False + + # Now let's delete the actual messages with purge. + self.mod_log.ignore(Event.message_delete, *message_ids) + for channel in channels: + if until_message: + for i in range(0, len(messages), 100): + # while purge automatically handles the amount of messages + # delete_messages only allows for up to 100 messages at once + # thus we need to paginate the amount to always be <= 100 + await channel.delete_messages(messages[i:i + 100]) + else: + messages += await channel.purge(limit=amount, check=predicate) + + # Reverse the list to restore chronological order + if messages: + messages = reversed(messages) + log_url = await self.mod_log.upload_log(messages, ctx.author.id) + else: + # Can't build an embed, nothing to clean! + embed = Embed( + color=Colour(Colours.soft_red), + description="No matching messages could be found." + ) + await ctx.send(embed=embed, delete_after=10) + return + + # Build the embed and send it + target_channels = ", ".join(channel.mention for channel in channels) + + message = ( + f"**{len(message_ids)}** messages deleted in {target_channels} by **{ctx.author.name}**\n\n" + f"A log of the deleted messages can be found [here]({log_url})." + ) + + await self.mod_log.send_log_message( + icon_url=Icons.message_bulk_delete, + colour=Colour(Colours.soft_red), + title="Bulk message delete", + text=message, + channel_id=Channels.mod_log, + ) + + @group(invoke_without_command=True, name="clean", aliases=["purge"]) + @with_role(*MODERATION_ROLES) + async def clean_group(self, ctx: Context) -> None: + """Commands for cleaning messages in channels.""" + await ctx.send_help(ctx.command) + + @clean_group.command(name="user", aliases=["users"]) + @with_role(*MODERATION_ROLES) + async def clean_user( + self, + ctx: Context, + user: User, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, user=user, channels=channels) + + @clean_group.command(name="all", aliases=["everything"]) + @with_role(*MODERATION_ROLES) + async def clean_all( + self, + ctx: Context, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, channels=channels) + + @clean_group.command(name="bots", aliases=["bot"]) + @with_role(*MODERATION_ROLES) + async def clean_bots( + self, + ctx: Context, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, bots_only=True, channels=channels) + + @clean_group.command(name="regex", aliases=["word", "expression"]) + @with_role(*MODERATION_ROLES) + async def clean_regex( + self, + ctx: Context, + regex: str, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, regex=regex, channels=channels) + + @clean_group.command(name="message", aliases=["messages"]) + @with_role(*MODERATION_ROLES) + async def clean_message(self, ctx: Context, message: Message) -> None: + """Delete all messages until certain message, stop cleaning after hitting the `message`.""" + await self._clean_messages( + CleanMessages.message_limit, + ctx, + channels=[message.channel], + until_message=message + ) + + @clean_group.command(name="stop", aliases=["cancel", "abort"]) + @with_role(*MODERATION_ROLES) + async def clean_cancel(self, ctx: Context) -> None: + """If there is an ongoing cleaning process, attempt to immediately cancel it.""" + self.cleaning = False + + embed = Embed( + color=Colour.blurple(), + description="Clean interrupted." + ) + await ctx.send(embed=embed, delete_after=10) + + +def setup(bot: Bot) -> None: + """Load the Clean cog.""" + bot.add_cog(Clean(bot)) diff --git a/bot/exts/utils/eval.py b/bot/exts/utils/eval.py new file mode 100644 index 000000000..eb8bfb1cf --- /dev/null +++ b/bot/exts/utils/eval.py @@ -0,0 +1,202 @@ +import contextlib +import inspect +import logging +import pprint +import re +import textwrap +import traceback +from io import StringIO +from typing import Any, Optional, Tuple + +import discord +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Roles +from bot.decorators import with_role +from bot.interpreter import Interpreter + +log = logging.getLogger(__name__) + + +class CodeEval(Cog): + """Owner and admin feature that evaluates code and returns the result to the channel.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.env = {} + self.ln = 0 + self.stdout = StringIO() + + self.interpreter = Interpreter(bot) + + def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: + """Format the eval output into a string & attempt to format it into an Embed.""" + self._ = out + + res = "" + + # Erase temp input we made + if inp.startswith("_ = "): + inp = inp[4:] + + # Get all non-empty lines + lines = [line for line in inp.split("\n") if line.strip()] + if len(lines) != 1: + lines += [""] + + # Create the input dialog + for i, line in enumerate(lines): + if i == 0: + # Start dialog + start = f"In [{self.ln}]: " + + else: + # Indent the 3 dots correctly; + # Normally, it's something like + # In [X]: + # ...: + # + # But if it's + # In [XX]: + # ...: + # + # You can see it doesn't look right. + # This code simply indents the dots + # far enough to align them. + # we first `str()` the line number + # then we get the length + # and use `str.rjust()` + # to indent it. + start = "...: ".rjust(len(str(self.ln)) + 7) + + if i == len(lines) - 2: + if line.startswith("return"): + line = line[6:].strip() + + # Combine everything + res += (start + line + "\n") + + self.stdout.seek(0) + text = self.stdout.read() + self.stdout.close() + self.stdout = StringIO() + + if text: + res += (text + "\n") + + if out is None: + # No output, return the input statement + return (res, None) + + res += f"Out[{self.ln}]: " + + if isinstance(out, discord.Embed): + # We made an embed? Send that as embed + res += "" + res = (res, out) + + else: + if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): + # Leave out the traceback message + out = "\n" + "\n".join(out.split("\n")[1:]) + + if isinstance(out, str): + pretty = out + else: + pretty = pprint.pformat(out, compact=True, width=60) + + if pretty != str(out): + # We're using the pretty version, start on the next line + res += "\n" + + if pretty.count("\n") > 20: + # Text too long, shorten + li = pretty.split("\n") + + pretty = ("\n".join(li[:3]) # First 3 lines + + "\n ...\n" # Ellipsis to indicate removed lines + + "\n".join(li[-3:])) # last 3 lines + + # Add the output + res += pretty + res = (res, None) + + return res # Return (text, embed) + + async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: + """Eval the input code string & send an embed to the invoking context.""" + self.ln += 1 + + if code.startswith("exit"): + self.ln = 0 + self.env = {} + return await ctx.send("```Reset history!```") + + env = { + "message": ctx.message, + "author": ctx.message.author, + "channel": ctx.channel, + "guild": ctx.guild, + "ctx": ctx, + "self": self, + "bot": self.bot, + "inspect": inspect, + "discord": discord, + "contextlib": contextlib + } + + self.env.update(env) + + # Ignore this code, it works + code_ = """ +async def func(): # (None,) -> Any + try: + with contextlib.redirect_stdout(self.stdout): +{0} + if '_' in locals(): + if inspect.isawaitable(_): + _ = await _ + return _ + finally: + self.env.update(locals()) +""".format(textwrap.indent(code, ' ')) + + try: + exec(code_, self.env) # noqa: B102,S102 + func = self.env['func'] + res = await func() + + except Exception: + res = traceback.format_exc() + + out, embed = self._format(code, res) + await ctx.send(f"```py\n{out}```", embed=embed) + + @group(name='internal', aliases=('int',)) + @with_role(Roles.owners, Roles.admins) + async def internal_group(self, ctx: Context) -> None: + """Internal commands. Top secret!""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @internal_group.command(name='eval', aliases=('e',)) + @with_role(Roles.admins, Roles.owners) + async def eval(self, ctx: Context, *, code: str) -> None: + """Run eval in a REPL-like format.""" + code = code.strip("`") + if re.match('py(thon)?\n', code): + code = "\n".join(code.split("\n")[1:]) + + if not re.search( # Check if it's an expression + r"^(return|import|for|while|def|class|" + r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( + code.split("\n")) == 1: + code = "_ = " + code + + await self._eval(ctx, code) + + +def setup(bot: Bot) -> None: + """Load the CodeEval cog.""" + bot.add_cog(CodeEval(bot)) diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py new file mode 100644 index 000000000..671397650 --- /dev/null +++ b/bot/exts/utils/extensions.py @@ -0,0 +1,289 @@ +import functools +import importlib +import inspect +import logging +import pkgutil +import typing as t +from enum import Enum + +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Context, group + +from bot import exts +from bot.bot import Bot +from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + + +def walk_extensions() -> t.Iterator[str]: + """Yield extension names from the bot.exts subpackage.""" + + def on_error(name: str) -> t.NoReturn: + raise ImportError(name=name) # pragma: no cover + + for module in pkgutil.walk_packages(exts.__path__, f"{exts.__name__}.", onerror=on_error): + if module.name.rsplit(".", maxsplit=1)[-1].startswith("_"): + # Ignore module/package names starting with an underscore. + continue + + if module.ispkg: + imported = importlib.import_module(module.name) + if not inspect.isfunction(getattr(imported, "setup", None)): + # If it lacks a setup function, it's not an extension. + continue + + yield module.name + + +UNLOAD_BLACKLIST = {f"{exts.__name__}.utils.extensions", f"{exts.__name__}.moderation.modlog"} +EXTENSIONS = frozenset(walk_extensions()) +BASE_PATH_LEN = len(exts.__name__.split(".")) + + +class Action(Enum): + """Represents an action to perform on an extension.""" + + # Need to be partial otherwise they are considered to be function definitions. + LOAD = functools.partial(Bot.load_extension) + UNLOAD = functools.partial(Bot.unload_extension) + RELOAD = functools.partial(Bot.reload_extension) + + +class Extension(commands.Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if argument in EXTENSIONS: + return argument + elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + return qualified_arg + + matches = [] + for ext in EXTENSIONS: + name = ext.rsplit(".", maxsplit=1)[-1] + if argument == name: + matches.append(ext) + + if len(matches) > 1: + matches.sort() + names = "\n".join(matches) + raise commands.BadArgument( + f":x: `{argument}` is an ambiguous extension name. " + f"Please use one of the following fully-qualified names.```\n{names}```" + ) + elif matches: + return matches[0] + else: + raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") + + +class Extensions(commands.Cog): + """Extension management commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) + async def extensions_group(self, ctx: Context) -> None: + """Load, unload, reload, and list loaded extensions.""" + await ctx.send_help(ctx.command) + + @extensions_group.command(name="load", aliases=("l",)) + async def load_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Load extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "*" in extensions or "**" in extensions: + extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + + msg = self.batch_manage(Action.LOAD, *extensions) + await ctx.send(msg) + + @extensions_group.command(name="unload", aliases=("ul",)) + async def unload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Unload currently loaded extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + + if blacklisted: + msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" + else: + if "*" in extensions or "**" in extensions: + extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + + msg = self.batch_manage(Action.UNLOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="reload", aliases=("r",)) + async def reload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Reload extensions given their fully qualified or unqualified names. + + If an extension fails to be reloaded, it will be rolled-back to the prior working state. + + If '\*' is given as the name, all currently loaded extensions will be reloaded. + If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "**" in extensions: + extensions = EXTENSIONS + elif "*" in extensions: + extensions = set(self.bot.extensions.keys()) | set(extensions) + extensions.remove("*") + + msg = self.batch_manage(Action.RELOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="list", aliases=("all",)) + async def list_command(self, ctx: Context) -> None: + """ + Get a list of all extensions, including their loaded status. + + Grey indicates that the extension is unloaded. + Green indicates that the extension is currently loaded. + """ + embed = Embed(colour=Colour.blurple()) + embed.set_author( + name="Extensions List", + url=URLs.github_bot_repo, + icon_url=URLs.bot_avatar + ) + + lines = [] + categories = self.group_extension_statuses() + for category, extensions in sorted(categories.items()): + # Treat each category as a single line by concatenating everything. + # This ensures the paginator will not cut off a page in the middle of a category. + category = category.replace("_", " ").title() + extensions = "\n".join(sorted(extensions)) + lines.append(f"**{category}**\n{extensions}\n") + + log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") + await LinePaginator.paginate(lines, ctx, embed, scale_to_size=700, empty=False) + + def group_extension_statuses(self) -> t.Mapping[str, str]: + """Return a mapping of extension names and statuses to their categories.""" + categories = {} + + for ext in EXTENSIONS: + if ext in self.bot.extensions: + status = Emojis.status_online + else: + status = Emojis.status_offline + + path = ext.split(".") + if len(path) > BASE_PATH_LEN + 1: + category = " - ".join(path[BASE_PATH_LEN:-1]) + else: + category = "uncategorised" + + categories.setdefault(category, []).append(f"{status} {path[-1]}") + + return categories + + def batch_manage(self, action: Action, *extensions: str) -> str: + """ + Apply an action to multiple extensions and return a message with the results. + + If only one extension is given, it is deferred to `manage()`. + """ + if len(extensions) == 1: + msg, _ = self.manage(action, extensions[0]) + return msg + + verb = action.name.lower() + failures = {} + + for extension in extensions: + _, error = self.manage(action, extension) + if error: + failures[extension] = error + + emoji = ":x:" if failures else ":ok_hand:" + msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." + + if failures: + failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) + msg += f"\nFailures:```{failures}```" + + log.debug(f"Batch {verb}ed extensions.") + + return msg + + def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: + """Apply an action to an extension and return the status message and any error message.""" + verb = action.name.lower() + error_msg = None + + try: + action.value(self.bot, ext) + except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): + if action is Action.RELOAD: + # When reloading, just load the extension if it was not loaded. + return self.manage(Action.LOAD, ext) + + msg = f":x: Extension `{ext}` is already {verb}ed." + log.debug(msg[4:]) + except Exception as e: + if hasattr(e, "original"): + e = e.original + + log.exception(f"Extension '{ext}' failed to {verb}.") + + error_msg = f"{e.__class__.__name__}: {e}" + msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" + else: + msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." + log.debug(msg[10:]) + + return msg, error_msg + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators and core developers to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Handle BadArgument errors locally to prevent the help command from showing.""" + if isinstance(error, commands.BadArgument): + await ctx.send(str(error)) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Extensions cog.""" + bot.add_cog(Extensions(bot)) diff --git a/bot/exts/utils/jams.py b/bot/exts/utils/jams.py new file mode 100644 index 000000000..b3102db2f --- /dev/null +++ b/bot/exts/utils/jams.py @@ -0,0 +1,150 @@ +import logging +import typing as t + +from discord import CategoryChannel, Guild, Member, PermissionOverwrite, Role +from discord.ext import commands +from more_itertools import unique_everseen + +from bot.bot import Bot +from bot.constants import Roles +from bot.decorators import with_role + +log = logging.getLogger(__name__) + +MAX_CHANNELS = 50 +CATEGORY_NAME = "Code Jam" + + +class CodeJams(commands.Cog): + """Manages the code-jam related parts of our server.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @commands.command() + @with_role(Roles.admins) + async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: + """ + Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. + + The first user passed will always be the team leader. + """ + # Ignore duplicate members + members = list(unique_everseen(members)) + + # We had a little issue during Code Jam 4 here, the greedy converter did it's job + # and ignored anything which wasn't a valid argument which left us with teams of + # two members or at some times even 1 member. This fixes that by checking that there + # are always 3 members in the members list. + if len(members) < 3: + await ctx.send( + ":no_entry_sign: One of your arguments was invalid\n" + f"There must be a minimum of 3 valid members in your team. Found: {len(members)}" + " members" + ) + return + + team_channel = await self.create_channels(ctx.guild, team_name, members) + await self.add_roles(ctx.guild, members) + + await ctx.send( + f":ok_hand: Team created: {team_channel}\n" + f"**Team Leader:** {members[0].mention}\n" + f"**Team Members:** {' '.join(member.mention for member in members[1:])}" + ) + + async def get_category(self, guild: Guild) -> CategoryChannel: + """ + Return a code jam category. + + If all categories are full or none exist, create a new category. + """ + for category in guild.categories: + # Need 2 available spaces: one for the text channel and one for voice. + if category.name == CATEGORY_NAME and MAX_CHANNELS - len(category.channels) >= 2: + return category + + return await self.create_category(guild) + + @staticmethod + async def create_category(guild: Guild) -> CategoryChannel: + """Create a new code jam category and return it.""" + log.info("Creating a new code jam category.") + + category_overwrites = { + guild.default_role: PermissionOverwrite(read_messages=False), + guild.me: PermissionOverwrite(read_messages=True) + } + + return await guild.create_category_channel( + CATEGORY_NAME, + overwrites=category_overwrites, + reason="It's code jam time!" + ) + + @staticmethod + def get_overwrites(members: t.List[Member], guild: Guild) -> t.Dict[t.Union[Member, Role], PermissionOverwrite]: + """Get code jam team channels permission overwrites.""" + # First member is always the team leader + team_channel_overwrites = { + members[0]: PermissionOverwrite( + manage_messages=True, + read_messages=True, + manage_webhooks=True, + connect=True + ), + guild.default_role: PermissionOverwrite(read_messages=False, connect=False), + guild.get_role(Roles.verified): PermissionOverwrite( + read_messages=False, + connect=False + ) + } + + # Rest of members should just have read_messages + for member in members[1:]: + team_channel_overwrites[member] = PermissionOverwrite( + read_messages=True, + connect=True + ) + + return team_channel_overwrites + + async def create_channels(self, guild: Guild, team_name: str, members: t.List[Member]) -> str: + """Create team text and voice channels. Return the mention for the text channel.""" + # Get permission overwrites and category + team_channel_overwrites = self.get_overwrites(members, guild) + code_jam_category = await self.get_category(guild) + + # Create a text channel for the team + team_channel = await guild.create_text_channel( + team_name, + overwrites=team_channel_overwrites, + category=code_jam_category + ) + + # Create a voice channel for the team + team_voice_name = " ".join(team_name.split("-")).title() + + await guild.create_voice_channel( + team_voice_name, + overwrites=team_channel_overwrites, + category=code_jam_category + ) + + return team_channel.mention + + @staticmethod + async def add_roles(guild: Guild, members: t.List[Member]) -> None: + """Assign team leader and jammer roles.""" + # Assign team leader role + await members[0].add_roles(guild.get_role(Roles.team_leaders)) + + # Assign rest of roles + jammer_role = guild.get_role(Roles.jammers) + for member in members: + await member.add_roles(jammer_role) + + +def setup(bot: Bot) -> None: + """Load the CodeJams cog.""" + bot.add_cog(CodeJams(bot)) diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py new file mode 100644 index 000000000..670493bcf --- /dev/null +++ b/bot/exts/utils/reminders.py @@ -0,0 +1,427 @@ +import asyncio +import logging +import random +import textwrap +import typing as t +from datetime import datetime, timedelta +from operator import itemgetter + +import discord +from dateutil.parser import isoparse +from dateutil.relativedelta import relativedelta +from discord.ext.commands import Cog, Context, Greedy, group + +from bot.bot import Bot +from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES +from bot.converters import Duration +from bot.pagination import LinePaginator +from bot.utils.checks import without_role_check +from bot.utils.messages import send_denial +from bot.utils.scheduling import Scheduler +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +WHITELISTED_CHANNELS = Guild.reminder_whitelist +MAXIMUM_REMINDERS = 5 + +Mentionable = t.Union[discord.Member, discord.Role] + + +class Reminders(Cog): + """Provide in-channel reminder functionality.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_reminders()) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + async def reschedule_reminders(self) -> None: + """Get all current reminders from the API and reschedule them.""" + await self.bot.wait_until_guild_available() + response = await self.bot.api_client.get( + 'bot/reminders', + params={'active': 'true'} + ) + + now = datetime.utcnow() + + for reminder in response: + is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) + if not is_valid: + continue + + remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) + + # If the reminder is already overdue ... + if remind_at < now: + late = relativedelta(now, remind_at) + await self.send_reminder(reminder, late) + else: + self.schedule_reminder(reminder) + + def ensure_valid_reminder( + self, + reminder: dict, + cancel_task: bool = True + ) -> t.Tuple[bool, discord.User, discord.TextChannel]: + """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" + user = self.bot.get_user(reminder['author']) + channel = self.bot.get_channel(reminder['channel_id']) + is_valid = True + if not user or not channel: + is_valid = False + log.info( + f"Reminder {reminder['id']} invalid: " + f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." + ) + asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) + + return is_valid, user, channel + + @staticmethod + async def _send_confirmation( + ctx: Context, + on_success: str, + reminder_id: str, + delivery_dt: t.Optional[datetime], + ) -> None: + """Send an embed confirming the reminder change was made successfully.""" + embed = discord.Embed() + embed.colour = discord.Colour.green() + embed.title = random.choice(POSITIVE_REPLIES) + embed.description = on_success + + footer_str = f"ID: {reminder_id}" + if delivery_dt: + # Reminder deletion will have a `None` `delivery_dt` + footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" + + embed.set_footer(text=footer_str) + + await ctx.send(embed=embed) + + @staticmethod + async def _check_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> t.Tuple[bool, str]: + """ + Returns whether or not the list of mentions is allowed. + + Conditions: + - Role reminders are Mods+ + - Reminders for other users are Helpers+ + + If mentions aren't allowed, also return the type of mention(s) disallowed. + """ + if without_role_check(ctx, *STAFF_ROLES): + return False, "members/roles" + elif without_role_check(ctx, *MODERATION_ROLES): + return all(isinstance(mention, discord.Member) for mention in mentions), "roles" + else: + return True, "" + + @staticmethod + async def validate_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> bool: + """ + Filter mentions to see if the user can mention, and sends a denial if not allowed. + + Returns whether or not the validation is successful. + """ + mentions_allowed, disallowed_mentions = await Reminders._check_mentions(ctx, mentions) + + if not mentions or mentions_allowed: + return True + else: + await send_denial(ctx, f"You can't mention other {disallowed_mentions} in your reminder!") + return False + + def get_mentionables(self, mention_ids: t.List[int]) -> t.Iterator[Mentionable]: + """Converts Role and Member ids to their corresponding objects if possible.""" + guild = self.bot.get_guild(Guild.id) + for mention_id in mention_ids: + if (mentionable := (guild.get_member(mention_id) or guild.get_role(mention_id))): + yield mentionable + + def schedule_reminder(self, reminder: dict) -> None: + """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" + reminder_id = reminder["id"] + reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) + + async def _remind() -> None: + await self.send_reminder(reminder) + + log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") + await self._delete_reminder(reminder_id) + + self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) + + async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: + """Delete a reminder from the database, given its ID, and cancel the running task.""" + await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) + + if cancel_task: + # Now we can remove it from the schedule list + self.scheduler.cancel(reminder_id) + + async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: + """ + Edits a reminder in the database given the ID and payload. + + Returns the edited reminder. + """ + # Send the request to update the reminder in the database + reminder = await self.bot.api_client.patch( + 'bot/reminders/' + str(reminder_id), + json=payload + ) + return reminder + + async def _reschedule_reminder(self, reminder: dict) -> None: + """Reschedule a reminder object.""" + log.trace(f"Cancelling old task #{reminder['id']}") + self.scheduler.cancel(reminder["id"]) + + log.trace(f"Scheduling new task #{reminder['id']}") + self.schedule_reminder(reminder) + + async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: + """Send the reminder.""" + is_valid, user, channel = self.ensure_valid_reminder(reminder) + if not is_valid: + return + + embed = discord.Embed() + embed.colour = discord.Colour.blurple() + embed.set_author( + icon_url=Icons.remind_blurple, + name="It has arrived!" + ) + + embed.description = f"Here's your reminder: `{reminder['content']}`." + + if reminder.get("jump_url"): # keep backward compatibility + embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" + + if late: + embed.colour = discord.Colour.red() + embed.set_author( + icon_url=Icons.remind_red, + name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" + ) + + additional_mentions = ' '.join( + mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) + ) + + await channel.send( + content=f"{user.mention} {additional_mentions}", + embed=embed + ) + await self._delete_reminder(reminder["id"]) + + @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) + async def remind_group( + self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + ) -> None: + """Commands for managing your reminders.""" + await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content) + + @remind_group.command(name="new", aliases=("add", "create")) + async def new_reminder( + self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + ) -> None: + """ + Set yourself a simple reminder. + + Expiration is parsed per: http://strftime.org/ + """ + # If the user is not staff, we need to verify whether or not to make a reminder at all. + if without_role_check(ctx, *STAFF_ROLES): + + # If they don't have permission to set a reminder in this channel + if ctx.channel.id not in WHITELISTED_CHANNELS: + await send_denial(ctx, "Sorry, you can't do that here!") + return + + # Get their current active reminders + active_reminders = await self.bot.api_client.get( + 'bot/reminders', + params={ + 'author__id': str(ctx.author.id) + } + ) + + # Let's limit this, so we don't get 10 000 + # reminders from kip or something like that :P + if len(active_reminders) > MAXIMUM_REMINDERS: + await send_denial(ctx, "You have too many active reminders!") + return + + # Remove duplicate mentions + mentions = set(mentions) + mentions.discard(ctx.author) + + # Filter mentions to see if the user can mention members/roles + if not await self.validate_mentions(ctx, mentions): + return + + mention_ids = [mention.id for mention in mentions] + + # Now we can attempt to actually set the reminder. + reminder = await self.bot.api_client.post( + 'bot/reminders', + json={ + 'author': ctx.author.id, + 'channel_id': ctx.message.channel.id, + 'jump_url': ctx.message.jump_url, + 'content': content, + 'expiration': expiration.isoformat(), + 'mentions': mention_ids, + } + ) + + now = datetime.utcnow() - timedelta(seconds=1) + humanized_delta = humanize_delta(relativedelta(expiration, now)) + mention_string = ( + f"Your reminder will arrive in {humanized_delta} " + f"and will mention {len(mentions)} other(s)!" + ) + + # Confirm to the user that it worked. + await self._send_confirmation( + ctx, + on_success=mention_string, + reminder_id=reminder["id"], + delivery_dt=expiration, + ) + + self.schedule_reminder(reminder) + + @remind_group.command(name="list") + async def list_reminders(self, ctx: Context) -> None: + """View a paginated embed of all reminders for your user.""" + # Get all the user's reminders from the database. + data = await self.bot.api_client.get( + 'bot/reminders', + params={'author__id': str(ctx.author.id)} + ) + + now = datetime.utcnow() + + # Make a list of tuples so it can be sorted by time. + reminders = sorted( + ( + (rem['content'], rem['expiration'], rem['id'], rem['mentions']) + for rem in data + ), + key=itemgetter(1) + ) + + lines = [] + + for content, remind_at, id_, mentions in reminders: + # Parse and humanize the time, make it pretty :D + remind_datetime = isoparse(remind_at).replace(tzinfo=None) + time = humanize_delta(relativedelta(remind_datetime, now)) + + mentions = ", ".join( + # Both Role and User objects have the `name` attribute + mention.name for mention in self.get_mentionables(mentions) + ) + mention_string = f"\n**Mentions:** {mentions}" if mentions else "" + + text = textwrap.dedent(f""" + **Reminder #{id_}:** *expires in {time}* (ID: {id_}){mention_string} + {content} + """).strip() + + lines.append(text) + + embed = discord.Embed() + embed.colour = discord.Colour.blurple() + embed.title = f"Reminders for {ctx.author}" + + # Remind the user that they have no reminders :^) + if not lines: + embed.description = "No active reminders could be found." + await ctx.send(embed=embed) + return + + # Construct the embed and paginate it. + embed.colour = discord.Colour.blurple() + + await LinePaginator.paginate( + lines, + ctx, embed, + max_lines=3, + empty=True + ) + + @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) + async def edit_reminder_group(self, ctx: Context) -> None: + """Commands for modifying your current reminders.""" + await ctx.send_help(ctx.command) + + @edit_reminder_group.command(name="duration", aliases=("time",)) + async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: + """ + Edit one of your reminder's expiration. + + Expiration is parsed per: http://strftime.org/ + """ + await self.edit_reminder(ctx, id_, {'expiration': expiration.isoformat()}) + + @edit_reminder_group.command(name="content", aliases=("reason",)) + async def edit_reminder_content(self, ctx: Context, id_: int, *, content: str) -> None: + """Edit one of your reminder's content.""" + await self.edit_reminder(ctx, id_, {"content": content}) + + @edit_reminder_group.command(name="mentions", aliases=("pings",)) + async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: + """Edit one of your reminder's mentions.""" + # Remove duplicate mentions + mentions = set(mentions) + mentions.discard(ctx.author) + + # Filter mentions to see if the user can mention members/roles + if not await self.validate_mentions(ctx, mentions): + return + + mention_ids = [mention.id for mention in mentions] + await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) + + async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: + """Edits a reminder with the given payload, then sends a confirmation message.""" + reminder = await self._edit_reminder(id_, payload) + + # Parse the reminder expiration back into a datetime + expiration = isoparse(reminder["expiration"]).replace(tzinfo=None) + + # Send a confirmation message to the channel + await self._send_confirmation( + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=id_, + delivery_dt=expiration, + ) + await self._reschedule_reminder(reminder) + + @remind_group.command("delete", aliases=("remove", "cancel")) + async def delete_reminder(self, ctx: Context, id_: int) -> None: + """Delete one of your active reminders.""" + await self._delete_reminder(id_) + await self._send_confirmation( + ctx, + on_success="That reminder has been deleted successfully!", + reminder_id=id_, + delivery_dt=None, + ) + + +def setup(bot: Bot) -> None: + """Load the Reminders cog.""" + bot.add_cog(Reminders(bot)) diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py new file mode 100644 index 000000000..52c8b6f88 --- /dev/null +++ b/bot/exts/utils/snekbox.py @@ -0,0 +1,349 @@ +import asyncio +import contextlib +import datetime +import logging +import re +import textwrap +from functools import partial +from signal import Signals +from typing import Optional, Tuple + +from discord import HTTPException, Message, NotFound, Reaction, User +from discord.ext.commands import Cog, Context, command, guild_only + +from bot.bot import Bot +from bot.constants import Categories, Channels, Roles, URLs +from bot.decorators import in_whitelist +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") +FORMATTED_CODE_REGEX = re.compile( + r"^\s*" # any leading whitespace from the beginning of the string + r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block + r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) + r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all code inside the markup + r"\s*" # any more whitespace before the end of the code markup + r"(?P=delim)" # match the exact same delimiter from the start again + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive +) +RAW_CODE_REGEX = re.compile( + r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all the rest as code + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL # "." also matches newlines +) + +MAX_PASTE_LEN = 1000 + +# `!eval` command whitelists +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) +EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) +EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) + +SIGKILL = 9 + +REEVAL_EMOJI = '\U0001f501' # :repeat: +REEVAL_TIMEOUT = 30 + + +class Snekbox(Cog): + """Safe evaluation of Python code using Snekbox.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.jobs = {} + + async def post_eval(self, code: str) -> dict: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + url = URLs.snekbox_eval_api + data = {"input": code} + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return await resp.json() + + async def upload_output(self, output: str) -> Optional[str]: + """Upload the eval output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") + + if len(output) > MAX_PASTE_LEN: + log.info("Full output is too long to upload") + return "too long to upload" + + url = URLs.paste_service.format(key="documents") + try: + async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: + data = await resp.json() + + if "key" in data: + return URLs.paste_service.format(key=data["key"]) + except Exception: + # 400 (Bad Request) means there are too many characters + log.exception("Failed to upload full output to paste service!") + + @staticmethod + def prepare_input(code: str) -> str: + """Extract code from the Markdown, format it, and insert it into the code template.""" + match = FORMATTED_CODE_REGEX.fullmatch(code) + if match: + code, block, lang, delim = match.group("code", "block", "lang", "delim") + code = textwrap.dedent(code) + if block: + info = (f"'{lang}' highlighted" if lang else "plain") + " code block" + else: + info = f"{delim}-enclosed inline code" + log.trace(f"Extracted {info} for evaluation:\n{code}") + else: + code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) + log.trace( + f"Eval message contains unformatted or badly formatted code, " + f"stripping whitespace only:\n{code}" + ) + + return code + + @staticmethod + def get_results_message(results: dict) -> Tuple[str, str]: + """Return a user-friendly message and error corresponding to the process's return code.""" + stdout, returncode = results["stdout"], results["returncode"] + msg = f"Your eval job has completed with return code {returncode}" + error = "" + + if returncode is None: + msg = "Your eval job has failed" + error = stdout.strip() + elif returncode == 128 + SIGKILL: + msg = "Your eval job timed out or ran out of memory" + elif returncode == 255: + msg = "Your eval job has failed" + error = "A fatal NsJail error occurred" + else: + # Try to append signal's name if one exists + try: + name = Signals(returncode - 128).name + msg = f"{msg} ({name})" + except ValueError: + pass + + return msg, error + + @staticmethod + def get_status_emoji(results: dict) -> str: + """Return an emoji corresponding to the status code or lack of output in result.""" + if not results["stdout"].strip(): # No output + return ":warning:" + elif results["returncode"] == 0: # No error + return ":white_check_mark:" + else: # Exception + return ":x:" + + async def format_output(self, output: str) -> Tuple[str, Optional[str]]: + """ + Format the output and return a tuple of the formatted output and a URL to the full output. + + Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters + and upload the full output to a paste service. + """ + log.trace("Formatting output...") + + output = output.rstrip("\n") + original_output = output # To be uploaded to a pasting service if needed + paste_link = None + + if "<@" in output: + output = output.replace("<@", "<@\u200B") # Zero-width space + + if " 0: + output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] + output = output[:11] # Limiting to only 11 lines + output = "\n".join(output) + + if lines > 10: + truncated = True + if len(output) >= 1000: + output = f"{output[:1000]}\n... (truncated - too long, too many lines)" + else: + output = f"{output}\n... (truncated - too many lines)" + elif len(output) >= 1000: + truncated = True + output = f"{output[:1000]}\n... (truncated - too long)" + + if truncated: + paste_link = await self.upload_output(original_output) + + output = output or "[No output]" + + return output, paste_link + + async def send_eval(self, ctx: Context, code: str) -> Message: + """ + Evaluate code, format it, and send the output to the corresponding channel. + + Return the bot response. + """ + async with ctx.typing(): + results = await self.post_eval(code) + msg, error = self.get_results_message(results) + + if error: + output, paste_link = error, None + else: + output, paste_link = await self.format_output(results["stdout"]) + + icon = self.get_status_emoji(results) + msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" + if paste_link: + msg = f"{msg}\nFull output: {paste_link}" + + # Collect stats of eval fails + successes + if icon == ":x:": + self.bot.stats.incr("snekbox.python.fail") + else: + self.bot.stats.incr("snekbox.python.success") + + filter_cog = self.bot.get_cog("Filtering") + filter_triggered = False + if filter_cog: + filter_triggered = await filter_cog.filter_eval(msg, ctx.message) + if filter_triggered: + response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + else: + response = await ctx.send(msg) + self.bot.loop.create_task( + wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) + ) + + log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") + return response + + async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + """ + Check if the eval session should continue. + + Return the new code to evaluate or None if the eval session should be terminated. + """ + _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + + with contextlib.suppress(NotFound): + try: + _, new_message = await self.bot.wait_for( + 'message_edit', + check=_predicate_eval_message_edit, + timeout=REEVAL_TIMEOUT + ) + await ctx.message.add_reaction(REEVAL_EMOJI) + await self.bot.wait_for( + 'reaction_add', + check=_predicate_emoji_reaction, + timeout=10 + ) + + code = await self.get_code(new_message) + await ctx.message.clear_reactions() + with contextlib.suppress(HTTPException): + await response.delete() + + except asyncio.TimeoutError: + await ctx.message.clear_reactions() + return None + + return code + + async def get_code(self, message: Message) -> Optional[str]: + """ + Return the code from `message` to be evaluated. + + If the message is an invocation of the eval command, return the first argument or None if it + doesn't exist. Otherwise, return the full content of the message. + """ + log.trace(f"Getting context for message {message.id}.") + new_ctx = await self.bot.get_context(message) + + if new_ctx.command is self.eval_command: + log.trace(f"Message {message.id} invokes eval command.") + split = message.content.split(maxsplit=1) + code = split[1] if len(split) > 1 else None + else: + log.trace(f"Message {message.id} does not invoke eval command.") + code = message.content + + return code + + @command(name="eval", aliases=("e",)) + @guild_only() + @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) + async def eval_command(self, ctx: Context, *, code: str = None) -> None: + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + if ctx.author.id in self.jobs: + await ctx.send( + f"{ctx.author.mention} You've already got a job running - " + "please wait for it to finish!" + ) + return + + if not code: # None or empty string + await ctx.send_help(ctx.command) + return + + if Roles.helpers in (role.id for role in ctx.author.roles): + self.bot.stats.incr("snekbox_usages.roles.helpers") + else: + self.bot.stats.incr("snekbox_usages.roles.developers") + + if ctx.channel.category_id == Categories.help_in_use: + self.bot.stats.incr("snekbox_usages.channels.help") + elif ctx.channel.id == Channels.bot_commands: + self.bot.stats.incr("snekbox_usages.channels.bot_commands") + else: + self.bot.stats.incr("snekbox_usages.channels.topical") + + log.info(f"Received code from {ctx.author} for evaluation:\n{code}") + + while True: + self.jobs[ctx.author.id] = datetime.datetime.now() + code = self.prepare_input(code) + try: + response = await self.send_eval(ctx, code) + finally: + del self.jobs[ctx.author.id] + + code = await self.continue_eval(ctx, response) + if not code: + break + log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") + + +def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: + """Return True if the edited message is the context message and the content was indeed modified.""" + return new_msg.id == ctx.message.id and old_msg.content != new_msg.content + + +def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI + + +def setup(bot: Bot) -> None: + """Load the Snekbox cog.""" + bot.add_cog(Snekbox(bot)) diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py new file mode 100644 index 000000000..d96abbd5a --- /dev/null +++ b/bot/exts/utils/utils.py @@ -0,0 +1,265 @@ +import difflib +import logging +import re +import unicodedata +from email.parser import HeaderParser +from io import StringIO +from typing import Tuple, Union + +from discord import Colour, Embed, utils +from discord.ext.commands import BadArgument, Cog, Context, clean_content, command + +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES +from bot.decorators import in_whitelist, with_role +from bot.pagination import LinePaginator +from bot.utils import messages + +log = logging.getLogger(__name__) + +ZEN_OF_PYTHON = """\ +Beautiful is better than ugly. +Explicit is better than implicit. +Simple is better than complex. +Complex is better than complicated. +Flat is better than nested. +Sparse is better than dense. +Readability counts. +Special cases aren't special enough to break the rules. +Although practicality beats purity. +Errors should never pass silently. +Unless explicitly silenced. +In the face of ambiguity, refuse the temptation to guess. +There should be one-- and preferably only one --obvious way to do it. +Although that way may not be obvious at first unless you're Dutch. +Now is better than never. +Although never is often better than *right* now. +If the implementation is hard to explain, it's a bad idea. +If the implementation is easy to explain, it may be a good idea. +Namespaces are one honking great idea -- let's do more of those! +""" + +ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" + + +class Utils(Cog): + """A selection of utilities which don't have a clear category.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.base_pep_url = "http://www.python.org/dev/peps/pep-" + self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" + + @command(name='pep', aliases=('get_pep', 'p')) + async def pep_command(self, ctx: Context, pep_number: str) -> None: + """Fetches information about a PEP and sends it to the channel.""" + if pep_number.isdigit(): + pep_number = int(pep_number) + else: + await ctx.send_help(ctx.command) + return + + # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. + if pep_number == 0: + return await self.send_pep_zero(ctx) + + possible_extensions = ['.txt', '.rst'] + found_pep = False + for extension in possible_extensions: + # Attempt to fetch the PEP + pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" + log.trace(f"Requesting PEP {pep_number} with {pep_url}") + response = await self.bot.http_session.get(pep_url) + + if response.status == 200: + log.trace("PEP found") + found_pep = True + + pep_content = await response.text() + + # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 + pep_header = HeaderParser().parse(StringIO(pep_content)) + + # Assemble the embed + pep_embed = Embed( + title=f"**PEP {pep_number} - {pep_header['Title']}**", + description=f"[Link]({self.base_pep_url}{pep_number:04})", + ) + + pep_embed.set_thumbnail(url=ICON_URL) + + # Add the interesting information + fields_to_check = ("Status", "Python-Version", "Created", "Type") + for field in fields_to_check: + # Check for a PEP metadata field that is present but has an empty value + # embed field values can't contain an empty string + if pep_header.get(field, ""): + pep_embed.add_field(name=field, value=pep_header[field]) + + elif response.status != 404: + # any response except 200 and 404 is expected + found_pep = True # actually not, but it's easier to display this way + log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " + f"{response.status}.\n{response.text}") + + error_message = "Unexpected HTTP error during PEP search. Please let us know." + pep_embed = Embed(title="Unexpected error", description=error_message) + pep_embed.colour = Colour.red() + break + + if not found_pep: + log.trace("PEP was not found") + not_found = f"PEP {pep_number} does not exist." + pep_embed = Embed(title="PEP not found", description=not_found) + pep_embed.colour = Colour.red() + + await ctx.message.channel.send(embed=pep_embed) + + @command() + @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) + async def charinfo(self, ctx: Context, *, characters: str) -> None: + """Shows you information on up to 50 unicode characters.""" + match = re.match(r"<(a?):(\w+):(\d+)>", characters) + if match: + return await messages.send_denial( + ctx, + "**Non-Character Detected**\n" + "Only unicode characters can be processed, but a custom Discord emoji " + "was found. Please remove it and try again." + ) + + if len(characters) > 50: + return await messages.send_denial(ctx, f"Too many characters ({len(characters)}/50)") + + def get_info(char: str) -> Tuple[str, str]: + digit = f"{ord(char):x}" + if len(digit) <= 4: + u_code = f"\\u{digit:>04}" + else: + u_code = f"\\U{digit:>08}" + url = f"https://www.compart.com/en/unicode/U+{digit:>04}" + name = f"[{unicodedata.name(char, '')}]({url})" + info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}" + return info, u_code + + char_list, raw_list = zip(*(get_info(c) for c in characters)) + embed = Embed().set_author(name="Character Info") + + if len(characters) > 1: + # Maximum length possible is 502 out of 1024, so there's no need to truncate. + embed.add_field(name='Full Raw Text', value=f"`{''.join(raw_list)}`", inline=False) + + await LinePaginator.paginate(char_list, ctx, embed, max_lines=10, max_size=2000, empty=False) + + @command() + async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: + """ + Show the Zen of Python. + + Without any arguments, the full Zen will be produced. + If an integer is provided, the line with that index will be produced. + If a string is provided, the line which matches best will be produced. + """ + embed = Embed( + colour=Colour.blurple(), + title="The Zen of Python", + description=ZEN_OF_PYTHON + ) + + if search_value is None: + embed.title += ", by Tim Peters" + await ctx.send(embed=embed) + return + + zen_lines = ZEN_OF_PYTHON.splitlines() + + # handle if it's an index int + if isinstance(search_value, int): + upper_bound = len(zen_lines) - 1 + lower_bound = -1 * upper_bound + if not (lower_bound <= search_value <= upper_bound): + raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") + + embed.title += f" (line {search_value % len(zen_lines)}):" + embed.description = zen_lines[search_value] + await ctx.send(embed=embed) + return + + # Try to handle first exact word due difflib.SequenceMatched may use some other similar word instead + # exact word. + for i, line in enumerate(zen_lines): + for word in line.split(): + if word.lower() == search_value.lower(): + embed.title += f" (line {i}):" + embed.description = line + await ctx.send(embed=embed) + return + + # handle if it's a search string and not exact word + matcher = difflib.SequenceMatcher(None, search_value.lower()) + + best_match = "" + match_index = 0 + best_ratio = 0 + + for index, line in enumerate(zen_lines): + matcher.set_seq2(line.lower()) + + # the match ratio needs to be adjusted because, naturally, + # longer lines will have worse ratios than shorter lines when + # fuzzy searching for keywords. this seems to work okay. + adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() + + if adjusted_ratio > best_ratio: + best_ratio = adjusted_ratio + best_match = line + match_index = index + + if not best_match: + raise BadArgument("I didn't get a match! Please try again with a different search term.") + + embed.title += f" (line {match_index}):" + embed.description = best_match + await ctx.send(embed=embed) + + @command(aliases=("poll",)) + @with_role(*MODERATION_ROLES) + async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None: + """ + Build a quick voting poll with matching reactions with the provided options. + + A maximum of 20 options can be provided, as Discord supports a max of 20 + reactions on a single message. + """ + if len(title) > 256: + raise BadArgument("The title cannot be longer than 256 characters.") + if len(options) < 2: + raise BadArgument("Please provide at least 2 options.") + if len(options) > 20: + raise BadArgument("I can only handle 20 options!") + + codepoint_start = 127462 # represents "regional_indicator_a" unicode value + options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} + embed = Embed(title=title, description="\n".join(options.values())) + message = await ctx.send(embed=embed) + for reaction in options: + await message.add_reaction(reaction) + + async def send_pep_zero(self, ctx: Context) -> None: + """Send information about PEP 0.""" + pep_embed = Embed( + title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**", + description="[Link](https://www.python.org/dev/peps/)" + ) + pep_embed.set_thumbnail(url=ICON_URL) + pep_embed.add_field(name="Status", value="Active") + pep_embed.add_field(name="Created", value="13-Jul-2000") + pep_embed.add_field(name="Type", value="Informational") + + await ctx.send(embed=pep_embed) + + +def setup(bot: Bot) -> None: + """Load the Utils cog.""" + bot.add_cog(Utils(bot)) diff --git a/tests/bot/cogs/__init__.py b/tests/bot/cogs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/backend/__init__.py b/tests/bot/cogs/backend/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/backend/sync/__init__.py b/tests/bot/cogs/backend/sync/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/backend/sync/test_base.py b/tests/bot/cogs/backend/sync/test_base.py deleted file mode 100644 index 3009aacb6..000000000 --- a/tests/bot/cogs/backend/sync/test_base.py +++ /dev/null @@ -1,404 +0,0 @@ -import asyncio -import unittest -from unittest import mock - -import discord - -from bot import constants -from bot.api import ResponseCodeError -from bot.cogs.backend.sync._syncers import Syncer, _Diff -from tests import helpers - - -class TestSyncer(Syncer): - """Syncer subclass with mocks for abstract methods for testing purposes.""" - - name = "test" - _get_diff = mock.AsyncMock() - _sync = mock.AsyncMock() - - -class SyncerBaseTests(unittest.TestCase): - """Tests for the syncer base class.""" - - def setUp(self): - self.bot = helpers.MockBot() - - def test_instantiation_fails_without_abstract_methods(self): - """The class must have abstract methods implemented.""" - with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): - Syncer(self.bot) - - -class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): - """Tests for sending the sync confirmation prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - - def mock_get_channel(self): - """Fixture to return a mock channel and message for when `get_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - mock_channel.send.return_value = mock_message - self.bot.get_channel.return_value = mock_channel - - return mock_channel, mock_message - - def mock_fetch_channel(self): - """Fixture to return a mock channel and message for when `fetch_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - self.bot.get_channel.return_value = None - mock_channel.send.return_value = mock_message - self.bot.fetch_channel.return_value = mock_channel - - return mock_channel, mock_message - - async def test_send_prompt_edits_and_returns_message(self): - """The given message should be edited to display the prompt and then should be returned.""" - msg = helpers.MockMessage() - ret_val = await self.syncer._send_prompt(msg) - - msg.edit.assert_called_once() - self.assertIn("content", msg.edit.call_args[1]) - self.assertEqual(ret_val, msg) - - async def test_send_prompt_gets_dev_core_channel(self): - """The dev-core channel should be retrieved if an extant message isn't given.""" - subtests = ( - (self.bot.get_channel, self.mock_get_channel), - (self.bot.fetch_channel, self.mock_fetch_channel), - ) - - for method, mock_ in subtests: - with self.subTest(method=method, msg=mock_.__name__): - mock_() - await self.syncer._send_prompt() - - method.assert_called_once_with(constants.Channels.dev_core) - - async def test_send_prompt_returns_none_if_channel_fetch_fails(self): - """None should be returned if there's an HTTPException when fetching the channel.""" - self.bot.get_channel.return_value = None - self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") - - ret_val = await self.syncer._send_prompt() - - self.assertIsNone(ret_val) - - async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): - """A new message mentioning core devs should be sent and returned if message isn't given.""" - for mock_ in (self.mock_get_channel, self.mock_fetch_channel): - with self.subTest(msg=mock_.__name__): - mock_channel, mock_message = mock_() - ret_val = await self.syncer._send_prompt() - - mock_channel.send.assert_called_once() - self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) - self.assertEqual(ret_val, mock_message) - - async def test_send_prompt_adds_reactions(self): - """The message should have reactions for confirmation added.""" - extant_message = helpers.MockMessage() - subtests = ( - (extant_message, lambda: (None, extant_message)), - (None, self.mock_get_channel), - (None, self.mock_fetch_channel), - ) - - for message_arg, mock_ in subtests: - subtest_msg = "Extant message" if mock_.__name__ == "" else mock_.__name__ - - with self.subTest(msg=subtest_msg): - _, mock_message = mock_() - await self.syncer._send_prompt(message_arg) - - calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] - mock_message.add_reaction.assert_has_calls(calls) - - -class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): - """Tests for waiting for a sync confirmation reaction on the prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) - - @staticmethod - def get_message_reaction(emoji): - """Fixture to return a mock message an reaction from the given `emoji`.""" - message = helpers.MockMessage() - reaction = helpers.MockReaction(emoji=emoji, message=message) - - return message, reaction - - def test_reaction_check_for_valid_emoji_and_authors(self): - """Should return True if authors are identical or are a bot and a core dev, respectively.""" - user_subtests = ( - ( - helpers.MockMember(id=77), - helpers.MockMember(id=77), - "identical users", - ), - ( - helpers.MockMember(id=77, bot=True), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "bot author and core-dev reactor", - ), - ) - - for emoji in self.syncer._REACTION_EMOJIS: - for author, user, msg in user_subtests: - with self.subTest(author=author, user=user, emoji=emoji, msg=msg): - message, reaction = self.get_message_reaction(emoji) - ret_val = self.syncer._reaction_check(author, message, reaction, user) - - self.assertTrue(ret_val) - - def test_reaction_check_for_invalid_reactions(self): - """Should return False for invalid reaction events.""" - valid_emoji = self.syncer._REACTION_EMOJIS[0] - subtests = ( - ( - helpers.MockMember(id=77), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "users are not identical", - ), - ( - helpers.MockMember(id=77, bot=True), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43), - "reactor lacks the core-dev role", - ), - ( - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - "reactor is a bot", - ), - ( - helpers.MockMember(id=77), - helpers.MockMessage(id=95), - helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), - helpers.MockMember(id=77), - "messages are not identical", - ), - ( - helpers.MockMember(id=77), - *self.get_message_reaction("InVaLiD"), - helpers.MockMember(id=77), - "emoji is invalid", - ), - ) - - for *args, msg in subtests: - kwargs = dict(zip(("author", "message", "reaction", "user"), args)) - with self.subTest(**kwargs, msg=msg): - ret_val = self.syncer._reaction_check(*args) - self.assertFalse(ret_val) - - async def test_wait_for_confirmation(self): - """The message should always be edited and only return True if the emoji is a check mark.""" - subtests = ( - (constants.Emojis.check_mark, True, None), - ("InVaLiD", False, None), - (None, False, asyncio.TimeoutError), - ) - - for emoji, ret_val, side_effect in subtests: - for bot in (True, False): - with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): - # Set up mocks - message = helpers.MockMessage() - member = helpers.MockMember(bot=bot) - - self.bot.wait_for.reset_mock() - self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) - self.bot.wait_for.side_effect = side_effect - - # Call the function - actual_return = await self.syncer._wait_for_confirmation(member, message) - - # Perform assertions - self.bot.wait_for.assert_called_once() - self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) - - message.edit.assert_called_once() - kwargs = message.edit.call_args[1] - self.assertIn("content", kwargs) - - # Core devs should only be mentioned if the author is a bot. - if bot: - self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - else: - self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - - self.assertIs(actual_return, ret_val) - - -class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for main function orchestrating the sync.""" - - def setUp(self): - self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) - self.syncer = TestSyncer(self.bot) - - async def test_sync_respects_confirmation_result(self): - """The sync should abort if confirmation fails and continue if confirmed.""" - mock_message = helpers.MockMessage() - subtests = ( - (True, mock_message), - (False, None), - ) - - for confirmed, message in subtests: - with self.subTest(confirmed=confirmed): - self.syncer._sync.reset_mock() - self.syncer._get_diff.reset_mock() - - diff = _Diff({1, 2, 3}, {4, 5}, None) - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(confirmed, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - - if confirmed: - self.syncer._sync.assert_called_once_with(diff) - else: - self.syncer._sync.assert_not_called() - - async def test_sync_diff_size(self): - """The diff size should be correctly calculated.""" - subtests = ( - (6, _Diff({1, 2}, {3, 4}, {5, 6})), - (5, _Diff({1, 2, 3}, None, {4, 5})), - (0, _Diff(None, None, None)), - (0, _Diff(set(), set(), set())), - ) - - for size, diff in subtests: - with self.subTest(size=size, diff=diff): - self.syncer._get_diff.reset_mock() - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) - - async def test_sync_message_edited(self): - """The message should be edited if one was sent, even if the sync has an API error.""" - subtests = ( - (None, None, False), - (helpers.MockMessage(), None, True), - (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), - ) - - for message, side_effect, should_edit in subtests: - with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): - self.syncer._sync.side_effect = side_effect - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(True, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - if should_edit: - message.edit.assert_called_once() - self.assertIn("content", message.edit.call_args[1]) - - async def test_sync_confirmation_context_redirect(self): - """If ctx is given, a new message should be sent and author should be ctx's author.""" - mock_member = helpers.MockMember() - subtests = ( - (None, self.bot.user, None), - (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), - ) - - for ctx, author, message in subtests: - with self.subTest(ctx=ctx, author=author, message=message): - if ctx is not None: - ctx.send.return_value = message - - # Make sure `_get_diff` returns a MagicMock, not an AsyncMock - self.syncer._get_diff.return_value = mock.MagicMock() - - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild, ctx) - - if ctx is not None: - ctx.send.assert_called_once() - - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_small_diff(self): - """Should always return True and the given message if the diff size is too small.""" - author = helpers.MockMember() - expected_message = helpers.MockMessage() - - for size in (3, 2): # pragma: no cover - with self.subTest(size=size): - self.syncer._send_prompt = mock.AsyncMock() - self.syncer._wait_for_confirmation = mock.AsyncMock() - - coro = self.syncer._get_confirmation_result(size, author, expected_message) - result, actual_message = await coro - - self.assertTrue(result) - self.assertEqual(actual_message, expected_message) - self.syncer._send_prompt.assert_not_called() - self.syncer._wait_for_confirmation.assert_not_called() - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_large_diff(self): - """Should return True if confirmed and False if _send_prompt fails or aborted.""" - author = helpers.MockMember() - mock_message = helpers.MockMessage() - - subtests = ( - (True, mock_message, True, "confirmed"), - (False, None, False, "_send_prompt failed"), - (False, mock_message, False, "aborted"), - ) - - for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover - with self.subTest(msg=msg): - self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) - self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) - - coro = self.syncer._get_confirmation_result(4, author) - actual_result, actual_message = await coro - - self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None - self.assertIs(actual_result, expected_result) - self.assertEqual(actual_message, expected_message) - - if expected_message: - self.syncer._wait_for_confirmation.assert_called_once_with( - author, expected_message - ) diff --git a/tests/bot/cogs/backend/sync/test_cog.py b/tests/bot/cogs/backend/sync/test_cog.py deleted file mode 100644 index e40552817..000000000 --- a/tests/bot/cogs/backend/sync/test_cog.py +++ /dev/null @@ -1,416 +0,0 @@ -import unittest -from unittest import mock - -import discord - -from bot import constants -from bot.api import ResponseCodeError -from bot.cogs.backend import sync -from bot.cogs.backend.sync._cog import Sync -from bot.cogs.backend.sync._syncers import Syncer -from tests import helpers -from tests.base import CommandTestCase - - -class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): - """Tests for the sync extension.""" - - @staticmethod - def test_extension_setup(): - """The Sync cog should be added.""" - bot = helpers.MockBot() - sync.setup(bot) - bot.add_cog.assert_called_once() - - -class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): - """Base class for Sync cog tests. Sets up patches for syncers.""" - - def setUp(self): - self.bot = helpers.MockBot() - - self.role_syncer_patcher = mock.patch( - "bot.cogs.backend.sync._syncers.RoleSyncer", - autospec=Syncer, - spec_set=True - ) - self.user_syncer_patcher = mock.patch( - "bot.cogs.backend.sync._syncers.UserSyncer", - autospec=Syncer, - spec_set=True - ) - self.RoleSyncer = self.role_syncer_patcher.start() - self.UserSyncer = self.user_syncer_patcher.start() - - self.cog = Sync(self.bot) - - def tearDown(self): - self.role_syncer_patcher.stop() - self.user_syncer_patcher.stop() - - @staticmethod - def response_error(status: int) -> ResponseCodeError: - """Fixture to return a ResponseCodeError with the given status code.""" - response = mock.MagicMock() - response.status = status - - return ResponseCodeError(response) - - -class SyncCogTests(SyncCogTestCase): - """Tests for the Sync cog.""" - - @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) - def test_sync_cog_init(self, sync_guild): - """Should instantiate syncers and run a sync for the guild.""" - # Reset because a Sync cog was already instantiated in setUp. - self.RoleSyncer.reset_mock() - self.UserSyncer.reset_mock() - self.bot.loop.create_task = mock.MagicMock() - - mock_sync_guild_coro = mock.MagicMock() - sync_guild.return_value = mock_sync_guild_coro - - Sync(self.bot) - - self.RoleSyncer.assert_called_once_with(self.bot) - self.UserSyncer.assert_called_once_with(self.bot) - sync_guild.assert_called_once_with() - self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) - - async def test_sync_cog_sync_guild(self): - """Roles and users should be synced only if a guild is successfully retrieved.""" - for guild in (helpers.MockGuild(), None): - with self.subTest(guild=guild): - self.bot.reset_mock() - self.cog.role_syncer.reset_mock() - self.cog.user_syncer.reset_mock() - - self.bot.get_guild = mock.MagicMock(return_value=guild) - - await self.cog.sync_guild() - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.get_guild.assert_called_once_with(constants.Guild.id) - - if guild is None: - self.cog.role_syncer.sync.assert_not_called() - self.cog.user_syncer.sync.assert_not_called() - else: - self.cog.role_syncer.sync.assert_called_once_with(guild) - self.cog.user_syncer.sync.assert_called_once_with(guild) - - async def patch_user_helper(self, side_effect: BaseException) -> None: - """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" - self.bot.api_client.patch.reset_mock(side_effect=True) - self.bot.api_client.patch.side_effect = side_effect - - user_id, updated_information = 5, {"key": 123} - await self.cog.patch_user(user_id, updated_information) - - self.bot.api_client.patch.assert_called_once_with( - f"bot/users/{user_id}", - json=updated_information, - ) - - async def test_sync_cog_patch_user(self): - """A PATCH request should be sent and 404 errors ignored.""" - for side_effect in (None, self.response_error(404)): - with self.subTest(side_effect=side_effect): - await self.patch_user_helper(side_effect) - - async def test_sync_cog_patch_user_non_404(self): - """A PATCH request should be sent and the error raised if it's not a 404.""" - with self.assertRaises(ResponseCodeError): - await self.patch_user_helper(self.response_error(500)) - - -class SyncCogListenerTests(SyncCogTestCase): - """Tests for the listeners of the Sync cog.""" - - def setUp(self): - super().setUp() - self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - - self.guild_id_patcher = mock.patch("bot.cogs.backend.sync._cog.constants.Guild.id", 5) - self.guild_id = self.guild_id_patcher.start() - - self.guild = helpers.MockGuild(id=self.guild_id) - self.other_guild = helpers.MockGuild(id=0) - - def tearDown(self): - self.guild_id_patcher.stop() - - async def test_sync_cog_on_guild_role_create(self): - """A POST request should be sent with the new role's data.""" - self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) - - role_data = { - "colour": 49, - "id": 777, - "name": "rolename", - "permissions": 8, - "position": 23, - } - role = helpers.MockRole(**role_data, guild=self.guild) - await self.cog.on_guild_role_create(role) - - self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) - - async def test_sync_cog_on_guild_role_create_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_create(role) - self.bot.api_client.post.assert_not_awaited() - - async def test_sync_cog_on_guild_role_delete(self): - """A DELETE request should be sent.""" - self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) - - role = helpers.MockRole(id=99, guild=self.guild) - await self.cog.on_guild_role_delete(role) - - self.bot.api_client.delete.assert_called_once_with("bot/roles/99") - - async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_delete(role) - self.bot.api_client.delete.assert_not_awaited() - - async def test_sync_cog_on_guild_role_update(self): - """A PUT request should be sent if the colour, name, permissions, or position changes.""" - self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) - - role_data = { - "colour": 49, - "id": 777, - "name": "rolename", - "permissions": 8, - "position": 23, - } - subtests = ( - (True, ("colour", "name", "permissions", "position")), - (False, ("hoist", "mentionable")), - ) - - for should_put, attributes in subtests: - for attribute in attributes: - with self.subTest(should_put=should_put, changed_attribute=attribute): - self.bot.api_client.put.reset_mock() - - after_role_data = role_data.copy() - after_role_data[attribute] = 876 - - before_role = helpers.MockRole(**role_data, guild=self.guild) - after_role = helpers.MockRole(**after_role_data, guild=self.guild) - - await self.cog.on_guild_role_update(before_role, after_role) - - if should_put: - self.bot.api_client.put.assert_called_once_with( - f"bot/roles/{after_role.id}", - json=after_role_data - ) - else: - self.bot.api_client.put.assert_not_called() - - async def test_sync_cog_on_guild_role_update_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_update(role, role) - self.bot.api_client.put.assert_not_awaited() - - async def test_sync_cog_on_member_remove(self): - """Member should be patched to set in_guild as False.""" - self.assertTrue(self.cog.on_member_remove.__cog_listener__) - - member = helpers.MockMember(guild=self.guild) - await self.cog.on_member_remove(member) - - self.cog.patch_user.assert_called_once_with( - member.id, - json={"in_guild": False} - ) - - async def test_sync_cog_on_member_remove_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_remove(member) - self.cog.patch_user.assert_not_awaited() - - async def test_sync_cog_on_member_update_roles(self): - """Members should be patched if their roles have changed.""" - self.assertTrue(self.cog.on_member_update.__cog_listener__) - - # Roles are intentionally unsorted. - before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] - before_member = helpers.MockMember(roles=before_roles, guild=self.guild) - after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) - - await self.cog.on_member_update(before_member, after_member) - - data = {"roles": sorted(role.id for role in after_member.roles)} - self.cog.patch_user.assert_called_once_with(after_member.id, json=data) - - async def test_sync_cog_on_member_update_other(self): - """Members should not be patched if other attributes have changed.""" - self.assertTrue(self.cog.on_member_update.__cog_listener__) - - subtests = ( - ("activities", discord.Game("Pong"), discord.Game("Frogger")), - ("nick", "old nick", "new nick"), - ("status", discord.Status.online, discord.Status.offline), - ) - - for attribute, old_value, new_value in subtests: - with self.subTest(attribute=attribute): - self.cog.patch_user.reset_mock() - - before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) - after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) - - await self.cog.on_member_update(before_member, after_member) - - self.cog.patch_user.assert_not_called() - - async def test_sync_cog_on_member_update_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_update(member, member) - self.cog.patch_user.assert_not_awaited() - - async def test_sync_cog_on_user_update(self): - """A user should be patched only if the name, discriminator, or avatar changes.""" - self.assertTrue(self.cog.on_user_update.__cog_listener__) - - before_data = { - "name": "old name", - "discriminator": "1234", - "bot": False, - } - - subtests = ( - (True, "name", "name", "new name", "new name"), - (True, "discriminator", "discriminator", "8765", 8765), - (False, "bot", "bot", True, True), - ) - - for should_patch, attribute, api_field, value, api_value in subtests: - with self.subTest(attribute=attribute): - self.cog.patch_user.reset_mock() - - after_data = before_data.copy() - after_data[attribute] = value - before_user = helpers.MockUser(**before_data) - after_user = helpers.MockUser(**after_data) - - await self.cog.on_user_update(before_user, after_user) - - if should_patch: - self.cog.patch_user.assert_called_once() - - # Don't care if *all* keys are present; only the changed one is required - call_args = self.cog.patch_user.call_args - self.assertEqual(call_args.args[0], after_user.id) - self.assertIn("json", call_args.kwargs) - - self.assertIn("ignore_404", call_args.kwargs) - self.assertTrue(call_args.kwargs["ignore_404"]) - - json = call_args.kwargs["json"] - self.assertIn(api_field, json) - self.assertEqual(json[api_field], api_value) - else: - self.cog.patch_user.assert_not_called() - - async def on_member_join_helper(self, side_effect: Exception) -> dict: - """ - Helper to set `side_effect` for on_member_join and assert a PUT request was sent. - - The request data for the mock member is returned. All exceptions will be re-raised. - """ - member = helpers.MockMember( - discriminator="1234", - roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], - guild=self.guild, - ) - - data = { - "discriminator": int(member.discriminator), - "id": member.id, - "in_guild": True, - "name": member.name, - "roles": sorted(role.id for role in member.roles) - } - - self.bot.api_client.put.reset_mock(side_effect=True) - self.bot.api_client.put.side_effect = side_effect - - try: - await self.cog.on_member_join(member) - except Exception: - raise - finally: - self.bot.api_client.put.assert_called_once_with( - f"bot/users/{member.id}", - json=data - ) - - return data - - async def test_sync_cog_on_member_join(self): - """Should PUT user's data or POST it if the user doesn't exist.""" - for side_effect in (None, self.response_error(404)): - with self.subTest(side_effect=side_effect): - self.bot.api_client.post.reset_mock() - data = await self.on_member_join_helper(side_effect) - - if side_effect: - self.bot.api_client.post.assert_called_once_with("bot/users", json=data) - else: - self.bot.api_client.post.assert_not_called() - - async def test_sync_cog_on_member_join_non_404(self): - """ResponseCodeError should be re-raised if status code isn't a 404.""" - with self.assertRaises(ResponseCodeError): - await self.on_member_join_helper(self.response_error(500)) - - self.bot.api_client.post.assert_not_called() - - async def test_sync_cog_on_member_join_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_join(member) - self.bot.api_client.post.assert_not_awaited() - self.bot.api_client.put.assert_not_awaited() - - -class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): - """Tests for the commands in the Sync cog.""" - - async def test_sync_roles_command(self): - """sync() should be called on the RoleSyncer.""" - ctx = helpers.MockContext() - await self.cog.sync_roles_command.callback(self.cog, ctx) - - self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) - - async def test_sync_users_command(self): - """sync() should be called on the UserSyncer.""" - ctx = helpers.MockContext() - await self.cog.sync_users_command.callback(self.cog, ctx) - - self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) - - async def test_commands_require_admin(self): - """The sync commands should only run if the author has the administrator permission.""" - cmds = ( - self.cog.sync_group, - self.cog.sync_roles_command, - self.cog.sync_users_command, - ) - - for cmd in cmds: - with self.subTest(cmd=cmd): - await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/backend/sync/test_roles.py b/tests/bot/cogs/backend/sync/test_roles.py deleted file mode 100644 index 99d682ede..000000000 --- a/tests/bot/cogs/backend/sync/test_roles.py +++ /dev/null @@ -1,157 +0,0 @@ -import unittest -from unittest import mock - -import discord - -from bot.cogs.backend.sync._syncers import RoleSyncer, _Diff, _Role -from tests import helpers - - -def fake_role(**kwargs): - """Fixture to return a dictionary representing a role with default values set.""" - kwargs.setdefault("id", 9) - kwargs.setdefault("name", "fake role") - kwargs.setdefault("colour", 7) - kwargs.setdefault("permissions", 0) - kwargs.setdefault("position", 55) - - return kwargs - - -class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): - """Tests for determining differences between roles in the DB and roles in the Guild cache.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) - - @staticmethod - def get_guild(*roles): - """Fixture to return a guild object with the given roles.""" - guild = helpers.MockGuild() - guild.roles = [] - - for role in roles: - mock_role = helpers.MockRole(**role) - mock_role.colour = discord.Colour(role["colour"]) - mock_role.permissions = discord.Permissions(role["permissions"]) - guild.roles.append(mock_role) - - return guild - - async def test_empty_diff_for_identical_roles(self): - """No differences should be found if the roles in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_role()] - guild = self.get_guild(fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_updated_roles(self): - """Only updated roles should be added to the 'updated' set of the diff.""" - updated_role = fake_role(id=41, name="new") - - self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] - guild = self.get_guild(updated_role, fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_Role(**updated_role)}, set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_roles(self): - """Only new roles should be added to the 'created' set of the diff.""" - new_role = fake_role(id=41, name="new") - - self.bot.api_client.get.return_value = [fake_role()] - guild = self.get_guild(fake_role(), new_role) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_Role(**new_role)}, set(), set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_deleted_roles(self): - """Only deleted roles should be added to the 'deleted' set of the diff.""" - deleted_role = fake_role(id=61, name="deleted") - - self.bot.api_client.get.return_value = [fake_role(), deleted_role] - guild = self.get_guild(fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), {_Role(**deleted_role)}) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_updated_and_deleted_roles(self): - """When roles are added, updated, and removed, all of them are returned properly.""" - new = fake_role(id=41, name="new") - updated = fake_role(id=71, name="updated") - deleted = fake_role(id=61, name="deleted") - - self.bot.api_client.get.return_value = [ - fake_role(), - fake_role(id=71, name="updated name"), - deleted, - ] - guild = self.get_guild(fake_role(), new, updated) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) - - self.assertEqual(actual_diff, expected_diff) - - -class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for the API requests that sync roles.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) - - async def test_sync_created_roles(self): - """Only POST requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(role_tuples, set(), set()) - await self.syncer._sync(diff) - - calls = [mock.call("bot/roles", json=role) for role in roles] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(roles)) - - self.bot.api_client.put.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_updated_roles(self): - """Only PUT requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(set(), role_tuples, set()) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(roles)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_deleted_roles(self): - """Only DELETE requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(set(), set(), role_tuples) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] - self.bot.api_client.delete.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/backend/sync/test_users.py b/tests/bot/cogs/backend/sync/test_users.py deleted file mode 100644 index 51dcbe48a..000000000 --- a/tests/bot/cogs/backend/sync/test_users.py +++ /dev/null @@ -1,158 +0,0 @@ -import unittest -from unittest import mock - -from bot.cogs.backend.sync._syncers import UserSyncer, _Diff, _User -from tests import helpers - - -def fake_user(**kwargs): - """Fixture to return a dictionary representing a user with default values set.""" - kwargs.setdefault("id", 43) - kwargs.setdefault("name", "bob the test man") - kwargs.setdefault("discriminator", 1337) - kwargs.setdefault("roles", (666,)) - kwargs.setdefault("in_guild", True) - - return kwargs - - -class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): - """Tests for determining differences between users in the DB and users in the Guild cache.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) - - @staticmethod - def get_guild(*members): - """Fixture to return a guild object with the given members.""" - guild = helpers.MockGuild() - guild.members = [] - - for member in members: - member = member.copy() - del member["in_guild"] - - mock_member = helpers.MockMember(**member) - mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] - - guild.members.append(mock_member) - - return guild - - async def test_empty_diff_for_no_users(self): - """When no users are given, an empty diff should be returned.""" - guild = self.get_guild() - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_empty_diff_for_identical_users(self): - """No differences should be found if the users in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_user()] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_updated_users(self): - """Only updated users should be added to the 'updated' set of the diff.""" - updated_user = fake_user(id=99, name="new") - - self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] - guild = self.get_guild(updated_user, fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**updated_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_users(self): - """Only new users should be added to the 'created' set of the diff.""" - new_user = fake_user(id=99, name="new") - - self.bot.api_client.get.return_value = [fake_user()] - guild = self.get_guild(fake_user(), new_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_sets_in_guild_false_for_leaving_users(self): - """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**leaving_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_updated_and_leaving_users(self): - """When users are added, updated, and removed, all of them are returned properly.""" - new_user = fake_user(id=99, name="new") - updated_user = fake_user(id=55, name="updated") - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] - guild = self.get_guild(fake_user(), new_user, updated_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_empty_diff_for_db_users_not_in_guild(self): - """When the DB knows a user the guild doesn't, no difference is found.""" - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - -class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for the API requests that sync users.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) - - async def test_sync_created_users(self): - """Only POST requests should be made with the correct payload.""" - users = [fake_user(id=111), fake_user(id=222)] - - user_tuples = {_User(**user) for user in users} - diff = _Diff(user_tuples, set(), None) - await self.syncer._sync(diff) - - calls = [mock.call("bot/users", json=user) for user in users] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(users)) - - self.bot.api_client.put.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_updated_users(self): - """Only PUT requests should be made with the correct payload.""" - users = [fake_user(id=111), fake_user(id=222)] - - user_tuples = {_User(**user) for user in users} - diff = _Diff(set(), user_tuples, None) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(users)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/backend/test_logging.py b/tests/bot/cogs/backend/test_logging.py deleted file mode 100644 index c867773e2..000000000 --- a/tests/bot/cogs/backend/test_logging.py +++ /dev/null @@ -1,32 +0,0 @@ -import unittest -from unittest.mock import patch - -from bot import constants -from bot.cogs.backend.logging import Logging -from tests.helpers import MockBot, MockTextChannel - - -class LoggingTests(unittest.IsolatedAsyncioTestCase): - """Test cases for connected login.""" - - def setUp(self): - self.bot = MockBot() - self.cog = Logging(self.bot) - self.dev_log = MockTextChannel(id=1234, name="dev-log") - - @patch("bot.cogs.backend.logging.DEBUG_MODE", False) - async def test_debug_mode_false(self): - """Should send connected message to dev-log.""" - self.bot.get_channel.return_value = self.dev_log - - await self.cog.startup_greeting() - self.bot.wait_until_guild_available.assert_awaited_once_with() - self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) - self.dev_log.send.assert_awaited_once() - - @patch("bot.cogs.backend.logging.DEBUG_MODE", True) - async def test_debug_mode_true(self): - """Should not send anything to dev-log.""" - await self.cog.startup_greeting() - self.bot.wait_until_guild_available.assert_awaited_once_with() - self.bot.get_channel.assert_not_called() diff --git a/tests/bot/cogs/filters/__init__.py b/tests/bot/cogs/filters/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/filters/test_antimalware.py b/tests/bot/cogs/filters/test_antimalware.py deleted file mode 100644 index b00211f47..000000000 --- a/tests/bot/cogs/filters/test_antimalware.py +++ /dev/null @@ -1,165 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, Mock - -from discord import NotFound - -from bot.cogs.filters import antimalware -from bot.constants import Channels, STAFF_ROLES -from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole - - -class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): - """Test the AntiMalware cog.""" - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = MockBot() - self.bot.filter_list_cache = { - "FILE_FORMAT.True": { - ".first": {}, - ".second": {}, - ".third": {}, - } - } - self.cog = antimalware.AntiMalware(self.bot) - self.message = MockMessage() - self.whitelist = [".first", ".second", ".third"] - - async def test_message_with_allowed_attachment(self): - """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_message_without_attachment(self): - """Messages without attachments should result in no action.""" - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_direct_message_with_attachment(self): - """Direct messages should have no action taken.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.guild = None - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_message_with_illegal_extension_gets_deleted(self): - """A message containing an illegal extension should send an embed.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_called_once() - - async def test_message_send_by_staff(self): - """A message send by a member of staff should be ignored.""" - staff_role = MockRole(id=STAFF_ROLES[0]) - self.message.author.roles.append(staff_role) - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_python_file_redirect_embed_description(self): - """A message containing a .py file should result in an embed redirecting the user to our paste site""" - attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - - self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - - async def test_txt_file_redirect_embed_description(self): - """A message containing a .txt file should result in the correct embed.""" - attachment = MockAttachment(filename="python.txt") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.TXT_EMBED_DESCRIPTION = Mock() - antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - cmd_channel = self.bot.get_channel(Channels.bot_commands) - - self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) - antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) - - async def test_other_disallowed_extension_embed_description(self): - """Test the description for a non .py/.txt disallowed extension.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - meta_channel = self.bot.get_channel(Channels.meta) - - self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( - joined_whitelist=", ".join(self.whitelist), - blocked_extensions_str=".disallowed", - meta_channel_mention=meta_channel.mention - ) - - async def test_removing_deleted_message_logs(self): - """Removing an already deleted message logs the correct message""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - self.message.delete.assert_called_once() - - async def test_message_with_illegal_attachment_logs(self): - """Deleting a message with an illegal attachment should result in a log.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - - async def test_get_disallowed_extensions(self): - """The return value should include all non-whitelisted extensions.""" - test_values = ( - ([], []), - (self.whitelist, []), - ([".first"], []), - ([".first", ".disallowed"], [".disallowed"]), - ([".disallowed"], [".disallowed"]), - ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), - ) - - for extensions, expected_disallowed_extensions in test_values: - with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog._get_disallowed_extensions(self.message) - self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) - - -class AntiMalwareSetupTests(unittest.TestCase): - """Tests setup of the `AntiMalware` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = MockBot() - antimalware.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/filters/test_antispam.py b/tests/bot/cogs/filters/test_antispam.py deleted file mode 100644 index 8a3d8d02e..000000000 --- a/tests/bot/cogs/filters/test_antispam.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from bot.cogs.filters import antispam - - -class AntispamConfigurationValidationTests(unittest.TestCase): - """Tests validation of the antispam cog configuration.""" - - def test_default_antispam_config_is_valid(self): - """The default antispam configuration is valid.""" - validation_errors = antispam.validate_config() - self.assertEqual(validation_errors, {}) - - def test_unknown_rule_returns_error(self): - """Configuring an unknown rule returns an error.""" - self.assertEqual( - antispam.validate_config({'invalid-rule': {}}), - {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} - ) - - def test_missing_keys_returns_error(self): - """Not configuring required keys returns an error.""" - keys = (('interval', 'max'), ('max', 'interval')) - for configured_key, unconfigured_key in keys: - with self.subTest( - configured_key=configured_key, - unconfigured_key=unconfigured_key - ): - config = {'burst': {configured_key: 10}} - error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" - - self.assertEqual( - antispam.validate_config(config), - {'burst': error} - ) diff --git a/tests/bot/cogs/filters/test_security.py b/tests/bot/cogs/filters/test_security.py deleted file mode 100644 index 82679f69c..000000000 --- a/tests/bot/cogs/filters/test_security.py +++ /dev/null @@ -1,54 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from discord.ext.commands import NoPrivateMessage - -from bot.cogs.filters import security -from tests.helpers import MockBot, MockContext - - -class SecurityCogTests(unittest.TestCase): - """Tests the `Security` cog.""" - - def setUp(self): - """Attach an instance of the cog to the class for tests.""" - self.bot = MockBot() - self.cog = security.Security(self.bot) - self.ctx = MockContext() - - def test_check_additions(self): - """The cog should add its checks after initialization.""" - self.bot.check.assert_any_call(self.cog.check_on_guild) - self.bot.check.assert_any_call(self.cog.check_not_bot) - - def test_check_not_bot_returns_false_for_humans(self): - """The bot check should return `True` when invoked with human authors.""" - self.ctx.author.bot = False - self.assertTrue(self.cog.check_not_bot(self.ctx)) - - def test_check_not_bot_returns_true_for_robots(self): - """The bot check should return `False` when invoked with robotic authors.""" - self.ctx.author.bot = True - self.assertFalse(self.cog.check_not_bot(self.ctx)) - - def test_check_on_guild_raises_when_outside_of_guild(self): - """When invoked outside of a guild, `check_on_guild` should cause an error.""" - self.ctx.guild = None - - with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): - self.cog.check_on_guild(self.ctx) - - def test_check_on_guild_returns_true_inside_of_guild(self): - """When invoked inside of a guild, `check_on_guild` should return `True`.""" - self.ctx.guild = "lemon's lemonade stand" - self.assertTrue(self.cog.check_on_guild(self.ctx)) - - -class SecurityCogLoadTests(unittest.TestCase): - """Tests loading the `Security` cog.""" - - def test_security_cog_load(self): - """Setup of the extension should call add_cog.""" - bot = MagicMock() - security.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/filters/test_token_remover.py b/tests/bot/cogs/filters/test_token_remover.py deleted file mode 100644 index 55b284ef9..000000000 --- a/tests/bot/cogs/filters/test_token_remover.py +++ /dev/null @@ -1,310 +0,0 @@ -import unittest -from re import Match -from unittest import mock -from unittest.mock import MagicMock - -from discord import Colour, NotFound - -from bot import constants -from bot.cogs.filters import token_remover -from bot.cogs.filters.token_remover import Token, TokenRemover -from bot.cogs.moderation.modlog import ModLog -from tests.helpers import MockBot, MockMessage, autospec - - -class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): - """Tests the `TokenRemover` cog.""" - - def setUp(self): - """Adds the cog, a bot, and a message to the instance for usage in tests.""" - self.bot = MockBot() - self.cog = TokenRemover(bot=self.bot) - - self.msg = MockMessage(id=555, content="hello world") - self.msg.channel.mention = "#lemonade-stand" - self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) - self.msg.author.avatar_url_as.return_value = "picture-lemon.png" - - def test_is_valid_user_id_valid(self): - """Should consider user IDs valid if they decode entirely to ASCII digits.""" - ids = ( - "NDcyMjY1OTQzMDYyNDEzMzMy", - "NDc1MDczNjI5Mzk5NTQ3OTA0", - "NDY3MjIzMjMwNjUwNzc3NjQx", - ) - - for user_id in ids: - with self.subTest(user_id=user_id): - result = TokenRemover.is_valid_user_id(user_id) - self.assertTrue(result) - - def test_is_valid_user_id_invalid(self): - """Should consider non-digit and non-ASCII IDs invalid.""" - ids = ( - ("SGVsbG8gd29ybGQ", "non-digit ASCII"), - ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), - ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), - ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), - ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), - ("{hello}[world]&(bye!)", "ASCII invalid Base64"), - ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), - ) - - for user_id, msg in ids: - with self.subTest(msg=msg): - result = TokenRemover.is_valid_user_id(user_id) - self.assertFalse(result) - - def test_is_valid_timestamp_valid(self): - """Should consider timestamps valid if they're greater than the Discord epoch.""" - timestamps = ( - "XsyRkw", - "Xrim9Q", - "XsyR-w", - "XsySD_", - "Dn9r_A", - ) - - for timestamp in timestamps: - with self.subTest(timestamp=timestamp): - result = TokenRemover.is_valid_timestamp(timestamp) - self.assertTrue(result) - - def test_is_valid_timestamp_invalid(self): - """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" - timestamps = ( - ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), - ("ew", "123"), - ("AoIKgA", "42076800"), - ("{hello}[world]&(bye!)", "ASCII invalid Base64"), - ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), - ) - - for timestamp, msg in timestamps: - with self.subTest(msg=msg): - result = TokenRemover.is_valid_timestamp(timestamp) - self.assertFalse(result) - - def test_mod_log_property(self): - """The `mod_log` property should ask the bot to return the `ModLog` cog.""" - self.bot.get_cog.return_value = 'lemon' - self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) - self.bot.get_cog.assert_called_once_with('ModLog') - - async def test_on_message_edit_uses_on_message(self): - """The edit listener should delegate handling of the message to the normal listener.""" - self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) - - await self.cog.on_message_edit(MockMessage(), self.msg) - self.cog.on_message.assert_awaited_once_with(self.msg) - - @autospec(TokenRemover, "find_token_in_message", "take_action") - async def test_on_message_takes_action(self, find_token_in_message, take_action): - """Should take action if a valid token is found when a message is sent.""" - cog = TokenRemover(self.bot) - found_token = "foobar" - find_token_in_message.return_value = found_token - - await cog.on_message(self.msg) - - find_token_in_message.assert_called_once_with(self.msg) - take_action.assert_awaited_once_with(cog, self.msg, found_token) - - @autospec(TokenRemover, "find_token_in_message", "take_action") - async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): - """Shouldn't take action if a valid token isn't found when a message is sent.""" - cog = TokenRemover(self.bot) - find_token_in_message.return_value = False - - await cog.on_message(self.msg) - - find_token_in_message.assert_called_once_with(self.msg) - take_action.assert_not_awaited() - - @autospec(TokenRemover, "find_token_in_message") - async def test_on_message_ignores_dms_bots(self, find_token_in_message): - """Shouldn't parse a message if it is a DM or authored by a bot.""" - cog = TokenRemover(self.bot) - dm_msg = MockMessage(guild=None) - bot_msg = MockMessage(author=MagicMock(bot=True)) - - for msg in (dm_msg, bot_msg): - await cog.on_message(msg) - find_token_in_message.assert_not_called() - - @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") - def test_find_token_no_matches(self, token_re): - """None should be returned if the regex matches no tokens in a message.""" - token_re.finditer.return_value = () - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - @autospec("bot.cogs.filters.token_remover", "Token") - @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") - def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): - """The first match with a valid user ID and timestamp should be returned as a `Token`.""" - matches = [ - mock.create_autospec(Match, spec_set=True, instance=True), - mock.create_autospec(Match, spec_set=True, instance=True), - ] - tokens = [ - mock.create_autospec(Token, spec_set=True, instance=True), - mock.create_autospec(Token, spec_set=True, instance=True), - ] - - token_re.finditer.return_value = matches - token_cls.side_effect = tokens - is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid. - is_valid_timestamp.return_value = True - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertEqual(tokens[1], return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - @autospec("bot.cogs.filters.token_remover", "Token") - @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") - def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): - """None should be returned if no matches have valid user IDs or timestamps.""" - token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] - token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) - is_valid_id.return_value = False - is_valid_timestamp.return_value = False - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - def test_regex_invalid_tokens(self): - """Messages without anything looking like a token are not matched.""" - tokens = ( - "", - "lemon wins", - "..", - "x.y", - "x.y.", - ".y.z", - ".y.", - "..z", - "x..z", - " . . ", - "\n.\n.\n", - "hellö.world.bye", - "base64.nötbåse64.morebase64", - "19jd3J.dfkm3d.€víł§tüff", - ) - - for token in tokens: - with self.subTest(token=token): - results = token_remover.TOKEN_RE.findall(token) - self.assertEqual(len(results), 0) - - def test_regex_valid_tokens(self): - """Messages that look like tokens should be matched.""" - # Don't worry, these tokens have been invalidated. - tokens = ( - "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", - "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", - "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", - "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", - ) - - for token in tokens: - with self.subTest(token=token): - results = token_remover.TOKEN_RE.fullmatch(token) - self.assertIsNotNone(results, f"{token} was not matched by the regex") - - def test_regex_matches_multiple_valid(self): - """Should support multiple matches in the middle of a string.""" - token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" - token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" - message = f"garbage {token_1} hello {token_2} world" - - results = token_remover.TOKEN_RE.finditer(message) - results = [match[0] for match in results] - self.assertCountEqual((token_1, token_2), results) - - @autospec("bot.cogs.filters.token_remover", "LOG_MESSAGE") - def test_format_log_message(self, log_message): - """Should correctly format the log message with info from the message and token.""" - token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - log_message.format.return_value = "Howdy" - - return_value = TokenRemover.format_log_message(self.msg, token) - - self.assertEqual(return_value, log_message.format.return_value) - log_message.format.assert_called_once_with( - author=self.msg.author, - author_id=self.msg.author.id, - channel=self.msg.channel.mention, - user_id=token.user_id, - timestamp=token.timestamp, - hmac="x" * len(token.hmac), - ) - - @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - @autospec("bot.cogs.filters.token_remover", "log") - @autospec(TokenRemover, "format_log_message") - async def test_take_action(self, format_log_message, logger, mod_log_property): - """Should delete the message and send a mod log.""" - cog = TokenRemover(self.bot) - mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) - token = mock.create_autospec(Token, spec_set=True, instance=True) - log_msg = "testing123" - - mod_log_property.return_value = mod_log - format_log_message.return_value = log_msg - - await cog.take_action(self.msg, token) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_called_once_with( - token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) - ) - - format_log_message.assert_called_once_with(self.msg, token) - logger.debug.assert_called_with(log_msg) - self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") - - mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) - mod_log.send_log_message.assert_called_once_with( - icon_url=constants.Icons.token_removed, - colour=Colour(constants.Colours.soft_red), - title="Token removed!", - text=log_msg, - thumbnail=self.msg.author.avatar_url_as.return_value, - channel_id=constants.Channels.mod_alerts - ) - - @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - async def test_take_action_delete_failure(self, mod_log_property): - """Shouldn't send any messages if the token message can't be deleted.""" - cog = TokenRemover(self.bot) - mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) - self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) - - token = mock.create_autospec(Token, spec_set=True, instance=True) - await cog.take_action(self.msg, token) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_not_awaited() - - -class TokenRemoverExtensionTests(unittest.TestCase): - """Tests for the token_remover extension.""" - - @autospec("bot.cogs.filters.token_remover", "TokenRemover") - def test_extension_setup(self, cog): - """The TokenRemover cog should be added.""" - bot = MockBot() - token_remover.setup(bot) - - cog.assert_called_once_with(bot) - bot.add_cog.assert_called_once() - self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/cogs/info/__init__.py b/tests/bot/cogs/info/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/info/test_information.py b/tests/bot/cogs/info/test_information.py deleted file mode 100644 index 895a8328e..000000000 --- a/tests/bot/cogs/info/test_information.py +++ /dev/null @@ -1,584 +0,0 @@ -import asyncio -import textwrap -import unittest -import unittest.mock - -import discord - -from bot import constants -from bot.cogs.info import information -from bot.utils.checks import InWhitelistCheckFailure -from tests import helpers - -COG_PATH = "bot.cogs.info.information.Information" - - -class InformationCogTests(unittest.TestCase): - """Tests the Information cog.""" - - @classmethod - def setUpClass(cls): - cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = helpers.MockBot() - - self.cog = information.Information(self.bot) - - self.ctx = helpers.MockContext() - self.ctx.author.roles.append(self.moderator_role) - - def test_roles_command_command(self): - """Test if the `role_info` command correctly returns the `moderator_role`.""" - self.ctx.guild.roles.append(self.moderator_role) - - self.cog.roles_info.can_run = unittest.mock.AsyncMock() - self.cog.roles_info.can_run.return_value = True - - coroutine = self.cog.roles_info.callback(self.cog, self.ctx) - - self.assertIsNone(asyncio.run(coroutine)) - self.ctx.send.assert_called_once() - - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - - self.assertEqual(embed.title, "Role information (Total 1 role)") - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") - - def test_role_info_command(self): - """Tests the `role info` command.""" - dummy_role = helpers.MockRole( - name="Dummy", - id=112233445566778899, - colour=discord.Colour.blurple(), - position=10, - members=[self.ctx.author], - permissions=discord.Permissions(0) - ) - - admin_role = helpers.MockRole( - name="Admins", - id=998877665544332211, - colour=discord.Colour.red(), - position=3, - members=[self.ctx.author], - permissions=discord.Permissions(0), - ) - - self.ctx.guild.roles.append([dummy_role, admin_role]) - - self.cog.role_info.can_run = unittest.mock.AsyncMock() - self.cog.role_info.can_run.return_value = True - - coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) - - self.assertIsNone(asyncio.run(coroutine)) - - self.assertEqual(self.ctx.send.call_count, 2) - - (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list - - dummy_embed = dummy_kwargs["embed"] - admin_embed = admin_kwargs["embed"] - - self.assertEqual(dummy_embed.title, "Dummy info") - self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) - - self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) - self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") - self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") - self.assertEqual(dummy_embed.fields[3].value, "1") - self.assertEqual(dummy_embed.fields[4].value, "10") - self.assertEqual(dummy_embed.fields[5].value, "0") - - self.assertEqual(admin_embed.title, "Admins info") - self.assertEqual(admin_embed.colour, discord.Colour.red()) - - @unittest.mock.patch('bot.cogs.info.information.time_since') - def test_server_info_command(self, time_since_patch): - time_since_patch.return_value = '2 days ago' - - self.ctx.guild = helpers.MockGuild( - features=('lemons', 'apples'), - region="The Moon", - roles=[self.moderator_role], - channels=[ - discord.TextChannel( - state={}, - guild=self.ctx.guild, - data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} - ), - discord.CategoryChannel( - state={}, - guild=self.ctx.guild, - data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} - ), - discord.VoiceChannel( - state={}, - guild=self.ctx.guild, - data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} - ) - ], - members=[ - *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), - *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), - *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), - *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), - ], - member_count=1_234, - icon_url='a-lemon.jpg', - ) - - coroutine = self.cog.server_info.callback(self.cog, self.ctx) - self.assertIsNone(asyncio.run(coroutine)) - - time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual( - embed.description, - textwrap.dedent( - f""" - **Server information** - Created: {time_since_patch.return_value} - Voice region: {self.ctx.guild.region} - Features: {', '.join(self.ctx.guild.features)} - - **Channel counts** - Category channels: 1 - Text channels: 1 - Voice channels: 1 - Staff channels: 0 - - **Member counts** - Members: {self.ctx.guild.member_count:,} - Staff members: 0 - Roles: {len(self.ctx.guild.roles)} - - **Member statuses** - {constants.Emojis.status_online} 2 - {constants.Emojis.status_idle} 1 - {constants.Emojis.status_dnd} 4 - {constants.Emojis.status_offline} 3 - """ - ) - ) - self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') - - -class UserInfractionHelperMethodTests(unittest.TestCase): - """Tests for the helper methods of the `!user` command.""" - - def setUp(self): - """Common set-up steps done before for each test.""" - self.bot = helpers.MockBot() - self.bot.api_client.get = unittest.mock.AsyncMock() - self.cog = information.Information(self.bot) - self.member = helpers.MockMember(id=1234) - - def test_user_command_helper_method_get_requests(self): - """The helper methods should form the correct get requests.""" - test_values = ( - { - "helper_method": self.cog.basic_user_infraction_counts, - "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}), - }, - { - "helper_method": self.cog.expanded_user_infraction_counts, - "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}), - }, - { - "helper_method": self.cog.user_nomination_counts, - "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}), - }, - ) - - for test_value in test_values: - helper_method = test_value["helper_method"] - endpoint, params = test_value["expected_args"] - - with self.subTest(method=helper_method, endpoint=endpoint, params=params): - asyncio.run(helper_method(self.member)) - self.bot.api_client.get.assert_called_once_with(endpoint, params=params) - self.bot.api_client.get.reset_mock() - - def _method_subtests(self, method, test_values, default_header): - """Helper method that runs the subtests for the different helper methods.""" - for test_value in test_values: - api_response = test_value["api response"] - expected_lines = test_value["expected_lines"] - - with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): - self.bot.api_client.get.return_value = api_response - - expected_output = "\n".join(default_header + expected_lines) - actual_output = asyncio.run(method(self.member)) - - self.assertEqual(expected_output, actual_output) - - def test_basic_user_infraction_counts_returns_correct_strings(self): - """The method should correctly list both the total and active number of non-hidden infractions.""" - test_values = ( - # No infractions means zero counts - { - "api response": [], - "expected_lines": ["Total: 0", "Active: 0"], - }, - # Simple, single-infraction dictionaries - { - "api response": [{"type": "ban", "active": True}], - "expected_lines": ["Total: 1", "Active: 1"], - }, - { - "api response": [{"type": "ban", "active": False}], - "expected_lines": ["Total: 1", "Active: 0"], - }, - # Multiple infractions with various `active` status - { - "api response": [ - {"type": "ban", "active": True}, - {"type": "kick", "active": False}, - {"type": "ban", "active": True}, - {"type": "ban", "active": False}, - ], - "expected_lines": ["Total: 4", "Active: 2"], - }, - ) - - header = ["**Infractions**"] - - self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) - - def test_expanded_user_infraction_counts_returns_correct_strings(self): - """The method should correctly list the total and active number of all infractions split by infraction type.""" - test_values = ( - { - "api response": [], - "expected_lines": ["This user has never received an infraction."], - }, - # Shows non-hidden inactive infraction as expected - { - "api response": [{"type": "kick", "active": False, "hidden": False}], - "expected_lines": ["Kicks: 1"], - }, - # Shows non-hidden active infraction as expected - { - "api response": [{"type": "mute", "active": True, "hidden": False}], - "expected_lines": ["Mutes: 1 (1 active)"], - }, - # Shows hidden inactive infraction as expected - { - "api response": [{"type": "superstar", "active": False, "hidden": True}], - "expected_lines": ["Superstars: 1"], - }, - # Shows hidden active infraction as expected - { - "api response": [{"type": "ban", "active": True, "hidden": True}], - "expected_lines": ["Bans: 1 (1 active)"], - }, - # Correctly displays tally of multiple infractions of mixed properties in alphabetical order - { - "api response": [ - {"type": "kick", "active": False, "hidden": True}, - {"type": "ban", "active": True, "hidden": True}, - {"type": "superstar", "active": True, "hidden": True}, - {"type": "mute", "active": True, "hidden": True}, - {"type": "ban", "active": False, "hidden": False}, - {"type": "note", "active": False, "hidden": True}, - {"type": "note", "active": False, "hidden": True}, - {"type": "warn", "active": False, "hidden": False}, - {"type": "note", "active": False, "hidden": True}, - ], - "expected_lines": [ - "Bans: 2 (1 active)", - "Kicks: 1", - "Mutes: 1 (1 active)", - "Notes: 3", - "Superstars: 1 (1 active)", - "Warns: 1", - ], - }, - ) - - header = ["**Infractions**"] - - self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) - - def test_user_nomination_counts_returns_correct_strings(self): - """The method should list the number of active and historical nominations for the user.""" - test_values = ( - { - "api response": [], - "expected_lines": ["This user has never been nominated."], - }, - { - "api response": [{'active': True}], - "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], - }, - { - "api response": [{'active': True}, {'active': False}], - "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], - }, - { - "api response": [{'active': False}], - "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."], - }, - { - "api response": [{'active': False}, {'active': False}], - "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."], - }, - - ) - - header = ["**Nominations**"] - - self._method_subtests(self.cog.user_nomination_counts, test_values, header) - - -@unittest.mock.patch("bot.cogs.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) -@unittest.mock.patch("bot.cogs.info.information.constants.MODERATION_CHANNELS", new=[50]) -class UserEmbedTests(unittest.TestCase): - """Tests for the creation of the `!user` embed.""" - - def setUp(self): - """Common set-up steps done before for each test.""" - self.bot = helpers.MockBot() - self.bot.api_client.get = unittest.mock.AsyncMock() - self.cog = information.Information(self.bot) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): - """The embed should use the string representation of the user if they don't have a nick.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - user = helpers.MockMember() - user.nick = None - user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.title, "Mr. Hemlock") - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_nick_in_title_if_available(self): - """The embed should use the nick if it's available.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - user = helpers.MockMember() - user.nick = "Cat lover" - user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_ignores_everyone_role(self): - """Created `!user` embeds should not contain mention of the @everyone-role.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - admins_role = helpers.MockRole(name='Admins') - admins_role.colour = 100 - - # A `MockMember` has the @Everyone role by default; we add the Admins to that. - user = helpers.MockMember(roles=[admins_role], top_role=admins_role) - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertIn("&Admins", embed.description) - self.assertNotIn("&Everyone", embed.description) - - @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): - """The embed should contain expanded infractions and nomination info in mod channels.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - infraction_counts.return_value = "expanded infractions info" - nomination_counts.return_value = "nomination info" - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - infraction_counts.assert_called_once_with(user) - nomination_counts.assert_called_once_with(user) - - self.assertEqual( - textwrap.dedent(f""" - **User Information** - Created: {"1 year ago"} - Profile: {user.mention} - ID: {user.id} - - **Member Information** - Joined: {"1 year ago"} - Roles: &Moderators - - expanded infractions info - - nomination info - """).strip(), - embed.description - ) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): - """The embed should contain only basic infraction data outside of mod channels.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - infraction_counts.return_value = "basic infractions info" - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - infraction_counts.assert_called_once_with(user) - - self.assertEqual( - textwrap.dedent(f""" - **User Information** - Created: {"1 year ago"} - Profile: {user.mention} - ID: {user.id} - - **Member Information** - Joined: {"1 year ago"} - Roles: &Moderators - - basic infractions info - """).strip(), - embed.description - ) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): - """The embed should be created with the colour of the top role, if a top role is available.""" - ctx = helpers.MockContext() - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): - """The embed should be created with a blurple colour if the user has no assigned roles.""" - ctx = helpers.MockContext() - - user = helpers.MockMember(id=217) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.colour, discord.Colour.blurple()) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): - """The embed thumbnail should be set to the user's avatar in `png` format.""" - ctx = helpers.MockContext() - - user = helpers.MockMember(id=217) - user.avatar_url_as.return_value = "avatar url" - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - user.avatar_url_as.assert_called_once_with(static_format="png") - self.assertEqual(embed.thumbnail.url, "avatar url") - - -@unittest.mock.patch("bot.cogs.info.information.constants") -class UserCommandTests(unittest.TestCase): - """Tests for the `!user` command.""" - - def setUp(self): - """Set up steps executed before each test is run.""" - self.bot = helpers.MockBot() - self.cog = information.Information(self.bot) - - self.moderator_role = helpers.MockRole(name="Moderators", id=2, position=10) - self.flautist_role = helpers.MockRole(name="Flautists", id=3, position=2) - self.bassist_role = helpers.MockRole(name="Bassists", id=4, position=3) - - self.author = helpers.MockMember(id=1, name="syntaxaire") - self.moderator = helpers.MockMember(id=2, name="riffautae", roles=[self.moderator_role]) - self.target = helpers.MockMember(id=3, name="__fluzz__") - - def test_regular_member_cannot_target_another_member(self, constants): - """A regular user should not be able to use `!user` targeting another user.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - - ctx = helpers.MockContext(author=self.author) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) - - ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") - - def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): - """A regular user should not be able to use this command outside of bot-commands.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) - - msg = "Sorry, but you may only use this command within <#50>." - with self.assertRaises(InWhitelistCheckFailure, msg=msg): - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") - def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): - """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - create_embed.assert_called_once_with(ctx, self.author) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") - def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): - """A user should target itself with `!user` when a `user` argument was not provided.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) - - create_embed.assert_called_once_with(ctx, self.author) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") - def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): - """Staff members should be able to bypass the bot-commands channel restriction.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - create_embed.assert_called_once_with(ctx, self.moderator) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") - def test_moderators_can_target_another_member(self, create_embed, constants): - """A moderator should be able to use `!user` targeting another user.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - constants.STAFF_ROLES = [self.moderator_role.id] - - ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) - - create_embed.assert_called_once_with(ctx, self.target) - ctx.send.assert_called_once() diff --git a/tests/bot/cogs/moderation/__init__.py b/tests/bot/cogs/moderation/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/moderation/infraction/__init__.py b/tests/bot/cogs/moderation/infraction/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/moderation/infraction/test_infractions.py b/tests/bot/cogs/moderation/infraction/test_infractions.py deleted file mode 100644 index 2df61d431..000000000 --- a/tests/bot/cogs/moderation/infraction/test_infractions.py +++ /dev/null @@ -1,55 +0,0 @@ -import textwrap -import unittest -from unittest.mock import AsyncMock, Mock, patch - -from bot.cogs.moderation.infraction.infractions import Infractions -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole - - -class TruncationTests(unittest.IsolatedAsyncioTestCase): - """Tests for ban and kick command reason truncation.""" - - def setUp(self): - self.bot = MockBot() - self.cog = Infractions(self.bot) - self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) - self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) - self.guild = MockGuild(id=4567) - self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) - - @patch("bot.cogs.moderation.infraction._utils.get_active_infraction") - @patch("bot.cogs.moderation.infraction._utils.post_infraction") - async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): - """Should truncate reason for `ctx.guild.ban`.""" - get_active_mock.return_value = None - post_infraction_mock.return_value = {"foo": "bar"} - - self.cog.apply_infraction = AsyncMock() - self.bot.get_cog.return_value = AsyncMock() - self.cog.mod_log.ignore = Mock() - self.ctx.guild.ban = Mock() - - await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) - self.ctx.guild.ban.assert_called_once_with( - self.target, - reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), - delete_message_days=0 - ) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value - ) - - @patch("bot.cogs.moderation.infraction._utils.post_infraction") - async def test_apply_kick_reason_truncation(self, post_infraction_mock): - """Should truncate reason for `Member.kick`.""" - post_infraction_mock.return_value = {"foo": "bar"} - - self.cog.apply_infraction = AsyncMock() - self.cog.mod_log.ignore = Mock() - self.target.kick = Mock() - - await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) - self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value - ) diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py deleted file mode 100644 index 5e4d90251..000000000 --- a/tests/bot/cogs/moderation/test_incidents.py +++ /dev/null @@ -1,770 +0,0 @@ -import asyncio -import enum -import logging -import typing as t -import unittest -from unittest.mock import AsyncMock, MagicMock, call, patch - -import aiohttp -import discord - -from bot.cogs.moderation import incidents -from bot.constants import Colours -from tests.helpers import ( - MockAsyncWebhook, - MockAttachment, - MockBot, - MockMember, - MockMessage, - MockReaction, - MockRole, - MockTextChannel, - MockUser, -) - - -class MockAsyncIterable: - """ - Helper for mocking asynchronous for loops. - - It does not appear that the `unittest` library currently provides anything that would - allow us to simply mock an async iterator, such as `discord.TextChannel.history`. - - We therefore write our own helper to wrap a regular synchronous iterable, and feed - its values via `__anext__` rather than `__next__`. - - This class was written for the purposes of testing the `Incidents` cog - it may not - be generic enough to be placed in the `tests.helpers` module. - """ - - def __init__(self, messages: t.Iterable): - """Take a sync iterable to be wrapped.""" - self.iter_messages = iter(messages) - - def __aiter__(self): - """Return `self` as we provide the `__anext__` method.""" - return self - - async def __anext__(self): - """ - Feed the next item, or raise `StopAsyncIteration`. - - Since we're wrapping a sync iterator, it will communicate that it has been depleted - by raising a `StopIteration`. The `async for` construct does not expect it, and we - therefore need to substitute it for the appropriate exception type. - """ - try: - return next(self.iter_messages) - except StopIteration: - raise StopAsyncIteration - - -class MockSignal(enum.Enum): - A = "A" - B = "B" - - -mock_404 = discord.NotFound( - response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response - message="Not found", -) - - -class TestDownloadFile(unittest.IsolatedAsyncioTestCase): - """Collection of tests for the `download_file` helper function.""" - - async def test_download_file_success(self): - """If `to_file` succeeds, function returns the acquired `discord.File`.""" - file = MagicMock(discord.File, filename="bigbadlemon.jpg") - attachment = MockAttachment(to_file=AsyncMock(return_value=file)) - - acquired_file = await incidents.download_file(attachment) - self.assertIs(file, acquired_file) - - async def test_download_file_404(self): - """If `to_file` encounters a 404, function handles the exception & returns None.""" - attachment = MockAttachment(to_file=AsyncMock(side_effect=mock_404)) - - acquired_file = await incidents.download_file(attachment) - self.assertIsNone(acquired_file) - - async def test_download_file_fail(self): - """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" - arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") - attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) - - with self.assertLogs(logger=incidents.log, level=logging.ERROR): - acquired_file = await incidents.download_file(attachment) - - self.assertIsNone(acquired_file) - - -class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): - """Collection of tests for the `make_embed` helper function.""" - - async def test_make_embed_actioned(self): - """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) - - self.assertEqual(embed.colour.value, Colours.soft_green) - self.assertIn("Actioned", embed.footer.text) - - async def test_make_embed_not_actioned(self): - """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) - - self.assertEqual(embed.colour.value, Colours.soft_red) - self.assertIn("Rejected", embed.footer.text) - - async def test_make_embed_content(self): - """Incident content appears as embed description.""" - incident = MockMessage(content="this is an incident") - embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - - self.assertEqual(incident.content, embed.description) - - async def test_make_embed_with_attachment_succeeds(self): - """Incident's attachment is downloaded and displayed in the embed's image field.""" - file = MagicMock(discord.File, filename="bigbadjoe.jpg") - attachment = MockAttachment(filename="bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) - - # Patch `download_file` to return our `file` - with patch("bot.cogs.moderation.incidents.download_file", AsyncMock(return_value=file)): - embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - - self.assertIs(file, returned_file) - self.assertEqual("attachment://bigbadjoe.jpg", embed.image.url) - - async def test_make_embed_with_attachment_fails(self): - """Incident's attachment fails to download, proxy url is linked instead.""" - attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) - - # Patch `download_file` to return None as if the download failed - with patch("bot.cogs.moderation.incidents.download_file", AsyncMock(return_value=None)): - embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - - self.assertIsNone(returned_file) - - # The author name field is simply expected to have something in it, we do not assert the message - self.assertGreater(len(embed.author.name), 0) - self.assertEqual(embed.author.url, "discord.com/bigbadjoe.jpg") # However, it should link the exact url - - -@patch("bot.constants.Channels.incidents", 123) -class TestIsIncident(unittest.TestCase): - """ - Collection of tests for the `is_incident` helper function. - - In `setUp`, we will create a mock message which should qualify as an incident. Each - test case will then mutate this instance to make it **not** qualify, in various ways. - - Notice that we patch the #incidents channel id globally for this class. - """ - - def setUp(self) -> None: - """Prepare a mock message which should qualify as an incident.""" - self.incident = MockMessage( - channel=MockTextChannel(id=123), - content="this is an incident", - author=MockUser(bot=False), - pinned=False, - ) - - def test_is_incident_true(self): - """Message qualifies as an incident if unchanged.""" - self.assertTrue(incidents.is_incident(self.incident)) - - def check_false(self): - """Assert that `self.incident` does **not** qualify as an incident.""" - self.assertFalse(incidents.is_incident(self.incident)) - - def test_is_incident_false_channel(self): - """Message doesn't qualify if sent outside of #incidents.""" - self.incident.channel = MockTextChannel(id=456) - self.check_false() - - def test_is_incident_false_content(self): - """Message doesn't qualify if content begins with hash symbol.""" - self.incident.content = "# this is a comment message" - self.check_false() - - def test_is_incident_false_author(self): - """Message doesn't qualify if author is a bot.""" - self.incident.author = MockUser(bot=True) - self.check_false() - - def test_is_incident_false_pinned(self): - """Message doesn't qualify if it is pinned.""" - self.incident.pinned = True - self.check_false() - - -class TestOwnReactions(unittest.TestCase): - """Assertions for the `own_reactions` function.""" - - def test_own_reactions(self): - """Only bot's own emoji are extracted from the input incident.""" - reactions = ( - MockReaction(emoji="A", me=True), - MockReaction(emoji="B", me=True), - MockReaction(emoji="C", me=False), - ) - message = MockMessage(reactions=reactions) - self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) - - -@patch("bot.cogs.moderation.incidents.ALL_SIGNALS", {"A", "B"}) -class TestHasSignals(unittest.TestCase): - """ - Assertions for the `has_signals` function. - - We patch `ALL_SIGNALS` globally. Each test function then patches `own_reactions` - as appropriate. - """ - - def test_has_signals_true(self): - """True when `own_reactions` returns all emoji in `ALL_SIGNALS`.""" - message = MockMessage() - own_reactions = MagicMock(return_value={"A", "B"}) - - with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): - self.assertTrue(incidents.has_signals(message)) - - def test_has_signals_false(self): - """False when `own_reactions` does not return all emoji in `ALL_SIGNALS`.""" - message = MockMessage() - own_reactions = MagicMock(return_value={"A", "C"}) - - with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): - self.assertFalse(incidents.has_signals(message)) - - -@patch("bot.cogs.moderation.incidents.Signal", MockSignal) -class TestAddSignals(unittest.IsolatedAsyncioTestCase): - """ - Assertions for the `add_signals` coroutine. - - These are all fairly similar and could go into a single test function, but I found the - patching & sub-testing fairly awkward in that case and decided to split them up - to avoid unnecessary syntax noise. - """ - - def setUp(self): - """Prepare a mock incident message for tests to use.""" - self.incident = MockMessage() - - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value=set())) - async def test_add_signals_missing(self): - """All emoji are added when none are present.""" - await incidents.add_signals(self.incident) - self.incident.add_reaction.assert_has_calls([call("A"), call("B")]) - - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A"})) - async def test_add_signals_partial(self): - """Only missing emoji are added when some are present.""" - await incidents.add_signals(self.incident) - self.incident.add_reaction.assert_has_calls([call("B")]) - - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"})) - async def test_add_signals_present(self): - """No emoji are added when all are present.""" - await incidents.add_signals(self.incident) - self.incident.add_reaction.assert_not_called() - - -class TestIncidents(unittest.IsolatedAsyncioTestCase): - """ - Tests for bound methods of the `Incidents` cog. - - Use this as a base class for `Incidents` tests - it will prepare a fresh instance - for each test function, but not make any assertions on its own. Tests can mutate - the instance as they wish. - """ - - def setUp(self): - """ - Prepare a fresh `Incidents` instance for each test. - - Note that this will not schedule `crawl_incidents` in the background, as everything - is being mocked. The `crawl_task` attribute will end up being None. - """ - self.cog_instance = incidents.Incidents(MockBot()) - - -@patch("asyncio.sleep", AsyncMock()) # Prevent the coro from sleeping to speed up the test -class TestCrawlIncidents(TestIncidents): - """ - Tests for the `Incidents.crawl_incidents` coroutine. - - Apart from `test_crawl_incidents_waits_until_cache_ready`, all tests in this class - will patch the return values of `is_incident` and `has_signal` and then observe - whether the `AsyncMock` for `add_signals` was awaited or not. - - The `add_signals` mock is added by each test separately to ensure it is clean (has not - been awaited by another test yet). The mock can be reset, but this appears to be the - cleaner way. - - For each test, we inject a mock channel with a history of 1 message only (see: `setUp`). - """ - - def setUp(self): - """For each test, ensure `bot.get_channel` returns a channel with 1 arbitrary message.""" - super().setUp() # First ensure we get `cog_instance` from parent - - incidents_history = MagicMock(return_value=MockAsyncIterable([MockMessage()])) - self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(history=incidents_history)) - - async def test_crawl_incidents_waits_until_cache_ready(self): - """ - The coroutine will await the `wait_until_guild_available` event. - - Since this task is schedule in the `__init__`, it is critical that it waits for the - cache to be ready, so that it can safely get the #incidents channel. - """ - await self.cog_instance.crawl_incidents() - self.cog_instance.bot.wait_until_guild_available.assert_awaited() - - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)) # Message doesn't qualify - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=False)) - async def test_crawl_incidents_noop_if_is_not_incident(self): - """Signals are not added for a non-incident message.""" - await self.cog_instance.crawl_incidents() - incidents.add_signals.assert_not_awaited() - - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=True)) # But already has signals - async def test_crawl_incidents_noop_if_message_already_has_signals(self): - """Signals are not added for messages which already have them.""" - await self.cog_instance.crawl_incidents() - incidents.add_signals.assert_not_awaited() - - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=False)) # And doesn't have signals - async def test_crawl_incidents_add_signals_called(self): - """Message has signals added as it does not have them yet and qualifies as an incident.""" - await self.cog_instance.crawl_incidents() - incidents.add_signals.assert_awaited_once() - - -class TestArchive(TestIncidents): - """Tests for the `Incidents.archive` coroutine.""" - - async def test_archive_webhook_not_found(self): - """ - Method recovers and returns False when the webhook is not found. - - Implicitly, this also tests that the error is handled internally and doesn't - propagate out of the method, which is just as important. - """ - self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) - self.assertFalse( - await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) - ) - - async def test_archive_relays_incident(self): - """ - If webhook is found, method relays `incident` properly. - - This test will assert that the fetched webhook's `send` method is fed the correct arguments, - and that the `archive` method returns True. - """ - webhook = MockAsyncWebhook() - self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) # Patch in our webhook - - # Define our own `incident` to be archived - incident = MockMessage( - content="this is an incident", - author=MockUser(name="author_name", avatar_url="author_avatar"), - id=123, - ) - built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this - - with patch("bot.cogs.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): - archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) - - # Now we check that the webhook was given the correct args, and that `archive` returned True - webhook.send.assert_called_once_with( - embed=built_embed, - username="author_name", - avatar_url="author_avatar", - file=None, - ) - self.assertTrue(archive_return) - - async def test_archive_clyde_username(self): - """ - The archive webhook username is cleansed using `sub_clyde`. - - Discord will reject any webhook with "clyde" in the username field, as it impersonates - the official Clyde bot. Since we do not control what the username will be (the incident - author name is used), we must ensure the name is cleansed, otherwise the relay may fail. - - This test assumes the username is passed as a kwarg. If this test fails, please review - whether the passed argument is being retrieved correctly. - """ - webhook = MockAsyncWebhook() - self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - - message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) - await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) - - self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) - - -class TestMakeConfirmationTask(TestIncidents): - """ - Tests for the `Incidents.make_confirmation_task` method. - - Writing tests for this method is difficult, as it mostly just delegates the provided - information elsewhere. There is very little internal logic. Whether our approach - works conceptually is difficult to prove using unit tests. - """ - - def test_make_confirmation_task_check(self): - """ - The internal check will recognize the passed incident. - - This is a little tricky - we first pass a message with a specific `id` in, and then - retrieve the built check from the `call_args` of the `wait_for` method. This relies - on the check being passed as a kwarg. - - Once the check is retrieved, we assert that it gives True for our incident's `id`, - and False for any other. - - If this function begins to fail, first check that `created_check` is being retrieved - correctly. It should be the function that is built locally in the tested method. - """ - self.cog_instance.make_confirmation_task(MockMessage(id=123)) - - self.cog_instance.bot.wait_for.assert_called_once() - created_check = self.cog_instance.bot.wait_for.call_args.kwargs["check"] - - # The `message_id` matches the `id` of our incident - self.assertTrue(created_check(payload=MagicMock(message_id=123))) - - # This `message_id` does not match - self.assertFalse(created_check(payload=MagicMock(message_id=0))) - - -@patch("bot.cogs.moderation.incidents.ALLOWED_ROLES", {1, 2}) -@patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", AsyncMock()) # Generic awaitable -class TestProcessEvent(TestIncidents): - """Tests for the `Incidents.process_event` coroutine.""" - - async def test_process_event_bad_role(self): - """The reaction is removed when the author lacks all allowed roles.""" - incident = MockMessage() - member = MockMember(roles=[MockRole(id=0)]) # Must have role 1 or 2 - - await self.cog_instance.process_event("reaction", incident, member) - incident.remove_reaction.assert_called_once_with("reaction", member) - - async def test_process_event_bad_emoji(self): - """ - The reaction is removed when an invalid emoji is used. - - This requires that we pass in a `member` with valid roles, as we need the role check - to succeed. - """ - incident = MockMessage() - member = MockMember(roles=[MockRole(id=1)]) # Member has allowed role - - await self.cog_instance.process_event("invalid_signal", incident, member) - incident.remove_reaction.assert_called_once_with("invalid_signal", member) - - async def test_process_event_no_archive_on_investigating(self): - """Message is not archived on `Signal.INVESTIGATING`.""" - with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive: - await self.cog_instance.process_event( - reaction=incidents.Signal.INVESTIGATING.value, - incident=MockMessage(), - member=MockMember(roles=[MockRole(id=1)]), - ) - - mocked_archive.assert_not_called() - - async def test_process_event_no_delete_if_archive_fails(self): - """ - Original message is not deleted when `Incidents.archive` returns False. - - This is the way of signaling that the relay failed, and we should not remove the original, - as that would result in losing the incident record. - """ - incident = MockMessage() - - with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)): - await self.cog_instance.process_event( - reaction=incidents.Signal.ACTIONED.value, - incident=incident, - member=MockMember(roles=[MockRole(id=1)]) - ) - - incident.delete.assert_not_called() - - async def test_process_event_confirmation_task_is_awaited(self): - """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" - mock_task = AsyncMock() - - with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task): - await self.cog_instance.process_event( - reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(), - member=MockMember(roles=[MockRole(id=1)]) - ) - - mock_task.assert_awaited() - - async def test_process_event_confirmation_task_timeout_is_handled(self): - """ - Confirmation task `asyncio.TimeoutError` is handled gracefully. - - We have `make_confirmation_task` return a mock with a side effect, and then catch the - exception should it propagate out of `process_event`. This is so that we can then manually - fail the test with a more informative message than just the plain traceback. - """ - mock_task = AsyncMock(side_effect=asyncio.TimeoutError()) - - try: - with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task): - await self.cog_instance.process_event( - reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(), - member=MockMember(roles=[MockRole(id=1)]) - ) - except asyncio.TimeoutError: - self.fail("TimeoutError was not handled gracefully, and propagated out of `process_event`!") - - -class TestResolveMessage(TestIncidents): - """Tests for the `Incidents.resolve_message` coroutine.""" - - async def test_resolve_message_pass_message_id(self): - """Method will call `_get_message` with the passed `message_id`.""" - await self.cog_instance.resolve_message(123) - self.cog_instance.bot._connection._get_message.assert_called_once_with(123) - - async def test_resolve_message_in_cache(self): - """ - No API call is made if the queried message exists in the cache. - - We mock the `_get_message` return value regardless of input. Whether it finds the message - internally is considered d.py's responsibility, not ours. - """ - cached_message = MockMessage(id=123) - self.cog_instance.bot._connection._get_message = MagicMock(return_value=cached_message) - - return_value = await self.cog_instance.resolve_message(123) - - self.assertIs(return_value, cached_message) - self.cog_instance.bot.get_channel.assert_not_called() # The `fetch_message` line was never hit - - async def test_resolve_message_not_in_cache(self): - """ - The message is retrieved from the API if it isn't cached. - - This is desired behaviour for messages which exist, but were sent before the bot's - current session. - """ - self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - - # API returns our message - uncached_message = MockMessage() - fetch_message = AsyncMock(return_value=uncached_message) - self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) - - retrieved_message = await self.cog_instance.resolve_message(123) - self.assertIs(retrieved_message, uncached_message) - - async def test_resolve_message_doesnt_exist(self): - """ - If the API returns a 404, the function handles it gracefully and returns None. - - This is an edge-case happening with racing events - event A will relay the message - to the archive and delete the original. Once event B acquires the `event_lock`, - it will not find the message in the cache, and will ask the API. - """ - self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - - fetch_message = AsyncMock(side_effect=mock_404) - self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) - - self.assertIsNone(await self.cog_instance.resolve_message(123)) - - async def test_resolve_message_fetch_fails(self): - """ - Non-404 errors are handled, logged & None is returned. - - In contrast with a 404, this should make an error-level log. We assert that at least - one such log was made - we do not make any assertions about the log's message. - """ - self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - - arbitrary_error = discord.HTTPException( - response=MagicMock(aiohttp.ClientResponse), - message="Arbitrary error", - ) - fetch_message = AsyncMock(side_effect=arbitrary_error) - self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) - - with self.assertLogs(logger=incidents.log, level=logging.ERROR): - self.assertIsNone(await self.cog_instance.resolve_message(123)) - - -@patch("bot.constants.Channels.incidents", 123) -class TestOnRawReactionAdd(TestIncidents): - """ - Tests for the `Incidents.on_raw_reaction_add` listener. - - Writing tests for this listener comes with additional complexity due to the listener - awaiting the `crawl_task` task. See `asyncSetUp` for further details, which attempts - to make unit testing this function possible. - """ - - def setUp(self): - """ - Prepare & assign `payload` attribute. - - This attribute represents an *ideal* payload which will not be rejected by the - listener. As each test will receive a fresh instance, it can be mutated to - observe how the listener's behaviour changes with different attributes on - the passed payload. - """ - super().setUp() # Ensure `cog_instance` is assigned - - self.payload = MagicMock( - discord.RawReactionActionEvent, - channel_id=123, # Patched at class level - message_id=456, - member=MockMember(bot=False), - emoji="reaction", - ) - - async def asyncSetUp(self): # noqa: N802 - """ - Prepare an empty task and assign it as `crawl_task`. - - It appears that the `unittest` framework does not provide anything for mocking - asyncio tasks. An `AsyncMock` instance can be called and then awaited, however, - it does not provide the `done` method or any other parts of the `asyncio.Task` - interface. - - Although we do not need to make any assertions about the task itself while - testing the listener, the code will still await it and call the `done` method, - and so we must inject something that will not fail on either action. - - Note that this is done in an `asyncSetUp`, which runs after `setUp`. - The justification is that creating an actual task requires the event - loop to be ready, which is not the case in the `setUp`. - """ - mock_task = asyncio.create_task(AsyncMock()()) # Mock async func, then a coro - self.cog_instance.crawl_task = mock_task - - async def test_on_raw_reaction_add_wrong_channel(self): - """ - Events outside of #incidents will be ignored. - - We check this by asserting that `resolve_message` was never queried. - """ - self.payload.channel_id = 0 - self.cog_instance.resolve_message = AsyncMock() - - await self.cog_instance.on_raw_reaction_add(self.payload) - self.cog_instance.resolve_message.assert_not_called() - - async def test_on_raw_reaction_add_user_is_bot(self): - """ - Events dispatched by bot accounts will be ignored. - - We check this by asserting that `resolve_message` was never queried. - """ - self.payload.member = MockMember(bot=True) - self.cog_instance.resolve_message = AsyncMock() - - await self.cog_instance.on_raw_reaction_add(self.payload) - self.cog_instance.resolve_message.assert_not_called() - - async def test_on_raw_reaction_add_message_doesnt_exist(self): - """ - Listener gracefully handles the case where `resolve_message` gives None. - - We check this by asserting that `process_event` was never called. - """ - self.cog_instance.process_event = AsyncMock() - self.cog_instance.resolve_message = AsyncMock(return_value=None) - - await self.cog_instance.on_raw_reaction_add(self.payload) - self.cog_instance.process_event.assert_not_called() - - async def test_on_raw_reaction_add_message_is_not_an_incident(self): - """ - The event won't be processed if the related message is not an incident. - - This is an edge-case that can happen if someone manually leaves a reaction - on a pinned message, or a comment. - - We check this by asserting that `process_event` was never called. - """ - self.cog_instance.process_event = AsyncMock() - self.cog_instance.resolve_message = AsyncMock(return_value=MockMessage()) - - with patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)): - await self.cog_instance.on_raw_reaction_add(self.payload) - - self.cog_instance.process_event.assert_not_called() - - async def test_on_raw_reaction_add_valid_event_is_processed(self): - """ - If the reaction event is valid, it is passed to `process_event`. - - This is the case when everything goes right: - * The reaction was placed in #incidents, and not by a bot - * The message was found successfully - * The message qualifies as an incident - - Additionally, we check that all arguments were passed as expected. - """ - incident = MockMessage(id=1) - - self.cog_instance.process_event = AsyncMock() - self.cog_instance.resolve_message = AsyncMock(return_value=incident) - - with patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)): - await self.cog_instance.on_raw_reaction_add(self.payload) - - self.cog_instance.process_event.assert_called_with( - "reaction", # Defined in `self.payload` - incident, - self.payload.member, - ) - - -class TestOnMessage(TestIncidents): - """ - Tests for the `Incidents.on_message` listener. - - Notice the decorators mocking the `is_incident` return value. The `is_incidents` - function is tested in `TestIsIncident` - here we do not worry about it. - """ - - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) - async def test_on_message_incident(self): - """Messages qualifying as incidents are passed to `add_signals`.""" - incident = MockMessage() - - with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: - await self.cog_instance.on_message(incident) - - mock_add_signals.assert_called_once_with(incident) - - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)) - async def test_on_message_non_incident(self): - """Messages not qualifying as incidents are ignored.""" - with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: - await self.cog_instance.on_message(MockMessage()) - - mock_add_signals.assert_not_called() diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py deleted file mode 100644 index f2809f40a..000000000 --- a/tests/bot/cogs/moderation/test_modlog.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest - -import discord - -from bot.cogs.moderation.modlog import ModLog -from tests.helpers import MockBot, MockTextChannel - - -class ModLogTests(unittest.IsolatedAsyncioTestCase): - """Tests for moderation logs.""" - - def setUp(self): - self.bot = MockBot() - self.cog = ModLog(self.bot) - self.channel = MockTextChannel() - - async def test_log_entry_description_truncation(self): - """Test that embed description for ModLog entry is truncated.""" - self.bot.get_channel.return_value = self.channel - await self.cog.send_log_message( - icon_url="foo", - colour=discord.Colour.blue(), - title="bar", - text="foo bar" * 3000 - ) - embed = self.channel.send.call_args[1]["embed"] - self.assertEqual( - embed.description, ("foo bar" * 3000)[:2045] + "..." - ) diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py deleted file mode 100644 index ab3d0742a..000000000 --- a/tests/bot/cogs/moderation/test_silence.py +++ /dev/null @@ -1,261 +0,0 @@ -import unittest -from unittest import mock -from unittest.mock import MagicMock, Mock - -from discord import PermissionOverwrite - -from bot.cogs.moderation.silence import Silence, SilenceNotifier -from bot.constants import Channels, Emojis, Guild, Roles -from tests.helpers import MockBot, MockContext, MockTextChannel - - -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: - self.alert_channel = MockTextChannel() - self.notifier = SilenceNotifier(self.alert_channel) - self.notifier.stop = self.notifier_stop_mock = Mock() - self.notifier.start = self.notifier_start_mock = Mock() - - def test_add_channel_adds_channel(self): - """Channel in FirstHash with current loop is added to internal set.""" - channel = Mock() - with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: - self.notifier.add_channel(channel) - silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop) - - def test_add_channel_starts_loop(self): - """Loop is started if `_silenced_channels` was empty.""" - self.notifier.add_channel(Mock()) - self.notifier_start_mock.assert_called_once() - - def test_add_channel_skips_start_with_channels(self): - """Loop start is not called when `_silenced_channels` is not empty.""" - with mock.patch.object(self.notifier, "_silenced_channels"): - self.notifier.add_channel(Mock()) - self.notifier_start_mock.assert_not_called() - - def test_remove_channel_removes_channel(self): - """Channel in FirstHash is removed from `_silenced_channels`.""" - channel = Mock() - with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: - self.notifier.remove_channel(channel) - silenced_channels.__delitem__.assert_called_with(channel) - - def test_remove_channel_stops_loop(self): - """Notifier loop is stopped if `_silenced_channels` is empty after remove.""" - with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False): - self.notifier.remove_channel(Mock()) - self.notifier_stop_mock.assert_called_once() - - def test_remove_channel_skips_stop_with_channels(self): - """Notifier loop is not stopped if `_silenced_channels` is not empty after remove.""" - self.notifier.remove_channel(Mock()) - self.notifier_stop_mock.assert_not_called() - - async def test_notifier_private_sends_alert(self): - """Alert is sent on 15 min intervals.""" - test_cases = (900, 1800, 2700) - for current_loop in test_cases: - with self.subTest(current_loop=current_loop): - with mock.patch.object(self.notifier, "_current_loop", new=current_loop): - await self.notifier._notifier() - self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") - self.alert_channel.send.reset_mock() - - async def test_notifier_skips_alert(self): - """Alert is skipped on first loop or not an increment of 900.""" - test_cases = (0, 15, 5000) - for current_loop in test_cases: - with self.subTest(current_loop=current_loop): - with mock.patch.object(self.notifier, "_current_loop", new=current_loop): - await self.notifier._notifier() - self.alert_channel.send.assert_not_called() - - -class SilenceTests(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: - self.bot = MockBot() - self.cog = Silence(self.bot) - self.ctx = MockContext() - self.cog._verified_role = None - # Set event so command callbacks can continue. - self.cog._get_instance_vars_event.set() - - async def test_instance_vars_got_guild(self): - """Bot got guild after it became available.""" - await self.cog._get_instance_vars() - self.bot.wait_until_guild_available.assert_called_once() - self.bot.get_guild.assert_called_once_with(Guild.id) - - async def test_instance_vars_got_role(self): - """Got `Roles.verified` role from guild.""" - await self.cog._get_instance_vars() - guild = self.bot.get_guild() - guild.get_role.assert_called_once_with(Roles.verified) - - async def test_instance_vars_got_channels(self): - """Got channels from bot.""" - await self.cog._get_instance_vars() - self.bot.get_channel.called_once_with(Channels.mod_alerts) - self.bot.get_channel.called_once_with(Channels.mod_log) - - @mock.patch("bot.cogs.moderation.silence.SilenceNotifier") - async def test_instance_vars_got_notifier(self, notifier): - """Notifier was started with channel.""" - mod_log = MockTextChannel() - self.bot.get_channel.side_effect = (None, mod_log) - await self.cog._get_instance_vars() - notifier.assert_called_once_with(mod_log) - self.bot.get_channel.side_effect = None - - async def test_silence_sent_correct_discord_message(self): - """Check if proper message was sent when called with duration in channel with previous state.""" - test_cases = ( - (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), - (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), - (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), - ) - for duration, result_message, _silence_patch_return in test_cases: - with self.subTest( - silence_duration=duration, - result_message=result_message, - starting_unsilenced_state=_silence_patch_return - ): - with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): - await self.cog.silence.callback(self.cog, self.ctx, duration) - self.ctx.send.assert_called_once_with(result_message) - self.ctx.reset_mock() - - async def test_unsilence_sent_correct_discord_message(self): - """Check if proper message was sent when unsilencing channel.""" - test_cases = ( - (True, f"{Emojis.check_mark} unsilenced current channel."), - (False, f"{Emojis.cross_mark} current channel was not silenced.") - ) - for _unsilence_patch_return, result_message in test_cases: - with self.subTest( - starting_silenced_state=_unsilence_patch_return, - result_message=result_message - ): - with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): - await self.cog.unsilence.callback(self.cog, self.ctx) - self.ctx.send.assert_called_once_with(result_message) - self.ctx.reset_mock() - - async def test_silence_private_for_false(self): - """Permissions are not set and `False` is returned in an already silenced channel.""" - perm_overwrite = Mock(send_messages=False) - channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - - self.assertFalse(await self.cog._silence(channel, True, None)) - channel.set_permissions.assert_not_called() - - async def test_silence_private_silenced_channel(self): - """Channel had `send_message` permissions revoked.""" - channel = MockTextChannel() - self.assertTrue(await self.cog._silence(channel, False, None)) - channel.set_permissions.assert_called_once() - self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) - - async def test_silence_private_preserves_permissions(self): - """Previous permissions were preserved when channel was silenced.""" - channel = MockTextChannel() - # Set up mock channel permission state. - mock_permissions = PermissionOverwrite() - mock_permissions_dict = dict(mock_permissions) - channel.overwrites_for.return_value = mock_permissions - await self.cog._silence(channel, False, None) - new_permissions = channel.set_permissions.call_args.kwargs - # Remove 'send_messages' key because it got changed in the method. - del new_permissions['send_messages'] - del mock_permissions_dict['send_messages'] - self.assertDictEqual(mock_permissions_dict, new_permissions) - - async def test_silence_private_notifier(self): - """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" - channel = MockTextChannel() - with mock.patch.object(self.cog, "notifier", create=True): - with self.subTest(persistent=True): - await self.cog._silence(channel, True, None) - self.cog.notifier.add_channel.assert_called_once() - - with mock.patch.object(self.cog, "notifier", create=True): - with self.subTest(persistent=False): - await self.cog._silence(channel, False, None) - self.cog.notifier.add_channel.assert_not_called() - - async def test_silence_private_added_muted_channel(self): - """Channel was added to `muted_channels` on silence.""" - channel = MockTextChannel() - with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._silence(channel, False, None) - muted_channels.add.assert_called_once_with(channel) - - async def test_unsilence_private_for_false(self): - """Permissions are not set and `False` is returned in an unsilenced channel.""" - channel = Mock() - self.assertFalse(await self.cog._unsilence(channel)) - channel.set_permissions.assert_not_called() - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_unsilenced_channel(self, _): - """Channel had `send_message` permissions restored""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - self.assertTrue(await self.cog._unsilence(channel)) - channel.set_permissions.assert_called_once() - self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_removed_notifier(self, notifier): - """Channel was removed from `notifier` on unsilence.""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - await self.cog._unsilence(channel) - notifier.remove_channel.assert_called_once_with(channel) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_removed_muted_channel(self, _): - """Channel was removed from `muted_channels` on unsilence.""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._unsilence(channel) - muted_channels.discard.assert_called_once_with(channel) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_preserves_permissions(self, _): - """Previous permissions were preserved when channel was unsilenced.""" - channel = MockTextChannel() - # Set up mock channel permission state. - mock_permissions = PermissionOverwrite(send_messages=False) - mock_permissions_dict = dict(mock_permissions) - channel.overwrites_for.return_value = mock_permissions - await self.cog._unsilence(channel) - new_permissions = channel.set_permissions.call_args.kwargs - # Remove 'send_messages' key because it got changed in the method. - del new_permissions['send_messages'] - del mock_permissions_dict['send_messages'] - self.assertDictEqual(mock_permissions_dict, new_permissions) - - @mock.patch("bot.cogs.moderation.silence.asyncio") - @mock.patch.object(Silence, "_mod_alerts_channel", create=True) - def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): - """Task for sending an alert was created with present `muted_channels`.""" - with mock.patch.object(self.cog, "muted_channels"): - self.cog.cog_unload() - alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") - asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - - @mock.patch("bot.cogs.moderation.silence.asyncio") - def test_cog_unload_skips_task_start(self, asyncio_mock): - """No task created with no channels.""" - self.cog.cog_unload() - asyncio_mock.create_task.assert_not_called() - - @mock.patch("bot.cogs.moderation.silence.with_role_check") - @mock.patch("bot.cogs.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) - def test_cog_check(self, role_check): - """Role check is called with `MODERATION_ROLES`""" - self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/moderation/test_slowmode.py b/tests/bot/cogs/moderation/test_slowmode.py deleted file mode 100644 index f442814c8..000000000 --- a/tests/bot/cogs/moderation/test_slowmode.py +++ /dev/null @@ -1,111 +0,0 @@ -import unittest -from unittest import mock - -from dateutil.relativedelta import relativedelta - -from bot.cogs.moderation.slowmode import Slowmode -from bot.constants import Emojis -from tests.helpers import MockBot, MockContext, MockTextChannel - - -class SlowmodeTests(unittest.IsolatedAsyncioTestCase): - - def setUp(self) -> None: - self.bot = MockBot() - self.cog = Slowmode(self.bot) - self.ctx = MockContext() - - async def test_get_slowmode_no_channel(self) -> None: - """Get slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) - - await self.cog.get_slowmode(self.cog, self.ctx, None) - self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") - - async def test_get_slowmode_with_channel(self) -> None: - """Get slowmode with a given channel.""" - text_channel = MockTextChannel(name='python-language', slowmode_delay=2) - - await self.cog.get_slowmode(self.cog, self.ctx, text_channel) - self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') - - async def test_set_slowmode_no_channel(self) -> None: - """Set slowmode without a given channel.""" - test_cases = ( - ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), - ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), - ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') - ) - - for channel_name, seconds, edited, result_msg in test_cases: - with self.subTest( - channel_mention=channel_name, - seconds=seconds, - edited=edited, - result_msg=result_msg - ): - self.ctx.channel = MockTextChannel(name=channel_name) - - await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) - - if edited: - self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) - else: - self.ctx.channel.edit.assert_not_called() - - self.ctx.send.assert_called_once_with(result_msg) - - self.ctx.reset_mock() - - async def test_set_slowmode_with_channel(self) -> None: - """Set slowmode with a given channel.""" - test_cases = ( - ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), - ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), - ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') - ) - - for channel_name, seconds, edited, result_msg in test_cases: - with self.subTest( - channel_mention=channel_name, - seconds=seconds, - edited=edited, - result_msg=result_msg - ): - text_channel = MockTextChannel(name=channel_name) - - await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) - - if edited: - text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) - else: - text_channel.edit.assert_not_called() - - self.ctx.send.assert_called_once_with(result_msg) - - self.ctx.reset_mock() - - async def test_reset_slowmode_no_channel(self) -> None: - """Reset slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) - - await self.cog.reset_slowmode(self.cog, self.ctx, None) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' - ) - - async def test_reset_slowmode_with_channel(self) -> None: - """Reset slowmode with a given channel.""" - text_channel = MockTextChannel(name='meta', slowmode_delay=1) - - await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' - ) - - @mock.patch("bot.cogs.moderation.slowmode.with_role_check") - @mock.patch("bot.cogs.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) - def test_cog_check(self, role_check): - """Role check is called with `MODERATION_ROLES`""" - self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py deleted file mode 100644 index fdda59a8f..000000000 --- a/tests/bot/cogs/test_cogs.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Test suite for general tests which apply to all cogs.""" - -import importlib -import pkgutil -import typing as t -import unittest -from collections import defaultdict -from types import ModuleType -from unittest import mock - -from discord.ext import commands - -from bot import cogs - - -class CommandNameTests(unittest.TestCase): - """Tests for shadowing command names and aliases.""" - - @staticmethod - def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]: - """An iterator that recursively walks through `cog`'s commands and subcommands.""" - # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods. - for command in cog.__cog_commands__: - if command.parent is None: - yield command - if isinstance(command, commands.GroupMixin): - # Annoyingly it returns duplicates for each alias so use a set to fix that - yield from set(command.walk_commands()) - - @staticmethod - def walk_modules() -> t.Iterator[ModuleType]: - """Yield imported modules from the bot.cogs subpackage.""" - def on_error(name: str) -> t.NoReturn: - raise ImportError(name=name) # pragma: no cover - - # The mock prevents asyncio.get_event_loop() from being called. - with mock.patch("discord.ext.tasks.loop"): - for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): - if not module.ispkg: - yield importlib.import_module(module.name) - - @staticmethod - def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: - """Yield all cogs defined in an extension.""" - for obj in module.__dict__.values(): - # Check if it's a class type cause otherwise issubclass() may raise a TypeError. - is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) - if is_cog and obj.__module__ == module.__name__: - yield obj - - @staticmethod - def get_qualified_names(command: commands.Command) -> t.List[str]: - """Return a list of all qualified names, including aliases, for the `command`.""" - names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] - names.append(command.qualified_name) - - return names - - def get_all_commands(self) -> t.Iterator[commands.Command]: - """Yield all commands for all cogs in all extensions.""" - for module in self.walk_modules(): - for cog in self.walk_cogs(module): - for cmd in self.walk_commands(cog): - yield cmd - - def test_names_dont_shadow(self): - """Names and aliases of commands should be unique.""" - all_names = defaultdict(list) - for cmd in self.get_all_commands(): - func_name = f"{cmd.module}.{cmd.callback.__qualname__}" - - for name in self.get_qualified_names(cmd): - with self.subTest(cmd=func_name, name=name): - if name in all_names: # pragma: no cover - conflicts = ", ".join(all_names.get(name, "")) - self.fail( - f"Name '{name}' of the command {func_name} conflicts with {conflicts}." - ) - - all_names[name].append(func_name) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py deleted file mode 100644 index cfe10aebf..000000000 --- a/tests/bot/cogs/test_duck_pond.py +++ /dev/null @@ -1,548 +0,0 @@ -import asyncio -import logging -import typing -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -import discord - -from bot import constants -from bot.cogs import duck_pond -from tests import base -from tests import helpers - -MODULE_PATH = "bot.cogs.duck_pond" - - -class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): - """Tests for DuckPond functionality.""" - - @classmethod - def setUpClass(cls): - """Sets up the objects that only have to be initialized once.""" - cls.nonstaff_member = helpers.MockMember(name="Non-staffer") - - cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) - cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) - - cls.checkmark_emoji = "\N{White Heavy Check Mark}" - cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" - cls.unicode_duck_emoji = "\N{Duck}" - cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) - cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) - - def setUp(self): - """Sets up the objects that need to be refreshed before each test.""" - self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) - self.cog = duck_pond.DuckPond(bot=self.bot) - - def test_duck_pond_correctly_initializes(self): - """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" - bot = helpers.MockBot() - cog = MagicMock() - - duck_pond.DuckPond.__init__(cog, bot) - - self.assertEqual(cog.bot, bot) - self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) - bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) - - def test_fetch_webhook_succeeds_without_connectivity_issues(self): - """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" - self.bot.fetch_webhook.return_value = "dummy webhook" - self.cog.webhook_id = 1 - - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - self.assertEqual(self.cog.webhook, "dummy webhook") - - def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): - """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" - self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") - self.cog.webhook_id = 1 - - log = logging.getLogger('bot.cogs.duck_pond') - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def test_is_staff_returns_correct_values_based_on_instance_passed(self): - """The `is_staff` method should return correct values based on the instance passed.""" - test_cases = ( - (helpers.MockUser(name="User instance"), False), - (helpers.MockMember(name="Member instance without staff role"), False), - (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) - ) - - for user, expected_return in test_cases: - actual_return = self.cog.is_staff(user) - with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): - """The `has_green_checkmark` method should only return `True` if one is present.""" - test_cases = ( - ( - "No reactions", helpers.MockMessage(), False - ), - ( - "No green check mark reactions", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) - ]), - False - ), - ( - "Green check mark reaction, but not from our bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) - ]), - False - ), - ( - "Green check mark reaction, with one from the bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) - ]), - True - ) - ) - - for description, message, expected_return in test_cases: - actual_return = await self.cog.has_green_checkmark(message) - with self.subTest( - test_case=description, - expected_return=expected_return, - actual_return=actual_return - ): - self.assertEqual(expected_return, actual_return) - - def _get_reaction( - self, - emoji: typing.Union[str, helpers.MockEmoji], - staff: int = 0, - nonstaff: int = 0 - ) -> helpers.MockReaction: - staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] - nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] - return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - - async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): - """The `count_ducks` method should return the number of unique staffers who gave a duck.""" - test_cases = ( - # Simple test cases - # A message without reactions should return 0 - ( - "No reactions", - helpers.MockMessage(), - 0 - ), - # A message with a non-duck reaction from a non-staffer should return 0 - ( - "Non-duck reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), - 0 - ), - # A message with a non-duck reaction from a staffer should return 0 - ( - "Non-duck reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), - 0 - ), - # A message with a non-duck reaction from a non-staffer and staffer should return 0 - ( - "Non-duck reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a non-staffer should return 0 - ( - "Unicode Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a staffer should return 1 - ( - "Unicode Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), - 1 - ), - # A message with a unicode duck reaction from a non-staffer and staffer should return 1 - ( - "Unicode Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer should return 0 - ( - "Duckpond Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), - 0 - ), - # A message with a duckpond duck reaction from a staffer should return 1 - ( - "Duckpond Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 - ( - "Duckpond Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), - 1 - ), - - # Complex test cases - # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), - 3 - ), - # A staffer with multiple duck reactions only counts once - ( - "Two different duck reactions from the same staffer", - helpers.MockMessage( - reactions=[ - helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), - ] - ), - 1 - ), - # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) - ( - "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), - 0 - ), - # We correctly sum when multiple reactions are provided. - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage( - reactions=[ - self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), - self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), - ] - ), - 3 + 4 - ), - ) - - for description, message, expected_count in test_cases: - actual_count = await self.cog.count_ducks(message) - with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): - self.assertEqual(expected_count, actual_count) - - async def test_relay_message_correctly_relays_content_and_attachments(self): - """The `relay_message` method should correctly relay message content and attachments.""" - send_webhook_path = f"{MODULE_PATH}.send_webhook" - send_attachments_path = f"{MODULE_PATH}.send_attachments" - author = MagicMock( - display_name="x", - avatar_url="https://" - ) - - self.cog.webhook = helpers.MockAsyncWebhook() - - test_values = ( - (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), - (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), - (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), - (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), - ) - - for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: - with self.subTest(clean_content=message.clean_content, attachments=message.attachments): - await self.cog.relay_message(message) - - self.assertEqual(expect_webhook_call, send_webhook.called) - self.assertEqual(expect_attachment_call, send_attachments.called) - - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.cogs.duck_pond") - - for side_effect in side_effects: # pragma: no cover - send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertNotLogs(logger=log, level=logging.ERROR): - await self.cog.relay_message(message) - - self.assertEqual(send_webhook.call_count, 2) - - @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.cogs.duck_pond") - - side_effect = discord.HTTPException(MagicMock(), "") - send_attachments.side_effect = side_effect - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - await self.cog.relay_message(message) - - send_webhook.assert_called_once_with( - webhook=self.cog.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): - """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" - payload = MagicMock(name=label) - payload.emoji.is_custom_emoji.return_value = is_custom_emoji - payload.emoji.id = id_ - payload.emoji.name = emoji_name - return payload - - async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): - """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" - test_values = ( - # Custom Emojis - ( - self._mock_payload( - label="Custom Duckpond Emoji", - is_custom_emoji=True, - id_=constants.DuckPond.custom_emojis[0], - emoji_name="" - ), - True - ), - ( - self._mock_payload( - label="Custom Non-Duckpond Emoji", - is_custom_emoji=True, - id_=123, - emoji_name="" - ), - False - ), - # Unicode Emojis - ( - self._mock_payload( - label="Unicode Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.unicode_duck_emoji - ), - True - ), - ( - self._mock_payload( - label="Unicode Non-Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.thumbs_up_emoji - ), - False - ), - ) - - for payload, expected_return in test_values: - actual_return = self.cog._payload_has_duckpond_emoji(payload) - with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - @patch(f"{MODULE_PATH}.discord.utils.get") - @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) - def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): - """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) - - # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check - utils_get.assert_not_called() - - def _raw_reaction_mocks(self, channel_id, message_id, user_id): - """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" - channel = helpers.MockTextChannel(id=channel_id) - self.bot.get_all_channels.return_value = (channel,) - - message = helpers.MockMessage(id=message_id) - - channel.fetch_message.return_value = message - - member = helpers.MockMember(id=user_id, roles=[self.staff_role]) - message.guild.members = (member,) - - payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) - - return channel, message, member, payload - - async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): - """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" - channel_id = 1234 - message_id = 2345 - user_id = 3456 - - channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - test_cases = ( - ("non-staff member", helpers.MockMember(id=user_id)), - ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), - ) - - payload.emoji = self.duck_pond_emoji - - for description, member in test_cases: - message.guild.members = (member, ) - with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: - checkmark.side_effect = AssertionError( - "Expected method to return before calling `self.has_green_checkmark`." - ) - self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) - - # Check that we did make it past the payload checks - channel.fetch_message.assert_called_once() - channel.fetch_message.reset_mock() - - @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) - def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): - """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" - channel_id = 31415926535 - message_id = 27182818284 - user_id = 16180339887 - - channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) - payload.emoji.is_custom_emoji.return_value = False - - message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] - - is_staff.return_value = True - count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) - - # Assert that we've made it past `self.is_staff` - is_staff.assert_called_once() - - async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): - """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - - channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) - - payload.emoji = self.duck_pond_emoji - - for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_relay=should_relay): - await self.cog.on_raw_reaction_add(payload) - - # Confirm that we've made it past counting - count_ducks.assert_called_once() - - # Did we relay a message? - has_relayed = relay_message.called - self.assertEqual(has_relayed, should_relay) - - if should_relay: - relay_message.assert_called_once_with(message) - - async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): - """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" - checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) - - message = helpers.MockMessage(id=1234) - - channel = helpers.MockTextChannel(id=98765) - channel.fetch_message.return_value = message - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) - - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): - await self.cog.on_raw_reaction_remove(payload) - - # Check if we fetched the message - channel.fetch_message.assert_called_once_with(message.id) - - # Check if we actually counted the number of ducks - count_ducks.assert_called_once_with(message) - - has_re_added_checkmark = message.add_reaction.called - self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) - - if should_re_add_checkmark: - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - message.add_reaction.reset_mock() - - # reset mocks - channel.fetch_message.reset_mock() - message.reset_mock() - - def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): - """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" - channel = helpers.MockTextChannel(id=98765) - - channel.fetch_message.side_effect = AssertionError( - "Expected method to return before calling `channel.fetch_message`" - ) - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) - - channel.fetch_message.assert_not_called() - - -class DuckPondSetupTests(unittest.TestCase): - """Tests setup of the `DuckPond` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = helpers.MockBot() - duck_pond.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/utils/__init__.py b/tests/bot/cogs/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/utils/test_jams.py b/tests/bot/cogs/utils/test_jams.py deleted file mode 100644 index 299f436ba..000000000 --- a/tests/bot/cogs/utils/test_jams.py +++ /dev/null @@ -1,173 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, MagicMock, create_autospec - -from discord import CategoryChannel - -from bot.cogs.utils import jams -from bot.constants import Roles -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel - - -def get_mock_category(channel_count: int, name: str) -> CategoryChannel: - """Return a mocked code jam category.""" - category = create_autospec(CategoryChannel, spec_set=True, instance=True) - category.name = name - category.channels = [MockTextChannel() for _ in range(channel_count)] - - return category - - -class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): - """Tests for `createteam` command.""" - - def setUp(self): - self.bot = MockBot() - self.admin_role = MockRole(name="Admins", id=Roles.admins) - self.command_user = MockMember([self.admin_role]) - self.guild = MockGuild([self.admin_role]) - self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) - self.cog = jams.CodeJams(self.bot) - - async def test_too_small_amount_of_team_members_passed(self): - """Should `ctx.send` and exit early when too small amount of members.""" - for case in (1, 2): - with self.subTest(amount_of_members=case): - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - self.ctx.reset_mock() - members = (MockMember() for _ in range(case)) - await self.cog.createteam(self.cog, self.ctx, "foo", members) - - self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_not_awaited() - self.cog.add_roles.assert_not_awaited() - - async def test_duplicate_members_provided(self): - """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - member = MockMember() - await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) - - self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_not_awaited() - self.cog.add_roles.assert_not_awaited() - - async def test_result_sending(self): - """Should call `ctx.send` when everything goes right.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - members = [MockMember() for _ in range(5)] - await self.cog.createteam(self.cog, self.ctx, "foo", members) - - self.cog.create_channels.assert_awaited_once() - self.cog.add_roles.assert_awaited_once() - self.ctx.send.assert_awaited_once() - - async def test_category_doesnt_exist(self): - """Should create a new code jam category.""" - subtests = ( - [], - [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], - [get_mock_category(jams.MAX_CHANNELS - 2, "other")], - ) - - for categories in subtests: - self.guild.reset_mock() - self.guild.categories = categories - - with self.subTest(categories=categories): - actual_category = await self.cog.get_category(self.guild) - - self.guild.create_category_channel.assert_awaited_once() - category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] - - self.assertFalse(category_overwrites[self.guild.default_role].read_messages) - self.assertTrue(category_overwrites[self.guild.me].read_messages) - self.assertEqual(self.guild.create_category_channel.return_value, actual_category) - - async def test_category_channel_exist(self): - """Should not try to create category channel.""" - expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) - self.guild.categories = [ - get_mock_category(jams.MAX_CHANNELS - 2, "other"), - expected_category, - get_mock_category(0, jams.CATEGORY_NAME), - ] - - actual_category = await self.cog.get_category(self.guild) - self.assertEqual(expected_category, actual_category) - - async def test_channel_overwrites(self): - """Should have correct permission overwrites for users and roles.""" - leader = MockMember() - members = [leader] + [MockMember() for _ in range(4)] - overwrites = self.cog.get_overwrites(members, self.guild) - - # Leader permission overwrites - self.assertTrue(overwrites[leader].manage_messages) - self.assertTrue(overwrites[leader].read_messages) - self.assertTrue(overwrites[leader].manage_webhooks) - self.assertTrue(overwrites[leader].connect) - - # Other members permission overwrites - for member in members[1:]: - self.assertTrue(overwrites[member].read_messages) - self.assertTrue(overwrites[member].connect) - - # Everyone and verified role overwrite - self.assertFalse(overwrites[self.guild.default_role].read_messages) - self.assertFalse(overwrites[self.guild.default_role].connect) - self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) - self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) - - async def test_team_channels_creation(self): - """Should create new voice and text channel for team.""" - members = [MockMember() for _ in range(5)] - - self.cog.get_overwrites = MagicMock() - self.cog.get_category = AsyncMock() - self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") - actual = await self.cog.create_channels(self.guild, "my-team", members) - - self.assertEqual("foobar-channel", actual) - self.cog.get_overwrites.assert_called_once_with(members, self.guild) - self.cog.get_category.assert_awaited_once_with(self.guild) - - self.guild.create_text_channel.assert_awaited_once_with( - "my-team", - overwrites=self.cog.get_overwrites.return_value, - category=self.cog.get_category.return_value - ) - self.guild.create_voice_channel.assert_awaited_once_with( - "My Team", - overwrites=self.cog.get_overwrites.return_value, - category=self.cog.get_category.return_value - ) - - async def test_jam_roles_adding(self): - """Should add team leader role to leader and jam role to every team member.""" - leader_role = MockRole(name="Team Leader") - jam_role = MockRole(name="Jammer") - self.guild.get_role.side_effect = [leader_role, jam_role] - - leader = MockMember() - members = [leader] + [MockMember() for _ in range(4)] - await self.cog.add_roles(self.guild, members) - - leader.add_roles.assert_any_await(leader_role) - for member in members: - member.add_roles.assert_any_await(jam_role) - - -class CodeJamSetup(unittest.TestCase): - """Test for `setup` function of `CodeJam` cog.""" - - def test_setup(self): - """Should call `bot.add_cog`.""" - bot = MockBot() - jams.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/utils/test_snekbox.py b/tests/bot/cogs/utils/test_snekbox.py deleted file mode 100644 index 3e447f319..000000000 --- a/tests/bot/cogs/utils/test_snekbox.py +++ /dev/null @@ -1,409 +0,0 @@ -import asyncio -import logging -import unittest -from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch - -from discord.ext import commands - -from bot import constants -from bot.cogs.utils import snekbox -from bot.cogs.utils.snekbox import Snekbox -from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser - - -class SnekboxTests(unittest.IsolatedAsyncioTestCase): - def setUp(self): - """Add mocked bot and cog to the instance.""" - self.bot = MockBot() - self.cog = Snekbox(bot=self.bot) - - async def test_post_eval(self): - """Post the eval code to the URLs.snekbox_eval_api endpoint.""" - resp = MagicMock() - resp.json = AsyncMock(return_value="return") - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual(await self.cog.post_eval("import random"), "return") - self.bot.http_session.post.assert_called_with( - constants.URLs.snekbox_eval_api, - json={"input": "import random"}, - raise_for_status=True - ) - resp.json.assert_awaited_once() - - async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LEN.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) - self.assertEqual(result, "too long to upload") - - async def test_upload_output(self): - """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - key = "MarkDiamond" - resp = MagicMock() - resp.json = AsyncMock(return_value={"key": key}) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual( - await self.cog.upload_output("My awesome output"), - constants.URLs.paste_service.format(key=key) - ) - self.bot.http_session.post.assert_called_with( - constants.URLs.paste_service.format(key="documents"), - data="My awesome output", - raise_for_status=True - ) - - async def test_upload_output_gracefully_fallback_if_exception_during_request(self): - """Output upload gracefully fallback if the upload fail.""" - resp = MagicMock() - resp.json = AsyncMock(side_effect=Exception) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - log = logging.getLogger("bot.cogs.utils.snekbox") - with self.assertLogs(logger=log, level='ERROR'): - await self.cog.upload_output('My awesome output!') - - async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): - """Output upload gracefully fallback if there is no key entry in the response body.""" - self.assertEqual((await self.cog.upload_output('My awesome output!')), None) - - def test_prepare_input(self): - cases = ( - ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), - ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), - ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), - ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), - ) - for case, expected, testname in cases: - with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) - - def test_get_results_message(self): - """Return error and message according to the eval result.""" - cases = ( - ('ERROR', None, ('Your eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), - ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) - - @patch('bot.cogs.utils.snekbox.Signals', side_effect=ValueError) - def test_get_results_message_invalid_signal(self, mock_signals: Mock): - self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127', '') - ) - - @patch('bot.cogs.utils.snekbox.Signals') - def test_get_results_message_valid_signal(self, mock_signals: Mock): - mock_signals.return_value.name = 'SIGTEST' - self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127 (SIGTEST)', '') - ) - - def test_get_status_emoji(self): - """Return emoji according to the eval result.""" - cases = ( - (' ', -1, ':warning:'), - ('Hello world!', 0, ':white_check_mark:'), - ('Invalid beard size', -1, ':x:') - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) - - async def test_format_output(self): - """Test output formatting.""" - self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') - - too_many_lines = ( - '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' - '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' - ) - too_long_too_many_lines = ( - "\n".join( - f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) - )[:1000] + "\n... (truncated - too long, too many lines)" - ) - - cases = ( - ('', ('[No output]', None), 'No output'), - ('My awesome output', ('My awesome output', None), 'One line output'), - ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), - ('" else mock_.__name__ + + with self.subTest(msg=subtest_msg): + _, mock_message = mock_() + await self.syncer._send_prompt(message_arg) + + calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] + mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): + """Tests for waiting for a sync confirmation reaction on the prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) + + @staticmethod + def get_message_reaction(emoji): + """Fixture to return a mock message an reaction from the given `emoji`.""" + message = helpers.MockMessage() + reaction = helpers.MockReaction(emoji=emoji, message=message) + + return message, reaction + + def test_reaction_check_for_valid_emoji_and_authors(self): + """Should return True if authors are identical or are a bot and a core dev, respectively.""" + user_subtests = ( + ( + helpers.MockMember(id=77), + helpers.MockMember(id=77), + "identical users", + ), + ( + helpers.MockMember(id=77, bot=True), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "bot author and core-dev reactor", + ), + ) + + for emoji in self.syncer._REACTION_EMOJIS: + for author, user, msg in user_subtests: + with self.subTest(author=author, user=user, emoji=emoji, msg=msg): + message, reaction = self.get_message_reaction(emoji) + ret_val = self.syncer._reaction_check(author, message, reaction, user) + + self.assertTrue(ret_val) + + def test_reaction_check_for_invalid_reactions(self): + """Should return False for invalid reaction events.""" + valid_emoji = self.syncer._REACTION_EMOJIS[0] + subtests = ( + ( + helpers.MockMember(id=77), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "users are not identical", + ), + ( + helpers.MockMember(id=77, bot=True), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43), + "reactor lacks the core-dev role", + ), + ( + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + "reactor is a bot", + ), + ( + helpers.MockMember(id=77), + helpers.MockMessage(id=95), + helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), + helpers.MockMember(id=77), + "messages are not identical", + ), + ( + helpers.MockMember(id=77), + *self.get_message_reaction("InVaLiD"), + helpers.MockMember(id=77), + "emoji is invalid", + ), + ) + + for *args, msg in subtests: + kwargs = dict(zip(("author", "message", "reaction", "user"), args)) + with self.subTest(**kwargs, msg=msg): + ret_val = self.syncer._reaction_check(*args) + self.assertFalse(ret_val) + + async def test_wait_for_confirmation(self): + """The message should always be edited and only return True if the emoji is a check mark.""" + subtests = ( + (constants.Emojis.check_mark, True, None), + ("InVaLiD", False, None), + (None, False, asyncio.TimeoutError), + ) + + for emoji, ret_val, side_effect in subtests: + for bot in (True, False): + with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): + # Set up mocks + message = helpers.MockMessage() + member = helpers.MockMember(bot=bot) + + self.bot.wait_for.reset_mock() + self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) + self.bot.wait_for.side_effect = side_effect + + # Call the function + actual_return = await self.syncer._wait_for_confirmation(member, message) + + # Perform assertions + self.bot.wait_for.assert_called_once() + self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + + message.edit.assert_called_once() + kwargs = message.edit.call_args[1] + self.assertIn("content", kwargs) + + # Core devs should only be mentioned if the author is a bot. + if bot: + self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + else: + self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + + self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for main function orchestrating the sync.""" + + def setUp(self): + self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) + self.syncer = TestSyncer(self.bot) + + async def test_sync_respects_confirmation_result(self): + """The sync should abort if confirmation fails and continue if confirmed.""" + mock_message = helpers.MockMessage() + subtests = ( + (True, mock_message), + (False, None), + ) + + for confirmed, message in subtests: + with self.subTest(confirmed=confirmed): + self.syncer._sync.reset_mock() + self.syncer._get_diff.reset_mock() + + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(confirmed, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + + if confirmed: + self.syncer._sync.assert_called_once_with(diff) + else: + self.syncer._sync.assert_not_called() + + async def test_sync_diff_size(self): + """The diff size should be correctly calculated.""" + subtests = ( + (6, _Diff({1, 2}, {3, 4}, {5, 6})), + (5, _Diff({1, 2, 3}, None, {4, 5})), + (0, _Diff(None, None, None)), + (0, _Diff(set(), set(), set())), + ) + + for size, diff in subtests: + with self.subTest(size=size, diff=diff): + self.syncer._get_diff.reset_mock() + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + + async def test_sync_message_edited(self): + """The message should be edited if one was sent, even if the sync has an API error.""" + subtests = ( + (None, None, False), + (helpers.MockMessage(), None, True), + (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), + ) + + for message, side_effect, should_edit in subtests: + with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): + self.syncer._sync.side_effect = side_effect + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(True, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + if should_edit: + message.edit.assert_called_once() + self.assertIn("content", message.edit.call_args[1]) + + async def test_sync_confirmation_context_redirect(self): + """If ctx is given, a new message should be sent and author should be ctx's author.""" + mock_member = helpers.MockMember() + subtests = ( + (None, self.bot.user, None), + (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + ) + + for ctx, author, message in subtests: + with self.subTest(ctx=ctx, author=author, message=message): + if ctx is not None: + ctx.send.return_value = message + + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock + self.syncer._get_diff.return_value = mock.MagicMock() + + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild, ctx) + + if ctx is not None: + ctx.send.assert_called_once() + + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_small_diff(self): + """Should always return True and the given message if the diff size is too small.""" + author = helpers.MockMember() + expected_message = helpers.MockMessage() + + for size in (3, 2): # pragma: no cover + with self.subTest(size=size): + self.syncer._send_prompt = mock.AsyncMock() + self.syncer._wait_for_confirmation = mock.AsyncMock() + + coro = self.syncer._get_confirmation_result(size, author, expected_message) + result, actual_message = await coro + + self.assertTrue(result) + self.assertEqual(actual_message, expected_message) + self.syncer._send_prompt.assert_not_called() + self.syncer._wait_for_confirmation.assert_not_called() + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_large_diff(self): + """Should return True if confirmed and False if _send_prompt fails or aborted.""" + author = helpers.MockMember() + mock_message = helpers.MockMessage() + + subtests = ( + (True, mock_message, True, "confirmed"), + (False, None, False, "_send_prompt failed"), + (False, mock_message, False, "aborted"), + ) + + for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover + with self.subTest(msg=msg): + self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) + + coro = self.syncer._get_confirmation_result(4, author) + actual_result, actual_message = await coro + + self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None + self.assertIs(actual_result, expected_result) + self.assertEqual(actual_message, expected_message) + + if expected_message: + self.syncer._wait_for_confirmation.assert_called_once_with( + author, expected_message + ) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py new file mode 100644 index 000000000..1b89564f2 --- /dev/null +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -0,0 +1,416 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.exts.backend import sync +from bot.exts.backend.sync._cog import Sync +from bot.exts.backend.sync._syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): + """Tests for the sync extension.""" + + @staticmethod + def test_extension_setup(): + """The Sync cog should be added.""" + bot = helpers.MockBot() + sync.setup(bot) + bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): + """Base class for Sync cog tests. Sets up patches for syncers.""" + + def setUp(self): + self.bot = helpers.MockBot() + + self.role_syncer_patcher = mock.patch( + "bot.exts.backend.sync._syncers.RoleSyncer", + autospec=Syncer, + spec_set=True + ) + self.user_syncer_patcher = mock.patch( + "bot.exts.backend.sync._syncers.UserSyncer", + autospec=Syncer, + spec_set=True + ) + self.RoleSyncer = self.role_syncer_patcher.start() + self.UserSyncer = self.user_syncer_patcher.start() + + self.cog = Sync(self.bot) + + def tearDown(self): + self.role_syncer_patcher.stop() + self.user_syncer_patcher.stop() + + @staticmethod + def response_error(status: int) -> ResponseCodeError: + """Fixture to return a ResponseCodeError with the given status code.""" + response = mock.MagicMock() + response.status = status + + return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): + """Tests for the Sync cog.""" + + @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) + def test_sync_cog_init(self, sync_guild): + """Should instantiate syncers and run a sync for the guild.""" + # Reset because a Sync cog was already instantiated in setUp. + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + self.bot.loop.create_task = mock.MagicMock() + + mock_sync_guild_coro = mock.MagicMock() + sync_guild.return_value = mock_sync_guild_coro + + Sync(self.bot) + + self.RoleSyncer.assert_called_once_with(self.bot) + self.UserSyncer.assert_called_once_with(self.bot) + sync_guild.assert_called_once_with() + self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + + async def test_sync_cog_sync_guild(self): + """Roles and users should be synced only if a guild is successfully retrieved.""" + for guild in (helpers.MockGuild(), None): + with self.subTest(guild=guild): + self.bot.reset_mock() + self.cog.role_syncer.reset_mock() + self.cog.user_syncer.reset_mock() + + self.bot.get_guild = mock.MagicMock(return_value=guild) + + await self.cog.sync_guild() + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(constants.Guild.id) + + if guild is None: + self.cog.role_syncer.sync.assert_not_called() + self.cog.user_syncer.sync.assert_not_called() + else: + self.cog.role_syncer.sync.assert_called_once_with(guild) + self.cog.user_syncer.sync.assert_called_once_with(guild) + + async def patch_user_helper(self, side_effect: BaseException) -> None: + """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" + self.bot.api_client.patch.reset_mock(side_effect=True) + self.bot.api_client.patch.side_effect = side_effect + + user_id, updated_information = 5, {"key": 123} + await self.cog.patch_user(user_id, updated_information) + + self.bot.api_client.patch.assert_called_once_with( + f"bot/users/{user_id}", + json=updated_information, + ) + + async def test_sync_cog_patch_user(self): + """A PATCH request should be sent and 404 errors ignored.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + await self.patch_user_helper(side_effect) + + async def test_sync_cog_patch_user_non_404(self): + """A PATCH request should be sent and the error raised if it's not a 404.""" + with self.assertRaises(ResponseCodeError): + await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): + """Tests for the listeners of the Sync cog.""" + + def setUp(self): + super().setUp() + self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) + + self.guild_id_patcher = mock.patch("bot.exts.backend.sync._cog.constants.Guild.id", 5) + self.guild_id = self.guild_id_patcher.start() + + self.guild = helpers.MockGuild(id=self.guild_id) + self.other_guild = helpers.MockGuild(id=0) + + def tearDown(self): + self.guild_id_patcher.stop() + + async def test_sync_cog_on_guild_role_create(self): + """A POST request should be sent with the new role's data.""" + self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + role = helpers.MockRole(**role_data, guild=self.guild) + await self.cog.on_guild_role_create(role) + + self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + + async def test_sync_cog_on_guild_role_create_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_create(role) + self.bot.api_client.post.assert_not_awaited() + + async def test_sync_cog_on_guild_role_delete(self): + """A DELETE request should be sent.""" + self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + + role = helpers.MockRole(id=99, guild=self.guild) + await self.cog.on_guild_role_delete(role) + + self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + + async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_delete(role) + self.bot.api_client.delete.assert_not_awaited() + + async def test_sync_cog_on_guild_role_update(self): + """A PUT request should be sent if the colour, name, permissions, or position changes.""" + self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + subtests = ( + (True, ("colour", "name", "permissions", "position")), + (False, ("hoist", "mentionable")), + ) + + for should_put, attributes in subtests: + for attribute in attributes: + with self.subTest(should_put=should_put, changed_attribute=attribute): + self.bot.api_client.put.reset_mock() + + after_role_data = role_data.copy() + after_role_data[attribute] = 876 + + before_role = helpers.MockRole(**role_data, guild=self.guild) + after_role = helpers.MockRole(**after_role_data, guild=self.guild) + + await self.cog.on_guild_role_update(before_role, after_role) + + if should_put: + self.bot.api_client.put.assert_called_once_with( + f"bot/roles/{after_role.id}", + json=after_role_data + ) + else: + self.bot.api_client.put.assert_not_called() + + async def test_sync_cog_on_guild_role_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_update(role, role) + self.bot.api_client.put.assert_not_awaited() + + async def test_sync_cog_on_member_remove(self): + """Member should be patched to set in_guild as False.""" + self.assertTrue(self.cog.on_member_remove.__cog_listener__) + + member = helpers.MockMember(guild=self.guild) + await self.cog.on_member_remove(member) + + self.cog.patch_user.assert_called_once_with( + member.id, + json={"in_guild": False} + ) + + async def test_sync_cog_on_member_remove_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_remove(member) + self.cog.patch_user.assert_not_awaited() + + async def test_sync_cog_on_member_update_roles(self): + """Members should be patched if their roles have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + # Roles are intentionally unsorted. + before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] + before_member = helpers.MockMember(roles=before_roles, guild=self.guild) + after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) + + await self.cog.on_member_update(before_member, after_member) + + data = {"roles": sorted(role.id for role in after_member.roles)} + self.cog.patch_user.assert_called_once_with(after_member.id, json=data) + + async def test_sync_cog_on_member_update_other(self): + """Members should not be patched if other attributes have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + subtests = ( + ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("nick", "old nick", "new nick"), + ("status", discord.Status.online, discord.Status.offline), + ) + + for attribute, old_value, new_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) + after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) + + await self.cog.on_member_update(before_member, after_member) + + self.cog.patch_user.assert_not_called() + + async def test_sync_cog_on_member_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_update(member, member) + self.cog.patch_user.assert_not_awaited() + + async def test_sync_cog_on_user_update(self): + """A user should be patched only if the name, discriminator, or avatar changes.""" + self.assertTrue(self.cog.on_user_update.__cog_listener__) + + before_data = { + "name": "old name", + "discriminator": "1234", + "bot": False, + } + + subtests = ( + (True, "name", "name", "new name", "new name"), + (True, "discriminator", "discriminator", "8765", 8765), + (False, "bot", "bot", True, True), + ) + + for should_patch, attribute, api_field, value, api_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + after_data = before_data.copy() + after_data[attribute] = value + before_user = helpers.MockUser(**before_data) + after_user = helpers.MockUser(**after_data) + + await self.cog.on_user_update(before_user, after_user) + + if should_patch: + self.cog.patch_user.assert_called_once() + + # Don't care if *all* keys are present; only the changed one is required + call_args = self.cog.patch_user.call_args + self.assertEqual(call_args.args[0], after_user.id) + self.assertIn("json", call_args.kwargs) + + self.assertIn("ignore_404", call_args.kwargs) + self.assertTrue(call_args.kwargs["ignore_404"]) + + json = call_args.kwargs["json"] + self.assertIn(api_field, json) + self.assertEqual(json[api_field], api_value) + else: + self.cog.patch_user.assert_not_called() + + async def on_member_join_helper(self, side_effect: Exception) -> dict: + """ + Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + + The request data for the mock member is returned. All exceptions will be re-raised. + """ + member = helpers.MockMember( + discriminator="1234", + roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + guild=self.guild, + ) + + data = { + "discriminator": int(member.discriminator), + "id": member.id, + "in_guild": True, + "name": member.name, + "roles": sorted(role.id for role in member.roles) + } + + self.bot.api_client.put.reset_mock(side_effect=True) + self.bot.api_client.put.side_effect = side_effect + + try: + await self.cog.on_member_join(member) + except Exception: + raise + finally: + self.bot.api_client.put.assert_called_once_with( + f"bot/users/{member.id}", + json=data + ) + + return data + + async def test_sync_cog_on_member_join(self): + """Should PUT user's data or POST it if the user doesn't exist.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + self.bot.api_client.post.reset_mock() + data = await self.on_member_join_helper(side_effect) + + if side_effect: + self.bot.api_client.post.assert_called_once_with("bot/users", json=data) + else: + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_non_404(self): + """ResponseCodeError should be re-raised if status code isn't a 404.""" + with self.assertRaises(ResponseCodeError): + await self.on_member_join_helper(self.response_error(500)) + + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_join(member) + self.bot.api_client.post.assert_not_awaited() + self.bot.api_client.put.assert_not_awaited() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): + """Tests for the commands in the Sync cog.""" + + async def test_sync_roles_command(self): + """sync() should be called on the RoleSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_roles_command.callback(self.cog, ctx) + + self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_sync_users_command(self): + """sync() should be called on the UserSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_users_command.callback(self.cog, ctx) + + self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_commands_require_admin(self): + """The sync commands should only run if the author has the administrator permission.""" + cmds = ( + self.cog.sync_group, + self.cog.sync_roles_command, + self.cog.sync_users_command, + ) + + for cmd in cmds: + with self.subTest(cmd=cmd): + await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py new file mode 100644 index 000000000..7b9f40cad --- /dev/null +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -0,0 +1,157 @@ +import unittest +from unittest import mock + +import discord + +from bot.exts.backend.sync._syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): + """Fixture to return a dictionary representing a role with default values set.""" + kwargs.setdefault("id", 9) + kwargs.setdefault("name", "fake role") + kwargs.setdefault("colour", 7) + kwargs.setdefault("permissions", 0) + kwargs.setdefault("position", 55) + + return kwargs + + +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @staticmethod + def get_guild(*roles): + """Fixture to return a guild object with the given roles.""" + guild = helpers.MockGuild() + guild.roles = [] + + for role in roles: + mock_role = helpers.MockRole(**role) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) + guild.roles.append(mock_role) + + return guild + + async def test_empty_diff_for_identical_roles(self): + """No differences should be found if the roles in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_roles(self): + """Only updated roles should be added to the 'updated' set of the diff.""" + updated_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] + guild = self.get_guild(updated_role, fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_Role(**updated_role)}, set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_roles(self): + """Only new roles should be added to the 'created' set of the diff.""" + new_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role(), new_role) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new_role)}, set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_deleted_roles(self): + """Only deleted roles should be added to the 'deleted' set of the diff.""" + deleted_role = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [fake_role(), deleted_role] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), {_Role(**deleted_role)}) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_deleted_roles(self): + """When roles are added, updated, and removed, all of them are returned properly.""" + new = fake_role(id=41, name="new") + updated = fake_role(id=71, name="updated") + deleted = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [ + fake_role(), + fake_role(id=71, name="updated name"), + deleted, + ] + guild = self.get_guild(fake_role(), new, updated) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + + self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync roles.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + async def test_sync_created_roles(self): + """Only POST requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(role_tuples, set(), set()) + await self.syncer._sync(diff) + + calls = [mock.call("bot/roles", json=role) for role in roles] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_roles(self): + """Only PUT requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), role_tuples, set()) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_deleted_roles(self): + """Only DELETE requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), set(), role_tuples) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] + self.bot.api_client.delete.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py new file mode 100644 index 000000000..c0a1da35c --- /dev/null +++ b/tests/bot/exts/backend/sync/test_users.py @@ -0,0 +1,158 @@ +import unittest +from unittest import mock + +from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User +from tests import helpers + + +def fake_user(**kwargs): + """Fixture to return a dictionary representing a user with default values set.""" + kwargs.setdefault("id", 43) + kwargs.setdefault("name", "bob the test man") + kwargs.setdefault("discriminator", 1337) + kwargs.setdefault("roles", (666,)) + kwargs.setdefault("in_guild", True) + + return kwargs + + +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between users in the DB and users in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @staticmethod + def get_guild(*members): + """Fixture to return a guild object with the given members.""" + guild = helpers.MockGuild() + guild.members = [] + + for member in members: + member = member.copy() + del member["in_guild"] + + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + + guild.members.append(mock_member) + + return guild + + async def test_empty_diff_for_no_users(self): + """When no users are given, an empty diff should be returned.""" + guild = self.get_guild() + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_identical_users(self): + """No differences should be found if the users in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_users(self): + """Only updated users should be added to the 'updated' set of the diff.""" + updated_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + guild = self.get_guild(updated_user, fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**updated_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_users(self): + """Only new users should be added to the 'created' set of the diff.""" + new_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user(), new_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_sets_in_guild_false_for_leaving_users(self): + """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_leaving_users(self): + """When users are added, updated, and removed, all of them are returned properly.""" + new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + guild = self.get_guild(fake_user(), new_user, updated_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_db_users_not_in_guild(self): + """When the DB knows a user the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync users.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + async def test_sync_created_users(self): + """Only POST requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(user_tuples, set(), None) + await self.syncer._sync(diff) + + calls = [mock.call("bot/users", json=user) for user in users] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(users)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_users(self): + """Only PUT requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(set(), user_tuples, None) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(users)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/exts/backend/test_logging.py b/tests/bot/exts/backend/test_logging.py new file mode 100644 index 000000000..466f207d9 --- /dev/null +++ b/tests/bot/exts/backend/test_logging.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import patch + +from bot import constants +from bot.exts.backend.logging import Logging +from tests.helpers import MockBot, MockTextChannel + + +class LoggingTests(unittest.IsolatedAsyncioTestCase): + """Test cases for connected login.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Logging(self.bot) + self.dev_log = MockTextChannel(id=1234, name="dev-log") + + @patch("bot.exts.backend.logging.DEBUG_MODE", False) + async def test_debug_mode_false(self): + """Should send connected message to dev-log.""" + self.bot.get_channel.return_value = self.dev_log + + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) + self.dev_log.send.assert_awaited_once() + + @patch("bot.exts.backend.logging.DEBUG_MODE", True) + async def test_debug_mode_true(self): + """Should not send anything to dev-log.""" + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_not_called() diff --git a/tests/bot/exts/filters/__init__.py b/tests/bot/exts/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py new file mode 100644 index 000000000..960894e5c --- /dev/null +++ b/tests/bot/exts/filters/test_antimalware.py @@ -0,0 +1,165 @@ +import unittest +from unittest.mock import AsyncMock, Mock + +from discord import NotFound + +from bot.constants import Channels, STAFF_ROLES +from bot.exts.filters import antimalware +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + + +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): + """Test the AntiMalware cog.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = MockBot() + self.bot.filter_list_cache = { + "FILE_FORMAT.True": { + ".first": {}, + ".second": {}, + ".third": {}, + } + } + self.cog = antimalware.AntiMalware(self.bot) + self.message = MockMessage() + self.whitelist = [".first", ".second", ".third"] + + async def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should not be deleted""" + attachment = MockAttachment(filename="python.first") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_message_without_attachment(self): + """Messages without attachments should result in no action.""" + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_direct_message_with_attachment(self): + """Direct messages should have no action taken.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.guild = None + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_message_with_illegal_extension_gets_deleted(self): + """A message containing an illegal extension should send an embed.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_called_once() + + async def test_message_send_by_staff(self): + """A message send by a member of staff should be ignored.""" + staff_role = MockRole(id=STAFF_ROLES[0]) + self.message.author.roles.append(staff_role) + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_python_file_redirect_embed_description(self): + """A message containing a .py file should result in an embed redirecting the user to our paste site""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + + self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt file should result in the correct embed.""" + attachment = MockAttachment(filename="python.txt") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.TXT_EMBED_DESCRIPTION = Mock() + antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + cmd_channel = self.bot.get_channel(Channels.bot_commands) + + self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) + antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + + async def test_other_disallowed_extension_embed_description(self): + """Test the description for a non .py/.txt disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + meta_channel = self.bot.get_channel(Channels.meta) + + self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + joined_whitelist=", ".join(self.whitelist), + blocked_extensions_str=".disallowed", + meta_channel_mention=meta_channel.mention + ) + + async def test_removing_deleted_message_logs(self): + """Removing an already deleted message logs the correct message""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + self.message.delete.assert_called_once() + + async def test_message_with_illegal_attachment_logs(self): + """Deleting a message with an illegal attachment should result in a log.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + ([], []), + (self.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], [".disallowed"]), + ([".disallowed"], [".disallowed"]), + ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] + disallowed_extensions = self.cog._get_disallowed_extensions(self.message) + self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): + """Tests setup of the `AntiMalware` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + antimalware.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/filters/test_antispam.py b/tests/bot/exts/filters/test_antispam.py new file mode 100644 index 000000000..6a0e4fded --- /dev/null +++ b/tests/bot/exts/filters/test_antispam.py @@ -0,0 +1,35 @@ +import unittest + +from bot.exts.filters import antispam + + +class AntispamConfigurationValidationTests(unittest.TestCase): + """Tests validation of the antispam cog configuration.""" + + def test_default_antispam_config_is_valid(self): + """The default antispam configuration is valid.""" + validation_errors = antispam.validate_config() + self.assertEqual(validation_errors, {}) + + def test_unknown_rule_returns_error(self): + """Configuring an unknown rule returns an error.""" + self.assertEqual( + antispam.validate_config({'invalid-rule': {}}), + {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} + ) + + def test_missing_keys_returns_error(self): + """Not configuring required keys returns an error.""" + keys = (('interval', 'max'), ('max', 'interval')) + for configured_key, unconfigured_key in keys: + with self.subTest( + configured_key=configured_key, + unconfigured_key=unconfigured_key + ): + config = {'burst': {configured_key: 10}} + error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" + + self.assertEqual( + antispam.validate_config(config), + {'burst': error} + ) diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py new file mode 100644 index 000000000..c0c3baa42 --- /dev/null +++ b/tests/bot/exts/filters/test_security.py @@ -0,0 +1,54 @@ +import unittest +from unittest.mock import MagicMock + +from discord.ext.commands import NoPrivateMessage + +from bot.exts.filters import security +from tests.helpers import MockBot, MockContext + + +class SecurityCogTests(unittest.TestCase): + """Tests the `Security` cog.""" + + def setUp(self): + """Attach an instance of the cog to the class for tests.""" + self.bot = MockBot() + self.cog = security.Security(self.bot) + self.ctx = MockContext() + + def test_check_additions(self): + """The cog should add its checks after initialization.""" + self.bot.check.assert_any_call(self.cog.check_on_guild) + self.bot.check.assert_any_call(self.cog.check_not_bot) + + def test_check_not_bot_returns_false_for_humans(self): + """The bot check should return `True` when invoked with human authors.""" + self.ctx.author.bot = False + self.assertTrue(self.cog.check_not_bot(self.ctx)) + + def test_check_not_bot_returns_true_for_robots(self): + """The bot check should return `False` when invoked with robotic authors.""" + self.ctx.author.bot = True + self.assertFalse(self.cog.check_not_bot(self.ctx)) + + def test_check_on_guild_raises_when_outside_of_guild(self): + """When invoked outside of a guild, `check_on_guild` should cause an error.""" + self.ctx.guild = None + + with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): + self.cog.check_on_guild(self.ctx) + + def test_check_on_guild_returns_true_inside_of_guild(self): + """When invoked inside of a guild, `check_on_guild` should return `True`.""" + self.ctx.guild = "lemon's lemonade stand" + self.assertTrue(self.cog.check_on_guild(self.ctx)) + + +class SecurityCogLoadTests(unittest.TestCase): + """Tests loading the `Security` cog.""" + + def test_security_cog_load(self): + """Setup of the extension should call add_cog.""" + bot = MagicMock() + security.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py new file mode 100644 index 000000000..a0ff8a877 --- /dev/null +++ b/tests/bot/exts/filters/test_token_remover.py @@ -0,0 +1,310 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock + +from discord import Colour, NotFound + +from bot import constants +from bot.exts.filters import token_remover +from bot.exts.filters.token_remover import Token, TokenRemover +from bot.exts.moderation.modlog import ModLog +from tests.helpers import MockBot, MockMessage, autospec + + +class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): + """Tests the `TokenRemover` cog.""" + + def setUp(self): + """Adds the cog, a bot, and a message to the instance for usage in tests.""" + self.bot = MockBot() + self.cog = TokenRemover(bot=self.bot) + + self.msg = MockMessage(id=555, content="hello world") + self.msg.channel.mention = "#lemonade-stand" + self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + self.msg.author.avatar_url_as.return_value = "picture-lemon.png" + + def test_is_valid_user_id_valid(self): + """Should consider user IDs valid if they decode entirely to ASCII digits.""" + ids = ( + "NDcyMjY1OTQzMDYyNDEzMzMy", + "NDc1MDczNjI5Mzk5NTQ3OTA0", + "NDY3MjIzMjMwNjUwNzc3NjQx", + ) + + for user_id in ids: + with self.subTest(user_id=user_id): + result = TokenRemover.is_valid_user_id(user_id) + self.assertTrue(result) + + def test_is_valid_user_id_invalid(self): + """Should consider non-digit and non-ASCII IDs invalid.""" + ids = ( + ("SGVsbG8gd29ybGQ", "non-digit ASCII"), + ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), + ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), + ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), + ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for user_id, msg in ids: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_user_id(user_id) + self.assertFalse(result) + + def test_is_valid_timestamp_valid(self): + """Should consider timestamps valid if they're greater than the Discord epoch.""" + timestamps = ( + "XsyRkw", + "Xrim9Q", + "XsyR-w", + "XsySD_", + "Dn9r_A", + ) + + for timestamp in timestamps: + with self.subTest(timestamp=timestamp): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertTrue(result) + + def test_is_valid_timestamp_invalid(self): + """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" + timestamps = ( + ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), + ("ew", "123"), + ("AoIKgA", "42076800"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for timestamp, msg in timestamps: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertFalse(result) + + def test_mod_log_property(self): + """The `mod_log` property should ask the bot to return the `ModLog` cog.""" + self.bot.get_cog.return_value = 'lemon' + self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) + self.bot.get_cog.assert_called_once_with('ModLog') + + async def test_on_message_edit_uses_on_message(self): + """The edit listener should delegate handling of the message to the normal listener.""" + self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) + + await self.cog.on_message_edit(MockMessage(), self.msg) + self.cog.on_message.assert_awaited_once_with(self.msg) + + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_takes_action(self, find_token_in_message, take_action): + """Should take action if a valid token is found when a message is sent.""" + cog = TokenRemover(self.bot) + found_token = "foobar" + find_token_in_message.return_value = found_token + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_awaited_once_with(cog, self.msg, found_token) + + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): + """Shouldn't take action if a valid token isn't found when a message is sent.""" + cog = TokenRemover(self.bot) + find_token_in_message.return_value = False + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_not_awaited() + + @autospec(TokenRemover, "find_token_in_message") + async def test_on_message_ignores_dms_bots(self, find_token_in_message): + """Shouldn't parse a message if it is a DM or authored by a bot.""" + cog = TokenRemover(self.bot) + dm_msg = MockMessage(guild=None) + bot_msg = MockMessage(author=MagicMock(bot=True)) + + for msg in (dm_msg, bot_msg): + await cog.on_message(msg) + find_token_in_message.assert_not_called() + + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") + def test_find_token_no_matches(self, token_re): + """None should be returned if the regex matches no tokens in a message.""" + token_re.finditer.return_value = () + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.exts.filters.token_remover", "Token") + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") + def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """The first match with a valid user ID and timestamp should be returned as a `Token`.""" + matches = [ + mock.create_autospec(Match, spec_set=True, instance=True), + mock.create_autospec(Match, spec_set=True, instance=True), + ] + tokens = [ + mock.create_autospec(Token, spec_set=True, instance=True), + mock.create_autospec(Token, spec_set=True, instance=True), + ] + + token_re.finditer.return_value = matches + token_cls.side_effect = tokens + is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid. + is_valid_timestamp.return_value = True + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertEqual(tokens[1], return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.exts.filters.token_remover", "Token") + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") + def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """None should be returned if no matches have valid user IDs or timestamps.""" + token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] + token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) + is_valid_id.return_value = False + is_valid_timestamp.return_value = False + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + def test_regex_invalid_tokens(self): + """Messages without anything looking like a token are not matched.""" + tokens = ( + "", + "lemon wins", + "..", + "x.y", + "x.y.", + ".y.z", + ".y.", + "..z", + "x..z", + " . . ", + "\n.\n.\n", + "hellö.world.bye", + "base64.nötbåse64.morebase64", + "19jd3J.dfkm3d.€víł§tüff", + ) + + for token in tokens: + with self.subTest(token=token): + results = token_remover.TOKEN_RE.findall(token) + self.assertEqual(len(results), 0) + + def test_regex_valid_tokens(self): + """Messages that look like tokens should be matched.""" + # Don't worry, these tokens have been invalidated. + tokens = ( + "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", + "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", + "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", + ) + + for token in tokens: + with self.subTest(token=token): + results = token_remover.TOKEN_RE.fullmatch(token) + self.assertIsNotNone(results, f"{token} was not matched by the regex") + + def test_regex_matches_multiple_valid(self): + """Should support multiple matches in the middle of a string.""" + token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" + token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" + message = f"garbage {token_1} hello {token_2} world" + + results = token_remover.TOKEN_RE.finditer(message) + results = [match[0] for match in results] + self.assertCountEqual((token_1, token_2), results) + + @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE") + def test_format_log_message(self, log_message): + """Should correctly format the log message with info from the message and token.""" + token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + log_message.format.return_value = "Howdy" + + return_value = TokenRemover.format_log_message(self.msg, token) + + self.assertEqual(return_value, log_message.format.return_value) + log_message.format.assert_called_once_with( + author=self.msg.author, + author_id=self.msg.author.id, + channel=self.msg.channel.mention, + user_id=token.user_id, + timestamp=token.timestamp, + hmac="x" * len(token.hmac), + ) + + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + @autospec("bot.exts.filters.token_remover", "log") + @autospec(TokenRemover, "format_log_message") + async def test_take_action(self, format_log_message, logger, mod_log_property): + """Should delete the message and send a mod log.""" + cog = TokenRemover(self.bot) + mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) + token = mock.create_autospec(Token, spec_set=True, instance=True) + log_msg = "testing123" + + mod_log_property.return_value = mod_log + format_log_message.return_value = log_msg + + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_called_once_with( + token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + ) + + format_log_message.assert_called_once_with(self.msg, token) + logger.debug.assert_called_with(log_msg) + self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") + + mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) + mod_log.send_log_message.assert_called_once_with( + icon_url=constants.Icons.token_removed, + colour=Colour(constants.Colours.soft_red), + title="Token removed!", + text=log_msg, + thumbnail=self.msg.author.avatar_url_as.return_value, + channel_id=constants.Channels.mod_alerts + ) + + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + async def test_take_action_delete_failure(self, mod_log_property): + """Shouldn't send any messages if the token message can't be deleted.""" + cog = TokenRemover(self.bot) + mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) + self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) + + token = mock.create_autospec(Token, spec_set=True, instance=True) + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_not_awaited() + + +class TokenRemoverExtensionTests(unittest.TestCase): + """Tests for the token_remover extension.""" + + @autospec("bot.exts.filters.token_remover", "TokenRemover") + def test_extension_setup(self, cog): + """The TokenRemover cog should be added.""" + bot = MockBot() + token_remover.setup(bot) + + cog.assert_called_once_with(bot) + bot.add_cog.assert_called_once() + self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/info/__init__.py b/tests/bot/exts/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py new file mode 100644 index 000000000..be47d42ef --- /dev/null +++ b/tests/bot/exts/info/test_information.py @@ -0,0 +1,584 @@ +import asyncio +import textwrap +import unittest +import unittest.mock + +import discord + +from bot import constants +from bot.exts.info import information +from bot.utils.checks import InWhitelistCheckFailure +from tests import helpers + +COG_PATH = "bot.exts.info.information.Information" + + +class InformationCogTests(unittest.TestCase): + """Tests the Information cog.""" + + @classmethod + def setUpClass(cls): + cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = helpers.MockBot() + + self.cog = information.Information(self.bot) + + self.ctx = helpers.MockContext() + self.ctx.author.roles.append(self.moderator_role) + + def test_roles_command_command(self): + """Test if the `role_info` command correctly returns the `moderator_role`.""" + self.ctx.guild.roles.append(self.moderator_role) + + self.cog.roles_info.can_run = unittest.mock.AsyncMock() + self.cog.roles_info.can_run.return_value = True + + coroutine = self.cog.roles_info.callback(self.cog, self.ctx) + + self.assertIsNone(asyncio.run(coroutine)) + self.ctx.send.assert_called_once() + + _, kwargs = self.ctx.send.call_args + embed = kwargs.pop('embed') + + self.assertEqual(embed.title, "Role information (Total 1 role)") + self.assertEqual(embed.colour, discord.Colour.blurple()) + self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") + + def test_role_info_command(self): + """Tests the `role info` command.""" + dummy_role = helpers.MockRole( + name="Dummy", + id=112233445566778899, + colour=discord.Colour.blurple(), + position=10, + members=[self.ctx.author], + permissions=discord.Permissions(0) + ) + + admin_role = helpers.MockRole( + name="Admins", + id=998877665544332211, + colour=discord.Colour.red(), + position=3, + members=[self.ctx.author], + permissions=discord.Permissions(0), + ) + + self.ctx.guild.roles.append([dummy_role, admin_role]) + + self.cog.role_info.can_run = unittest.mock.AsyncMock() + self.cog.role_info.can_run.return_value = True + + coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) + + self.assertIsNone(asyncio.run(coroutine)) + + self.assertEqual(self.ctx.send.call_count, 2) + + (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list + + dummy_embed = dummy_kwargs["embed"] + admin_embed = admin_kwargs["embed"] + + self.assertEqual(dummy_embed.title, "Dummy info") + self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) + + self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) + self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") + self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") + self.assertEqual(dummy_embed.fields[3].value, "1") + self.assertEqual(dummy_embed.fields[4].value, "10") + self.assertEqual(dummy_embed.fields[5].value, "0") + + self.assertEqual(admin_embed.title, "Admins info") + self.assertEqual(admin_embed.colour, discord.Colour.red()) + + @unittest.mock.patch('bot.exts.info.information.time_since') + def test_server_info_command(self, time_since_patch): + time_since_patch.return_value = '2 days ago' + + self.ctx.guild = helpers.MockGuild( + features=('lemons', 'apples'), + region="The Moon", + roles=[self.moderator_role], + channels=[ + discord.TextChannel( + state={}, + guild=self.ctx.guild, + data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} + ), + discord.CategoryChannel( + state={}, + guild=self.ctx.guild, + data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} + ), + discord.VoiceChannel( + state={}, + guild=self.ctx.guild, + data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} + ) + ], + members=[ + *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), + *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), + *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), + *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), + ], + member_count=1_234, + icon_url='a-lemon.jpg', + ) + + coroutine = self.cog.server_info.callback(self.cog, self.ctx) + self.assertIsNone(asyncio.run(coroutine)) + + time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') + _, kwargs = self.ctx.send.call_args + embed = kwargs.pop('embed') + self.assertEqual(embed.colour, discord.Colour.blurple()) + self.assertEqual( + embed.description, + textwrap.dedent( + f""" + **Server information** + Created: {time_since_patch.return_value} + Voice region: {self.ctx.guild.region} + Features: {', '.join(self.ctx.guild.features)} + + **Channel counts** + Category channels: 1 + Text channels: 1 + Voice channels: 1 + Staff channels: 0 + + **Member counts** + Members: {self.ctx.guild.member_count:,} + Staff members: 0 + Roles: {len(self.ctx.guild.roles)} + + **Member statuses** + {constants.Emojis.status_online} 2 + {constants.Emojis.status_idle} 1 + {constants.Emojis.status_dnd} 4 + {constants.Emojis.status_offline} 3 + """ + ) + ) + self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') + + +class UserInfractionHelperMethodTests(unittest.TestCase): + """Tests for the helper methods of the `!user` command.""" + + def setUp(self): + """Common set-up steps done before for each test.""" + self.bot = helpers.MockBot() + self.bot.api_client.get = unittest.mock.AsyncMock() + self.cog = information.Information(self.bot) + self.member = helpers.MockMember(id=1234) + + def test_user_command_helper_method_get_requests(self): + """The helper methods should form the correct get requests.""" + test_values = ( + { + "helper_method": self.cog.basic_user_infraction_counts, + "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}), + }, + { + "helper_method": self.cog.expanded_user_infraction_counts, + "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}), + }, + { + "helper_method": self.cog.user_nomination_counts, + "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}), + }, + ) + + for test_value in test_values: + helper_method = test_value["helper_method"] + endpoint, params = test_value["expected_args"] + + with self.subTest(method=helper_method, endpoint=endpoint, params=params): + asyncio.run(helper_method(self.member)) + self.bot.api_client.get.assert_called_once_with(endpoint, params=params) + self.bot.api_client.get.reset_mock() + + def _method_subtests(self, method, test_values, default_header): + """Helper method that runs the subtests for the different helper methods.""" + for test_value in test_values: + api_response = test_value["api response"] + expected_lines = test_value["expected_lines"] + + with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): + self.bot.api_client.get.return_value = api_response + + expected_output = "\n".join(default_header + expected_lines) + actual_output = asyncio.run(method(self.member)) + + self.assertEqual(expected_output, actual_output) + + def test_basic_user_infraction_counts_returns_correct_strings(self): + """The method should correctly list both the total and active number of non-hidden infractions.""" + test_values = ( + # No infractions means zero counts + { + "api response": [], + "expected_lines": ["Total: 0", "Active: 0"], + }, + # Simple, single-infraction dictionaries + { + "api response": [{"type": "ban", "active": True}], + "expected_lines": ["Total: 1", "Active: 1"], + }, + { + "api response": [{"type": "ban", "active": False}], + "expected_lines": ["Total: 1", "Active: 0"], + }, + # Multiple infractions with various `active` status + { + "api response": [ + {"type": "ban", "active": True}, + {"type": "kick", "active": False}, + {"type": "ban", "active": True}, + {"type": "ban", "active": False}, + ], + "expected_lines": ["Total: 4", "Active: 2"], + }, + ) + + header = ["**Infractions**"] + + self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) + + def test_expanded_user_infraction_counts_returns_correct_strings(self): + """The method should correctly list the total and active number of all infractions split by infraction type.""" + test_values = ( + { + "api response": [], + "expected_lines": ["This user has never received an infraction."], + }, + # Shows non-hidden inactive infraction as expected + { + "api response": [{"type": "kick", "active": False, "hidden": False}], + "expected_lines": ["Kicks: 1"], + }, + # Shows non-hidden active infraction as expected + { + "api response": [{"type": "mute", "active": True, "hidden": False}], + "expected_lines": ["Mutes: 1 (1 active)"], + }, + # Shows hidden inactive infraction as expected + { + "api response": [{"type": "superstar", "active": False, "hidden": True}], + "expected_lines": ["Superstars: 1"], + }, + # Shows hidden active infraction as expected + { + "api response": [{"type": "ban", "active": True, "hidden": True}], + "expected_lines": ["Bans: 1 (1 active)"], + }, + # Correctly displays tally of multiple infractions of mixed properties in alphabetical order + { + "api response": [ + {"type": "kick", "active": False, "hidden": True}, + {"type": "ban", "active": True, "hidden": True}, + {"type": "superstar", "active": True, "hidden": True}, + {"type": "mute", "active": True, "hidden": True}, + {"type": "ban", "active": False, "hidden": False}, + {"type": "note", "active": False, "hidden": True}, + {"type": "note", "active": False, "hidden": True}, + {"type": "warn", "active": False, "hidden": False}, + {"type": "note", "active": False, "hidden": True}, + ], + "expected_lines": [ + "Bans: 2 (1 active)", + "Kicks: 1", + "Mutes: 1 (1 active)", + "Notes: 3", + "Superstars: 1 (1 active)", + "Warns: 1", + ], + }, + ) + + header = ["**Infractions**"] + + self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) + + def test_user_nomination_counts_returns_correct_strings(self): + """The method should list the number of active and historical nominations for the user.""" + test_values = ( + { + "api response": [], + "expected_lines": ["This user has never been nominated."], + }, + { + "api response": [{'active': True}], + "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], + }, + { + "api response": [{'active': True}, {'active': False}], + "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], + }, + { + "api response": [{'active': False}], + "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."], + }, + { + "api response": [{'active': False}, {'active': False}], + "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."], + }, + + ) + + header = ["**Nominations**"] + + self._method_subtests(self.cog.user_nomination_counts, test_values, header) + + +@unittest.mock.patch("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) +@unittest.mock.patch("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50]) +class UserEmbedTests(unittest.TestCase): + """Tests for the creation of the `!user` embed.""" + + def setUp(self): + """Common set-up steps done before for each test.""" + self.bot = helpers.MockBot() + self.bot.api_client.get = unittest.mock.AsyncMock() + self.cog = information.Information(self.bot) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): + """The embed should use the string representation of the user if they don't have a nick.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + user = helpers.MockMember() + user.nick = None + user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.title, "Mr. Hemlock") + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_nick_in_title_if_available(self): + """The embed should use the nick if it's available.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + user = helpers.MockMember() + user.nick = "Cat lover" + user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_ignores_everyone_role(self): + """Created `!user` embeds should not contain mention of the @everyone-role.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + admins_role = helpers.MockRole(name='Admins') + admins_role.colour = 100 + + # A `MockMember` has the @Everyone role by default; we add the Admins to that. + user = helpers.MockMember(roles=[admins_role], top_role=admins_role) + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertIn("&Admins", embed.description) + self.assertNotIn("&Everyone", embed.description) + + @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) + def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): + """The embed should contain expanded infractions and nomination info in mod channels.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + infraction_counts.return_value = "expanded infractions info" + nomination_counts.return_value = "nomination info" + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + infraction_counts.assert_called_once_with(user) + nomination_counts.assert_called_once_with(user) + + self.assertEqual( + textwrap.dedent(f""" + **User Information** + Created: {"1 year ago"} + Profile: {user.mention} + ID: {user.id} + + **Member Information** + Joined: {"1 year ago"} + Roles: &Moderators + + expanded infractions info + + nomination info + """).strip(), + embed.description + ) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): + """The embed should contain only basic infraction data outside of mod channels.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + infraction_counts.return_value = "basic infractions info" + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + infraction_counts.assert_called_once_with(user) + + self.assertEqual( + textwrap.dedent(f""" + **User Information** + Created: {"1 year ago"} + Profile: {user.mention} + ID: {user.id} + + **Member Information** + Joined: {"1 year ago"} + Roles: &Moderators + + basic infractions info + """).strip(), + embed.description + ) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): + """The embed should be created with the colour of the top role, if a top role is available.""" + ctx = helpers.MockContext() + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): + """The embed should be created with a blurple colour if the user has no assigned roles.""" + ctx = helpers.MockContext() + + user = helpers.MockMember(id=217) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.colour, discord.Colour.blurple()) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): + """The embed thumbnail should be set to the user's avatar in `png` format.""" + ctx = helpers.MockContext() + + user = helpers.MockMember(id=217) + user.avatar_url_as.return_value = "avatar url" + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + user.avatar_url_as.assert_called_once_with(static_format="png") + self.assertEqual(embed.thumbnail.url, "avatar url") + + +@unittest.mock.patch("bot.exts.info.information.constants") +class UserCommandTests(unittest.TestCase): + """Tests for the `!user` command.""" + + def setUp(self): + """Set up steps executed before each test is run.""" + self.bot = helpers.MockBot() + self.cog = information.Information(self.bot) + + self.moderator_role = helpers.MockRole(name="Moderators", id=2, position=10) + self.flautist_role = helpers.MockRole(name="Flautists", id=3, position=2) + self.bassist_role = helpers.MockRole(name="Bassists", id=4, position=3) + + self.author = helpers.MockMember(id=1, name="syntaxaire") + self.moderator = helpers.MockMember(id=2, name="riffautae", roles=[self.moderator_role]) + self.target = helpers.MockMember(id=3, name="__fluzz__") + + def test_regular_member_cannot_target_another_member(self, constants): + """A regular user should not be able to use `!user` targeting another user.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + + ctx = helpers.MockContext(author=self.author) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + + ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") + + def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): + """A regular user should not be able to use this command outside of bot-commands.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) + + msg = "Sorry, but you may only use this command within <#50>." + with self.assertRaises(InWhitelistCheckFailure, msg=msg): + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") + def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): + """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + create_embed.assert_called_once_with(ctx, self.author) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") + def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): + """A user should target itself with `!user` when a `user` argument was not provided.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) + + create_embed.assert_called_once_with(ctx, self.author) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") + def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): + """Staff members should be able to bypass the bot-commands channel restriction.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + create_embed.assert_called_once_with(ctx, self.moderator) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") + def test_moderators_can_target_another_member(self, create_embed, constants): + """A moderator should be able to use `!user` targeting another user.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + constants.STAFF_ROLES = [self.moderator_role.id] + + ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + + create_embed.assert_called_once_with(ctx, self.target) + ctx.send.assert_called_once() diff --git a/tests/bot/exts/moderation/__init__.py b/tests/bot/exts/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/moderation/infraction/__init__.py b/tests/bot/exts/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py new file mode 100644 index 000000000..be1b649e1 --- /dev/null +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.exts.moderation.infraction.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): + """Tests for ban and kick command reason truncation.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Infractions(self.bot) + self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) + self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) + self.guild = MockGuild(id=4567) + self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + + @patch("bot.exts.moderation.infraction._utils.get_active_infraction") + @patch("bot.exts.moderation.infraction._utils.post_infraction") + async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): + """Should truncate reason for `ctx.guild.ban`.""" + get_active_mock.return_value = None + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.bot.get_cog.return_value = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.ctx.guild.ban = Mock() + + await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) + self.ctx.guild.ban.assert_called_once_with( + self.target, + reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), + delete_message_days=0 + ) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value + ) + + @patch("bot.exts.moderation.infraction._utils.post_infraction") + async def test_apply_kick_reason_truncation(self, post_infraction_mock): + """Should truncate reason for `Member.kick`.""" + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.target.kick = Mock() + + await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) + self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value + ) diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py new file mode 100644 index 000000000..cbf7f7bcf --- /dev/null +++ b/tests/bot/exts/moderation/test_incidents.py @@ -0,0 +1,770 @@ +import asyncio +import enum +import logging +import typing as t +import unittest +from unittest.mock import AsyncMock, MagicMock, call, patch + +import aiohttp +import discord + +from bot.constants import Colours +from bot.exts.moderation import incidents +from tests.helpers import ( + MockAsyncWebhook, + MockAttachment, + MockBot, + MockMember, + MockMessage, + MockReaction, + MockRole, + MockTextChannel, + MockUser, +) + + +class MockAsyncIterable: + """ + Helper for mocking asynchronous for loops. + + It does not appear that the `unittest` library currently provides anything that would + allow us to simply mock an async iterator, such as `discord.TextChannel.history`. + + We therefore write our own helper to wrap a regular synchronous iterable, and feed + its values via `__anext__` rather than `__next__`. + + This class was written for the purposes of testing the `Incidents` cog - it may not + be generic enough to be placed in the `tests.helpers` module. + """ + + def __init__(self, messages: t.Iterable): + """Take a sync iterable to be wrapped.""" + self.iter_messages = iter(messages) + + def __aiter__(self): + """Return `self` as we provide the `__anext__` method.""" + return self + + async def __anext__(self): + """ + Feed the next item, or raise `StopAsyncIteration`. + + Since we're wrapping a sync iterator, it will communicate that it has been depleted + by raising a `StopIteration`. The `async for` construct does not expect it, and we + therefore need to substitute it for the appropriate exception type. + """ + try: + return next(self.iter_messages) + except StopIteration: + raise StopAsyncIteration + + +class MockSignal(enum.Enum): + A = "A" + B = "B" + + +mock_404 = discord.NotFound( + response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response + message="Not found", +) + + +class TestDownloadFile(unittest.IsolatedAsyncioTestCase): + """Collection of tests for the `download_file` helper function.""" + + async def test_download_file_success(self): + """If `to_file` succeeds, function returns the acquired `discord.File`.""" + file = MagicMock(discord.File, filename="bigbadlemon.jpg") + attachment = MockAttachment(to_file=AsyncMock(return_value=file)) + + acquired_file = await incidents.download_file(attachment) + self.assertIs(file, acquired_file) + + async def test_download_file_404(self): + """If `to_file` encounters a 404, function handles the exception & returns None.""" + attachment = MockAttachment(to_file=AsyncMock(side_effect=mock_404)) + + acquired_file = await incidents.download_file(attachment) + self.assertIsNone(acquired_file) + + async def test_download_file_fail(self): + """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" + arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") + attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) + + with self.assertLogs(logger=incidents.log, level=logging.ERROR): + acquired_file = await incidents.download_file(attachment) + + self.assertIsNone(acquired_file) + + +class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): + """Collection of tests for the `make_embed` helper function.""" + + async def test_make_embed_actioned(self): + """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" + embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + + self.assertEqual(embed.colour.value, Colours.soft_green) + self.assertIn("Actioned", embed.footer.text) + + async def test_make_embed_not_actioned(self): + """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" + embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + + self.assertEqual(embed.colour.value, Colours.soft_red) + self.assertIn("Rejected", embed.footer.text) + + async def test_make_embed_content(self): + """Incident content appears as embed description.""" + incident = MockMessage(content="this is an incident") + embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertEqual(incident.content, embed.description) + + async def test_make_embed_with_attachment_succeeds(self): + """Incident's attachment is downloaded and displayed in the embed's image field.""" + file = MagicMock(discord.File, filename="bigbadjoe.jpg") + attachment = MockAttachment(filename="bigbadjoe.jpg") + incident = MockMessage(content="this is an incident", attachments=[attachment]) + + # Patch `download_file` to return our `file` + with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): + embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertIs(file, returned_file) + self.assertEqual("attachment://bigbadjoe.jpg", embed.image.url) + + async def test_make_embed_with_attachment_fails(self): + """Incident's attachment fails to download, proxy url is linked instead.""" + attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") + incident = MockMessage(content="this is an incident", attachments=[attachment]) + + # Patch `download_file` to return None as if the download failed + with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): + embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertIsNone(returned_file) + + # The author name field is simply expected to have something in it, we do not assert the message + self.assertGreater(len(embed.author.name), 0) + self.assertEqual(embed.author.url, "discord.com/bigbadjoe.jpg") # However, it should link the exact url + + +@patch("bot.constants.Channels.incidents", 123) +class TestIsIncident(unittest.TestCase): + """ + Collection of tests for the `is_incident` helper function. + + In `setUp`, we will create a mock message which should qualify as an incident. Each + test case will then mutate this instance to make it **not** qualify, in various ways. + + Notice that we patch the #incidents channel id globally for this class. + """ + + def setUp(self) -> None: + """Prepare a mock message which should qualify as an incident.""" + self.incident = MockMessage( + channel=MockTextChannel(id=123), + content="this is an incident", + author=MockUser(bot=False), + pinned=False, + ) + + def test_is_incident_true(self): + """Message qualifies as an incident if unchanged.""" + self.assertTrue(incidents.is_incident(self.incident)) + + def check_false(self): + """Assert that `self.incident` does **not** qualify as an incident.""" + self.assertFalse(incidents.is_incident(self.incident)) + + def test_is_incident_false_channel(self): + """Message doesn't qualify if sent outside of #incidents.""" + self.incident.channel = MockTextChannel(id=456) + self.check_false() + + def test_is_incident_false_content(self): + """Message doesn't qualify if content begins with hash symbol.""" + self.incident.content = "# this is a comment message" + self.check_false() + + def test_is_incident_false_author(self): + """Message doesn't qualify if author is a bot.""" + self.incident.author = MockUser(bot=True) + self.check_false() + + def test_is_incident_false_pinned(self): + """Message doesn't qualify if it is pinned.""" + self.incident.pinned = True + self.check_false() + + +class TestOwnReactions(unittest.TestCase): + """Assertions for the `own_reactions` function.""" + + def test_own_reactions(self): + """Only bot's own emoji are extracted from the input incident.""" + reactions = ( + MockReaction(emoji="A", me=True), + MockReaction(emoji="B", me=True), + MockReaction(emoji="C", me=False), + ) + message = MockMessage(reactions=reactions) + self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) + + +@patch("bot.exts.moderation.incidents.ALL_SIGNALS", {"A", "B"}) +class TestHasSignals(unittest.TestCase): + """ + Assertions for the `has_signals` function. + + We patch `ALL_SIGNALS` globally. Each test function then patches `own_reactions` + as appropriate. + """ + + def test_has_signals_true(self): + """True when `own_reactions` returns all emoji in `ALL_SIGNALS`.""" + message = MockMessage() + own_reactions = MagicMock(return_value={"A", "B"}) + + with patch("bot.exts.moderation.incidents.own_reactions", own_reactions): + self.assertTrue(incidents.has_signals(message)) + + def test_has_signals_false(self): + """False when `own_reactions` does not return all emoji in `ALL_SIGNALS`.""" + message = MockMessage() + own_reactions = MagicMock(return_value={"A", "C"}) + + with patch("bot.exts.moderation.incidents.own_reactions", own_reactions): + self.assertFalse(incidents.has_signals(message)) + + +@patch("bot.exts.moderation.incidents.Signal", MockSignal) +class TestAddSignals(unittest.IsolatedAsyncioTestCase): + """ + Assertions for the `add_signals` coroutine. + + These are all fairly similar and could go into a single test function, but I found the + patching & sub-testing fairly awkward in that case and decided to split them up + to avoid unnecessary syntax noise. + """ + + def setUp(self): + """Prepare a mock incident message for tests to use.""" + self.incident = MockMessage() + + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value=set())) + async def test_add_signals_missing(self): + """All emoji are added when none are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_has_calls([call("A"), call("B")]) + + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value={"A"})) + async def test_add_signals_partial(self): + """Only missing emoji are added when some are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_has_calls([call("B")]) + + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"})) + async def test_add_signals_present(self): + """No emoji are added when all are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_not_called() + + +class TestIncidents(unittest.IsolatedAsyncioTestCase): + """ + Tests for bound methods of the `Incidents` cog. + + Use this as a base class for `Incidents` tests - it will prepare a fresh instance + for each test function, but not make any assertions on its own. Tests can mutate + the instance as they wish. + """ + + def setUp(self): + """ + Prepare a fresh `Incidents` instance for each test. + + Note that this will not schedule `crawl_incidents` in the background, as everything + is being mocked. The `crawl_task` attribute will end up being None. + """ + self.cog_instance = incidents.Incidents(MockBot()) + + +@patch("asyncio.sleep", AsyncMock()) # Prevent the coro from sleeping to speed up the test +class TestCrawlIncidents(TestIncidents): + """ + Tests for the `Incidents.crawl_incidents` coroutine. + + Apart from `test_crawl_incidents_waits_until_cache_ready`, all tests in this class + will patch the return values of `is_incident` and `has_signal` and then observe + whether the `AsyncMock` for `add_signals` was awaited or not. + + The `add_signals` mock is added by each test separately to ensure it is clean (has not + been awaited by another test yet). The mock can be reset, but this appears to be the + cleaner way. + + For each test, we inject a mock channel with a history of 1 message only (see: `setUp`). + """ + + def setUp(self): + """For each test, ensure `bot.get_channel` returns a channel with 1 arbitrary message.""" + super().setUp() # First ensure we get `cog_instance` from parent + + incidents_history = MagicMock(return_value=MockAsyncIterable([MockMessage()])) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(history=incidents_history)) + + async def test_crawl_incidents_waits_until_cache_ready(self): + """ + The coroutine will await the `wait_until_guild_available` event. + + Since this task is schedule in the `__init__`, it is critical that it waits for the + cache to be ready, so that it can safely get the #incidents channel. + """ + await self.cog_instance.crawl_incidents() + self.cog_instance.bot.wait_until_guild_available.assert_awaited() + + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)) # Message doesn't qualify + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=False)) + async def test_crawl_incidents_noop_if_is_not_incident(self): + """Signals are not added for a non-incident message.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_not_awaited() + + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=True)) # But already has signals + async def test_crawl_incidents_noop_if_message_already_has_signals(self): + """Signals are not added for messages which already have them.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_not_awaited() + + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=False)) # And doesn't have signals + async def test_crawl_incidents_add_signals_called(self): + """Message has signals added as it does not have them yet and qualifies as an incident.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_awaited_once() + + +class TestArchive(TestIncidents): + """Tests for the `Incidents.archive` coroutine.""" + + async def test_archive_webhook_not_found(self): + """ + Method recovers and returns False when the webhook is not found. + + Implicitly, this also tests that the error is handled internally and doesn't + propagate out of the method, which is just as important. + """ + self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) + self.assertFalse( + await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) + ) + + async def test_archive_relays_incident(self): + """ + If webhook is found, method relays `incident` properly. + + This test will assert that the fetched webhook's `send` method is fed the correct arguments, + and that the `archive` method returns True. + """ + webhook = MockAsyncWebhook() + self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) # Patch in our webhook + + # Define our own `incident` to be archived + incident = MockMessage( + content="this is an incident", + author=MockUser(name="author_name", avatar_url="author_avatar"), + id=123, + ) + built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this + + with patch("bot.exts.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): + archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) + + # Now we check that the webhook was given the correct args, and that `archive` returned True + webhook.send.assert_called_once_with( + embed=built_embed, + username="author_name", + avatar_url="author_avatar", + file=None, + ) + self.assertTrue(archive_return) + + async def test_archive_clyde_username(self): + """ + The archive webhook username is cleansed using `sub_clyde`. + + Discord will reject any webhook with "clyde" in the username field, as it impersonates + the official Clyde bot. Since we do not control what the username will be (the incident + author name is used), we must ensure the name is cleansed, otherwise the relay may fail. + + This test assumes the username is passed as a kwarg. If this test fails, please review + whether the passed argument is being retrieved correctly. + """ + webhook = MockAsyncWebhook() + self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) + + message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) + await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) + + self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) + + +class TestMakeConfirmationTask(TestIncidents): + """ + Tests for the `Incidents.make_confirmation_task` method. + + Writing tests for this method is difficult, as it mostly just delegates the provided + information elsewhere. There is very little internal logic. Whether our approach + works conceptually is difficult to prove using unit tests. + """ + + def test_make_confirmation_task_check(self): + """ + The internal check will recognize the passed incident. + + This is a little tricky - we first pass a message with a specific `id` in, and then + retrieve the built check from the `call_args` of the `wait_for` method. This relies + on the check being passed as a kwarg. + + Once the check is retrieved, we assert that it gives True for our incident's `id`, + and False for any other. + + If this function begins to fail, first check that `created_check` is being retrieved + correctly. It should be the function that is built locally in the tested method. + """ + self.cog_instance.make_confirmation_task(MockMessage(id=123)) + + self.cog_instance.bot.wait_for.assert_called_once() + created_check = self.cog_instance.bot.wait_for.call_args.kwargs["check"] + + # The `message_id` matches the `id` of our incident + self.assertTrue(created_check(payload=MagicMock(message_id=123))) + + # This `message_id` does not match + self.assertFalse(created_check(payload=MagicMock(message_id=0))) + + +@patch("bot.exts.moderation.incidents.ALLOWED_ROLES", {1, 2}) +@patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", AsyncMock()) # Generic awaitable +class TestProcessEvent(TestIncidents): + """Tests for the `Incidents.process_event` coroutine.""" + + async def test_process_event_bad_role(self): + """The reaction is removed when the author lacks all allowed roles.""" + incident = MockMessage() + member = MockMember(roles=[MockRole(id=0)]) # Must have role 1 or 2 + + await self.cog_instance.process_event("reaction", incident, member) + incident.remove_reaction.assert_called_once_with("reaction", member) + + async def test_process_event_bad_emoji(self): + """ + The reaction is removed when an invalid emoji is used. + + This requires that we pass in a `member` with valid roles, as we need the role check + to succeed. + """ + incident = MockMessage() + member = MockMember(roles=[MockRole(id=1)]) # Member has allowed role + + await self.cog_instance.process_event("invalid_signal", incident, member) + incident.remove_reaction.assert_called_once_with("invalid_signal", member) + + async def test_process_event_no_archive_on_investigating(self): + """Message is not archived on `Signal.INVESTIGATING`.""" + with patch("bot.exts.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive: + await self.cog_instance.process_event( + reaction=incidents.Signal.INVESTIGATING.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]), + ) + + mocked_archive.assert_not_called() + + async def test_process_event_no_delete_if_archive_fails(self): + """ + Original message is not deleted when `Incidents.archive` returns False. + + This is the way of signaling that the relay failed, and we should not remove the original, + as that would result in losing the incident record. + """ + incident = MockMessage() + + with patch("bot.exts.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=incident, + member=MockMember(roles=[MockRole(id=1)]) + ) + + incident.delete.assert_not_called() + + async def test_process_event_confirmation_task_is_awaited(self): + """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" + mock_task = AsyncMock() + + with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]) + ) + + mock_task.assert_awaited() + + async def test_process_event_confirmation_task_timeout_is_handled(self): + """ + Confirmation task `asyncio.TimeoutError` is handled gracefully. + + We have `make_confirmation_task` return a mock with a side effect, and then catch the + exception should it propagate out of `process_event`. This is so that we can then manually + fail the test with a more informative message than just the plain traceback. + """ + mock_task = AsyncMock(side_effect=asyncio.TimeoutError()) + + try: + with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]) + ) + except asyncio.TimeoutError: + self.fail("TimeoutError was not handled gracefully, and propagated out of `process_event`!") + + +class TestResolveMessage(TestIncidents): + """Tests for the `Incidents.resolve_message` coroutine.""" + + async def test_resolve_message_pass_message_id(self): + """Method will call `_get_message` with the passed `message_id`.""" + await self.cog_instance.resolve_message(123) + self.cog_instance.bot._connection._get_message.assert_called_once_with(123) + + async def test_resolve_message_in_cache(self): + """ + No API call is made if the queried message exists in the cache. + + We mock the `_get_message` return value regardless of input. Whether it finds the message + internally is considered d.py's responsibility, not ours. + """ + cached_message = MockMessage(id=123) + self.cog_instance.bot._connection._get_message = MagicMock(return_value=cached_message) + + return_value = await self.cog_instance.resolve_message(123) + + self.assertIs(return_value, cached_message) + self.cog_instance.bot.get_channel.assert_not_called() # The `fetch_message` line was never hit + + async def test_resolve_message_not_in_cache(self): + """ + The message is retrieved from the API if it isn't cached. + + This is desired behaviour for messages which exist, but were sent before the bot's + current session. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + # API returns our message + uncached_message = MockMessage() + fetch_message = AsyncMock(return_value=uncached_message) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + retrieved_message = await self.cog_instance.resolve_message(123) + self.assertIs(retrieved_message, uncached_message) + + async def test_resolve_message_doesnt_exist(self): + """ + If the API returns a 404, the function handles it gracefully and returns None. + + This is an edge-case happening with racing events - event A will relay the message + to the archive and delete the original. Once event B acquires the `event_lock`, + it will not find the message in the cache, and will ask the API. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + fetch_message = AsyncMock(side_effect=mock_404) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + self.assertIsNone(await self.cog_instance.resolve_message(123)) + + async def test_resolve_message_fetch_fails(self): + """ + Non-404 errors are handled, logged & None is returned. + + In contrast with a 404, this should make an error-level log. We assert that at least + one such log was made - we do not make any assertions about the log's message. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + arbitrary_error = discord.HTTPException( + response=MagicMock(aiohttp.ClientResponse), + message="Arbitrary error", + ) + fetch_message = AsyncMock(side_effect=arbitrary_error) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + with self.assertLogs(logger=incidents.log, level=logging.ERROR): + self.assertIsNone(await self.cog_instance.resolve_message(123)) + + +@patch("bot.constants.Channels.incidents", 123) +class TestOnRawReactionAdd(TestIncidents): + """ + Tests for the `Incidents.on_raw_reaction_add` listener. + + Writing tests for this listener comes with additional complexity due to the listener + awaiting the `crawl_task` task. See `asyncSetUp` for further details, which attempts + to make unit testing this function possible. + """ + + def setUp(self): + """ + Prepare & assign `payload` attribute. + + This attribute represents an *ideal* payload which will not be rejected by the + listener. As each test will receive a fresh instance, it can be mutated to + observe how the listener's behaviour changes with different attributes on + the passed payload. + """ + super().setUp() # Ensure `cog_instance` is assigned + + self.payload = MagicMock( + discord.RawReactionActionEvent, + channel_id=123, # Patched at class level + message_id=456, + member=MockMember(bot=False), + emoji="reaction", + ) + + async def asyncSetUp(self): # noqa: N802 + """ + Prepare an empty task and assign it as `crawl_task`. + + It appears that the `unittest` framework does not provide anything for mocking + asyncio tasks. An `AsyncMock` instance can be called and then awaited, however, + it does not provide the `done` method or any other parts of the `asyncio.Task` + interface. + + Although we do not need to make any assertions about the task itself while + testing the listener, the code will still await it and call the `done` method, + and so we must inject something that will not fail on either action. + + Note that this is done in an `asyncSetUp`, which runs after `setUp`. + The justification is that creating an actual task requires the event + loop to be ready, which is not the case in the `setUp`. + """ + mock_task = asyncio.create_task(AsyncMock()()) # Mock async func, then a coro + self.cog_instance.crawl_task = mock_task + + async def test_on_raw_reaction_add_wrong_channel(self): + """ + Events outside of #incidents will be ignored. + + We check this by asserting that `resolve_message` was never queried. + """ + self.payload.channel_id = 0 + self.cog_instance.resolve_message = AsyncMock() + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.resolve_message.assert_not_called() + + async def test_on_raw_reaction_add_user_is_bot(self): + """ + Events dispatched by bot accounts will be ignored. + + We check this by asserting that `resolve_message` was never queried. + """ + self.payload.member = MockMember(bot=True) + self.cog_instance.resolve_message = AsyncMock() + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.resolve_message.assert_not_called() + + async def test_on_raw_reaction_add_message_doesnt_exist(self): + """ + Listener gracefully handles the case where `resolve_message` gives None. + + We check this by asserting that `process_event` was never called. + """ + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=None) + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.process_event.assert_not_called() + + async def test_on_raw_reaction_add_message_is_not_an_incident(self): + """ + The event won't be processed if the related message is not an incident. + + This is an edge-case that can happen if someone manually leaves a reaction + on a pinned message, or a comment. + + We check this by asserting that `process_event` was never called. + """ + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=MockMessage()) + + with patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)): + await self.cog_instance.on_raw_reaction_add(self.payload) + + self.cog_instance.process_event.assert_not_called() + + async def test_on_raw_reaction_add_valid_event_is_processed(self): + """ + If the reaction event is valid, it is passed to `process_event`. + + This is the case when everything goes right: + * The reaction was placed in #incidents, and not by a bot + * The message was found successfully + * The message qualifies as an incident + + Additionally, we check that all arguments were passed as expected. + """ + incident = MockMessage(id=1) + + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=incident) + + with patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)): + await self.cog_instance.on_raw_reaction_add(self.payload) + + self.cog_instance.process_event.assert_called_with( + "reaction", # Defined in `self.payload` + incident, + self.payload.member, + ) + + +class TestOnMessage(TestIncidents): + """ + Tests for the `Incidents.on_message` listener. + + Notice the decorators mocking the `is_incident` return value. The `is_incidents` + function is tested in `TestIsIncident` - here we do not worry about it. + """ + + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) + async def test_on_message_incident(self): + """Messages qualifying as incidents are passed to `add_signals`.""" + incident = MockMessage() + + with patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: + await self.cog_instance.on_message(incident) + + mock_add_signals.assert_called_once_with(incident) + + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)) + async def test_on_message_non_incident(self): + """Messages not qualifying as incidents are ignored.""" + with patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: + await self.cog_instance.on_message(MockMessage()) + + mock_add_signals.assert_not_called() diff --git a/tests/bot/exts/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py new file mode 100644 index 000000000..f8f142484 --- /dev/null +++ b/tests/bot/exts/moderation/test_modlog.py @@ -0,0 +1,29 @@ +import unittest + +import discord + +from bot.exts.moderation.modlog import ModLog +from tests.helpers import MockBot, MockTextChannel + + +class ModLogTests(unittest.IsolatedAsyncioTestCase): + """Tests for moderation logs.""" + + def setUp(self): + self.bot = MockBot() + self.cog = ModLog(self.bot) + self.channel = MockTextChannel() + + async def test_log_entry_description_truncation(self): + """Test that embed description for ModLog entry is truncated.""" + self.bot.get_channel.return_value = self.channel + await self.cog.send_log_message( + icon_url="foo", + colour=discord.Colour.blue(), + title="bar", + text="foo bar" * 3000 + ) + embed = self.channel.send.call_args[1]["embed"] + self.assertEqual( + embed.description, ("foo bar" * 3000)[:2045] + "..." + ) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py new file mode 100644 index 000000000..8c4fb764a --- /dev/null +++ b/tests/bot/exts/moderation/test_silence.py @@ -0,0 +1,261 @@ +import unittest +from unittest import mock +from unittest.mock import MagicMock, Mock + +from discord import PermissionOverwrite + +from bot.constants import Channels, Emojis, Guild, Roles +from bot.exts.moderation.silence import Silence, SilenceNotifier +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.alert_channel = MockTextChannel() + self.notifier = SilenceNotifier(self.alert_channel) + self.notifier.stop = self.notifier_stop_mock = Mock() + self.notifier.start = self.notifier_start_mock = Mock() + + def test_add_channel_adds_channel(self): + """Channel in FirstHash with current loop is added to internal set.""" + channel = Mock() + with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: + self.notifier.add_channel(channel) + silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop) + + def test_add_channel_starts_loop(self): + """Loop is started if `_silenced_channels` was empty.""" + self.notifier.add_channel(Mock()) + self.notifier_start_mock.assert_called_once() + + def test_add_channel_skips_start_with_channels(self): + """Loop start is not called when `_silenced_channels` is not empty.""" + with mock.patch.object(self.notifier, "_silenced_channels"): + self.notifier.add_channel(Mock()) + self.notifier_start_mock.assert_not_called() + + def test_remove_channel_removes_channel(self): + """Channel in FirstHash is removed from `_silenced_channels`.""" + channel = Mock() + with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: + self.notifier.remove_channel(channel) + silenced_channels.__delitem__.assert_called_with(channel) + + def test_remove_channel_stops_loop(self): + """Notifier loop is stopped if `_silenced_channels` is empty after remove.""" + with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False): + self.notifier.remove_channel(Mock()) + self.notifier_stop_mock.assert_called_once() + + def test_remove_channel_skips_stop_with_channels(self): + """Notifier loop is not stopped if `_silenced_channels` is not empty after remove.""" + self.notifier.remove_channel(Mock()) + self.notifier_stop_mock.assert_not_called() + + async def test_notifier_private_sends_alert(self): + """Alert is sent on 15 min intervals.""" + test_cases = (900, 1800, 2700) + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with mock.patch.object(self.notifier, "_current_loop", new=current_loop): + await self.notifier._notifier() + self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") + self.alert_channel.send.reset_mock() + + async def test_notifier_skips_alert(self): + """Alert is skipped on first loop or not an increment of 900.""" + test_cases = (0, 15, 5000) + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with mock.patch.object(self.notifier, "_current_loop", new=current_loop): + await self.notifier._notifier() + self.alert_channel.send.assert_not_called() + + +class SilenceTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.bot = MockBot() + self.cog = Silence(self.bot) + self.ctx = MockContext() + self.cog._verified_role = None + # Set event so command callbacks can continue. + self.cog._get_instance_vars_event.set() + + async def test_instance_vars_got_guild(self): + """Bot got guild after it became available.""" + await self.cog._get_instance_vars() + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(Guild.id) + + async def test_instance_vars_got_role(self): + """Got `Roles.verified` role from guild.""" + await self.cog._get_instance_vars() + guild = self.bot.get_guild() + guild.get_role.assert_called_once_with(Roles.verified) + + async def test_instance_vars_got_channels(self): + """Got channels from bot.""" + await self.cog._get_instance_vars() + self.bot.get_channel.called_once_with(Channels.mod_alerts) + self.bot.get_channel.called_once_with(Channels.mod_log) + + @mock.patch("bot.exts.moderation.silence.SilenceNotifier") + async def test_instance_vars_got_notifier(self, notifier): + """Notifier was started with channel.""" + mod_log = MockTextChannel() + self.bot.get_channel.side_effect = (None, mod_log) + await self.cog._get_instance_vars() + notifier.assert_called_once_with(mod_log) + self.bot.get_channel.side_effect = None + + async def test_silence_sent_correct_discord_message(self): + """Check if proper message was sent when called with duration in channel with previous state.""" + test_cases = ( + (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), + (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), + (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), + ) + for duration, result_message, _silence_patch_return in test_cases: + with self.subTest( + silence_duration=duration, + result_message=result_message, + starting_unsilenced_state=_silence_patch_return + ): + with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): + await self.cog.silence.callback(self.cog, self.ctx, duration) + self.ctx.send.assert_called_once_with(result_message) + self.ctx.reset_mock() + + async def test_unsilence_sent_correct_discord_message(self): + """Check if proper message was sent when unsilencing channel.""" + test_cases = ( + (True, f"{Emojis.check_mark} unsilenced current channel."), + (False, f"{Emojis.cross_mark} current channel was not silenced.") + ) + for _unsilence_patch_return, result_message in test_cases: + with self.subTest( + starting_silenced_state=_unsilence_patch_return, + result_message=result_message + ): + with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): + await self.cog.unsilence.callback(self.cog, self.ctx) + self.ctx.send.assert_called_once_with(result_message) + self.ctx.reset_mock() + + async def test_silence_private_for_false(self): + """Permissions are not set and `False` is returned in an already silenced channel.""" + perm_overwrite = Mock(send_messages=False) + channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) + + self.assertFalse(await self.cog._silence(channel, True, None)) + channel.set_permissions.assert_not_called() + + async def test_silence_private_silenced_channel(self): + """Channel had `send_message` permissions revoked.""" + channel = MockTextChannel() + self.assertTrue(await self.cog._silence(channel, False, None)) + channel.set_permissions.assert_called_once() + self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) + + async def test_silence_private_preserves_permissions(self): + """Previous permissions were preserved when channel was silenced.""" + channel = MockTextChannel() + # Set up mock channel permission state. + mock_permissions = PermissionOverwrite() + mock_permissions_dict = dict(mock_permissions) + channel.overwrites_for.return_value = mock_permissions + await self.cog._silence(channel, False, None) + new_permissions = channel.set_permissions.call_args.kwargs + # Remove 'send_messages' key because it got changed in the method. + del new_permissions['send_messages'] + del mock_permissions_dict['send_messages'] + self.assertDictEqual(mock_permissions_dict, new_permissions) + + async def test_silence_private_notifier(self): + """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" + channel = MockTextChannel() + with mock.patch.object(self.cog, "notifier", create=True): + with self.subTest(persistent=True): + await self.cog._silence(channel, True, None) + self.cog.notifier.add_channel.assert_called_once() + + with mock.patch.object(self.cog, "notifier", create=True): + with self.subTest(persistent=False): + await self.cog._silence(channel, False, None) + self.cog.notifier.add_channel.assert_not_called() + + async def test_silence_private_added_muted_channel(self): + """Channel was added to `muted_channels` on silence.""" + channel = MockTextChannel() + with mock.patch.object(self.cog, "muted_channels") as muted_channels: + await self.cog._silence(channel, False, None) + muted_channels.add.assert_called_once_with(channel) + + async def test_unsilence_private_for_false(self): + """Permissions are not set and `False` is returned in an unsilenced channel.""" + channel = Mock() + self.assertFalse(await self.cog._unsilence(channel)) + channel.set_permissions.assert_not_called() + + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_unsilenced_channel(self, _): + """Channel had `send_message` permissions restored""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + self.assertTrue(await self.cog._unsilence(channel)) + channel.set_permissions.assert_called_once() + self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) + + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_removed_notifier(self, notifier): + """Channel was removed from `notifier` on unsilence.""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + await self.cog._unsilence(channel) + notifier.remove_channel.assert_called_once_with(channel) + + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_removed_muted_channel(self, _): + """Channel was removed from `muted_channels` on unsilence.""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + with mock.patch.object(self.cog, "muted_channels") as muted_channels: + await self.cog._unsilence(channel) + muted_channels.discard.assert_called_once_with(channel) + + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_preserves_permissions(self, _): + """Previous permissions were preserved when channel was unsilenced.""" + channel = MockTextChannel() + # Set up mock channel permission state. + mock_permissions = PermissionOverwrite(send_messages=False) + mock_permissions_dict = dict(mock_permissions) + channel.overwrites_for.return_value = mock_permissions + await self.cog._unsilence(channel) + new_permissions = channel.set_permissions.call_args.kwargs + # Remove 'send_messages' key because it got changed in the method. + del new_permissions['send_messages'] + del mock_permissions_dict['send_messages'] + self.assertDictEqual(mock_permissions_dict, new_permissions) + + @mock.patch("bot.exts.moderation.silence.asyncio") + @mock.patch.object(Silence, "_mod_alerts_channel", create=True) + def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): + """Task for sending an alert was created with present `muted_channels`.""" + with mock.patch.object(self.cog, "muted_channels"): + self.cog.cog_unload() + alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") + asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) + + @mock.patch("bot.exts.moderation.silence.asyncio") + def test_cog_unload_skips_task_start(self, asyncio_mock): + """No task created with no channels.""" + self.cog.cog_unload() + asyncio_mock.create_task.assert_not_called() + + @mock.patch("bot.exts.moderation.silence.with_role_check") + @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) + def test_cog_check(self, role_check): + """Role check is called with `MODERATION_ROLES`""" + self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/exts/moderation/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py new file mode 100644 index 000000000..e90394ab9 --- /dev/null +++ b/tests/bot/exts/moderation/test_slowmode.py @@ -0,0 +1,111 @@ +import unittest +from unittest import mock + +from dateutil.relativedelta import relativedelta + +from bot.constants import Emojis +from bot.exts.moderation.slowmode import Slowmode +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SlowmodeTests(unittest.IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.bot = MockBot() + self.cog = Slowmode(self.bot) + self.ctx = MockContext() + + async def test_get_slowmode_no_channel(self) -> None: + """Get slowmode without a given channel.""" + self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) + + await self.cog.get_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") + + async def test_get_slowmode_with_channel(self) -> None: + """Get slowmode with a given channel.""" + text_channel = MockTextChannel(name='python-language', slowmode_delay=2) + + await self.cog.get_slowmode(self.cog, self.ctx, text_channel) + self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') + + async def test_set_slowmode_no_channel(self) -> None: + """Set slowmode without a given channel.""" + test_cases = ( + ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), + ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), + ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') + ) + + for channel_name, seconds, edited, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + edited=edited, + result_msg=result_msg + ): + self.ctx.channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) + + if edited: + self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + self.ctx.channel.edit.assert_not_called() + + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + + async def test_set_slowmode_with_channel(self) -> None: + """Set slowmode with a given channel.""" + test_cases = ( + ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), + ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), + ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') + ) + + for channel_name, seconds, edited, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + edited=edited, + result_msg=result_msg + ): + text_channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) + + if edited: + text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + text_channel.edit.assert_not_called() + + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + + async def test_reset_slowmode_no_channel(self) -> None: + """Reset slowmode without a given channel.""" + self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) + + await self.cog.reset_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' + ) + + async def test_reset_slowmode_with_channel(self) -> None: + """Reset slowmode with a given channel.""" + text_channel = MockTextChannel(name='meta', slowmode_delay=1) + + await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' + ) + + @mock.patch("bot.exts.moderation.slowmode.with_role_check") + @mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) + def test_cog_check(self, role_check): + """Role check is called with `MODERATION_ROLES`""" + self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/exts/test_cogs.py b/tests/bot/exts/test_cogs.py new file mode 100644 index 000000000..775c40722 --- /dev/null +++ b/tests/bot/exts/test_cogs.py @@ -0,0 +1,81 @@ +"""Test suite for general tests which apply to all cogs.""" + +import importlib +import pkgutil +import typing as t +import unittest +from collections import defaultdict +from types import ModuleType +from unittest import mock + +from discord.ext import commands + +from bot import exts + + +class CommandNameTests(unittest.TestCase): + """Tests for shadowing command names and aliases.""" + + @staticmethod + def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]: + """An iterator that recursively walks through `cog`'s commands and subcommands.""" + # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods. + for command in cog.__cog_commands__: + if command.parent is None: + yield command + if isinstance(command, commands.GroupMixin): + # Annoyingly it returns duplicates for each alias so use a set to fix that + yield from set(command.walk_commands()) + + @staticmethod + def walk_modules() -> t.Iterator[ModuleType]: + """Yield imported modules from the bot.exts subpackage.""" + def on_error(name: str) -> t.NoReturn: + raise ImportError(name=name) # pragma: no cover + + # The mock prevents asyncio.get_event_loop() from being called. + with mock.patch("discord.ext.tasks.loop"): + prefix = f"{exts.__name__}." + for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error): + if not module.ispkg: + yield importlib.import_module(module.name) + + @staticmethod + def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: + """Yield all cogs defined in an extension.""" + for obj in module.__dict__.values(): + # Check if it's a class type cause otherwise issubclass() may raise a TypeError. + is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) + if is_cog and obj.__module__ == module.__name__: + yield obj + + @staticmethod + def get_qualified_names(command: commands.Command) -> t.List[str]: + """Return a list of all qualified names, including aliases, for the `command`.""" + names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] + names.append(command.qualified_name) + + return names + + def get_all_commands(self) -> t.Iterator[commands.Command]: + """Yield all commands for all cogs in all extensions.""" + for module in self.walk_modules(): + for cog in self.walk_cogs(module): + for cmd in self.walk_commands(cog): + yield cmd + + def test_names_dont_shadow(self): + """Names and aliases of commands should be unique.""" + all_names = defaultdict(list) + for cmd in self.get_all_commands(): + func_name = f"{cmd.module}.{cmd.callback.__qualname__}" + + for name in self.get_qualified_names(cmd): + with self.subTest(cmd=func_name, name=name): + if name in all_names: # pragma: no cover + conflicts = ", ".join(all_names.get(name, "")) + self.fail( + f"Name '{name}' of the command {func_name} conflicts with {conflicts}." + ) + + all_names[name].append(func_name) diff --git a/tests/bot/exts/test_duck_pond.py b/tests/bot/exts/test_duck_pond.py new file mode 100644 index 000000000..f6d977482 --- /dev/null +++ b/tests/bot/exts/test_duck_pond.py @@ -0,0 +1,548 @@ +import asyncio +import logging +import typing +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +import discord + +from bot import constants +from bot.exts import duck_pond +from tests import base +from tests import helpers + +MODULE_PATH = "bot.exts.duck_pond" + + +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): + """Tests for DuckPond functionality.""" + + @classmethod + def setUpClass(cls): + """Sets up the objects that only have to be initialized once.""" + cls.nonstaff_member = helpers.MockMember(name="Non-staffer") + + cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) + cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) + + cls.checkmark_emoji = "\N{White Heavy Check Mark}" + cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" + cls.unicode_duck_emoji = "\N{Duck}" + cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) + cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) + + def setUp(self): + """Sets up the objects that need to be refreshed before each test.""" + self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) + self.cog = duck_pond.DuckPond(bot=self.bot) + + def test_duck_pond_correctly_initializes(self): + """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" + bot = helpers.MockBot() + cog = MagicMock() + + duck_pond.DuckPond.__init__(cog, bot) + + self.assertEqual(cog.bot, bot) + self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) + bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) + + def test_fetch_webhook_succeeds_without_connectivity_issues(self): + """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" + self.bot.fetch_webhook.return_value = "dummy webhook" + self.cog.webhook_id = 1 + + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + self.assertEqual(self.cog.webhook, "dummy webhook") + + def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): + """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" + self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") + self.cog.webhook_id = 1 + + log = logging.getLogger('bot.exts.duck_pond') + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def test_is_staff_returns_correct_values_based_on_instance_passed(self): + """The `is_staff` method should return correct values based on the instance passed.""" + test_cases = ( + (helpers.MockUser(name="User instance"), False), + (helpers.MockMember(name="Member instance without staff role"), False), + (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) + ) + + for user, expected_return in test_cases: + actual_return = self.cog.is_staff(user) + with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): + """The `has_green_checkmark` method should only return `True` if one is present.""" + test_cases = ( + ( + "No reactions", helpers.MockMessage(), False + ), + ( + "No green check mark reactions", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) + ]), + False + ), + ( + "Green check mark reaction, but not from our bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) + ]), + False + ), + ( + "Green check mark reaction, with one from the bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) + ]), + True + ) + ) + + for description, message, expected_return in test_cases: + actual_return = await self.cog.has_green_checkmark(message) + with self.subTest( + test_case=description, + expected_return=expected_return, + actual_return=actual_return + ): + self.assertEqual(expected_return, actual_return) + + def _get_reaction( + self, + emoji: typing.Union[str, helpers.MockEmoji], + staff: int = 0, + nonstaff: int = 0 + ) -> helpers.MockReaction: + staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] + nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] + return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) + + async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): + """The `count_ducks` method should return the number of unique staffers who gave a duck.""" + test_cases = ( + # Simple test cases + # A message without reactions should return 0 + ( + "No reactions", + helpers.MockMessage(), + 0 + ), + # A message with a non-duck reaction from a non-staffer should return 0 + ( + "Non-duck reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), + 0 + ), + # A message with a non-duck reaction from a staffer should return 0 + ( + "Non-duck reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), + 0 + ), + # A message with a non-duck reaction from a non-staffer and staffer should return 0 + ( + "Non-duck reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a non-staffer should return 0 + ( + "Unicode Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a staffer should return 1 + ( + "Unicode Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), + 1 + ), + # A message with a unicode duck reaction from a non-staffer and staffer should return 1 + ( + "Unicode Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer should return 0 + ( + "Duckpond Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), + 0 + ), + # A message with a duckpond duck reaction from a staffer should return 1 + ( + "Duckpond Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 + ( + "Duckpond Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), + 1 + ), + + # Complex test cases + # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), + 3 + ), + # A staffer with multiple duck reactions only counts once + ( + "Two different duck reactions from the same staffer", + helpers.MockMessage( + reactions=[ + helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), + ] + ), + 1 + ), + # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) + ( + "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), + 0 + ), + # We correctly sum when multiple reactions are provided. + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage( + reactions=[ + self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), + self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), + ] + ), + 3 + 4 + ), + ) + + for description, message, expected_count in test_cases: + actual_count = await self.cog.count_ducks(message) + with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): + self.assertEqual(expected_count, actual_count) + + async def test_relay_message_correctly_relays_content_and_attachments(self): + """The `relay_message` method should correctly relay message content and attachments.""" + send_webhook_path = f"{MODULE_PATH}.send_webhook" + send_attachments_path = f"{MODULE_PATH}.send_attachments" + author = MagicMock( + display_name="x", + avatar_url="https://" + ) + + self.cog.webhook = helpers.MockAsyncWebhook() + + test_values = ( + (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), + (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), + (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), + (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), + ) + + for message, expect_webhook_call, expect_attachment_call in test_values: + with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: + with self.subTest(clean_content=message.clean_content, attachments=message.attachments): + await self.cog.relay_message(message) + + self.assertEqual(expect_webhook_call, send_webhook.called) + self.assertEqual(expect_attachment_call, send_attachments.called) + + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) + async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.exts.duck_pond") + + for side_effect in side_effects: # pragma: no cover + send_attachments.side_effect = side_effect + with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertNotLogs(logger=log, level=logging.ERROR): + await self.cog.relay_message(message) + + self.assertEqual(send_webhook.call_count, 2) + + @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) + async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.exts.duck_pond") + + side_effect = discord.HTTPException(MagicMock(), "") + send_attachments.side_effect = side_effect + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + await self.cog.relay_message(message) + + send_webhook.assert_called_once_with( + webhook=self.cog.webhook, + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): + """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" + payload = MagicMock(name=label) + payload.emoji.is_custom_emoji.return_value = is_custom_emoji + payload.emoji.id = id_ + payload.emoji.name = emoji_name + return payload + + async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): + """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" + test_values = ( + # Custom Emojis + ( + self._mock_payload( + label="Custom Duckpond Emoji", + is_custom_emoji=True, + id_=constants.DuckPond.custom_emojis[0], + emoji_name="" + ), + True + ), + ( + self._mock_payload( + label="Custom Non-Duckpond Emoji", + is_custom_emoji=True, + id_=123, + emoji_name="" + ), + False + ), + # Unicode Emojis + ( + self._mock_payload( + label="Unicode Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.unicode_duck_emoji + ), + True + ), + ( + self._mock_payload( + label="Unicode Non-Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.thumbs_up_emoji + ), + False + ), + ) + + for payload, expected_return in test_values: + actual_return = self.cog._payload_has_duckpond_emoji(payload) + with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @patch(f"{MODULE_PATH}.discord.utils.get") + @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) + def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): + """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) + + # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check + utils_get.assert_not_called() + + def _raw_reaction_mocks(self, channel_id, message_id, user_id): + """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" + channel = helpers.MockTextChannel(id=channel_id) + self.bot.get_all_channels.return_value = (channel,) + + message = helpers.MockMessage(id=message_id) + + channel.fetch_message.return_value = message + + member = helpers.MockMember(id=user_id, roles=[self.staff_role]) + message.guild.members = (member,) + + payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) + + return channel, message, member, payload + + async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): + """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" + channel_id = 1234 + message_id = 2345 + user_id = 3456 + + channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + test_cases = ( + ("non-staff member", helpers.MockMember(id=user_id)), + ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), + ) + + payload.emoji = self.duck_pond_emoji + + for description, member in test_cases: + message.guild.members = (member, ) + with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: + checkmark.side_effect = AssertionError( + "Expected method to return before calling `self.has_green_checkmark`." + ) + self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) + + # Check that we did make it past the payload checks + channel.fetch_message.assert_called_once() + channel.fetch_message.reset_mock() + + @patch(f"{MODULE_PATH}.DuckPond.is_staff") + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) + def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): + """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" + channel_id = 31415926535 + message_id = 27182818284 + user_id = 16180339887 + + channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) + payload.emoji.is_custom_emoji.return_value = False + + message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] + + is_staff.return_value = True + count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) + + # Assert that we've made it past `self.is_staff` + is_staff.assert_called_once() + + async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): + """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + + channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) + + payload.emoji = self.duck_pond_emoji + + for duck_count, should_relay in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_relay=should_relay): + await self.cog.on_raw_reaction_add(payload) + + # Confirm that we've made it past counting + count_ducks.assert_called_once() + + # Did we relay a message? + has_relayed = relay_message.called + self.assertEqual(has_relayed, should_relay) + + if should_relay: + relay_message.assert_called_once_with(message) + + async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): + """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" + checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) + + message = helpers.MockMessage(id=1234) + + channel = helpers.MockTextChannel(id=98765) + channel.fetch_message.return_value = message + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) + + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + for duck_count, should_re_add_checkmark in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): + await self.cog.on_raw_reaction_remove(payload) + + # Check if we fetched the message + channel.fetch_message.assert_called_once_with(message.id) + + # Check if we actually counted the number of ducks + count_ducks.assert_called_once_with(message) + + has_re_added_checkmark = message.add_reaction.called + self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) + + if should_re_add_checkmark: + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + message.add_reaction.reset_mock() + + # reset mocks + channel.fetch_message.reset_mock() + message.reset_mock() + + def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): + """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" + channel = helpers.MockTextChannel(id=98765) + + channel.fetch_message.side_effect = AssertionError( + "Expected method to return before calling `channel.fetch_message`" + ) + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) + + channel.fetch_message.assert_not_called() + + +class DuckPondSetupTests(unittest.TestCase): + """Tests setup of the `DuckPond` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = helpers.MockBot() + duck_pond.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/utils/__init__.py b/tests/bot/exts/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/utils/test_jams.py b/tests/bot/exts/utils/test_jams.py new file mode 100644 index 000000000..45e7b5b51 --- /dev/null +++ b/tests/bot/exts/utils/test_jams.py @@ -0,0 +1,173 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, create_autospec + +from discord import CategoryChannel + +from bot.constants import Roles +from bot.exts.utils import jams +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel + + +def get_mock_category(channel_count: int, name: str) -> CategoryChannel: + """Return a mocked code jam category.""" + category = create_autospec(CategoryChannel, spec_set=True, instance=True) + category.name = name + category.channels = [MockTextChannel() for _ in range(channel_count)] + + return category + + +class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): + """Tests for `createteam` command.""" + + def setUp(self): + self.bot = MockBot() + self.admin_role = MockRole(name="Admins", id=Roles.admins) + self.command_user = MockMember([self.admin_role]) + self.guild = MockGuild([self.admin_role]) + self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) + self.cog = jams.CodeJams(self.bot) + + async def test_too_small_amount_of_team_members_passed(self): + """Should `ctx.send` and exit early when too small amount of members.""" + for case in (1, 2): + with self.subTest(amount_of_members=case): + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + self.ctx.reset_mock() + members = (MockMember() for _ in range(case)) + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + self.ctx.send.assert_awaited_once() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() + + async def test_duplicate_members_provided(self): + """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + member = MockMember() + await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) + + self.ctx.send.assert_awaited_once() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() + + async def test_result_sending(self): + """Should call `ctx.send` when everything goes right.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + members = [MockMember() for _ in range(5)] + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + self.cog.create_channels.assert_awaited_once() + self.cog.add_roles.assert_awaited_once() + self.ctx.send.assert_awaited_once() + + async def test_category_doesnt_exist(self): + """Should create a new code jam category.""" + subtests = ( + [], + [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], + [get_mock_category(jams.MAX_CHANNELS - 2, "other")], + ) + + for categories in subtests: + self.guild.reset_mock() + self.guild.categories = categories + + with self.subTest(categories=categories): + actual_category = await self.cog.get_category(self.guild) + + self.guild.create_category_channel.assert_awaited_once() + category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] + + self.assertFalse(category_overwrites[self.guild.default_role].read_messages) + self.assertTrue(category_overwrites[self.guild.me].read_messages) + self.assertEqual(self.guild.create_category_channel.return_value, actual_category) + + async def test_category_channel_exist(self): + """Should not try to create category channel.""" + expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) + self.guild.categories = [ + get_mock_category(jams.MAX_CHANNELS - 2, "other"), + expected_category, + get_mock_category(0, jams.CATEGORY_NAME), + ] + + actual_category = await self.cog.get_category(self.guild) + self.assertEqual(expected_category, actual_category) + + async def test_channel_overwrites(self): + """Should have correct permission overwrites for users and roles.""" + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + overwrites = self.cog.get_overwrites(members, self.guild) + + # Leader permission overwrites + self.assertTrue(overwrites[leader].manage_messages) + self.assertTrue(overwrites[leader].read_messages) + self.assertTrue(overwrites[leader].manage_webhooks) + self.assertTrue(overwrites[leader].connect) + + # Other members permission overwrites + for member in members[1:]: + self.assertTrue(overwrites[member].read_messages) + self.assertTrue(overwrites[member].connect) + + # Everyone and verified role overwrite + self.assertFalse(overwrites[self.guild.default_role].read_messages) + self.assertFalse(overwrites[self.guild.default_role].connect) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) + + async def test_team_channels_creation(self): + """Should create new voice and text channel for team.""" + members = [MockMember() for _ in range(5)] + + self.cog.get_overwrites = MagicMock() + self.cog.get_category = AsyncMock() + self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") + actual = await self.cog.create_channels(self.guild, "my-team", members) + + self.assertEqual("foobar-channel", actual) + self.cog.get_overwrites.assert_called_once_with(members, self.guild) + self.cog.get_category.assert_awaited_once_with(self.guild) + + self.guild.create_text_channel.assert_awaited_once_with( + "my-team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + self.guild.create_voice_channel.assert_awaited_once_with( + "My Team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + + async def test_jam_roles_adding(self): + """Should add team leader role to leader and jam role to every team member.""" + leader_role = MockRole(name="Team Leader") + jam_role = MockRole(name="Jammer") + self.guild.get_role.side_effect = [leader_role, jam_role] + + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + await self.cog.add_roles(self.guild, members) + + leader.add_roles.assert_any_await(leader_role) + for member in members: + member.add_roles.assert_any_await(jam_role) + + +class CodeJamSetup(unittest.TestCase): + """Test for `setup` function of `CodeJam` cog.""" + + def test_setup(self): + """Should call `bot.add_cog`.""" + bot = MockBot() + jams.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py new file mode 100644 index 000000000..f7b861035 --- /dev/null +++ b/tests/bot/exts/utils/test_snekbox.py @@ -0,0 +1,409 @@ +import asyncio +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch + +from discord.ext import commands + +from bot import constants +from bot.exts.utils import snekbox +from bot.exts.utils.snekbox import Snekbox +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser + + +class SnekboxTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Add mocked bot and cog to the instance.""" + self.bot = MockBot() + self.cog = Snekbox(bot=self.bot) + + async def test_post_eval(self): + """Post the eval code to the URLs.snekbox_eval_api endpoint.""" + resp = MagicMock() + resp.json = AsyncMock(return_value="return") + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + self.assertEqual(await self.cog.post_eval("import random"), "return") + self.bot.http_session.post.assert_called_with( + constants.URLs.snekbox_eval_api, + json={"input": "import random"}, + raise_for_status=True + ) + resp.json.assert_awaited_once() + + async def test_upload_output_reject_too_long(self): + """Reject output longer than MAX_PASTE_LEN.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + self.assertEqual(result, "too long to upload") + + async def test_upload_output(self): + """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" + key = "MarkDiamond" + resp = MagicMock() + resp.json = AsyncMock(return_value={"key": key}) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + self.assertEqual( + await self.cog.upload_output("My awesome output"), + constants.URLs.paste_service.format(key=key) + ) + self.bot.http_session.post.assert_called_with( + constants.URLs.paste_service.format(key="documents"), + data="My awesome output", + raise_for_status=True + ) + + async def test_upload_output_gracefully_fallback_if_exception_during_request(self): + """Output upload gracefully fallback if the upload fail.""" + resp = MagicMock() + resp.json = AsyncMock(side_effect=Exception) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + log = logging.getLogger("bot.exts.utils.snekbox") + with self.assertLogs(logger=log, level='ERROR'): + await self.cog.upload_output('My awesome output!') + + async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): + """Output upload gracefully fallback if there is no key entry in the response body.""" + self.assertEqual((await self.cog.upload_output('My awesome output!')), None) + + def test_prepare_input(self): + cases = ( + ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), + ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), + ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), + ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), + ) + for case, expected, testname in cases: + with self.subTest(msg=f'Extract code from {testname}.'): + self.assertEqual(self.cog.prepare_input(case), expected) + + def test_get_results_message(self): + """Return error and message according to the eval result.""" + cases = ( + ('ERROR', None, ('Your eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), + ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) + def test_get_results_message_invalid_signal(self, mock_signals: Mock): + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127', '') + ) + + @patch('bot.exts.utils.snekbox.Signals') + def test_get_results_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = 'SIGTEST' + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127 (SIGTEST)', '') + ) + + def test_get_status_emoji(self): + """Return emoji according to the eval result.""" + cases = ( + (' ', -1, ':warning:'), + ('Hello world!', 0, ':white_check_mark:'), + ('Invalid beard size', -1, ':x:') + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + async def test_format_output(self): + """Test output formatting.""" + self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + + too_many_lines = ( + '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' + '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' + ) + too_long_too_many_lines = ( + "\n".join( + f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) + )[:1000] + "\n... (truncated - too long, too many lines)" + ) + + cases = ( + ('', ('[No output]', None), 'No output'), + ('My awesome output', ('My awesome output', None), 'One line output'), + ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), + (' Date: Wed, 19 Aug 2020 13:34:34 -0700 Subject: Categorise most of the uncategorised extensions --- bot/exts/alias.py | 153 ---------- bot/exts/backend/alias.py | 153 ++++++++++ bot/exts/dm_relay.py | 124 -------- bot/exts/duck_pond.py | 166 ----------- bot/exts/fun/__init__.py | 0 bot/exts/fun/duck_pond.py | 166 +++++++++++ bot/exts/fun/off_topic_names.py | 162 +++++++++++ bot/exts/moderation/dm_relay.py | 124 ++++++++ bot/exts/off_topic_names.py | 162 ----------- tests/bot/exts/fun/__init__.py | 0 tests/bot/exts/fun/test_duck_pond.py | 548 +++++++++++++++++++++++++++++++++++ tests/bot/exts/test_duck_pond.py | 548 ----------------------------------- 12 files changed, 1153 insertions(+), 1153 deletions(-) delete mode 100644 bot/exts/alias.py create mode 100644 bot/exts/backend/alias.py delete mode 100644 bot/exts/dm_relay.py delete mode 100644 bot/exts/duck_pond.py create mode 100644 bot/exts/fun/__init__.py create mode 100644 bot/exts/fun/duck_pond.py create mode 100644 bot/exts/fun/off_topic_names.py create mode 100644 bot/exts/moderation/dm_relay.py delete mode 100644 bot/exts/off_topic_names.py create mode 100644 tests/bot/exts/fun/__init__.py create mode 100644 tests/bot/exts/fun/test_duck_pond.py delete mode 100644 tests/bot/exts/test_duck_pond.py (limited to 'tests') diff --git a/bot/exts/alias.py b/bot/exts/alias.py deleted file mode 100644 index 77867b933..000000000 --- a/bot/exts/alias.py +++ /dev/null @@ -1,153 +0,0 @@ -import inspect -import logging - -from discord import Colour, Embed -from discord.ext.commands import ( - Cog, Command, Context, Greedy, - clean_content, command, group, -) - -from bot.bot import Bot -from bot.converters import FetchedMember, TagNameConverter -from bot.exts.utils.extensions import Extension -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - - -class Alias (Cog): - """Aliases for commonly used commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: - """Invokes a command with args and kwargs.""" - log.debug(f"{cmd_name} was invoked through an alias") - cmd = self.bot.get_command(cmd_name) - if not cmd: - return log.info(f'Did not find command "{cmd_name}" to invoke.') - elif not await cmd.can_run(ctx): - return log.info( - f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' - ) - - await ctx.invoke(cmd, *args, **kwargs) - - @command(name='aliases') - async def aliases_command(self, ctx: Context) -> None: - """Show configured aliases on the bot.""" - embed = Embed( - title='Configured aliases', - colour=Colour.blue() - ) - await LinePaginator.paginate( - ( - f"• `{ctx.prefix}{value.name}` " - f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" - for name, value in inspect.getmembers(self) - if isinstance(value, Command) and name.endswith('_alias') - ), - ctx, embed, empty=False, max_lines=20 - ) - - @command(name="resources", aliases=("resource",), hidden=True) - async def site_resources_alias(self, ctx: Context) -> None: - """Alias for invoking site resources.""" - await self.invoke(ctx, "site resources") - - @command(name="tools", hidden=True) - async def site_tools_alias(self, ctx: Context) -> None: - """Alias for invoking site tools.""" - await self.invoke(ctx, "site tools") - - @command(name="watch", hidden=True) - async def bigbrother_watch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking bigbrother watch [user] [reason].""" - await self.invoke(ctx, "bigbrother watch", user, reason=reason) - - @command(name="unwatch", hidden=True) - async def bigbrother_unwatch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking bigbrother unwatch [user] [reason].""" - await self.invoke(ctx, "bigbrother unwatch", user, reason=reason) - - @command(name="home", hidden=True) - async def site_home_alias(self, ctx: Context) -> None: - """Alias for invoking site home.""" - await self.invoke(ctx, "site home") - - @command(name="faq", hidden=True) - async def site_faq_alias(self, ctx: Context) -> None: - """Alias for invoking site faq.""" - await self.invoke(ctx, "site faq") - - @command(name="rules", aliases=("rule",), hidden=True) - async def site_rules_alias(self, ctx: Context, rules: Greedy[int], *_: str) -> None: - """Alias for invoking site rules.""" - await self.invoke(ctx, "site rules", *rules) - - @command(name="reload", hidden=True) - async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: - """Alias for invoking extensions reload [extensions...].""" - await self.invoke(ctx, "extensions reload", *extensions) - - @command(name="defon", hidden=True) - async def defcon_enable_alias(self, ctx: Context) -> None: - """Alias for invoking defcon enable.""" - await self.invoke(ctx, "defcon enable") - - @command(name="defoff", hidden=True) - async def defcon_disable_alias(self, ctx: Context) -> None: - """Alias for invoking defcon disable.""" - await self.invoke(ctx, "defcon disable") - - @command(name="exception", hidden=True) - async def tags_get_traceback_alias(self, ctx: Context) -> None: - """Alias for invoking tags get traceback.""" - await self.invoke(ctx, "tags get", tag_name="traceback") - - @group(name="get", - aliases=("show", "g"), - hidden=True, - invoke_without_command=True) - async def get_group_alias(self, ctx: Context) -> None: - """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" - pass - - @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) - async def tags_get_alias( - self, ctx: Context, *, tag_name: TagNameConverter = None - ) -> None: - """ - Alias for invoking tags get [tag_name]. - - tag_name: str - tag to be viewed. - """ - await self.invoke(ctx, "tags get", tag_name=tag_name) - - @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) - async def docs_get_alias( - self, ctx: Context, symbol: clean_content = None - ) -> None: - """Alias for invoking docs get [symbol].""" - await self.invoke(ctx, "docs get", symbol) - - @command(name="nominate", hidden=True) - async def nomination_add_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking talentpool add [user] [reason].""" - await self.invoke(ctx, "talentpool add", user, reason=reason) - - @command(name="unnominate", hidden=True) - async def nomination_end_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking nomination end [user] [reason].""" - await self.invoke(ctx, "nomination end", user, reason=reason) - - @command(name="nominees", hidden=True) - async def nominees_alias(self, ctx: Context) -> None: - """Alias for invoking tp watched.""" - await self.invoke(ctx, "talentpool watched") - - -def setup(bot: Bot) -> None: - """Load the Alias cog.""" - bot.add_cog(Alias(bot)) diff --git a/bot/exts/backend/alias.py b/bot/exts/backend/alias.py new file mode 100644 index 000000000..77867b933 --- /dev/null +++ b/bot/exts/backend/alias.py @@ -0,0 +1,153 @@ +import inspect +import logging + +from discord import Colour, Embed +from discord.ext.commands import ( + Cog, Command, Context, Greedy, + clean_content, command, group, +) + +from bot.bot import Bot +from bot.converters import FetchedMember, TagNameConverter +from bot.exts.utils.extensions import Extension +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + + +class Alias (Cog): + """Aliases for commonly used commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: + """Invokes a command with args and kwargs.""" + log.debug(f"{cmd_name} was invoked through an alias") + cmd = self.bot.get_command(cmd_name) + if not cmd: + return log.info(f'Did not find command "{cmd_name}" to invoke.') + elif not await cmd.can_run(ctx): + return log.info( + f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' + ) + + await ctx.invoke(cmd, *args, **kwargs) + + @command(name='aliases') + async def aliases_command(self, ctx: Context) -> None: + """Show configured aliases on the bot.""" + embed = Embed( + title='Configured aliases', + colour=Colour.blue() + ) + await LinePaginator.paginate( + ( + f"• `{ctx.prefix}{value.name}` " + f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" + for name, value in inspect.getmembers(self) + if isinstance(value, Command) and name.endswith('_alias') + ), + ctx, embed, empty=False, max_lines=20 + ) + + @command(name="resources", aliases=("resource",), hidden=True) + async def site_resources_alias(self, ctx: Context) -> None: + """Alias for invoking site resources.""" + await self.invoke(ctx, "site resources") + + @command(name="tools", hidden=True) + async def site_tools_alias(self, ctx: Context) -> None: + """Alias for invoking site tools.""" + await self.invoke(ctx, "site tools") + + @command(name="watch", hidden=True) + async def bigbrother_watch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking bigbrother watch [user] [reason].""" + await self.invoke(ctx, "bigbrother watch", user, reason=reason) + + @command(name="unwatch", hidden=True) + async def bigbrother_unwatch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking bigbrother unwatch [user] [reason].""" + await self.invoke(ctx, "bigbrother unwatch", user, reason=reason) + + @command(name="home", hidden=True) + async def site_home_alias(self, ctx: Context) -> None: + """Alias for invoking site home.""" + await self.invoke(ctx, "site home") + + @command(name="faq", hidden=True) + async def site_faq_alias(self, ctx: Context) -> None: + """Alias for invoking site faq.""" + await self.invoke(ctx, "site faq") + + @command(name="rules", aliases=("rule",), hidden=True) + async def site_rules_alias(self, ctx: Context, rules: Greedy[int], *_: str) -> None: + """Alias for invoking site rules.""" + await self.invoke(ctx, "site rules", *rules) + + @command(name="reload", hidden=True) + async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: + """Alias for invoking extensions reload [extensions...].""" + await self.invoke(ctx, "extensions reload", *extensions) + + @command(name="defon", hidden=True) + async def defcon_enable_alias(self, ctx: Context) -> None: + """Alias for invoking defcon enable.""" + await self.invoke(ctx, "defcon enable") + + @command(name="defoff", hidden=True) + async def defcon_disable_alias(self, ctx: Context) -> None: + """Alias for invoking defcon disable.""" + await self.invoke(ctx, "defcon disable") + + @command(name="exception", hidden=True) + async def tags_get_traceback_alias(self, ctx: Context) -> None: + """Alias for invoking tags get traceback.""" + await self.invoke(ctx, "tags get", tag_name="traceback") + + @group(name="get", + aliases=("show", "g"), + hidden=True, + invoke_without_command=True) + async def get_group_alias(self, ctx: Context) -> None: + """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" + pass + + @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) + async def tags_get_alias( + self, ctx: Context, *, tag_name: TagNameConverter = None + ) -> None: + """ + Alias for invoking tags get [tag_name]. + + tag_name: str - tag to be viewed. + """ + await self.invoke(ctx, "tags get", tag_name=tag_name) + + @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) + async def docs_get_alias( + self, ctx: Context, symbol: clean_content = None + ) -> None: + """Alias for invoking docs get [symbol].""" + await self.invoke(ctx, "docs get", symbol) + + @command(name="nominate", hidden=True) + async def nomination_add_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking talentpool add [user] [reason].""" + await self.invoke(ctx, "talentpool add", user, reason=reason) + + @command(name="unnominate", hidden=True) + async def nomination_end_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking nomination end [user] [reason].""" + await self.invoke(ctx, "nomination end", user, reason=reason) + + @command(name="nominees", hidden=True) + async def nominees_alias(self, ctx: Context) -> None: + """Alias for invoking tp watched.""" + await self.invoke(ctx, "talentpool watched") + + +def setup(bot: Bot) -> None: + """Load the Alias cog.""" + bot.add_cog(Alias(bot)) diff --git a/bot/exts/dm_relay.py b/bot/exts/dm_relay.py deleted file mode 100644 index 0d8f340b4..000000000 --- a/bot/exts/dm_relay.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging -from typing import Optional - -import discord -from discord import Color -from discord.ext import commands -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot -from bot.converters import UserMentionOrID -from bot.utils import RedisCache -from bot.utils.checks import in_whitelist_check, with_role_check -from bot.utils.messages import send_attachments -from bot.utils.webhooks import send_webhook - -log = logging.getLogger(__name__) - - -class DMRelay(Cog): - """Relay direct messages to and from the bot.""" - - # RedisCache[str, t.Union[discord.User.id, discord.Member.id]] - dm_cache = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_id = constants.Webhooks.dm_log - self.webhook = None - self.bot.loop.create_task(self.fetch_webhook()) - - @commands.command(aliases=("reply",)) - async def send_dm(self, ctx: commands.Context, member: Optional[UserMentionOrID], *, message: str) -> None: - """ - Allows you to send a DM to a user from the bot. - - If `member` is not provided, it will send to the last user who DM'd the bot. - - This feature should be used extremely sparingly. Use ModMail if you need to have a serious - conversation with a user. This is just for responding to extraordinary DMs, having a little - fun with users, and telling people they are DMing the wrong bot. - - NOTE: This feature will be removed if it is overused. - """ - if not member: - user_id = await self.dm_cache.get("last_user") - member = ctx.guild.get_member(user_id) if user_id else None - - # If we still don't have a Member at this point, give up - if not member: - log.debug("This bot has never gotten a DM, or the RedisCache has been cleared.") - await ctx.message.add_reaction("❌") - return - - try: - await member.send(message) - except discord.errors.Forbidden: - log.debug("User has disabled DMs.") - await ctx.message.add_reaction("❌") - else: - await ctx.message.add_reaction("✅") - self.bot.stats.incr("dm_relay.dm_sent") - - async def fetch_webhook(self) -> None: - """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_guild_available() - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - @Cog.listener() - async def on_message(self, message: discord.Message) -> None: - """Relays the message's content and attachments to the dm_log channel.""" - # Only relay DMs from humans - if message.author.bot or message.guild or self.webhook is None: - return - - if message.clean_content: - await send_webhook( - webhook=self.webhook, - content=message.clean_content, - username=f"{message.author.display_name} ({message.author.id})", - avatar_url=message.author.avatar_url - ) - await self.dm_cache.set("last_user", message.author.id) - self.bot.stats.incr("dm_relay.dm_received") - - # Handle any attachments - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (discord.errors.Forbidden, discord.errors.NotFound): - e = discord.Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await send_webhook( - webhook=self.webhook, - embed=e, - username=f"{message.author.display_name} ({message.author.id})", - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception("Failed to send an attachment to the webhook") - - def cog_check(self, ctx: commands.Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - checks = [ - with_role_check(ctx, *constants.MODERATION_ROLES), - in_whitelist_check( - ctx, - channels=[constants.Channels.dm_log], - redirect=None, - fail_silently=True, - ) - ] - return all(checks) - - -def setup(bot: Bot) -> None: - """Load the DMRelay cog.""" - bot.add_cog(DMRelay(bot)) diff --git a/bot/exts/duck_pond.py b/bot/exts/duck_pond.py deleted file mode 100644 index 7021069fa..000000000 --- a/bot/exts/duck_pond.py +++ /dev/null @@ -1,166 +0,0 @@ -import logging -from typing import Union - -import discord -from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot -from bot.utils.messages import send_attachments -from bot.utils.webhooks import send_webhook - -log = logging.getLogger(__name__) - - -class DuckPond(Cog): - """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_id = constants.Webhooks.duck_pond - self.webhook = None - self.bot.loop.create_task(self.fetch_webhook()) - - async def fetch_webhook(self) -> None: - """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_guild_available() - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - @staticmethod - def is_staff(member: Union[User, Member]) -> bool: - """Check if a specific member or user is staff.""" - if hasattr(member, "roles"): - for role in member.roles: - if role.id in constants.STAFF_ROLES: - return True - return False - - async def has_green_checkmark(self, message: Message) -> bool: - """Check if the message has a green checkmark reaction.""" - for reaction in message.reactions: - if reaction.emoji == "✅": - async for user in reaction.users(): - if user == self.bot.user: - return True - return False - - async def count_ducks(self, message: Message) -> int: - """ - Count the number of ducks in the reactions of a specific message. - - Only counts ducks added by staff members. - """ - duck_count = 0 - duck_reactors = [] - - for reaction in message.reactions: - async for user in reaction.users(): - - # Is the user a staff member and not already counted as reactor? - if not self.is_staff(user) or user.id in duck_reactors: - continue - - # Is the emoji a duck? - if hasattr(reaction.emoji, "id"): - if reaction.emoji.id in constants.DuckPond.custom_emojis: - duck_count += 1 - duck_reactors.append(user.id) - elif isinstance(reaction.emoji, str): - if reaction.emoji == "🦆": - duck_count += 1 - duck_reactors.append(user.id) - return duck_count - - async def relay_message(self, message: Message) -> None: - """Relays the message's content and attachments to the duck pond channel.""" - if message.clean_content: - await send_webhook( - webhook=self.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await send_webhook( - webhook=self.webhook, - embed=e, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception("Failed to send an attachment to the webhook") - - await message.add_reaction("✅") - - @staticmethod - def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: - """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" - if payload.emoji.is_custom_emoji(): - if payload.emoji.id in constants.DuckPond.custom_emojis: - return True - elif payload.emoji.name == "🦆": - return True - - return False - - @Cog.listener() - async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: - """ - Determine if a message should be sent to the duck pond. - - This will count the number of duck reactions on the message, and if this amount meets the - amount of ducks specified in the config under duck_pond/threshold, it will - send the message off to the duck pond. - """ - # Is the emoji in the reaction a duck? - if not self._payload_has_duckpond_emoji(payload): - return - - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - message = await channel.fetch_message(payload.message_id) - member = discord.utils.get(message.guild.members, id=payload.user_id) - - # Is the member a human and a staff member? - if not self.is_staff(member) or member.bot: - return - - # Does the message already have a green checkmark? - if await self.has_green_checkmark(message): - return - - # Time to count our ducks! - duck_count = await self.count_ducks(message) - - # If we've got more than the required amount of ducks, send the message to the duck_pond. - if duck_count >= constants.DuckPond.threshold: - await self.relay_message(message) - - @Cog.listener() - async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: - """Ensure that people don't remove the green checkmark from duck ponded messages.""" - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - - # Prevent the green checkmark from being removed - if payload.emoji.name == "✅": - message = await channel.fetch_message(payload.message_id) - duck_count = await self.count_ducks(message) - if duck_count >= constants.DuckPond.threshold: - await message.add_reaction("✅") - - -def setup(bot: Bot) -> None: - """Load the DuckPond cog.""" - bot.add_cog(DuckPond(bot)) diff --git a/bot/exts/fun/__init__.py b/bot/exts/fun/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py new file mode 100644 index 000000000..7021069fa --- /dev/null +++ b/bot/exts/fun/duck_pond.py @@ -0,0 +1,166 @@ +import logging +from typing import Union + +import discord +from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.utils.messages import send_attachments +from bot.utils.webhooks import send_webhook + +log = logging.getLogger(__name__) + + +class DuckPond(Cog): + """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.duck_pond + self.webhook = None + self.bot.loop.create_task(self.fetch_webhook()) + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_guild_available() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @staticmethod + def is_staff(member: Union[User, Member]) -> bool: + """Check if a specific member or user is staff.""" + if hasattr(member, "roles"): + for role in member.roles: + if role.id in constants.STAFF_ROLES: + return True + return False + + async def has_green_checkmark(self, message: Message) -> bool: + """Check if the message has a green checkmark reaction.""" + for reaction in message.reactions: + if reaction.emoji == "✅": + async for user in reaction.users(): + if user == self.bot.user: + return True + return False + + async def count_ducks(self, message: Message) -> int: + """ + Count the number of ducks in the reactions of a specific message. + + Only counts ducks added by staff members. + """ + duck_count = 0 + duck_reactors = [] + + for reaction in message.reactions: + async for user in reaction.users(): + + # Is the user a staff member and not already counted as reactor? + if not self.is_staff(user) or user.id in duck_reactors: + continue + + # Is the emoji a duck? + if hasattr(reaction.emoji, "id"): + if reaction.emoji.id in constants.DuckPond.custom_emojis: + duck_count += 1 + duck_reactors.append(user.id) + elif isinstance(reaction.emoji, str): + if reaction.emoji == "🦆": + duck_count += 1 + duck_reactors.append(user.id) + return duck_count + + async def relay_message(self, message: Message) -> None: + """Relays the message's content and attachments to the duck pond channel.""" + if message.clean_content: + await send_webhook( + webhook=self.webhook, + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await send_webhook( + webhook=self.webhook, + embed=e, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception("Failed to send an attachment to the webhook") + + await message.add_reaction("✅") + + @staticmethod + def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: + """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" + if payload.emoji.is_custom_emoji(): + if payload.emoji.id in constants.DuckPond.custom_emojis: + return True + elif payload.emoji.name == "🦆": + return True + + return False + + @Cog.listener() + async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: + """ + Determine if a message should be sent to the duck pond. + + This will count the number of duck reactions on the message, and if this amount meets the + amount of ducks specified in the config under duck_pond/threshold, it will + send the message off to the duck pond. + """ + # Is the emoji in the reaction a duck? + if not self._payload_has_duckpond_emoji(payload): + return + + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + message = await channel.fetch_message(payload.message_id) + member = discord.utils.get(message.guild.members, id=payload.user_id) + + # Is the member a human and a staff member? + if not self.is_staff(member) or member.bot: + return + + # Does the message already have a green checkmark? + if await self.has_green_checkmark(message): + return + + # Time to count our ducks! + duck_count = await self.count_ducks(message) + + # If we've got more than the required amount of ducks, send the message to the duck_pond. + if duck_count >= constants.DuckPond.threshold: + await self.relay_message(message) + + @Cog.listener() + async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: + """Ensure that people don't remove the green checkmark from duck ponded messages.""" + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + + # Prevent the green checkmark from being removed + if payload.emoji.name == "✅": + message = await channel.fetch_message(payload.message_id) + duck_count = await self.count_ducks(message) + if duck_count >= constants.DuckPond.threshold: + await message.add_reaction("✅") + + +def setup(bot: Bot) -> None: + """Load the DuckPond cog.""" + bot.add_cog(DuckPond(bot)) diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py new file mode 100644 index 000000000..ce95450e0 --- /dev/null +++ b/bot/exts/fun/off_topic_names.py @@ -0,0 +1,162 @@ +import asyncio +import difflib +import logging +from datetime import datetime, timedelta + +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES +from bot.converters import OffTopicName +from bot.decorators import with_role +from bot.pagination import LinePaginator + +CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) +log = logging.getLogger(__name__) + + +async def update_names(bot: Bot) -> None: + """Background updater task that performs the daily channel name update.""" + while True: + # Since we truncate the compute timedelta to seconds, we add one second to ensure + # we go past midnight in the `seconds_to_sleep` set below. + today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0) + next_midnight = today_at_midnight + timedelta(days=1) + seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1 + await asyncio.sleep(seconds_to_sleep) + + try: + channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get( + 'bot/off-topic-channel-names', params={'random_items': 3} + ) + except ResponseCodeError as e: + log.error(f"Failed to get new off topic channel names: code {e.response.status}") + continue + channel_0, channel_1, channel_2 = (bot.get_channel(channel_id) for channel_id in CHANNELS) + + await channel_0.edit(name=f'ot0-{channel_0_name}') + await channel_1.edit(name=f'ot1-{channel_1_name}') + await channel_2.edit(name=f'ot2-{channel_2_name}') + log.debug( + "Updated off-topic channel names to" + f" {channel_0_name}, {channel_1_name} and {channel_2_name}" + ) + + +class OffTopicNames(Cog): + """Commands related to managing the off-topic category channel names.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.updater_task = None + + self.bot.loop.create_task(self.init_offtopic_updater()) + + def cog_unload(self) -> None: + """Cancel any running updater tasks on cog unload.""" + if self.updater_task is not None: + self.updater_task.cancel() + + async def init_offtopic_updater(self) -> None: + """Start off-topic channel updating event loop if it hasn't already started.""" + await self.bot.wait_until_guild_available() + if self.updater_task is None: + coro = update_names(self.bot) + self.updater_task = self.bot.loop.create_task(coro) + + @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def otname_group(self, ctx: Context) -> None: + """Add or list items from the off-topic channel name rotation.""" + await ctx.send_help(ctx.command) + + @otname_group.command(name='add', aliases=('a',)) + @with_role(*MODERATION_ROLES) + async def add_command(self, ctx: Context, *, name: OffTopicName) -> None: + """ + Adds a new off-topic name to the rotation. + + The name is not added if it is too similar to an existing name. + """ + existing_names = await self.bot.api_client.get('bot/off-topic-channel-names') + close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) + + if close_match: + match = close_match[0] + log.info( + f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'" + ) + await ctx.send( + f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. " + "Use `!otn forceadd` to override this check." + ) + else: + await self._add_name(ctx, name) + + @otname_group.command(name='forceadd', aliases=('fa',)) + @with_role(*MODERATION_ROLES) + async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None: + """Forcefully adds a new off-topic name to the rotation.""" + await self._add_name(ctx, name) + + async def _add_name(self, ctx: Context, name: str) -> None: + """Adds an off-topic channel name to the site storage.""" + await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name}) + + log.info(f"{ctx.author} added the off-topic channel name '{name}'") + await ctx.send(f":ok_hand: Added `{name}` to the names list.") + + @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) + @with_role(*MODERATION_ROLES) + async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None: + """Removes a off-topic name from the rotation.""" + await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') + + log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") + await ctx.send(f":ok_hand: Removed `{name}` from the names list.") + + @otname_group.command(name='list', aliases=('l',)) + @with_role(*MODERATION_ROLES) + async def list_command(self, ctx: Context) -> None: + """ + Lists all currently known off-topic channel names in a paginator. + + Restricted to Moderator and above to not spoil the surprise. + """ + result = await self.bot.api_client.get('bot/off-topic-channel-names') + lines = sorted(f"• {name}" for name in result) + embed = Embed( + title=f"Known off-topic names (`{len(result)}` total)", + colour=Colour.blue() + ) + if result: + await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + + @otname_group.command(name='search', aliases=('s',)) + @with_role(*MODERATION_ROLES) + async def search_command(self, ctx: Context, *, query: OffTopicName) -> None: + """Search for an off-topic name.""" + result = await self.bot.api_client.get('bot/off-topic-channel-names') + in_matches = {name for name in result if query in name} + close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70) + lines = sorted(f"• {name}" for name in in_matches.union(close_matches)) + embed = Embed( + title="Query results", + colour=Colour.blue() + ) + + if lines: + await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) + else: + embed.description = "Nothing found." + await ctx.send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the OffTopicNames cog.""" + bot.add_cog(OffTopicNames(bot)) diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py new file mode 100644 index 000000000..0d8f340b4 --- /dev/null +++ b/bot/exts/moderation/dm_relay.py @@ -0,0 +1,124 @@ +import logging +from typing import Optional + +import discord +from discord import Color +from discord.ext import commands +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.converters import UserMentionOrID +from bot.utils import RedisCache +from bot.utils.checks import in_whitelist_check, with_role_check +from bot.utils.messages import send_attachments +from bot.utils.webhooks import send_webhook + +log = logging.getLogger(__name__) + + +class DMRelay(Cog): + """Relay direct messages to and from the bot.""" + + # RedisCache[str, t.Union[discord.User.id, discord.Member.id]] + dm_cache = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.dm_log + self.webhook = None + self.bot.loop.create_task(self.fetch_webhook()) + + @commands.command(aliases=("reply",)) + async def send_dm(self, ctx: commands.Context, member: Optional[UserMentionOrID], *, message: str) -> None: + """ + Allows you to send a DM to a user from the bot. + + If `member` is not provided, it will send to the last user who DM'd the bot. + + This feature should be used extremely sparingly. Use ModMail if you need to have a serious + conversation with a user. This is just for responding to extraordinary DMs, having a little + fun with users, and telling people they are DMing the wrong bot. + + NOTE: This feature will be removed if it is overused. + """ + if not member: + user_id = await self.dm_cache.get("last_user") + member = ctx.guild.get_member(user_id) if user_id else None + + # If we still don't have a Member at this point, give up + if not member: + log.debug("This bot has never gotten a DM, or the RedisCache has been cleared.") + await ctx.message.add_reaction("❌") + return + + try: + await member.send(message) + except discord.errors.Forbidden: + log.debug("User has disabled DMs.") + await ctx.message.add_reaction("❌") + else: + await ctx.message.add_reaction("✅") + self.bot.stats.incr("dm_relay.dm_sent") + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_guild_available() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Relays the message's content and attachments to the dm_log channel.""" + # Only relay DMs from humans + if message.author.bot or message.guild or self.webhook is None: + return + + if message.clean_content: + await send_webhook( + webhook=self.webhook, + content=message.clean_content, + username=f"{message.author.display_name} ({message.author.id})", + avatar_url=message.author.avatar_url + ) + await self.dm_cache.set("last_user", message.author.id) + self.bot.stats.incr("dm_relay.dm_received") + + # Handle any attachments + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (discord.errors.Forbidden, discord.errors.NotFound): + e = discord.Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await send_webhook( + webhook=self.webhook, + embed=e, + username=f"{message.author.display_name} ({message.author.id})", + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception("Failed to send an attachment to the webhook") + + def cog_check(self, ctx: commands.Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_whitelist_check( + ctx, + channels=[constants.Channels.dm_log], + redirect=None, + fail_silently=True, + ) + ] + return all(checks) + + +def setup(bot: Bot) -> None: + """Load the DMRelay cog.""" + bot.add_cog(DMRelay(bot)) diff --git a/bot/exts/off_topic_names.py b/bot/exts/off_topic_names.py deleted file mode 100644 index ce95450e0..000000000 --- a/bot/exts/off_topic_names.py +++ /dev/null @@ -1,162 +0,0 @@ -import asyncio -import difflib -import logging -from datetime import datetime, timedelta - -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES -from bot.converters import OffTopicName -from bot.decorators import with_role -from bot.pagination import LinePaginator - -CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) -log = logging.getLogger(__name__) - - -async def update_names(bot: Bot) -> None: - """Background updater task that performs the daily channel name update.""" - while True: - # Since we truncate the compute timedelta to seconds, we add one second to ensure - # we go past midnight in the `seconds_to_sleep` set below. - today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0) - next_midnight = today_at_midnight + timedelta(days=1) - seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1 - await asyncio.sleep(seconds_to_sleep) - - try: - channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get( - 'bot/off-topic-channel-names', params={'random_items': 3} - ) - except ResponseCodeError as e: - log.error(f"Failed to get new off topic channel names: code {e.response.status}") - continue - channel_0, channel_1, channel_2 = (bot.get_channel(channel_id) for channel_id in CHANNELS) - - await channel_0.edit(name=f'ot0-{channel_0_name}') - await channel_1.edit(name=f'ot1-{channel_1_name}') - await channel_2.edit(name=f'ot2-{channel_2_name}') - log.debug( - "Updated off-topic channel names to" - f" {channel_0_name}, {channel_1_name} and {channel_2_name}" - ) - - -class OffTopicNames(Cog): - """Commands related to managing the off-topic category channel names.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.updater_task = None - - self.bot.loop.create_task(self.init_offtopic_updater()) - - def cog_unload(self) -> None: - """Cancel any running updater tasks on cog unload.""" - if self.updater_task is not None: - self.updater_task.cancel() - - async def init_offtopic_updater(self) -> None: - """Start off-topic channel updating event loop if it hasn't already started.""" - await self.bot.wait_until_guild_available() - if self.updater_task is None: - coro = update_names(self.bot) - self.updater_task = self.bot.loop.create_task(coro) - - @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def otname_group(self, ctx: Context) -> None: - """Add or list items from the off-topic channel name rotation.""" - await ctx.send_help(ctx.command) - - @otname_group.command(name='add', aliases=('a',)) - @with_role(*MODERATION_ROLES) - async def add_command(self, ctx: Context, *, name: OffTopicName) -> None: - """ - Adds a new off-topic name to the rotation. - - The name is not added if it is too similar to an existing name. - """ - existing_names = await self.bot.api_client.get('bot/off-topic-channel-names') - close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) - - if close_match: - match = close_match[0] - log.info( - f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'" - ) - await ctx.send( - f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. " - "Use `!otn forceadd` to override this check." - ) - else: - await self._add_name(ctx, name) - - @otname_group.command(name='forceadd', aliases=('fa',)) - @with_role(*MODERATION_ROLES) - async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None: - """Forcefully adds a new off-topic name to the rotation.""" - await self._add_name(ctx, name) - - async def _add_name(self, ctx: Context, name: str) -> None: - """Adds an off-topic channel name to the site storage.""" - await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name}) - - log.info(f"{ctx.author} added the off-topic channel name '{name}'") - await ctx.send(f":ok_hand: Added `{name}` to the names list.") - - @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) - @with_role(*MODERATION_ROLES) - async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None: - """Removes a off-topic name from the rotation.""" - await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') - - log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") - await ctx.send(f":ok_hand: Removed `{name}` from the names list.") - - @otname_group.command(name='list', aliases=('l',)) - @with_role(*MODERATION_ROLES) - async def list_command(self, ctx: Context) -> None: - """ - Lists all currently known off-topic channel names in a paginator. - - Restricted to Moderator and above to not spoil the surprise. - """ - result = await self.bot.api_client.get('bot/off-topic-channel-names') - lines = sorted(f"• {name}" for name in result) - embed = Embed( - title=f"Known off-topic names (`{len(result)}` total)", - colour=Colour.blue() - ) - if result: - await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - - @otname_group.command(name='search', aliases=('s',)) - @with_role(*MODERATION_ROLES) - async def search_command(self, ctx: Context, *, query: OffTopicName) -> None: - """Search for an off-topic name.""" - result = await self.bot.api_client.get('bot/off-topic-channel-names') - in_matches = {name for name in result if query in name} - close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70) - lines = sorted(f"• {name}" for name in in_matches.union(close_matches)) - embed = Embed( - title="Query results", - colour=Colour.blue() - ) - - if lines: - await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) - else: - embed.description = "Nothing found." - await ctx.send(embed=embed) - - -def setup(bot: Bot) -> None: - """Load the OffTopicNames cog.""" - bot.add_cog(OffTopicNames(bot)) diff --git a/tests/bot/exts/fun/__init__.py b/tests/bot/exts/fun/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/fun/test_duck_pond.py b/tests/bot/exts/fun/test_duck_pond.py new file mode 100644 index 000000000..704b08066 --- /dev/null +++ b/tests/bot/exts/fun/test_duck_pond.py @@ -0,0 +1,548 @@ +import asyncio +import logging +import typing +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +import discord + +from bot import constants +from bot.exts.fun import duck_pond +from tests import base +from tests import helpers + +MODULE_PATH = "bot.exts.fun.duck_pond" + + +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): + """Tests for DuckPond functionality.""" + + @classmethod + def setUpClass(cls): + """Sets up the objects that only have to be initialized once.""" + cls.nonstaff_member = helpers.MockMember(name="Non-staffer") + + cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) + cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) + + cls.checkmark_emoji = "\N{White Heavy Check Mark}" + cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" + cls.unicode_duck_emoji = "\N{Duck}" + cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) + cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) + + def setUp(self): + """Sets up the objects that need to be refreshed before each test.""" + self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) + self.cog = duck_pond.DuckPond(bot=self.bot) + + def test_duck_pond_correctly_initializes(self): + """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" + bot = helpers.MockBot() + cog = MagicMock() + + duck_pond.DuckPond.__init__(cog, bot) + + self.assertEqual(cog.bot, bot) + self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) + bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) + + def test_fetch_webhook_succeeds_without_connectivity_issues(self): + """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" + self.bot.fetch_webhook.return_value = "dummy webhook" + self.cog.webhook_id = 1 + + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + self.assertEqual(self.cog.webhook, "dummy webhook") + + def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): + """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" + self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") + self.cog.webhook_id = 1 + + log = logging.getLogger(MODULE_PATH) + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def test_is_staff_returns_correct_values_based_on_instance_passed(self): + """The `is_staff` method should return correct values based on the instance passed.""" + test_cases = ( + (helpers.MockUser(name="User instance"), False), + (helpers.MockMember(name="Member instance without staff role"), False), + (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) + ) + + for user, expected_return in test_cases: + actual_return = self.cog.is_staff(user) + with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): + """The `has_green_checkmark` method should only return `True` if one is present.""" + test_cases = ( + ( + "No reactions", helpers.MockMessage(), False + ), + ( + "No green check mark reactions", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) + ]), + False + ), + ( + "Green check mark reaction, but not from our bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) + ]), + False + ), + ( + "Green check mark reaction, with one from the bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) + ]), + True + ) + ) + + for description, message, expected_return in test_cases: + actual_return = await self.cog.has_green_checkmark(message) + with self.subTest( + test_case=description, + expected_return=expected_return, + actual_return=actual_return + ): + self.assertEqual(expected_return, actual_return) + + def _get_reaction( + self, + emoji: typing.Union[str, helpers.MockEmoji], + staff: int = 0, + nonstaff: int = 0 + ) -> helpers.MockReaction: + staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] + nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] + return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) + + async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): + """The `count_ducks` method should return the number of unique staffers who gave a duck.""" + test_cases = ( + # Simple test cases + # A message without reactions should return 0 + ( + "No reactions", + helpers.MockMessage(), + 0 + ), + # A message with a non-duck reaction from a non-staffer should return 0 + ( + "Non-duck reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), + 0 + ), + # A message with a non-duck reaction from a staffer should return 0 + ( + "Non-duck reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), + 0 + ), + # A message with a non-duck reaction from a non-staffer and staffer should return 0 + ( + "Non-duck reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a non-staffer should return 0 + ( + "Unicode Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a staffer should return 1 + ( + "Unicode Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), + 1 + ), + # A message with a unicode duck reaction from a non-staffer and staffer should return 1 + ( + "Unicode Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer should return 0 + ( + "Duckpond Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), + 0 + ), + # A message with a duckpond duck reaction from a staffer should return 1 + ( + "Duckpond Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 + ( + "Duckpond Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), + 1 + ), + + # Complex test cases + # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), + 3 + ), + # A staffer with multiple duck reactions only counts once + ( + "Two different duck reactions from the same staffer", + helpers.MockMessage( + reactions=[ + helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), + ] + ), + 1 + ), + # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) + ( + "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), + 0 + ), + # We correctly sum when multiple reactions are provided. + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage( + reactions=[ + self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), + self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), + ] + ), + 3 + 4 + ), + ) + + for description, message, expected_count in test_cases: + actual_count = await self.cog.count_ducks(message) + with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): + self.assertEqual(expected_count, actual_count) + + async def test_relay_message_correctly_relays_content_and_attachments(self): + """The `relay_message` method should correctly relay message content and attachments.""" + send_webhook_path = f"{MODULE_PATH}.send_webhook" + send_attachments_path = f"{MODULE_PATH}.send_attachments" + author = MagicMock( + display_name="x", + avatar_url="https://" + ) + + self.cog.webhook = helpers.MockAsyncWebhook() + + test_values = ( + (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), + (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), + (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), + (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), + ) + + for message, expect_webhook_call, expect_attachment_call in test_values: + with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: + with self.subTest(clean_content=message.clean_content, attachments=message.attachments): + await self.cog.relay_message(message) + + self.assertEqual(expect_webhook_call, send_webhook.called) + self.assertEqual(expect_attachment_call, send_attachments.called) + + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) + async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger(MODULE_PATH) + + for side_effect in side_effects: # pragma: no cover + send_attachments.side_effect = side_effect + with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertNotLogs(logger=log, level=logging.ERROR): + await self.cog.relay_message(message) + + self.assertEqual(send_webhook.call_count, 2) + + @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) + async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger(MODULE_PATH) + + side_effect = discord.HTTPException(MagicMock(), "") + send_attachments.side_effect = side_effect + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + await self.cog.relay_message(message) + + send_webhook.assert_called_once_with( + webhook=self.cog.webhook, + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): + """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" + payload = MagicMock(name=label) + payload.emoji.is_custom_emoji.return_value = is_custom_emoji + payload.emoji.id = id_ + payload.emoji.name = emoji_name + return payload + + async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): + """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" + test_values = ( + # Custom Emojis + ( + self._mock_payload( + label="Custom Duckpond Emoji", + is_custom_emoji=True, + id_=constants.DuckPond.custom_emojis[0], + emoji_name="" + ), + True + ), + ( + self._mock_payload( + label="Custom Non-Duckpond Emoji", + is_custom_emoji=True, + id_=123, + emoji_name="" + ), + False + ), + # Unicode Emojis + ( + self._mock_payload( + label="Unicode Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.unicode_duck_emoji + ), + True + ), + ( + self._mock_payload( + label="Unicode Non-Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.thumbs_up_emoji + ), + False + ), + ) + + for payload, expected_return in test_values: + actual_return = self.cog._payload_has_duckpond_emoji(payload) + with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @patch(f"{MODULE_PATH}.discord.utils.get") + @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) + def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): + """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) + + # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check + utils_get.assert_not_called() + + def _raw_reaction_mocks(self, channel_id, message_id, user_id): + """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" + channel = helpers.MockTextChannel(id=channel_id) + self.bot.get_all_channels.return_value = (channel,) + + message = helpers.MockMessage(id=message_id) + + channel.fetch_message.return_value = message + + member = helpers.MockMember(id=user_id, roles=[self.staff_role]) + message.guild.members = (member,) + + payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) + + return channel, message, member, payload + + async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): + """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" + channel_id = 1234 + message_id = 2345 + user_id = 3456 + + channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + test_cases = ( + ("non-staff member", helpers.MockMember(id=user_id)), + ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), + ) + + payload.emoji = self.duck_pond_emoji + + for description, member in test_cases: + message.guild.members = (member, ) + with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: + checkmark.side_effect = AssertionError( + "Expected method to return before calling `self.has_green_checkmark`." + ) + self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) + + # Check that we did make it past the payload checks + channel.fetch_message.assert_called_once() + channel.fetch_message.reset_mock() + + @patch(f"{MODULE_PATH}.DuckPond.is_staff") + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) + def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): + """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" + channel_id = 31415926535 + message_id = 27182818284 + user_id = 16180339887 + + channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) + payload.emoji.is_custom_emoji.return_value = False + + message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] + + is_staff.return_value = True + count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) + + # Assert that we've made it past `self.is_staff` + is_staff.assert_called_once() + + async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): + """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + + channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) + + payload.emoji = self.duck_pond_emoji + + for duck_count, should_relay in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_relay=should_relay): + await self.cog.on_raw_reaction_add(payload) + + # Confirm that we've made it past counting + count_ducks.assert_called_once() + + # Did we relay a message? + has_relayed = relay_message.called + self.assertEqual(has_relayed, should_relay) + + if should_relay: + relay_message.assert_called_once_with(message) + + async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): + """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" + checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) + + message = helpers.MockMessage(id=1234) + + channel = helpers.MockTextChannel(id=98765) + channel.fetch_message.return_value = message + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) + + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + for duck_count, should_re_add_checkmark in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): + await self.cog.on_raw_reaction_remove(payload) + + # Check if we fetched the message + channel.fetch_message.assert_called_once_with(message.id) + + # Check if we actually counted the number of ducks + count_ducks.assert_called_once_with(message) + + has_re_added_checkmark = message.add_reaction.called + self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) + + if should_re_add_checkmark: + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + message.add_reaction.reset_mock() + + # reset mocks + channel.fetch_message.reset_mock() + message.reset_mock() + + def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): + """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" + channel = helpers.MockTextChannel(id=98765) + + channel.fetch_message.side_effect = AssertionError( + "Expected method to return before calling `channel.fetch_message`" + ) + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) + + channel.fetch_message.assert_not_called() + + +class DuckPondSetupTests(unittest.TestCase): + """Tests setup of the `DuckPond` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = helpers.MockBot() + duck_pond.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/test_duck_pond.py b/tests/bot/exts/test_duck_pond.py deleted file mode 100644 index f6d977482..000000000 --- a/tests/bot/exts/test_duck_pond.py +++ /dev/null @@ -1,548 +0,0 @@ -import asyncio -import logging -import typing -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -import discord - -from bot import constants -from bot.exts import duck_pond -from tests import base -from tests import helpers - -MODULE_PATH = "bot.exts.duck_pond" - - -class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): - """Tests for DuckPond functionality.""" - - @classmethod - def setUpClass(cls): - """Sets up the objects that only have to be initialized once.""" - cls.nonstaff_member = helpers.MockMember(name="Non-staffer") - - cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) - cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) - - cls.checkmark_emoji = "\N{White Heavy Check Mark}" - cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" - cls.unicode_duck_emoji = "\N{Duck}" - cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) - cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) - - def setUp(self): - """Sets up the objects that need to be refreshed before each test.""" - self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) - self.cog = duck_pond.DuckPond(bot=self.bot) - - def test_duck_pond_correctly_initializes(self): - """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" - bot = helpers.MockBot() - cog = MagicMock() - - duck_pond.DuckPond.__init__(cog, bot) - - self.assertEqual(cog.bot, bot) - self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) - bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) - - def test_fetch_webhook_succeeds_without_connectivity_issues(self): - """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" - self.bot.fetch_webhook.return_value = "dummy webhook" - self.cog.webhook_id = 1 - - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - self.assertEqual(self.cog.webhook, "dummy webhook") - - def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): - """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" - self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") - self.cog.webhook_id = 1 - - log = logging.getLogger('bot.exts.duck_pond') - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def test_is_staff_returns_correct_values_based_on_instance_passed(self): - """The `is_staff` method should return correct values based on the instance passed.""" - test_cases = ( - (helpers.MockUser(name="User instance"), False), - (helpers.MockMember(name="Member instance without staff role"), False), - (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) - ) - - for user, expected_return in test_cases: - actual_return = self.cog.is_staff(user) - with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): - """The `has_green_checkmark` method should only return `True` if one is present.""" - test_cases = ( - ( - "No reactions", helpers.MockMessage(), False - ), - ( - "No green check mark reactions", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) - ]), - False - ), - ( - "Green check mark reaction, but not from our bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) - ]), - False - ), - ( - "Green check mark reaction, with one from the bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) - ]), - True - ) - ) - - for description, message, expected_return in test_cases: - actual_return = await self.cog.has_green_checkmark(message) - with self.subTest( - test_case=description, - expected_return=expected_return, - actual_return=actual_return - ): - self.assertEqual(expected_return, actual_return) - - def _get_reaction( - self, - emoji: typing.Union[str, helpers.MockEmoji], - staff: int = 0, - nonstaff: int = 0 - ) -> helpers.MockReaction: - staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] - nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] - return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - - async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): - """The `count_ducks` method should return the number of unique staffers who gave a duck.""" - test_cases = ( - # Simple test cases - # A message without reactions should return 0 - ( - "No reactions", - helpers.MockMessage(), - 0 - ), - # A message with a non-duck reaction from a non-staffer should return 0 - ( - "Non-duck reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), - 0 - ), - # A message with a non-duck reaction from a staffer should return 0 - ( - "Non-duck reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), - 0 - ), - # A message with a non-duck reaction from a non-staffer and staffer should return 0 - ( - "Non-duck reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a non-staffer should return 0 - ( - "Unicode Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a staffer should return 1 - ( - "Unicode Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), - 1 - ), - # A message with a unicode duck reaction from a non-staffer and staffer should return 1 - ( - "Unicode Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer should return 0 - ( - "Duckpond Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), - 0 - ), - # A message with a duckpond duck reaction from a staffer should return 1 - ( - "Duckpond Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 - ( - "Duckpond Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), - 1 - ), - - # Complex test cases - # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), - 3 - ), - # A staffer with multiple duck reactions only counts once - ( - "Two different duck reactions from the same staffer", - helpers.MockMessage( - reactions=[ - helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), - ] - ), - 1 - ), - # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) - ( - "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), - 0 - ), - # We correctly sum when multiple reactions are provided. - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage( - reactions=[ - self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), - self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), - ] - ), - 3 + 4 - ), - ) - - for description, message, expected_count in test_cases: - actual_count = await self.cog.count_ducks(message) - with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): - self.assertEqual(expected_count, actual_count) - - async def test_relay_message_correctly_relays_content_and_attachments(self): - """The `relay_message` method should correctly relay message content and attachments.""" - send_webhook_path = f"{MODULE_PATH}.send_webhook" - send_attachments_path = f"{MODULE_PATH}.send_attachments" - author = MagicMock( - display_name="x", - avatar_url="https://" - ) - - self.cog.webhook = helpers.MockAsyncWebhook() - - test_values = ( - (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), - (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), - (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), - (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), - ) - - for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: - with self.subTest(clean_content=message.clean_content, attachments=message.attachments): - await self.cog.relay_message(message) - - self.assertEqual(expect_webhook_call, send_webhook.called) - self.assertEqual(expect_attachment_call, send_attachments.called) - - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.exts.duck_pond") - - for side_effect in side_effects: # pragma: no cover - send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertNotLogs(logger=log, level=logging.ERROR): - await self.cog.relay_message(message) - - self.assertEqual(send_webhook.call_count, 2) - - @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.exts.duck_pond") - - side_effect = discord.HTTPException(MagicMock(), "") - send_attachments.side_effect = side_effect - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - await self.cog.relay_message(message) - - send_webhook.assert_called_once_with( - webhook=self.cog.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): - """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" - payload = MagicMock(name=label) - payload.emoji.is_custom_emoji.return_value = is_custom_emoji - payload.emoji.id = id_ - payload.emoji.name = emoji_name - return payload - - async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): - """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" - test_values = ( - # Custom Emojis - ( - self._mock_payload( - label="Custom Duckpond Emoji", - is_custom_emoji=True, - id_=constants.DuckPond.custom_emojis[0], - emoji_name="" - ), - True - ), - ( - self._mock_payload( - label="Custom Non-Duckpond Emoji", - is_custom_emoji=True, - id_=123, - emoji_name="" - ), - False - ), - # Unicode Emojis - ( - self._mock_payload( - label="Unicode Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.unicode_duck_emoji - ), - True - ), - ( - self._mock_payload( - label="Unicode Non-Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.thumbs_up_emoji - ), - False - ), - ) - - for payload, expected_return in test_values: - actual_return = self.cog._payload_has_duckpond_emoji(payload) - with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - @patch(f"{MODULE_PATH}.discord.utils.get") - @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) - def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): - """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) - - # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check - utils_get.assert_not_called() - - def _raw_reaction_mocks(self, channel_id, message_id, user_id): - """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" - channel = helpers.MockTextChannel(id=channel_id) - self.bot.get_all_channels.return_value = (channel,) - - message = helpers.MockMessage(id=message_id) - - channel.fetch_message.return_value = message - - member = helpers.MockMember(id=user_id, roles=[self.staff_role]) - message.guild.members = (member,) - - payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) - - return channel, message, member, payload - - async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): - """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" - channel_id = 1234 - message_id = 2345 - user_id = 3456 - - channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - test_cases = ( - ("non-staff member", helpers.MockMember(id=user_id)), - ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), - ) - - payload.emoji = self.duck_pond_emoji - - for description, member in test_cases: - message.guild.members = (member, ) - with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: - checkmark.side_effect = AssertionError( - "Expected method to return before calling `self.has_green_checkmark`." - ) - self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) - - # Check that we did make it past the payload checks - channel.fetch_message.assert_called_once() - channel.fetch_message.reset_mock() - - @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) - def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): - """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" - channel_id = 31415926535 - message_id = 27182818284 - user_id = 16180339887 - - channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) - payload.emoji.is_custom_emoji.return_value = False - - message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] - - is_staff.return_value = True - count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) - - # Assert that we've made it past `self.is_staff` - is_staff.assert_called_once() - - async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): - """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - - channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) - - payload.emoji = self.duck_pond_emoji - - for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_relay=should_relay): - await self.cog.on_raw_reaction_add(payload) - - # Confirm that we've made it past counting - count_ducks.assert_called_once() - - # Did we relay a message? - has_relayed = relay_message.called - self.assertEqual(has_relayed, should_relay) - - if should_relay: - relay_message.assert_called_once_with(message) - - async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): - """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" - checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) - - message = helpers.MockMessage(id=1234) - - channel = helpers.MockTextChannel(id=98765) - channel.fetch_message.return_value = message - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) - - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): - await self.cog.on_raw_reaction_remove(payload) - - # Check if we fetched the message - channel.fetch_message.assert_called_once_with(message.id) - - # Check if we actually counted the number of ducks - count_ducks.assert_called_once_with(message) - - has_re_added_checkmark = message.add_reaction.called - self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) - - if should_re_add_checkmark: - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - message.add_reaction.reset_mock() - - # reset mocks - channel.fetch_message.reset_mock() - message.reset_mock() - - def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): - """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" - channel = helpers.MockTextChannel(id=98765) - - channel.fetch_message.side_effect = AssertionError( - "Expected method to return before calling `channel.fetch_message`" - ) - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) - - channel.fetch_message.assert_not_called() - - -class DuckPondSetupTests(unittest.TestCase): - """Tests setup of the `DuckPond` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = helpers.MockBot() - duck_pond.setup(bot) - bot.add_cog.assert_called_once() -- cgit v1.2.3 From c22561d2f527666def2e201e655f5ac767d95212 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sun, 20 Sep 2020 22:08:16 +0300 Subject: Try to fix location from where post infraction test get ID --- tests/bot/cogs/moderation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index c9a4e4040..02a18bbca 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -306,7 +306,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): """Should return response from POST request if there are no errors.""" now = datetime.now() payload = { - "actor": self.ctx.message.author.id, + "actor": self.ctx.author.id, "hidden": True, "reason": "Test reason", "type": "ban", @@ -344,7 +344,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" payload = { - "actor": self.ctx.message.author.id, + "actor": self.ctx.author.id, "hidden": False, "reason": "Test reason", "type": "mute", -- cgit v1.2.3 From a8b1c72d379d187b6266b6b38f9e85e594f39b11 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sun, 20 Sep 2020 22:22:37 +0300 Subject: Apply recent changes of notify infraction to test --- tests/bot/cogs/moderation/test_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/cogs/moderation/test_utils.py index 02a18bbca..5f649e136 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/cogs/moderation/test_utils.py @@ -1,4 +1,3 @@ -import textwrap import unittest from collections import namedtuple from datetime import datetime @@ -211,8 +210,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( type="Mute", expires="N/A", - reason=textwrap.shorten("foo bar" * 4000, 1000, placeholder="...") - ), + reason="foo bar" * 4000 + )[:2045] + "...", colour=Colours.soft_red, url=utils.RULES_URL ).set_author( -- cgit v1.2.3 From e68fad590415479f7b53545bf942d9f3b25ad1d3 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Mon, 21 Sep 2020 20:41:47 +0300 Subject: Fix end of file of mod utils tests --- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 674993862..5f649e136 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -356,4 +356,4 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") self.assertEqual(actual, "foo") self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) - post_user_mock.assert_awaited_once_with(self.ctx, self.user) \ No newline at end of file + post_user_mock.assert_awaited_once_with(self.ctx, self.user) -- cgit v1.2.3 From cebee6c45f54fab1ab965cc0c764d5f478fc4cdd Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Mon, 21 Sep 2020 20:52:02 +0300 Subject: Fix import path of mod utils --- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5f649e136..412f4398e 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, call, patch from discord import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError -from bot.cogs.moderation import utils +from bot.exts.moderation.infraction import _utils as utils from bot.constants import Colours, Icons from tests.helpers import MockBot, MockContext, MockMember, MockUser -- cgit v1.2.3 From 82b6af9ef458cea71704c2aff0d2ef10e8b623be Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Mon, 21 Sep 2020 20:59:00 +0300 Subject: Fix import order of mod utils tests --- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 412f4398e..fbbe112de 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -6,8 +6,8 @@ from unittest.mock import AsyncMock, MagicMock, call, patch from discord import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError -from bot.exts.moderation.infraction import _utils as utils from bot.constants import Colours, Icons +from bot.exts.moderation.infraction import _utils as utils from tests.helpers import MockBot, MockContext, MockMember, MockUser -- cgit v1.2.3 From 5c36823f8b47589cd5ed6e7bcde72ac6977921eb Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Mon, 21 Sep 2020 21:45:27 +0300 Subject: Fix mod utils tests patch locations --- tests/bot/exts/moderation/infraction/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index fbbe112de..5b62463e0 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -123,7 +123,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): else: self.ctx.send.assert_not_awaited() - @patch("bot.cogs.moderation.utils.send_private_embed") + @patch("bot.exts.moderation.infraction._utils.send_private_embed") async def test_notify_infraction(self, send_private_embed_mock): """ Should send an embed of a certain format as a DM and return `True` if DM successful. @@ -238,7 +238,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.assert_awaited_once_with(case["args"][0], embed) - @patch("bot.cogs.moderation.utils.send_private_embed") + @patch("bot.exts.moderation.infraction._utils.send_private_embed") async def test_notify_pardon(self, send_private_embed_mock): """Should send an embed of a certain format as a DM and return `True` if DM successful.""" test_case = namedtuple("test_case", ["args", "icon", "send_result"]) @@ -330,7 +330,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.assertTrue("500" in self.ctx.send.call_args[0][0]) - @patch("bot.cogs.moderation.utils.post_user", return_value=None) + @patch("bot.exts.moderation.infraction._utils.post_user", return_value=None) async def test_user_not_found_none_post_infraction(self, post_user_mock): """Should abort and return `None` when a new user fails to be posted.""" self.bot.api_client.post.side_effect = ResponseCodeError(MagicMock(status=400), {"user": "foo"}) @@ -339,7 +339,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.assertIsNone(actual) post_user_mock.assert_awaited_once_with(self.ctx, self.user) - @patch("bot.cogs.moderation.utils.post_user", return_value="bar") + @patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar") async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" payload = { -- cgit v1.2.3