From 4890cc5ba43ad73229ce4d2fe240acaf39194edb Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Tue, 14 Apr 2020 08:30:18 +0300 Subject: Created tests for `bot.cogs.logging` connected message. --- tests/bot/cogs/test_logging.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/bot/cogs/test_logging.py (limited to 'tests') diff --git a/tests/bot/cogs/test_logging.py b/tests/bot/cogs/test_logging.py new file mode 100644 index 000000000..ba98a5a56 --- /dev/null +++ b/tests/bot/cogs/test_logging.py @@ -0,0 +1,42 @@ +import unittest +from unittest.mock import patch + +from bot import constants +from bot.cogs.logging import Logging +from tests.helpers import MockBot, MockTextChannel + + +class LoggingTests(unittest.IsolatedAsyncioTestCase): + """Test cases for connected login.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Logging(self.bot) + self.dev_log = MockTextChannel(id=1234, name="dev-log") + + @patch("bot.cogs.logging.DEBUG_MODE", False) + async def test_debug_mode_false(self): + """Should send connected message to dev-log.""" + self.bot.get_channel.return_value = self.dev_log + + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) + + embed = self.dev_log.send.call_args[1]['embed'] + self.dev_log.send.assert_awaited_once_with(embed=embed) + + self.assertEqual(embed.description, "Connected!") + self.assertEqual(embed.author.name, "Python Bot") + self.assertEqual(embed.author.url, "https://github.com/python-discord/bot") + self.assertEqual( + embed.author.icon_url, + "https://raw.githubusercontent.com/python-discord/branding/master/logos/logo_circle/logo_circle_large.png" + ) + + @patch("bot.cogs.logging.DEBUG_MODE", True) + async def test_debug_mode_true(self): + """Should not send anything to dev-log.""" + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_not_called() -- cgit v1.2.3 From 5b11b248b945cd2a732c6d8d430d117fc062cc8d Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Thu, 7 May 2020 16:46:32 +0200 Subject: Remove tests from moved function. --- tests/bot/cogs/test_snekbox.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 1dec0ccaf..d32d80ead 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -1,5 +1,4 @@ import asyncio -import logging import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch @@ -53,20 +52,6 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): raise_for_status=True ) - async def test_upload_output_gracefully_fallback_if_exception_during_request(self): - """Output upload gracefully fallback if the upload fail.""" - resp = MagicMock() - resp.json = AsyncMock(side_effect=Exception) - self.bot.http_session.post().__aenter__.return_value = resp - - log = logging.getLogger("bot.cogs.snekbox") - with self.assertLogs(logger=log, level='ERROR'): - await self.cog.upload_output('My awesome output!') - - async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): - """Output upload gracefully fallback if there is no key entry in the response body.""" - self.assertEqual((await self.cog.upload_output('My awesome output!')), None) - def test_prepare_input(self): cases = ( ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), -- cgit v1.2.3 From 14c670dfa87e142e24c027e2976fa02b07c4d7ac Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Thu, 7 May 2020 17:11:56 +0200 Subject: Adjust behaviour for new func usage. --- tests/bot/cogs/test_snekbox.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index d32d80ead..f4c13fc43 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -35,21 +35,12 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) self.assertEqual(result, "too long to upload") - async def test_upload_output(self): + @patch("bot.cogs.snekbox.send_to_paste_service") + async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - key = "MarkDiamond" - resp = MagicMock() - resp.json = AsyncMock(return_value={"key": key}) - self.bot.http_session.post().__aenter__.return_value = resp - - self.assertEqual( - await self.cog.upload_output("My awesome output"), - constants.URLs.paste_service.format(key=key) - ) - self.bot.http_session.post.assert_called_with( - constants.URLs.paste_service.format(key="documents"), - data="My awesome output", - raise_for_status=True + await self.cog.upload_output("Test output.") + mock_paste_util.assert_called_once_with( + self.bot.http_session, "Test output.", extension="txt" ) def test_prepare_input(self): -- cgit v1.2.3 From 5d96e96a2e8982ec57c1a19d1a085ceccd35a6d7 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Fri, 8 May 2020 01:38:14 +0200 Subject: Add tests for `send_to_paste_service`. --- tests/bot/utils/test_init.py | 74 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/bot/utils/test_init.py (limited to 'tests') diff --git a/tests/bot/utils/test_init.py b/tests/bot/utils/test_init.py new file mode 100644 index 000000000..f3a8f5939 --- /dev/null +++ b/tests/bot/utils/test_init.py @@ -0,0 +1,74 @@ +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +from aiohttp import ClientConnectorError + +from bot.utils import FAILED_REQUEST_ATTEMPTS, send_to_paste_service + + +class PasteTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.http_session = MagicMock() + + @patch("bot.utils.URLs.paste_service", "https://paste_service.com/{key}") + async def test_url_and_sent_contents(self): + """Correct url was used and post was called with expected data.""" + response = MagicMock( + json=AsyncMock(return_value={"key": ""}) + ) + self.http_session.post().__aenter__.return_value = response + self.http_session.post.reset_mock() + await send_to_paste_service(self.http_session, "Content") + self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") + + @patch("bot.utils.URLs.paste_service", "https://paste_service.com/{key}") + async def test_paste_returns_correct_url_on_success(self): + """Url with specified extension is returned on successful requests.""" + key = "paste_key" + test_cases = ( + (f"https://paste_service.com/{key}.txt", "txt"), + (f"https://paste_service.com/{key}.py", "py"), + (f"https://paste_service.com/{key}", ""), + ) + response = MagicMock( + json=AsyncMock(return_value={"key": key}) + ) + self.http_session.post().__aenter__.return_value = response + + for expected_output, extension in test_cases: + with self.subTest(msg=f"Send contents with extension {repr(extension)}"): + self.assertEqual( + await send_to_paste_service(self.http_session, "", extension=extension), + expected_output + ) + + async def test_request_repeated_on_json_errors(self): + """Json with error message and invalid json are handled as errors and requests repeated.""" + test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) + self.http_session.post().__aenter__.return_value = response = MagicMock() + self.http_session.post.reset_mock() + + for error_json in test_cases: + with self.subTest(error_json=error_json): + response.json = AsyncMock(return_value=error_json) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + self.http_session.post.reset_mock() + + async def test_request_repeated_on_connection_errors(self): + """Requests are repeated in the case of connection errors.""" + self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + async def test_general_error_handled_and_request_repeated(self): + """All `Exception`s are handled, logged and request repeated.""" + self.http_session.post = MagicMock(side_effect=Exception) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertLogs("bot.utils", logging.ERROR) + self.assertIsNone(result) -- cgit v1.2.3 From d193a93828582965eb361dc6f3185291fff649a7 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 14:11:39 -0700 Subject: Test on_message_edit of token remover uses on_message --- tests/bot/cogs/test_token_remover.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 33d1ec170..e7b5a9bea 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,6 +1,7 @@ import asyncio import logging import unittest +from unittest import mock from unittest.mock import AsyncMock, MagicMock from discord import Colour @@ -14,7 +15,7 @@ from bot.constants import Channels, Colours, Event, Icons from tests.helpers import MockBot, MockMessage -class TokenRemoverTests(unittest.TestCase): +class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): """Tests the `TokenRemover` cog.""" def setUp(self): @@ -58,6 +59,13 @@ class TokenRemoverTests(unittest.TestCase): self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) self.bot.get_cog.assert_called_once_with('ModLog') + async def test_on_message_edit_uses_on_message(self): + """The edit listener should delegate handling of the message to the normal listener.""" + self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) + + await self.cog.on_message_edit(MockMessage(), self.msg) + self.cog.on_message.assert_awaited_once_with(self.msg) + def test_ignores_bot_messages(self): """When the message event handler is called with a bot message, nothing is done.""" self.msg.author.bot = True @@ -77,7 +85,7 @@ class TokenRemoverTests(unittest.TestCase): for content in ('foo.bar.baz', 'x.y.'): with self.subTest(content=content): self.msg.content = content - coroutine = self.cog.on_message(self.msg) + coroutine = self.cog.is_maybe_token(self.msg) self.assertIsNone(asyncio.run(coroutine)) def test_censors_valid_tokens(self): -- cgit v1.2.3 From 0bfd003dbfc5919220129f984dc043421e535f8c Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 14:38:12 -0700 Subject: Add a test helper function to patch multiple attributes with autospecs This helper reduces redundancy/boilerplate by setting default values. It also has the consequence of shortening the length of the invocation, which makes it faster to use and easier to read. --- tests/helpers.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 2b79a6c2a..d444cc49d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -23,6 +23,15 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) +def autospec(target, *attributes: str, **kwargs) -> unittest.mock._patch: + """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" + # Caller's kwargs should take priority and overwrite the defaults. + kwargs = {'spec_set': True, 'autospec': True, **kwargs} + attributes = {attribute: unittest.mock.DEFAULT for attribute in attributes} + + return unittest.mock.patch.multiple(target, **attributes, **kwargs) + + class HashableMixin(discord.mixins.EqualityComparable): """ Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. -- cgit v1.2.3 From e8bd69a6c556d78eca1a1eb2adfa26248273a1cd Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 14:42:07 -0700 Subject: Test token remover takes action if a token is found --- tests/bot/cogs/test_token_remover.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index e7b5a9bea..e0ec67684 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -12,7 +12,7 @@ from bot.cogs.token_remover import ( setup as setup_cog, ) from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import MockBot, MockMessage +from tests.helpers import MockBot, MockMessage, autospec class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @@ -66,6 +66,18 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): await self.cog.on_message_edit(MockMessage(), self.msg) self.cog.on_message.assert_awaited_once_with(self.msg) + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_takes_action(self, find_token_in_message, take_action): + """Should take action if a valid token is found when a message is sent.""" + cog = TokenRemover(self.bot) + found_token = "foobar" + find_token_in_message.return_value = found_token + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_awaited_once_with(cog, self.msg, found_token) + def test_ignores_bot_messages(self): """When the message event handler is called with a bot message, nothing is done.""" self.msg.author.bot = True -- cgit v1.2.3 From 4cf7996a1d4630ccb05f57569ca62b1798dc7a93 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 14:44:54 -0700 Subject: Test token remover skips messages without tokens --- tests/bot/cogs/test_token_remover.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index e0ec67684..2b377e221 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -78,6 +78,17 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): find_token_in_message.assert_called_once_with(self.msg) take_action.assert_awaited_once_with(cog, self.msg, found_token) + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): + """Shouldn't take action if a valid token isn't found when a message is sent.""" + cog = TokenRemover(self.bot) + find_token_in_message.return_value = False + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_not_awaited() + def test_ignores_bot_messages(self): """When the message event handler is called with a bot message, nothing is done.""" self.msg.author.bot = True -- cgit v1.2.3 From 593e09299c6e4115d41bfd5b074785a5e304a8d0 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 15:41:14 -0700 Subject: Allow using arbitrary parameter names with the autospec decorator This gives the caller more flexibility. Sometimes attribute names are too long or they don't follow a naming scheme accepted by the linter. --- tests/helpers.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index d444cc49d..1ab8b455f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -24,12 +24,25 @@ for logger in logging.Logger.manager.loggerDict.values(): def autospec(target, *attributes: str, **kwargs) -> unittest.mock._patch: - """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" + """ + Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. + + To allow for arbitrary parameter names to be used by the decorated function, the patchers have + no attribute names associated with them. As a consequence, it will not be possible to retrieve + mocks by their attribute names when using this as a context manager, + """ # Caller's kwargs should take priority and overwrite the defaults. kwargs = {'spec_set': True, 'autospec': True, **kwargs} attributes = {attribute: unittest.mock.DEFAULT for attribute in attributes} - return unittest.mock.patch.multiple(target, **attributes, **kwargs) + patcher = unittest.mock.patch.multiple(target, **attributes, **kwargs) + + # Unset attribute names to allow arbitrary parameter names for the decorator function. + patcher.attribute_name = None + for additional_patcher in patcher.additional_patchers: + additional_patcher.attribute_name = None + + return patcher class HashableMixin(discord.mixins.EqualityComparable): -- cgit v1.2.3 From b0dd290710799c342240d066abaebbe9e6940b54 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 15:09:22 -0700 Subject: Fix test for token remover ignoring bot messages It's not possible to test this via asserting the return value of `on_message` since it never returns anything. Instead, the actual relevant unit, `find_token_in_message,` should be tested. --- tests/bot/cogs/test_token_remover.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 2b377e221..e8b641101 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -89,11 +89,16 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): find_token_in_message.assert_called_once_with(self.msg) take_action.assert_not_awaited() - def test_ignores_bot_messages(self): - """When the message event handler is called with a bot message, nothing is done.""" + @autospec("bot.cogs.token_remover", "TOKEN_RE") + def test_find_token_ignores_bot_messages(self, token_re): + """The token finder should ignore messages authored by bots.""" + cog = TokenRemover(self.bot) self.msg.author.bot = True - coroutine = self.cog.on_message(self.msg) - self.assertIsNone(asyncio.run(coroutine)) + + return_value = cog.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.findall.assert_not_called() def test_ignores_messages_without_tokens(self): """Messages without anything looking like a token are ignored.""" -- cgit v1.2.3 From 52f0f8a29d7f239c961beaa81881bf4b09da4749 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 15:53:06 -0700 Subject: Test `find_token_in_message` returns None if no matches found --- tests/bot/cogs/test_token_remover.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index e8b641101..5932cf4f0 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -100,6 +100,20 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(return_value) token_re.findall.assert_not_called() + @autospec(TokenRemover, "is_maybe_token") + @autospec("bot.cogs.token_remover", "TOKEN_RE") + def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): + """None should be returned if the regex matches no tokens in a message.""" + cog = TokenRemover(self.bot) + token_re.findall.return_value = () + self.msg.content = "foobar" + + return_value = cog.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.findall.assert_called_once_with(self.msg.content) + is_maybe_token.assert_not_called() + def test_ignores_messages_without_tokens(self): """Messages without anything looking like a token are ignored.""" for content in ('', 'lemon wins'): -- cgit v1.2.3 From cf658bd58559b2683527443f2908257f197ef0bb Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 16:06:47 -0700 Subject: Test `find_token_in_message` returns the found token --- tests/bot/cogs/test_token_remover.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 5932cf4f0..2b946778b 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -114,6 +114,30 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): token_re.findall.assert_called_once_with(self.msg.content) is_maybe_token.assert_not_called() + @autospec(TokenRemover, "is_maybe_token") + @autospec("bot.cogs.token_remover", "TOKEN_RE") + def test_find_token_returns_found_token(self, token_re, is_maybe_token): + """The found token should be returned.""" + true_index = 1 + matches = ("foo", "bar", "baz") + side_effects = [False] * len(matches) + side_effects[true_index] = True + + cog = TokenRemover(self.bot) + self.msg.content = "foobar" + token_re.findall.return_value = matches + is_maybe_token.side_effect = side_effects + + return_value = cog.find_token_in_message(self.msg) + + self.assertEqual(return_value, matches[true_index]) + token_re.findall.assert_called_once_with(self.msg.content) + + # assert_has_calls isn't used cause it'd allow for extra calls before or after. + # The function should short-circuit, so nothing past true_index should have been used. + calls = [mock.call(match) for match in matches[:true_index + 1]] + self.assertEqual(is_maybe_token.mock_calls, calls) + def test_ignores_messages_without_tokens(self): """Messages without anything looking like a token are ignored.""" for content in ('', 'lemon wins'): -- cgit v1.2.3 From f92bc80d6bddb5c57c190187adaa528ae44536f6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 16:25:14 -0700 Subject: Test token regex doesn't match invalid tokens --- tests/bot/cogs/test_token_remover.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 2b946778b..b67602eb9 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -8,6 +8,7 @@ from discord import Colour from bot.cogs.token_remover import ( DELETION_MESSAGE_TEMPLATE, + TOKEN_RE, TokenRemover, setup as setup_cog, ) @@ -138,13 +139,30 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): calls = [mock.call(match) for match in matches[:true_index + 1]] self.assertEqual(is_maybe_token.mock_calls, calls) - def test_ignores_messages_without_tokens(self): - """Messages without anything looking like a token are ignored.""" - for content in ('', 'lemon wins'): - with self.subTest(content=content): - self.msg.content = content - coroutine = self.cog.on_message(self.msg) - self.assertIsNone(asyncio.run(coroutine)) + def test_regex_invalid_tokens(self): + """Messages without anything looking like a token are not matched.""" + tokens = ( + "", + "lemon wins", + "..", + "x.y", + "x.y.", + ".y.z", + ".y.", + "..z", + "x..z", + " . . ", + "\n.\n.\n", + "'.'.'", + '"."."', + "(.(.(", + ").).)" + ) + + for token in tokens: + with self.subTest(token=token): + results = TOKEN_RE.findall(token) + self.assertEquals(len(results), 0) def test_ignores_messages_with_invalid_tokens(self): """Messages with values that are invalid tokens are ignored.""" -- cgit v1.2.3 From 34b836a8eba0f006c77a7b3f48f7ab14c37d31ee Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 17:47:09 -0700 Subject: Fix autospec decorator when used with multiple attributes The original approach of messing with the `attribute_name` didn't work for reasons I won't discuss here (would require knowledge of patcher internals). The new approach doesn't use patch.multiple but mimics it by applying multiple patch decorators to the function. As a consequence, this can no longer be used as a context manager. --- tests/helpers.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 1ab8b455f..dfbe539ec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -24,25 +24,21 @@ for logger in logging.Logger.manager.loggerDict.values(): def autospec(target, *attributes: str, **kwargs) -> unittest.mock._patch: - """ - Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. - - To allow for arbitrary parameter names to be used by the decorated function, the patchers have - no attribute names associated with them. As a consequence, it will not be possible to retrieve - mocks by their attribute names when using this as a context manager, - """ + """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" # Caller's kwargs should take priority and overwrite the defaults. kwargs = {'spec_set': True, 'autospec': True, **kwargs} - attributes = {attribute: unittest.mock.DEFAULT for attribute in attributes} - - patcher = unittest.mock.patch.multiple(target, **attributes, **kwargs) - - # Unset attribute names to allow arbitrary parameter names for the decorator function. - patcher.attribute_name = None - for additional_patcher in patcher.additional_patchers: - additional_patcher.attribute_name = None - return patcher + # Import the target if it's a string. + # This is to support both object and string targets like patch.multiple. + if type(target) is str: + target = unittest.mock._importer(target) + + def decorator(func): + for attribute in attributes: + patcher = unittest.mock.patch.object(target, attribute, **kwargs) + func = patcher(func) + return func + return decorator class HashableMixin(discord.mixins.EqualityComparable): -- cgit v1.2.3 From 834bd543d1d301bb853e713560a7447dc75f1ab8 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 17:53:40 -0700 Subject: Test `is_maybe_token` returns False for missing parts In practice, this won't ever happen since the regex wouldn't match strings with missing parts. However, the function does check it so may as well test it. It's not necessarily bound to always use inputs from the regex either I suppose. --- tests/bot/cogs/test_token_remover.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index b67602eb9..9e1d96a37 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -164,6 +164,16 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = TOKEN_RE.findall(token) self.assertEquals(len(results), 0) + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): + """False should be returned for tokens which do not have all 3 parts.""" + cog = TokenRemover(self.bot) + return_value = cog.is_maybe_token("x.y") + + self.assertFalse(return_value) + valid_user.assert_not_called() + valid_time.assert_not_called() + def test_ignores_messages_with_invalid_tokens(self): """Messages with values that are invalid tokens are ignored.""" for content in ('foo.bar.baz', 'x.y.'): -- cgit v1.2.3 From ab5d194b90a7e068c8ab7171939f471e252ee073 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 18:11:31 -0700 Subject: Test is_maybe_token --- tests/bot/cogs/test_token_remover.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 9e1d96a37..85bbbdf6b 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -174,13 +174,30 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): valid_user.assert_not_called() valid_time.assert_not_called() - def test_ignores_messages_with_invalid_tokens(self): - """Messages with values that are invalid tokens are ignored.""" - for content in ('foo.bar.baz', 'x.y.'): - with self.subTest(content=content): - self.msg.content = content - coroutine = self.cog.is_maybe_token(self.msg) - self.assertIsNone(asyncio.run(coroutine)) + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + def test_is_maybe_token(self, valid_user, valid_time): + """Should return True if the user ID and timestamp are valid or return False otherwise.""" + cog = TokenRemover(self.bot) + subtests = ( + (False, True, False), + (True, False, False), + (True, True, True), + ) + + for user_return, time_return, expected in subtests: + valid_user.reset_mock() + valid_time.reset_mock() + + with self.subTest(user_return=user_return, time_return=time_return, expected=expected): + valid_user.return_value = user_return + valid_time.return_value = time_return + + actual = cog.is_maybe_token("x.y.z") + self.assertIs(actual, expected) + + valid_user.assert_called_once_with("x") + if user_return: + valid_time.assert_called_once_with("y") def test_censors_valid_tokens(self): """Valid tokens are censored.""" -- cgit v1.2.3 From 4b6fde69a7e193382701dccf80a5471ea7ccea72 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 18:22:31 -0700 Subject: Test token regex matches valid tokens --- tests/bot/cogs/test_token_remover.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 85bbbdf6b..7310b4637 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -164,6 +164,27 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = TOKEN_RE.findall(token) self.assertEquals(len(results), 0) + def test_regex_valid_tokens(self): + """Messages that look like tokens should be matched.""" + # Don't worry, the token's been invalidated. + tokens = ( + "x1.y2.z_3", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8" + ) + + for token in tokens: + with self.subTest(token=token): + results = TOKEN_RE.findall(token) + self.assertIn(token, results) + + def test_regex_matches_multiple_valid(self): + """Should support multiple matches in the middle of a string.""" + tokens = ["x.y.z", "a.b.c"] + message = f"garbage {tokens[0]} hello {tokens[1]} world" + + results = TOKEN_RE.findall(message) + self.assertEquals(tokens, results) + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): """False should be returned for tokens which do not have all 3 parts.""" -- cgit v1.2.3 From d8d8e144adfe4c2de15dbbf4346e2eec548a9f67 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 18:28:06 -0700 Subject: Correct the return type annotation for the autospec decorator --- tests/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index dfbe539ec..3cd8a63c0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,7 +4,7 @@ import collections import itertools import logging import unittest.mock -from typing import Iterable, Optional +from typing import Callable, Iterable, Optional import discord from discord.ext.commands import Context @@ -23,7 +23,7 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -def autospec(target, *attributes: str, **kwargs) -> unittest.mock._patch: +def autospec(target, *attributes: str, **kwargs) -> Callable: """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" # Caller's kwargs should take priority and overwrite the defaults. kwargs = {'spec_set': True, 'autospec': True, **kwargs} -- cgit v1.2.3 From 5b9bf9aba686f570322cb9996dd35d3ab669a162 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 10:26:16 -0700 Subject: Avoid instantiating the cog when testing static/class methods --- tests/bot/cogs/test_token_remover.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 7310b4637..6a8247070 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -93,10 +93,9 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec("bot.cogs.token_remover", "TOKEN_RE") def test_find_token_ignores_bot_messages(self, token_re): """The token finder should ignore messages authored by bots.""" - cog = TokenRemover(self.bot) self.msg.author.bot = True - return_value = cog.find_token_in_message(self.msg) + return_value = TokenRemover.find_token_in_message(self.msg) self.assertIsNone(return_value) token_re.findall.assert_not_called() @@ -105,11 +104,10 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec("bot.cogs.token_remover", "TOKEN_RE") def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): """None should be returned if the regex matches no tokens in a message.""" - cog = TokenRemover(self.bot) token_re.findall.return_value = () self.msg.content = "foobar" - return_value = cog.find_token_in_message(self.msg) + return_value = TokenRemover.find_token_in_message(self.msg) self.assertIsNone(return_value) token_re.findall.assert_called_once_with(self.msg.content) @@ -124,12 +122,11 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): side_effects = [False] * len(matches) side_effects[true_index] = True - cog = TokenRemover(self.bot) self.msg.content = "foobar" token_re.findall.return_value = matches is_maybe_token.side_effect = side_effects - return_value = cog.find_token_in_message(self.msg) + return_value = TokenRemover.find_token_in_message(self.msg) self.assertEqual(return_value, matches[true_index]) token_re.findall.assert_called_once_with(self.msg.content) @@ -188,8 +185,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): """False should be returned for tokens which do not have all 3 parts.""" - cog = TokenRemover(self.bot) - return_value = cog.is_maybe_token("x.y") + return_value = TokenRemover.is_maybe_token("x.y") self.assertFalse(return_value) valid_user.assert_not_called() @@ -198,7 +194,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token(self, valid_user, valid_time): """Should return True if the user ID and timestamp are valid or return False otherwise.""" - cog = TokenRemover(self.bot) subtests = ( (False, True, False), (True, False, False), @@ -213,7 +208,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): valid_user.return_value = user_return valid_time.return_value = time_return - actual = cog.is_maybe_token("x.y.z") + actual = TokenRemover.is_maybe_token("x.y.z") self.assertIs(actual, expected) valid_user.assert_called_once_with("x") -- cgit v1.2.3 From 2127239840085ba523d411899e0b7a188530df07 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 10:33:05 -0700 Subject: Simplify token remover's message mock * Rely on default values for the author * Set the content to a non-empty string --- tests/bot/cogs/test_token_remover.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 6a8247070..5ca863926 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -26,14 +26,10 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.bot.get_cog.return_value.send_log_message = AsyncMock() self.cog = TokenRemover(bot=self.bot) - self.msg = MockMessage(id=555, content='') - self.msg.author.__str__ = MagicMock() - self.msg.author.__str__.return_value = 'lemon' - self.msg.author.bot = False - self.msg.author.avatar_url_as.return_value = 'picture-lemon.png' - self.msg.author.id = 42 - self.msg.author.mention = '@lemon' + self.msg = MockMessage(id=555, content="hello world") self.msg.channel.mention = "#lemonade-stand" + self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + self.msg.author.avatar_url_as.return_value = "picture-lemon.png" def test_is_valid_user_id_is_true_for_numeric_content(self): """A string decoding to numeric characters is a valid user ID.""" @@ -105,7 +101,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): """None should be returned if the regex matches no tokens in a message.""" token_re.findall.return_value = () - self.msg.content = "foobar" return_value = TokenRemover.find_token_in_message(self.msg) @@ -122,7 +117,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): side_effects = [False] * len(matches) side_effects[true_index] = True - self.msg.content = "foobar" token_re.findall.return_value = matches is_maybe_token.side_effect = side_effects -- cgit v1.2.3 From e4790b330da1605573b5d23615bfe62b481e1e04 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 10:37:59 -0700 Subject: Test token remover's message deletion --- tests/bot/cogs/test_token_remover.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 5ca863926..d65ce2ce5 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -209,6 +209,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): if user_return: valid_time.assert_called_once_with("y") + async def test_delete_message(self): + """The message should be deleted, and a message should be sent to the same channel.""" + await TokenRemover.delete_message(self.msg) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_called_once_with( + DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + ) + def test_censors_valid_tokens(self): """Valid tokens are censored.""" cases = ( -- cgit v1.2.3 From 567a5f9242912d6a3340c088c0ae1a62977a141e Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 10:46:02 -0700 Subject: Test TokenRemover.format_log_message --- tests/bot/cogs/test_token_remover.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index d65ce2ce5..f5412e692 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -218,6 +218,22 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) ) + @autospec("bot.cogs.token_remover", "LOG_MESSAGE") + async def test_format_log_message(self, log_message): + """Should correctly format the log message with info from the message and token.""" + log_message.format.return_value = "Howdy" + return_value = TokenRemover.format_log_message(self.msg, "MTIz.DN9R_A.xyz") + + self.assertEqual(return_value, log_message.format.return_value) + log_message.format.assert_called_once_with( + author=self.msg.author, + author_id=self.msg.author.id, + channel=self.msg.channel.mention, + user_id="MTIz", + timestamp="DN9R_A", + hmac="xxx", + ) + def test_censors_valid_tokens(self): """Valid tokens are censored.""" cases = ( -- cgit v1.2.3 From f47cbef0b47ef11b8c1fd63076105e4cb7d73601 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:29:28 -0700 Subject: Test TokenRemover.take_action * Remove `bot.get_cog` mocks in `setUp` * Mock the logger cause it's easier to assert logs * Remove subtests * Assert helper functions were called * Create an autospec for ModLog --- tests/bot/cogs/test_token_remover.py | 73 +++++++++++++++--------------------- 1 file changed, 30 insertions(+), 43 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index f5412e692..3546e7964 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,11 +1,10 @@ -import asyncio -import logging import unittest from unittest import mock -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock from discord import Colour +from bot.cogs.moderation import ModLog from bot.cogs.token_remover import ( DELETION_MESSAGE_TEMPLATE, TOKEN_RE, @@ -22,8 +21,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Adds the cog, a bot, and a message to the instance for usage in tests.""" self.bot = MockBot() - self.bot.get_cog.return_value = MagicMock() - self.bot.get_cog.return_value.send_log_message = AsyncMock() self.cog = TokenRemover(bot=self.bot) self.msg = MockMessage(id=555, content="hello world") @@ -234,46 +231,36 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): hmac="xxx", ) - def test_censors_valid_tokens(self): - """Valid tokens are censored.""" - cases = ( - # (content, censored_token) - ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'), + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + @autospec("bot.cogs.token_remover", "log") + @autospec(TokenRemover, "delete_message", "format_log_message") + async def test_take_action(self, delete_message, format_log_message, logger, mod_log_property): + """Should delete the message and send a mod log.""" + cog = TokenRemover(self.bot) + mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) + token = "MTIz.DN9R_A.xyz" + log_msg = "testing123" + + mod_log_property.return_value = mod_log + format_log_message.return_value = log_msg + + await cog.take_action(self.msg, token) + + delete_message.assert_awaited_once_with(self.msg) + format_log_message.assert_called_once_with(self.msg, token) + logger.debug.assert_called_with(log_msg) + self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") + + mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) + mod_log.send_log_message.assert_called_once_with( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=log_msg, + thumbnail=self.msg.author.avatar_url_as.return_value, + channel_id=Channels.mod_alerts ) - for content, censored_token in cases: - with self.subTest(content=content, censored_token=censored_token): - self.msg.content = content - coroutine = self.cog.on_message(self.msg) - with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm: - self.assertIsNone(asyncio.run(coroutine)) # no return value - - [line] = cm.output - log_message = ( - "Censored a seemingly valid token sent by " - "lemon (`42`) in #lemonade-stand, " - f"token was `{censored_token}`" - ) - self.assertIn(log_message, line) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_called_once_with( - DELETION_MESSAGE_TEMPLATE.format(mention='@lemon') - ) - self.bot.get_cog.assert_called_with('ModLog') - self.msg.author.avatar_url_as.assert_called_once_with(static_format='png') - - mod_log = self.bot.get_cog.return_value - mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) - mod_log.send_log_message.assert_called_once_with( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=log_message, - thumbnail='picture-lemon.png', - channel_id=Channels.mod_alerts - ) - class TokenRemoverSetupTests(unittest.TestCase): """Tests setup of the `TokenRemover` cog.""" -- cgit v1.2.3 From 5734a4d84922a9497014dfeb3eba31ad3c57536f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:44:08 -0700 Subject: Refactor `TokenRemoverSetupTests` and add a more thorough test The test now ensures the cog is instantiated and that the instance is passed as an argument to `add_cog`. --- tests/bot/cogs/test_token_remover.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3546e7964..c377de7b2 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -262,11 +262,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): ) -class TokenRemoverSetupTests(unittest.TestCase): - """Tests setup of the `TokenRemover` cog.""" +class TokenRemoverExtensionTests(unittest.TestCase): + """Tests for the token_remover extension.""" - def test_setup(self): - """Setup of the extension should call add_cog.""" + @autospec("bot.cogs.token_remover", "TokenRemover") + def test_extension_setup(self, cog): + """The TokenRemover cog should be added.""" bot = MockBot() setup_cog(bot) + + cog.assert_called_once_with(bot) bot.add_cog.assert_called_once() + self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) -- cgit v1.2.3 From d0303d715d485842a2d5c906099d767d74cf8bd8 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:45:50 -0700 Subject: Replace deprecated assertion methods --- tests/bot/cogs/test_token_remover.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index c377de7b2..aecb51403 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -150,7 +150,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): for token in tokens: with self.subTest(token=token): results = TOKEN_RE.findall(token) - self.assertEquals(len(results), 0) + self.assertEqual(len(results), 0) def test_regex_valid_tokens(self): """Messages that look like tokens should be matched.""" @@ -171,7 +171,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): message = f"garbage {tokens[0]} hello {tokens[1]} world" results = TOKEN_RE.findall(message) - self.assertEquals(tokens, results) + self.assertEqual(tokens, results) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): -- cgit v1.2.3 From 862153f2e4ab5b1408719fb2c1abc5143cfb15ce Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:47:40 -0700 Subject: Clean up token remover test imports --- tests/bot/cogs/test_token_remover.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index aecb51403..5cc8c7ad1 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -4,14 +4,10 @@ from unittest.mock import MagicMock from discord import Colour +from bot import constants +from bot.cogs import token_remover from bot.cogs.moderation import ModLog -from bot.cogs.token_remover import ( - DELETION_MESSAGE_TEMPLATE, - TOKEN_RE, - TokenRemover, - setup as setup_cog, -) -from bot.constants import Channels, Colours, Event, Icons +from bot.cogs.token_remover import TokenRemover from tests.helpers import MockBot, MockMessage, autospec @@ -149,7 +145,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): for token in tokens: with self.subTest(token=token): - results = TOKEN_RE.findall(token) + results = token_remover.TOKEN_RE.findall(token) self.assertEqual(len(results), 0) def test_regex_valid_tokens(self): @@ -162,7 +158,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): for token in tokens: with self.subTest(token=token): - results = TOKEN_RE.findall(token) + results = token_remover.TOKEN_RE.findall(token) self.assertIn(token, results) def test_regex_matches_multiple_valid(self): @@ -170,7 +166,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): tokens = ["x.y.z", "a.b.c"] message = f"garbage {tokens[0]} hello {tokens[1]} world" - results = TOKEN_RE.findall(message) + results = token_remover.TOKEN_RE.findall(message) self.assertEqual(tokens, results) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") @@ -212,7 +208,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.delete.assert_called_once_with() self.msg.channel.send.assert_called_once_with( - DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) ) @autospec("bot.cogs.token_remover", "LOG_MESSAGE") @@ -251,14 +247,14 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): logger.debug.assert_called_with(log_msg) self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") - mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) + mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) mod_log.send_log_message.assert_called_once_with( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), + icon_url=constants.Icons.token_removed, + colour=Colour(constants.Colours.soft_red), title="Token removed!", text=log_msg, thumbnail=self.msg.author.avatar_url_as.return_value, - channel_id=Channels.mod_alerts + channel_id=constants.Channels.mod_alerts ) @@ -269,7 +265,7 @@ class TokenRemoverExtensionTests(unittest.TestCase): def test_extension_setup(self, cog): """The TokenRemover cog should be added.""" bot = MockBot() - setup_cog(bot) + token_remover.setup(bot) cog.assert_called_once_with(bot) bot.add_cog.assert_called_once() -- cgit v1.2.3 From 4701b0da36c7f42792c0af258b785076237fd661 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:56:15 -0700 Subject: Use subtests for valid ID/timestamp tests and test non-ASCII inputs --- tests/bot/cogs/test_token_remover.py | 43 +++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 18 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 5cc8c7ad1..f1a56c235 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -24,24 +24,31 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) self.msg.author.avatar_url_as.return_value = "picture-lemon.png" - def test_is_valid_user_id_is_true_for_numeric_content(self): - """A string decoding to numeric characters is a valid user ID.""" - # MTIz = base64(123) - self.assertTrue(TokenRemover.is_valid_user_id('MTIz')) - - def test_is_valid_user_id_is_false_for_alphabetic_content(self): - """A string decoding to alphabetic characters is not a valid user ID.""" - # YWJj = base64(abc) - self.assertFalse(TokenRemover.is_valid_user_id('YWJj')) - - def test_is_valid_timestamp_is_true_for_valid_timestamps(self): - """A string decoding to a valid timestamp should be recognized as such.""" - self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A')) - - def test_is_valid_timestamp_is_false_for_invalid_values(self): - """A string not decoding to a valid timestamp should not be recognized as such.""" - # MTIz = base64(123) - self.assertFalse(TokenRemover.is_valid_timestamp('MTIz')) + def test_is_valid_user_id(self): + """Should correctly discern valid user IDs and ignore non-numeric and non-ASCII IDs.""" + subtests = ( + ("MTIz", True), # base64(123) + ("YWJj", False), # base64(abc) + ("λδµ", False), + ) + + for user_id, is_valid in subtests: + with self.subTest(user_id=user_id, is_valid=is_valid): + result = TokenRemover.is_valid_user_id(user_id) + self.assertIs(result, is_valid) + + def test_is_valid_timestamp(self): + """Should correctly discern valid timestamps.""" + subtests = ( + ("DN9r_A", True), + ("MTIz", False), # base64(123) + ("λδµ", False), + ) + + for timestamp, is_valid in subtests: + with self.subTest(timestamp=timestamp, is_valid=is_valid): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertIs(result, is_valid) def test_mod_log_property(self): """The `mod_log` property should ask the bot to return the `ModLog` cog.""" -- cgit v1.2.3 From 31aff51655d3783bc70f04628f189cf3c3591028 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 13 May 2020 18:58:43 -0700 Subject: Fix a test needlessly being a coroutine --- tests/bot/cogs/test_token_remover.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index f1a56c235..8e743a715 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -219,7 +219,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): ) @autospec("bot.cogs.token_remover", "LOG_MESSAGE") - async def test_format_log_message(self, log_message): + def test_format_log_message(self, log_message): """Should correctly format the log message with info from the message and token.""" log_message.format.return_value = "Howdy" return_value = TokenRemover.format_log_message(self.msg, "MTIz.DN9R_A.xyz") -- cgit v1.2.3 From ededd1879cfb914445342b202d4c66aed23ee94b Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Fri, 22 May 2020 08:43:10 +0300 Subject: Logging Tests: Simplify `DEBUG_MODE` `False` test - Remove embed attributes checks - Replace `self.dev_log.assert_awaited_once_with` with `self.dev_log.assert_awaited_once`. --- tests/bot/cogs/test_logging.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_logging.py b/tests/bot/cogs/test_logging.py index ba98a5a56..8a18fdcd6 100644 --- a/tests/bot/cogs/test_logging.py +++ b/tests/bot/cogs/test_logging.py @@ -22,17 +22,7 @@ class LoggingTests(unittest.IsolatedAsyncioTestCase): await self.cog.startup_greeting() self.bot.wait_until_guild_available.assert_awaited_once_with() self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) - - embed = self.dev_log.send.call_args[1]['embed'] - self.dev_log.send.assert_awaited_once_with(embed=embed) - - self.assertEqual(embed.description, "Connected!") - self.assertEqual(embed.author.name, "Python Bot") - self.assertEqual(embed.author.url, "https://github.com/python-discord/bot") - self.assertEqual( - embed.author.icon_url, - "https://raw.githubusercontent.com/python-discord/branding/master/logos/logo_circle/logo_circle_large.png" - ) + self.dev_log.send.assert_awaited_once() @patch("bot.cogs.logging.DEBUG_MODE", True) async def test_debug_mode_true(self): -- cgit v1.2.3 From 47886501fb7d030f1cb91c69413058e3ffcb76bf Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 20:47:32 -0700 Subject: Test token regex won't match non-base64 characters --- tests/bot/cogs/test_token_remover.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 8e743a715..dbea5ad1b 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -144,10 +144,9 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): "x..z", " . . ", "\n.\n.\n", - "'.'.'", - '"."."', - "(.(.(", - ").).)" + "hellö.world.bye", + "base64.nötbåse64.morebase64", + "19jd3J.dfkm3d.€víł§tüff", ) for token in tokens: -- cgit v1.2.3 From e76099d48b9a895c48e58c5f5489886f4191eeb6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 20:50:30 -0700 Subject: Add more valid tokens to test the regex with --- tests/bot/cogs/test_token_remover.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index dbea5ad1b..6a280f358 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -156,10 +156,12 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): def test_regex_valid_tokens(self): """Messages that look like tokens should be matched.""" - # Don't worry, the token's been invalidated. + # Don't worry, these tokens have been invalidated. tokens = ( - "x1.y2.z_3", - "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8" + "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", + "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", + "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", ) for token in tokens: -- cgit v1.2.3 From a8a216d0803b67a330ae092a17bea563f5012275 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 21:02:24 -0700 Subject: Fix valid token regex test It was broken due to the addition of groups. Rather than returning the full match, `findall` returns groups if any exist. The test was comparing a tuple of groups to the token string, which was of course failing. Now `fullmatch` is used cause it's simpler - just check for `None` and don't worry about iterating matches to search. --- tests/bot/cogs/test_token_remover.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 6a280f358..518bf91ca 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -166,8 +166,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): for token in tokens: with self.subTest(token=token): - results = token_remover.TOKEN_RE.findall(token) - self.assertIn(token, results) + results = token_remover.TOKEN_RE.fullmatch(token) + self.assertIsNotNone(results, f"{token} was not matched by the regex") def test_regex_matches_multiple_valid(self): """Should support multiple matches in the middle of a string.""" -- cgit v1.2.3 From 19cc849d4c70bc3e792460ad712aa308fa500462 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 21:07:21 -0700 Subject: Fix multiple match text for token regex It has to account for the addition of groups. It's easiest to compare the entire string so `finditer` is used to return re.Match objects; the tuples of `findall` would be cumbersome. Also threw in a change to use `assertCountEqual` cause the order doesn't really matter. --- tests/bot/cogs/test_token_remover.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 518bf91ca..2ecfae2bd 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -174,8 +174,9 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): tokens = ["x.y.z", "a.b.c"] message = f"garbage {tokens[0]} hello {tokens[1]} world" - results = token_remover.TOKEN_RE.findall(message) - self.assertEqual(tokens, results) + results = token_remover.TOKEN_RE.finditer(message) + results = [match[0] for match in results] + self.assertCountEqual(tokens, results) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): -- cgit v1.2.3 From 300f8c093edea03855d94be179c64c328ec842ac Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 21:09:04 -0700 Subject: Use real token values for testing multiple matches in regex --- tests/bot/cogs/test_token_remover.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 2ecfae2bd..971bc93fc 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -171,12 +171,13 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): def test_regex_matches_multiple_valid(self): """Should support multiple matches in the middle of a string.""" - tokens = ["x.y.z", "a.b.c"] - message = f"garbage {tokens[0]} hello {tokens[1]} world" + token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" + token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" + message = f"garbage {token_1} hello {token_2} world" results = token_remover.TOKEN_RE.finditer(message) results = [match[0] for match in results] - self.assertCountEqual(tokens, results) + self.assertCountEqual((token_1, token_2), results) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): -- cgit v1.2.3 From 96db6087254c957fcb8fb45aad7ffcddb46ee839 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 27 May 2020 17:08:18 -0700 Subject: Switch findall to finditer in assertions `find_token_in_message` now uses the latter so the tests should adjust accordingly. --- tests/bot/cogs/test_token_remover.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 971bc93fc..4fff3ab33 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -94,18 +94,18 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): return_value = TokenRemover.find_token_in_message(self.msg) self.assertIsNone(return_value) - token_re.findall.assert_not_called() + token_re.finditer.assert_not_called() @autospec(TokenRemover, "is_maybe_token") @autospec("bot.cogs.token_remover", "TOKEN_RE") def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): """None should be returned if the regex matches no tokens in a message.""" - token_re.findall.return_value = () + token_re.finditer.return_value = () return_value = TokenRemover.find_token_in_message(self.msg) self.assertIsNone(return_value) - token_re.findall.assert_called_once_with(self.msg.content) + token_re.finditer.assert_called_once_with(self.msg.content) is_maybe_token.assert_not_called() @autospec(TokenRemover, "is_maybe_token") @@ -123,7 +123,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): return_value = TokenRemover.find_token_in_message(self.msg) self.assertEqual(return_value, matches[true_index]) - token_re.findall.assert_called_once_with(self.msg.content) + token_re.finditer.assert_called_once_with(self.msg.content) # assert_has_calls isn't used cause it'd allow for extra calls before or after. # The function should short-circuit, so nothing past true_index should have been used. -- cgit v1.2.3 From f937032466a4124bacf217d1bfd0af097fc3395d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 27 May 2020 19:31:55 -0700 Subject: Adjust token remover tests to use the Token NamedTuple --- tests/bot/cogs/test_token_remover.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 4fff3ab33..65bc1ee58 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -7,7 +7,7 @@ from discord import Colour from bot import constants from bot.cogs import token_remover from bot.cogs.moderation import ModLog -from bot.cogs.token_remover import TokenRemover +from bot.cogs.token_remover import Token, TokenRemover from tests.helpers import MockBot, MockMessage, autospec @@ -224,17 +224,19 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec("bot.cogs.token_remover", "LOG_MESSAGE") def test_format_log_message(self, log_message): """Should correctly format the log message with info from the message and token.""" + token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") log_message.format.return_value = "Howdy" - return_value = TokenRemover.format_log_message(self.msg, "MTIz.DN9R_A.xyz") + + return_value = TokenRemover.format_log_message(self.msg, token) self.assertEqual(return_value, log_message.format.return_value) log_message.format.assert_called_once_with( author=self.msg.author, author_id=self.msg.author.id, channel=self.msg.channel.mention, - user_id="MTIz", - timestamp="DN9R_A", - hmac="xxx", + user_id=token.user_id, + timestamp=token.timestamp, + hmac="x" * len(token.hmac), ) @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) @@ -244,7 +246,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): """Should delete the message and send a mod log.""" cog = TokenRemover(self.bot) mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) - token = "MTIz.DN9R_A.xyz" + token = mock.create_autospec(Token, spec_set=True, instance=True) log_msg = "testing123" mod_log_property.return_value = mod_log -- cgit v1.2.3 From 12b8f5002807144451a313180c639bf6b4925f2e Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 27 May 2020 20:00:33 -0700 Subject: Add more thorough and realistic inputs for token ID and timestamp tests The tests for valid inputs and invalid inputs were split to make them more readable. --- tests/bot/cogs/test_token_remover.py | 70 ++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 18 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 65bc1ee58..ffe76865a 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -24,31 +24,65 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) self.msg.author.avatar_url_as.return_value = "picture-lemon.png" - def test_is_valid_user_id(self): - """Should correctly discern valid user IDs and ignore non-numeric and non-ASCII IDs.""" - subtests = ( - ("MTIz", True), # base64(123) - ("YWJj", False), # base64(abc) - ("λδµ", False), + def test_is_valid_user_id_valid(self): + """Should consider user IDs valid if they decode entirely to ASCII digits.""" + ids = ( + "NDcyMjY1OTQzMDYyNDEzMzMy", + "NDc1MDczNjI5Mzk5NTQ3OTA0", + "NDY3MjIzMjMwNjUwNzc3NjQx", ) - for user_id, is_valid in subtests: - with self.subTest(user_id=user_id, is_valid=is_valid): + for user_id in ids: + with self.subTest(user_id=user_id): result = TokenRemover.is_valid_user_id(user_id) - self.assertIs(result, is_valid) + self.assertTrue(result) + + def test_is_valid_user_id_invalid(self): + """Should consider non-digit and non-ASCII IDs invalid.""" + ids = ( + ("SGVsbG8gd29ybGQ", "non-digit ASCII"), + ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), + ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), + ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), + ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) - def test_is_valid_timestamp(self): - """Should correctly discern valid timestamps.""" - subtests = ( - ("DN9r_A", True), - ("MTIz", False), # base64(123) - ("λδµ", False), + for user_id, msg in ids: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_user_id(user_id) + self.assertFalse(result) + + def test_is_valid_timestamp_valid(self): + """Should consider timestamps valid if they're greater than the Discord epoch.""" + timestamps = ( + "XsyRkw", + "Xrim9Q", + "XsyR-w", + "XsySD_", + "Dn9r_A", + ) + + for timestamp in timestamps: + with self.subTest(timestamp=timestamp): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertTrue(result) + + def test_is_valid_timestamp_invalid(self): + """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" + timestamps = ( + ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), + ("ew", "123"), + ("AoIKgA", "42076800"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), ) - for timestamp, is_valid in subtests: - with self.subTest(timestamp=timestamp, is_valid=is_valid): + for timestamp, msg in timestamps: + with self.subTest(msg=msg): result = TokenRemover.is_valid_timestamp(timestamp) - self.assertIs(result, is_valid) + self.assertFalse(result) def test_mod_log_property(self): """The `mod_log` property should ask the bot to return the `ModLog` cog.""" -- cgit v1.2.3 From 67472080fef5c38b21d74daa2178c3f35081b58f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 28 May 2020 19:52:41 -0700 Subject: Remove is_maybe_token tests The function was removed due to redundancy. Therefore, its tests are obsolete. --- tests/bot/cogs/test_token_remover.py | 33 --------------------------------- 1 file changed, 33 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index ffe76865a..5dd12636c 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -213,39 +213,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = [match[0] for match in results] self.assertCountEqual((token_1, token_2), results) - @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): - """False should be returned for tokens which do not have all 3 parts.""" - return_value = TokenRemover.is_maybe_token("x.y") - - self.assertFalse(return_value) - valid_user.assert_not_called() - valid_time.assert_not_called() - - @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - def test_is_maybe_token(self, valid_user, valid_time): - """Should return True if the user ID and timestamp are valid or return False otherwise.""" - subtests = ( - (False, True, False), - (True, False, False), - (True, True, True), - ) - - for user_return, time_return, expected in subtests: - valid_user.reset_mock() - valid_time.reset_mock() - - with self.subTest(user_return=user_return, time_return=time_return, expected=expected): - valid_user.return_value = user_return - valid_time.return_value = time_return - - actual = TokenRemover.is_maybe_token("x.y.z") - self.assertIs(actual, expected) - - valid_user.assert_called_once_with("x") - if user_return: - valid_time.assert_called_once_with("y") - async def test_delete_message(self): """The message should be deleted, and a message should be sent to the same channel.""" await TokenRemover.delete_message(self.msg) -- cgit v1.2.3 From 84cd8235863acc80b7f140309424c33180cc34ea Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 28 May 2020 20:32:48 -0700 Subject: Adjust find_token_in_message tests for the recent cog changes It now supports the changes that switched to finditer, added match groups, and added the Token NamedTuple. It also accounts for the is_maybe_token function being removed. For the sake of simplicity, call assertions on is_valid_user_id and is_valid_timestamp were not made. --- tests/bot/cogs/test_token_remover.py | 39 ++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 5dd12636c..8238e235a 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,4 +1,5 @@ import unittest +from re import Match from unittest import mock from unittest.mock import MagicMock @@ -130,9 +131,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(return_value) token_re.finditer.assert_not_called() - @autospec(TokenRemover, "is_maybe_token") @autospec("bot.cogs.token_remover", "TOKEN_RE") - def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): + def test_find_token_no_matches(self, token_re): """None should be returned if the regex matches no tokens in a message.""" token_re.finditer.return_value = () @@ -140,30 +140,31 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(return_value) token_re.finditer.assert_called_once_with(self.msg.content) - is_maybe_token.assert_not_called() - @autospec(TokenRemover, "is_maybe_token") + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.cogs.token_remover", "Token") @autospec("bot.cogs.token_remover", "TOKEN_RE") - def test_find_token_returns_found_token(self, token_re, is_maybe_token): - """The found token should be returned.""" - true_index = 1 - matches = ("foo", "bar", "baz") - side_effects = [False] * len(matches) - side_effects[true_index] = True - - token_re.findall.return_value = matches - is_maybe_token.side_effect = side_effects + def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """The first match with a valid user ID and timestamp should be returned as a `Token`.""" + matches = [ + mock.create_autospec(Match, spec_set=True, instance=True), + mock.create_autospec(Match, spec_set=True, instance=True), + ] + tokens = [ + mock.create_autospec(Token, spec_set=True, instance=True), + mock.create_autospec(Token, spec_set=True, instance=True), + ] + + token_re.finditer.return_value = matches + token_cls.side_effect = tokens + is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid. + is_valid_timestamp.return_value = True return_value = TokenRemover.find_token_in_message(self.msg) - self.assertEqual(return_value, matches[true_index]) + self.assertEqual(tokens[1], return_value) token_re.finditer.assert_called_once_with(self.msg.content) - # assert_has_calls isn't used cause it'd allow for extra calls before or after. - # The function should short-circuit, so nothing past true_index should have been used. - calls = [mock.call(match) for match in matches[:true_index + 1]] - self.assertEqual(is_maybe_token.mock_calls, calls) - def test_regex_invalid_tokens(self): """Messages without anything looking like a token are not matched.""" tokens = ( -- cgit v1.2.3 From 5930a044b8347019d474a809fc86f89263574ad0 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 28 May 2020 20:33:34 -0700 Subject: Test find_token_in_message returns None for invalid matches This covers the case when a token is matched, but its user ID and timestamp turn out to be invalid. --- tests/bot/cogs/test_token_remover.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 8238e235a..9b4b04ecd 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -165,6 +165,21 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(tokens[1], return_value) token_re.finditer.assert_called_once_with(self.msg.content) + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.cogs.token_remover", "Token") + @autospec("bot.cogs.token_remover", "TOKEN_RE") + def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """None should be returned if no matches have valid user IDs or timestamps.""" + token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] + token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) + is_valid_id.return_value = False + is_valid_timestamp.return_value = False + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + def test_regex_invalid_tokens(self): """Messages without anything looking like a token are not matched.""" tokens = ( -- cgit v1.2.3 From 876b4846f612fe0011cc2e0b498b4df9e54d74cb Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 31 May 2020 19:17:07 +0200 Subject: Add support for bool values in RedisCache We're gonna need this for the help channel handling, and it seems like a reasonable type to support anyway. It requires a tiny bit of special handling, but nothing outrageous. --- bot/utils/redis_cache.py | 14 ++++++++++++-- tests/bot/utils/test_redis_cache.py | 4 +++- 2 files changed, 15 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index de80cee84..2926e7a89 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import logging +from distutils.util import strtobool from functools import partialmethod from typing import Any, Dict, ItemsView, Optional, Tuple, Union @@ -11,7 +12,7 @@ log = logging.getLogger(__name__) # Type aliases RedisKeyType = Union[str, int] -RedisValueType = Union[str, int, float] +RedisValueType = Union[str, int, float, bool] RedisKeyOrValue = Union[RedisKeyType, RedisValueType] # Prefix tuples @@ -20,6 +21,7 @@ _VALUE_PREFIXES = ( ("f|", float), ("i|", int), ("s|", str), + ("b|", bool), ) _KEY_PREFIXES = ( ("i|", int), @@ -117,7 +119,8 @@ class RedisCache: def _to_typestring(key_or_value: RedisKeyOrValue, prefixes: _PrefixTuple) -> str: """Turn a valid Redis type into a typestring.""" for prefix, _type in prefixes: - if isinstance(key_or_value, _type): + # isinstance is a bad idea here, because isintance(False, int) == True. + if type(key_or_value) is _type: return f"{prefix}{key_or_value}" raise TypeError(f"RedisCache._to_typestring only supports the following: {prefixes}.") @@ -131,6 +134,13 @@ class RedisCache: # Now we convert our unicode string back into the type it originally was. for prefix, _type in prefixes: if key_or_value.startswith(prefix): + + # For booleans, we need special handling because bool("False") is True. + if prefix == "b|": + value = key_or_value[len(prefix):] + return bool(strtobool(value)) + + # Otherwise we can just convert normally. return _type(key_or_value[len(prefix):]) raise TypeError(f"RedisCache._from_typestring only supports the following: {prefixes}.") diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 8c1a40640..62c411681 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -59,7 +59,9 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): test_cases = ( ('favorite_fruit', 'melon'), ('favorite_number', 86), - ('favorite_fraction', 86.54) + ('favorite_fraction', 86.54), + ('favorite_boolean', False), + ('other_boolean', True), ) # Test that we can get and set different types. -- cgit v1.2.3 From ebbaa6274cfc278c772593b193356aa8bf066de4 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 31 May 2020 14:17:20 -0700 Subject: Remove redis namespace collision test --- tests/bot/utils/test_redis_cache.py | 10 ---------- 1 file changed, 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 8c1a40640..e5d6e4078 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -44,16 +44,6 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): with self.assertRaises(RuntimeError): await bad_cache.set("test", "me_up_deadman") - def test_namespace_collision(self): - """Test that we prevent colliding namespaces.""" - bob_cache_1 = RedisCache() - bob_cache_1._set_namespace("BobRoss") - self.assertEqual(bob_cache_1._namespace, "BobRoss") - - bob_cache_2 = RedisCache() - bob_cache_2._set_namespace("BobRoss") - self.assertEqual(bob_cache_2._namespace, "BobRoss_") - async def test_set_get_item(self): """Test that users can set and get items from the RedisDict.""" test_cases = ( -- cgit v1.2.3 From 9b3ab7df5ae1ecf95705f2fab7d99fdb36eb98ea Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Tue, 2 Jun 2020 19:22:49 -0700 Subject: Token remover: remove the `delete_message` function It's redundant; there's no benefit here in abstracting two lines of code into a function. --- bot/cogs/token_remover.py | 9 ++------- tests/bot/cogs/test_token_remover.py | 19 +++++++------------ 2 files changed, 9 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 46329e207..d55e079e9 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -79,7 +79,8 @@ class TokenRemover(Cog): async def take_action(self, msg: Message, found_token: Token) -> None: """Remove the `msg` containing the `found_token` and send a mod log message.""" self.mod_log.ignore(Event.message_delete, msg.id) - await self.delete_message(msg) + await msg.delete() + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) log_message = self.format_log_message(msg, found_token) log.debug(log_message) @@ -96,12 +97,6 @@ class TokenRemover(Cog): self.bot.stats.incr("tokens.removed_tokens") - @staticmethod - async def delete_message(msg: Message) -> None: - """Remove a `msg` containing a token and send an explanatory message in the same channel.""" - await msg.delete() - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - @staticmethod def format_log_message(msg: Message, token: Token) -> str: """Return the log message to send for `token` being censored in `msg`.""" diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 9b4b04ecd..a10124d2d 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -229,15 +229,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = [match[0] for match in results] self.assertCountEqual((token_1, token_2), results) - async def test_delete_message(self): - """The message should be deleted, and a message should be sent to the same channel.""" - await TokenRemover.delete_message(self.msg) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_called_once_with( - token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) - ) - @autospec("bot.cogs.token_remover", "LOG_MESSAGE") def test_format_log_message(self, log_message): """Should correctly format the log message with info from the message and token.""" @@ -258,8 +249,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) @autospec("bot.cogs.token_remover", "log") - @autospec(TokenRemover, "delete_message", "format_log_message") - async def test_take_action(self, delete_message, format_log_message, logger, mod_log_property): + @autospec(TokenRemover, "format_log_message") + async def test_take_action(self, format_log_message, logger, mod_log_property): """Should delete the message and send a mod log.""" cog = TokenRemover(self.bot) mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) @@ -271,7 +262,11 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): await cog.take_action(self.msg, token) - delete_message.assert_awaited_once_with(self.msg) + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_called_once_with( + token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + ) + format_log_message.assert_called_once_with(self.msg, token) logger.debug.assert_called_with(log_msg) self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") -- cgit v1.2.3 From 5f5a51b1715228ac5b401ef6bed8a83491e313de Mon Sep 17 00:00:00 2001 From: Kyle Stanley Date: Thu, 4 Jun 2020 03:17:11 -0400 Subject: Improve LinePaginator to support long lines --- bot/cogs/moderation/management.py | 8 ++--- bot/pagination.py | 66 +++++++++++++++++++++++++++++++++++---- tests/bot/test_pagination.py | 41 +++++++++++++++++++----- 3 files changed, 98 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 250a24247..ad17a90b0 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -83,14 +83,14 @@ class ModManagement(commands.Cog): "actor__id": ctx.author.id, "ordering": "-inserted_at" } - infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + infractions = await self.bot.api_client.get("bot/infractions", params=params) if infractions: old_infraction = infractions[0] infraction_id = old_infraction["id"] else: await ctx.send( - f":x: Couldn't find most recent infraction; you have never given an infraction." + ":x: Couldn't find most recent infraction; you have never given an infraction." ) return else: @@ -224,7 +224,7 @@ class ModManagement(commands.Cog): ) -> None: """Send a paginated embed of infractions for the specified user.""" if not infractions: - await ctx.send(f":warning: No infractions could be found for that query.") + await ctx.send(":warning: No infractions could be found for that query.") return lines = tuple( @@ -268,12 +268,12 @@ class ModManagement(commands.Cog): User: {self.bot.get_user(user_id)} (`{user_id}`) Type: **{infraction["type"]}** Shadow: {hidden} - Reason: {infraction["reason"] or "*None*"} Created: {created} Expires: {expires} Remaining: {remaining} Actor: {actor.mention if actor else actor_id} ID: `{infraction["id"]}` + Reason: {infraction["reason"] or "*None*"} {"**===============**" if active else "==============="} """) diff --git a/bot/pagination.py b/bot/pagination.py index 90c8f849c..5c7be564d 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -37,12 +37,19 @@ class LinePaginator(Paginator): The suffix appended at the end of every page. e.g. three backticks. * max_size: `int` The maximum amount of codepoints allowed in a page. + * scale_to_size: `int` + The maximum amount of characters a single line can scale up to. * max_lines: `int` The maximum amount of lines allowed in a page. """ def __init__( - self, prefix: str = '```', suffix: str = '```', max_size: int = 2000, max_lines: int = None + self, + prefix: str = '```', + suffix: str = '```', + max_size: int = 2000, + scale_to_size: int = 2000, + max_lines: t.Optional[int] = None ) -> None: """ This function overrides the Paginator.__init__ from inside discord.ext.commands. @@ -52,6 +59,10 @@ class LinePaginator(Paginator): self.prefix = prefix self.suffix = suffix self.max_size = max_size - len(suffix) + if scale_to_size < max_size: + raise ValueError("scale_to_size must be >= max_size.") + + self.scale_to_size = scale_to_size self.max_lines = max_lines self._current_page = [prefix] self._linecount = 0 @@ -62,14 +73,26 @@ class LinePaginator(Paginator): """ Adds a line to the current page. - If the line exceeds the `self.max_size` then an exception is raised. + If the line exceeds `self.max_size`, then `self.max_size` will go up to `scale_to_size` for + a single line before creating a new page. If it is still exceeded, the excess characters + are stored and placed on the next pages until there are none remaining (by word boundary). + + Raises a RuntimeError if `self.max_size` is still exceeded after attempting to continue + onto the next page. This function overrides the `Paginator.add_line` from inside `discord.ext.commands`. It overrides in order to allow us to configure the maximum number of lines per page. """ - if len(line) > self.max_size - len(self.prefix) - 2: - raise RuntimeError('Line exceeds maximum page size %s' % (self.max_size - len(self.prefix) - 2)) + remaining_words = None + if len(line) > (max_chars := self.max_size - len(self.prefix) - 2): + if len(line) > self.scale_to_size: + line, remaining_words = self._split_remaining_words(line, max_chars) + # If line still exceeds scale_to_size, we were unable to split into a second + # page without truncating. + if len(line) > self.scale_to_size: + raise RuntimeError(f'Line exceeds maximum scale_to_size {self.scale_to_size}' + ' and could not be split.') if self.max_lines is not None: if self._linecount >= self.max_lines: @@ -87,6 +110,36 @@ class LinePaginator(Paginator): self._current_page.append('') self._count += 1 + if remaining_words: + self.add_line(remaining_words) + + def _split_remaining_words(self, line: str, max_chars: int) -> t.Tuple[str, t.Optional[str]]: + """Internal: split a line into two strings; one that fits within *max_chars* characters + (reduced_words) and another for the remaining (remaining_words), rounding down to the + nearest word. + + Return a tuple in the format (reduced_words, remaining_words). + """ + reduced_words = [] + # "(Continued)" is used on a line by itself to indicate the continuation of last page + remaining_words = ["(Continued)\n", "---------------\n"] + reduced_char_count = 0 + is_full = False + + for word in line.split(" "): + if not is_full: + if len(word) + reduced_char_count <= max_chars: + reduced_words.append(word) + reduced_char_count += len(word) + else: + is_full = True + remaining_words.append(word) + else: + remaining_words.append(word) + + return " ".join(reduced_words), " ".join(remaining_words) if len(remaining_words) > 2 \ + else None + @classmethod async def paginate( cls, @@ -97,6 +150,7 @@ class LinePaginator(Paginator): suffix: str = "", max_lines: t.Optional[int] = None, max_size: int = 500, + scale_to_size: int = 2000, empty: bool = True, restrict_to_user: User = None, timeout: int = 300, @@ -147,7 +201,7 @@ class LinePaginator(Paginator): if not lines: if exception_on_empty_embed: - log.exception(f"Pagination asked for empty lines iterable") + log.exception("Pagination asked for empty lines iterable") raise EmptyPaginatorEmbed("No lines to paginate") log.debug("No lines to add to paginator, adding '(nothing to display)' message") @@ -357,7 +411,7 @@ class ImagePaginator(Paginator): if not pages: if exception_on_empty_embed: - log.exception(f"Pagination asked for empty image list") + log.exception("Pagination asked for empty image list") raise EmptyPaginatorEmbed("No images to paginate") log.debug("No images to add to paginator, adding '(no images to display)' message") diff --git a/tests/bot/test_pagination.py b/tests/bot/test_pagination.py index 0a734b505..f2e2c27ce 100644 --- a/tests/bot/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -8,17 +8,44 @@ class LinePaginatorTests(TestCase): def setUp(self): """Create a paginator for the test method.""" - self.paginator = pagination.LinePaginator(prefix='', suffix='', max_size=30) - - def test_add_line_raises_on_too_long_lines(self): - """`add_line` should raise a `RuntimeError` for too long lines.""" - message = f"Line exceeds maximum page size {self.paginator.max_size - 2}" - with self.assertRaises(RuntimeError, msg=message): - self.paginator.add_line('x' * self.paginator.max_size) + self.paginator = pagination.LinePaginator(prefix='', suffix='', max_size=30, + scale_to_size=50) def test_add_line_works_on_small_lines(self): """`add_line` should allow small lines to be added.""" self.paginator.add_line('x' * (self.paginator.max_size - 3)) + # Note that the page isn't added to _pages until it's full. + self.assertEqual(len(self.paginator._pages), 0) + + def test_add_line_works_on_long_lines(self): + """`add_line` should scale long lines up to `scale_to_size`.""" + self.paginator.add_line('x' * self.paginator.scale_to_size) + self.assertEqual(len(self.paginator._pages), 1) + + # Any additional lines should start a new page after `max_size` is exceeded. + self.paginator.add_line('x') + self.assertEqual(len(self.paginator._pages), 2) + + def test_add_line_continuation(self): + """When `scale_to_size` is exceeded, remaining words should be split onto the next page.""" + self.paginator.add_line('zyz ' * (self.paginator.scale_to_size//4 + 1)) + self.assertEqual(len(self.paginator._pages), 2) + + def test_add_line_no_continuation(self): + """If adding a new line to an existing page would exceed `max_size`, it should start a new + page rather than using continuation. + """ + self.paginator.add_line('z' * (self.paginator.max_size - 3)) + self.paginator.add_line('z') + self.assertEqual(len(self.paginator._pages), 1) + + def test_add_line_raises_on_very_long_words(self): + """`add_line` should raise if a single long word is added that exceeds `scale_to_size`. + + Note: truncation is also a potential option, but this should not occur from normal usage. + """ + with self.assertRaises(RuntimeError): + self.paginator.add_line('x' * (self.paginator.scale_to_size + 1)) class ImagePaginatorTests(TestCase): -- cgit v1.2.3 From be4902cbd66c2f7223608ddbfee4aa4f0e1a011a Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Sat, 6 Jun 2020 22:11:53 +0200 Subject: Test for channel not silenced message --- tests/bot/cogs/moderation/test_silence.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 3fd149f04..ab3d0742a 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -127,10 +127,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.ctx.reset_mock() async def test_unsilence_sent_correct_discord_message(self): - """Proper reply after a successful unsilence.""" - with mock.patch.object(self.cog, "_unsilence", return_value=True): - await self.cog.unsilence.callback(self.cog, self.ctx) - self.ctx.send.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") + """Check if proper message was sent when unsilencing channel.""" + test_cases = ( + (True, f"{Emojis.check_mark} unsilenced current channel."), + (False, f"{Emojis.cross_mark} current channel was not silenced.") + ) + for _unsilence_patch_return, result_message in test_cases: + with self.subTest( + starting_silenced_state=_unsilence_patch_return, + result_message=result_message + ): + with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): + await self.cog.unsilence.callback(self.cog, self.ctx) + self.ctx.send.assert_called_once_with(result_message) + self.ctx.reset_mock() async def test_silence_private_for_false(self): """Permissions are not set and `False` is returned in an already silenced channel.""" -- cgit v1.2.3 From 78782868040d1b2ca0b655efc4123b3d9b6bfda3 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 10:23:45 +0300 Subject: Jam Tests: Created base test layout --- tests/bot/cogs/test_jams.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/bot/cogs/test_jams.py (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py new file mode 100644 index 000000000..33dee593e --- /dev/null +++ b/tests/bot/cogs/test_jams.py @@ -0,0 +1,14 @@ +import unittest + +from bot.constants import Roles +from tests.helpers import MockBot, MockContext, MockMember, MockRole + + +class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): + """Tests for `createteam` command.""" + + def setUp(self): + self.bot = MockBot() + self.admin_role = MockRole(name="Admins", id=Roles.admins) + self.command_user = MockMember([self.admin_role]) + self.context = MockContext(bot=self.bot, author=self.command_user) -- cgit v1.2.3 From 6242fbdce8935c681fa575b1c208642fe9d2635b Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 10:38:41 +0300 Subject: Jam Tests: Created tests for case when too small amount of members given --- tests/bot/cogs/test_jams.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 33dee593e..3e71370c2 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -1,5 +1,7 @@ import unittest +from unittest.mock import patch +from bot.cogs.jams import CodeJams from bot.constants import Roles from tests.helpers import MockBot, MockContext, MockMember, MockRole @@ -11,4 +13,18 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.admin_role = MockRole(name="Admins", id=Roles.admins) self.command_user = MockMember([self.admin_role]) - self.context = MockContext(bot=self.bot, author=self.command_user) + self.ctx = MockContext(bot=self.bot, author=self.command_user) + self.cog = CodeJams(self.bot) + + @patch("bot.cogs.jams.utils") + async def test_too_small_amount_of_team_members_passed(self, utils_mock): + """Should `ctx.send` and exit early when too small amount of members.""" + for case in (1, 2): + with self.subTest(amount_of_members=case): + self.ctx.reset_mock() + utils_mock.reset_mock() + await self.cog.createteam( + self.cog, self.ctx, team_name="foo", members=(MockMember() for _ in range(case)) + ) + self.ctx.send.assert_awaited_once() + utils_mock.get.assert_not_called() -- cgit v1.2.3 From a9122b781191f93f5dd375b5c1d9e7744943b464 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 10:46:08 +0300 Subject: Jam Tests: Created tests for removing duplicate team members --- tests/bot/cogs/test_jams.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 3e71370c2..1cface1c1 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -28,3 +28,12 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): ) self.ctx.send.assert_awaited_once() utils_mock.get.assert_not_called() + + @patch("bot.cogs.jams.utils") + async def test_duplicate_members_provided(self, utils_mock): + """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" + self.ctx.reset_mock() + member = MockMember() + await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) + self.ctx.send.assert_awaited_once() + utils_mock.get.assert_not_called() -- cgit v1.2.3 From ebaac5988d7ff1558595008540eab5368312d170 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 17:57:26 +0300 Subject: Jam Tests: Created test for category creating when not exist --- tests/bot/cogs/test_jams.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 1cface1c1..2153178c3 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -3,7 +3,7 @@ from unittest.mock import patch from bot.cogs.jams import CodeJams from bot.constants import Roles -from tests.helpers import MockBot, MockContext, MockMember, MockRole +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): @@ -13,7 +13,8 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.admin_role = MockRole(name="Admins", id=Roles.admins) self.command_user = MockMember([self.admin_role]) - self.ctx = MockContext(bot=self.bot, author=self.command_user) + self.guild = MockGuild([self.admin_role]) + self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) self.cog = CodeJams(self.bot) @patch("bot.cogs.jams.utils") @@ -37,3 +38,14 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) self.ctx.send.assert_awaited_once() utils_mock.get.assert_not_called() + + @patch("bot.cogs.jams.utils") + async def test_category_dont_exist(self, utils_mock): + """Should create code jam category.""" + utils_mock.get.return_value = None + await self.cog.createteam(self.cog, self.ctx, "foo", (MockMember() for _ in range(5))) + self.ctx.guild.create_category_channel.assert_awaited_once() + category_overwrites = self.ctx.guild.create_category_channel.call_args[1]["overwrites"] + + self.assertFalse(category_overwrites[self.ctx.guild.default_role].read_messages) + self.assertTrue(category_overwrites[self.ctx.guild.me].read_messages) -- cgit v1.2.3 From 14d4eda8b1e7839b286402091ac060d3c869f447 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 17:58:38 +0300 Subject: Jam Tests: Added utils.get assert to category creating test --- tests/bot/cogs/test_jams.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 2153178c3..f5f87761b 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -44,6 +44,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should create code jam category.""" utils_mock.get.return_value = None await self.cog.createteam(self.cog, self.ctx, "foo", (MockMember() for _ in range(5))) + utils_mock.get.assert_called_once() self.ctx.guild.create_category_channel.assert_awaited_once() category_overwrites = self.ctx.guild.create_category_channel.call_args[1]["overwrites"] -- cgit v1.2.3 From 464c4bbb53101d4456314bf7a40243337525d514 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 18:03:17 +0300 Subject: Jam Tests: Created test that make sure when category exist, don't create --- tests/bot/cogs/test_jams.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index f5f87761b..1ce71a942 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -50,3 +50,11 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.assertFalse(category_overwrites[self.ctx.guild.default_role].read_messages) self.assertTrue(category_overwrites[self.ctx.guild.me].read_messages) + + @patch("bot.cogs.jams.utils") + async def test_category_channel_exist(self, utils_mock): + """Should not try to create category channel.""" + utils_mock.return_value = "foo" + await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) + utils_mock.get.assert_called_once() + self.ctx.guild.create_category_channel.assert_not_awaited() -- cgit v1.2.3 From a63545510f392cf3e36e310b68792177a178b769 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 18:08:29 +0300 Subject: Jam Tests: Created test for creating text channel for team --- tests/bot/cogs/test_jams.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 1ce71a942..9d26628ff 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -58,3 +58,8 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) utils_mock.get.assert_called_once() self.ctx.guild.create_category_channel.assert_not_awaited() + + async def test_team_text_channel_creation(self): + """Should create text channel for team.""" + await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) + self.ctx.guild.create_text_channel.assert_awaited_once() -- cgit v1.2.3 From 3df28c1b2a64bee3a52442fe42decaa960c45fde Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 18:18:09 +0300 Subject: Jam Tests: Created test for channel overwrites --- tests/bot/cogs/test_jams.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 9d26628ff..d21c5ea29 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -63,3 +63,27 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should create text channel for team.""" await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) self.ctx.guild.create_text_channel.assert_awaited_once() + + async def test_channel_overwrites(self): + """Should have correct permission overwrites for users and roles.""" + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + await self.cog.createteam(self.cog, self.ctx, "foo", members) + overwrites = self.ctx.guild.create_text_channel.call_args[1]["overwrites"] + + # Leader permission overwrites + self.assertTrue(overwrites[leader].manage_messages) + self.assertTrue(overwrites[leader].read_messages) + self.assertTrue(overwrites[leader].manage_webhooks) + self.assertTrue(overwrites[leader].connect) + + # Other members permission overwrites + for member in members[1:]: + self.assertTrue(overwrites[member].read_messages) + self.assertTrue(overwrites[member].connect) + + # Everyone and verified role overwrite + self.assertFalse(overwrites[self.ctx.guild.default_role].read_messages) + self.assertFalse(overwrites[self.ctx.guild.default_role].connect) + self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].read_messages) + self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].connect) -- cgit v1.2.3 From 6476d3ba6dfc28441d097aaa15a7c9e13f53f646 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 18:22:02 +0300 Subject: Jam Tests: Make text channel creation test more specific --- tests/bot/cogs/test_jams.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index d21c5ea29..94c48b995 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -59,11 +59,18 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): utils_mock.get.assert_called_once() self.ctx.guild.create_category_channel.assert_not_awaited() - async def test_team_text_channel_creation(self): + @patch("bot.cogs.jams.utils") + async def test_team_text_channel_creation(self, utils_mock): """Should create text channel for team.""" + utils_mock.get.return_value = "foo" await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) + # Make sure that we awaited function before getting call arguments self.ctx.guild.create_text_channel.assert_awaited_once() + # All other arguments is possible to get somewhere else except this + overwrites = self.ctx.guild.create_text_channel.call_args[1]["overwrites"] + self.ctx.guild.create_text_channel.assert_awaited_once_with("bar", overwrites=overwrites, category="foo") + async def test_channel_overwrites(self): """Should have correct permission overwrites for users and roles.""" leader = MockMember() -- cgit v1.2.3 From b1359f0ed37cdbbb6bae9dbbe92e3bf0db660636 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 18:26:41 +0300 Subject: Jam Tests: Create test for team voice channel creating --- tests/bot/cogs/test_jams.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 94c48b995..2e1419f8e 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -94,3 +94,15 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.assertFalse(overwrites[self.ctx.guild.default_role].connect) self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].read_messages) self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].connect) + + @patch("bot.cogs.jams.utils") + async def test_team_voice_channel_creation(self, utils_mock): + """Should create new voice channel for team.""" + utils_mock.get.return_value = "foo" + await self.cog.createteam(self.cog, self.ctx, "my-team", (MockMember() for _ in range(5))) + # Make sure that we awaited function before getting call arguments + self.ctx.guild.create_voice_channel.assert_awaited_once() + + # All other arguments is possible to get somewhere else except this + overwrites = self.ctx.guild.create_voice_channel.call_args[1]["overwrites"] + self.ctx.guild.create_voice_channel.assert_awaited_once_with("My Team", overwrites=overwrites, category="foo") -- cgit v1.2.3 From b5b05adc41e55dd58810608f4ac7ade6281cdf84 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 18:37:27 +0300 Subject: Jam Tests: Create test for team jam roles adding --- tests/bot/cogs/test_jams.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 2e1419f8e..16caa98c6 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -106,3 +106,17 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): # All other arguments is possible to get somewhere else except this overwrites = self.ctx.guild.create_voice_channel.call_args[1]["overwrites"] self.ctx.guild.create_voice_channel.assert_awaited_once_with("My Team", overwrites=overwrites, category="foo") + + async def test_jam_roles_adding(self): + """Should add team leader role to leader and jam role to every team member.""" + leader_role = MockRole(name="Team Leader") + jam_role = MockRole(name="Jammer") + self.ctx.guild.get_role.side_effect = [MockRole(), leader_role, jam_role] + + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + leader.add_roles.assert_any_await(leader_role) + for member in members: + member.add_roles.assert_any_await(jam_role) -- cgit v1.2.3 From 76ad4d141027f6351e2feedc466c8acc805f671d Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 18:39:13 +0300 Subject: Jam Tests: Create test for successful `ctx.send` calling --- tests/bot/cogs/test_jams.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 16caa98c6..7db66ff11 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -120,3 +120,9 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): leader.add_roles.assert_any_await(leader_role) for member in members: member.add_roles.assert_any_await(jam_role) + + async def test_result_sending(self): + """Should call `ctx.send` when everything go right.""" + self.ctx.reset_mock() + await self.cog.createteam(self.cog, self.ctx, "foo", (MockMember() for _ in range(5))) + self.ctx.send.assert_awaited_once() -- cgit v1.2.3 From bbe4f137bd583d66a6bcb03102327bc6c586af86 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 23 May 2020 18:42:03 +0300 Subject: Jam Tests: Create test for `setup` function --- tests/bot/cogs/test_jams.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 7db66ff11..2c5cef835 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch -from bot.cogs.jams import CodeJams +from bot.cogs.jams import CodeJams, setup from bot.constants import Roles from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole @@ -126,3 +126,13 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.ctx.reset_mock() await self.cog.createteam(self.cog, self.ctx, "foo", (MockMember() for _ in range(5))) self.ctx.send.assert_awaited_once() + + +class CodeJamSetup(unittest.TestCase): + """Test for `setup` function of `CodeJam` cog.""" + + def test_setup(self): + """Should call `bot.add_cog`.""" + bot = MockBot() + setup(bot) + bot.add_cog.assert_called_once() -- cgit v1.2.3 From 5ca860fb3b2bcb77ab8574d83e8159df471f0faf Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 08:54:07 +0300 Subject: Jam Tests: Fix `test_result_sending` docstring Co-authored-by: Mark --- tests/bot/cogs/test_jams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 2c5cef835..51720d957 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -122,7 +122,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): member.add_roles.assert_any_await(jam_role) async def test_result_sending(self): - """Should call `ctx.send` when everything go right.""" + """Should call `ctx.send` when everything goes right.""" self.ctx.reset_mock() await self.cog.createteam(self.cog, self.ctx, "foo", (MockMember() for _ in range(5))) self.ctx.send.assert_awaited_once() -- cgit v1.2.3 From 28f33584b65b1f9d7e7254b4822d8896c7f19284 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 09:26:43 +0300 Subject: Jam Tests: Use class member of patch instead decorator on most of tests --- tests/bot/cogs/test_jams.py | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 51720d957..bf542458b 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -16,53 +16,52 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.guild = MockGuild([self.admin_role]) self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) self.cog = CodeJams(self.bot) + self.utils_mock = patch("bot.cogs.jams.utils").start() - @patch("bot.cogs.jams.utils") - async def test_too_small_amount_of_team_members_passed(self, utils_mock): + def tearDown(self): + self.utils_mock.stop() + + async def test_too_small_amount_of_team_members_passed(self): """Should `ctx.send` and exit early when too small amount of members.""" for case in (1, 2): with self.subTest(amount_of_members=case): self.ctx.reset_mock() - utils_mock.reset_mock() + self.utils_mock.reset_mock() await self.cog.createteam( self.cog, self.ctx, team_name="foo", members=(MockMember() for _ in range(case)) ) self.ctx.send.assert_awaited_once() - utils_mock.get.assert_not_called() + self.utils_mock.get.assert_not_called() - @patch("bot.cogs.jams.utils") - async def test_duplicate_members_provided(self, utils_mock): + async def test_duplicate_members_provided(self): """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" self.ctx.reset_mock() member = MockMember() await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) self.ctx.send.assert_awaited_once() - utils_mock.get.assert_not_called() + self.utils_mock.get.assert_not_called() - @patch("bot.cogs.jams.utils") - async def test_category_dont_exist(self, utils_mock): + async def test_category_dont_exist(self): """Should create code jam category.""" - utils_mock.get.return_value = None + self.utils_mock.get.return_value = None await self.cog.createteam(self.cog, self.ctx, "foo", (MockMember() for _ in range(5))) - utils_mock.get.assert_called_once() + self.utils_mock.get.assert_called_once() self.ctx.guild.create_category_channel.assert_awaited_once() category_overwrites = self.ctx.guild.create_category_channel.call_args[1]["overwrites"] self.assertFalse(category_overwrites[self.ctx.guild.default_role].read_messages) self.assertTrue(category_overwrites[self.ctx.guild.me].read_messages) - @patch("bot.cogs.jams.utils") - async def test_category_channel_exist(self, utils_mock): + async def test_category_channel_exist(self): """Should not try to create category channel.""" - utils_mock.return_value = "foo" + self.utils_mock.return_value = "foo" await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) - utils_mock.get.assert_called_once() + self.utils_mock.get.assert_called_once() self.ctx.guild.create_category_channel.assert_not_awaited() - @patch("bot.cogs.jams.utils") - async def test_team_text_channel_creation(self, utils_mock): + async def test_team_text_channel_creation(self): """Should create text channel for team.""" - utils_mock.get.return_value = "foo" + self.utils_mock.get.return_value = "foo" await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) # Make sure that we awaited function before getting call arguments self.ctx.guild.create_text_channel.assert_awaited_once() @@ -95,10 +94,9 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].read_messages) self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].connect) - @patch("bot.cogs.jams.utils") - async def test_team_voice_channel_creation(self, utils_mock): + async def test_team_voice_channel_creation(self): """Should create new voice channel for team.""" - utils_mock.get.return_value = "foo" + self.utils_mock.get.return_value = "foo" await self.cog.createteam(self.cog, self.ctx, "my-team", (MockMember() for _ in range(5))) # Make sure that we awaited function before getting call arguments self.ctx.guild.create_voice_channel.assert_awaited_once() -- cgit v1.2.3 From 1c860606a122ff1378cb55e228312acb2bb2d49e Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:16:50 +0300 Subject: Jam Tests: Make early exiting test more secure --- tests/bot/cogs/test_jams.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index bf542458b..98fa12f66 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from bot.cogs.jams import CodeJams, setup from bot.constants import Roles @@ -25,13 +25,17 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should `ctx.send` and exit early when too small amount of members.""" for case in (1, 2): with self.subTest(amount_of_members=case): + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + self.ctx.reset_mock() self.utils_mock.reset_mock() await self.cog.createteam( self.cog, self.ctx, team_name="foo", members=(MockMember() for _ in range(case)) ) self.ctx.send.assert_awaited_once() - self.utils_mock.get.assert_not_called() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() async def test_duplicate_members_provided(self): """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" -- cgit v1.2.3 From fd05997c1aa9054024ad62dc0cbf19c1a296f4b7 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:22:52 +0300 Subject: Jam Tests: Add more assertions to result message sending test --- tests/bot/cogs/test_jams.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 98fa12f66..4307d7deb 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -126,8 +126,14 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_result_sending(self): """Should call `ctx.send` when everything goes right.""" self.ctx.reset_mock() - await self.cog.createteam(self.cog, self.ctx, "foo", (MockMember() for _ in range(5))) + members = [MockMember() for _ in range(5)] + await self.cog.createteam(self.cog, self.ctx, "foo", members) self.ctx.send.assert_awaited_once() + sent_string = self.ctx.send.call_args[0][0] + + self.assertIn(str(self.ctx.guild.create_text_channel.return_value.mention), sent_string) + self.assertIn(members[0].mention, sent_string) + self.assertIn(" ".join(member.mention for member in members[1:]), sent_string) class CodeJamSetup(unittest.TestCase): -- cgit v1.2.3 From fa4783c5e15709625e21d6a1aa766664eb2423e2 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:24:27 +0300 Subject: Jam Tests: Apply recent changes to overwrites test --- tests/bot/cogs/test_jams.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 4307d7deb..1cbff2674 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -78,8 +78,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should have correct permission overwrites for users and roles.""" leader = MockMember() members = [leader] + [MockMember() for _ in range(4)] - await self.cog.createteam(self.cog, self.ctx, "foo", members) - overwrites = self.ctx.guild.create_text_channel.call_args[1]["overwrites"] + overwrites = self.cog.get_overwrites(members, self.ctx) # Leader permission overwrites self.assertTrue(overwrites[leader].manage_messages) -- cgit v1.2.3 From 0d2b61fd72f7b44d0534901c8f2e6ee3ccaad3f7 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:36:11 +0300 Subject: Jam Tests: Merge text and voice channel creation tests --- tests/bot/cogs/test_jams.py | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 1cbff2674..54f906ed9 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -1,9 +1,9 @@ import unittest -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from bot.cogs.jams import CodeJams, setup from bot.constants import Roles -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): @@ -63,17 +63,6 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.utils_mock.get.assert_called_once() self.ctx.guild.create_category_channel.assert_not_awaited() - async def test_team_text_channel_creation(self): - """Should create text channel for team.""" - self.utils_mock.get.return_value = "foo" - await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) - # Make sure that we awaited function before getting call arguments - self.ctx.guild.create_text_channel.assert_awaited_once() - - # All other arguments is possible to get somewhere else except this - overwrites = self.ctx.guild.create_text_channel.call_args[1]["overwrites"] - self.ctx.guild.create_text_channel.assert_awaited_once_with("bar", overwrites=overwrites, category="foo") - async def test_channel_overwrites(self): """Should have correct permission overwrites for users and roles.""" leader = MockMember() @@ -97,16 +86,30 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].read_messages) self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].connect) - async def test_team_voice_channel_creation(self): - """Should create new voice channel for team.""" + async def test_team_channels_creation(self): + """Should create new voice and text channel for team.""" self.utils_mock.get.return_value = "foo" - await self.cog.createteam(self.cog, self.ctx, "my-team", (MockMember() for _ in range(5))) - # Make sure that we awaited function before getting call arguments - self.ctx.guild.create_voice_channel.assert_awaited_once() + members = [MockMember() for _ in range(5)] - # All other arguments is possible to get somewhere else except this - overwrites = self.ctx.guild.create_voice_channel.call_args[1]["overwrites"] - self.ctx.guild.create_voice_channel.assert_awaited_once_with("My Team", overwrites=overwrites, category="foo") + self.cog.get_overwrites = MagicMock() + self.cog.get_category = AsyncMock() + self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") + actual = await self.cog.create_channels(self.ctx, "my-team", members) + + self.assertEqual("foobar-channel", actual) + self.cog.get_overwrites.assert_called_once_with(members, self.ctx) + self.cog.get_category.assert_awaited_once_with(self.ctx) + + self.ctx.guild.create_text_channel.assert_awaited_once_with( + "my-team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + self.ctx.guild.create_voice_channel.assert_awaited_once_with( + "My Team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) async def test_jam_roles_adding(self): """Should add team leader role to leader and jam role to every team member.""" -- cgit v1.2.3 From 4af2be7310141ab3ddc34a2184366c0d8212cdd5 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:39:46 +0300 Subject: Jam Tests: Simplify and update `test_category_channel_exist` --- tests/bot/cogs/test_jams.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 54f906ed9..ae3e35dbb 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -58,9 +58,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_category_channel_exist(self): """Should not try to create category channel.""" - self.utils_mock.return_value = "foo" - await self.cog.createteam(self.cog, self.ctx, "bar", (MockMember() for _ in range(5))) - self.utils_mock.get.assert_called_once() + await self.cog.get_category(self.ctx) self.ctx.guild.create_category_channel.assert_not_awaited() async def test_channel_overwrites(self): -- cgit v1.2.3 From ea91aefe55bf52fca6714897347bb24d4a4efb5b Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:41:25 +0300 Subject: Jam Tests: Apply recent changes to `test_category_dont_exist` --- tests/bot/cogs/test_jams.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index ae3e35dbb..ecd06179f 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -48,8 +48,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_category_dont_exist(self): """Should create code jam category.""" self.utils_mock.get.return_value = None - await self.cog.createteam(self.cog, self.ctx, "foo", (MockMember() for _ in range(5))) - self.utils_mock.get.assert_called_once() + await self.cog.get_category(self.ctx) self.ctx.guild.create_category_channel.assert_awaited_once() category_overwrites = self.ctx.guild.create_category_channel.call_args[1]["overwrites"] -- cgit v1.2.3 From 6e070a43f616f898e328bfc4581ed48551e73b12 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:47:13 +0300 Subject: Jam Tests: Implement default arguments To avoid repeating same arguments, added default arguments that is unpacked on function call. --- tests/bot/cogs/test_jams.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index ecd06179f..94be8dd03 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -17,6 +17,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) self.cog = CodeJams(self.bot) self.utils_mock = patch("bot.cogs.jams.utils").start() + self.default_args = [self.cog, self.ctx, "foo"] def tearDown(self): self.utils_mock.stop() @@ -30,9 +31,8 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.ctx.reset_mock() self.utils_mock.reset_mock() - await self.cog.createteam( - self.cog, self.ctx, team_name="foo", members=(MockMember() for _ in range(case)) - ) + await self.cog.createteam(*self.default_args, (MockMember() for _ in range(case))) + self.ctx.send.assert_awaited_once() self.cog.create_channels.assert_not_awaited() self.cog.add_roles.assert_not_awaited() @@ -41,7 +41,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" self.ctx.reset_mock() member = MockMember() - await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) + await self.cog.createteam(*self.default_args, (member for _ in range(5))) self.ctx.send.assert_awaited_once() self.utils_mock.get.assert_not_called() -- cgit v1.2.3 From b129658bf260d458d5fad5925e945c78f881388a Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:49:00 +0300 Subject: Jam Tests: Remove unnecessary `Context` mock resets --- tests/bot/cogs/test_jams.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 94be8dd03..0f8ba3574 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -39,7 +39,6 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_duplicate_members_provided(self): """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" - self.ctx.reset_mock() member = MockMember() await self.cog.createteam(*self.default_args, (member for _ in range(5))) self.ctx.send.assert_awaited_once() @@ -124,7 +123,6 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_result_sending(self): """Should call `ctx.send` when everything goes right.""" - self.ctx.reset_mock() members = [MockMember() for _ in range(5)] await self.cog.createteam(self.cog, self.ctx, "foo", members) self.ctx.send.assert_awaited_once() -- cgit v1.2.3 From 0481bcc1d99dd9d7fe9d41276599437b11670b27 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 11 Jun 2020 11:50:40 +0300 Subject: Jam Tests: Apply recent command splitting to `test_jam_roles_adding` --- tests/bot/cogs/test_jams.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 0f8ba3574..54fe0b5f2 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -111,11 +111,11 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should add team leader role to leader and jam role to every team member.""" leader_role = MockRole(name="Team Leader") jam_role = MockRole(name="Jammer") - self.ctx.guild.get_role.side_effect = [MockRole(), leader_role, jam_role] + self.ctx.guild.get_role.side_effect = [leader_role, jam_role] leader = MockMember() members = [leader] + [MockMember() for _ in range(4)] - await self.cog.createteam(self.cog, self.ctx, "foo", members) + await self.cog.add_roles(self.ctx, members) leader.add_roles.assert_any_await(leader_role) for member in members: -- cgit v1.2.3 From 5c70a7dad3ee59e865df08affe7905a843a823ce Mon Sep 17 00:00:00 2001 From: kwzrd Date: Fri, 12 Jun 2020 22:05:15 +0200 Subject: Incidents tests: create new test module --- tests/bot/cogs/moderation/test_incidents.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/bot/cogs/moderation/test_incidents.py (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py new file mode 100644 index 000000000..e69de29bb -- cgit v1.2.3 From ae5028d5966ba126f902783db8ad685646f45f37 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Fri, 12 Jun 2020 23:14:41 +0200 Subject: Incidents tests: write tests for module-level helpers --- tests/bot/cogs/moderation/test_incidents.py | 135 ++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index e69de29bb..4c1f9bc07 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -0,0 +1,135 @@ +import enum +import unittest +from unittest.mock import AsyncMock, MagicMock, call, patch + +import discord + +from bot.cogs.moderation import incidents + + +@patch("bot.constants.Channels.incidents", 123) +class TestIsIncident(unittest.TestCase): + """ + Collection of tests for the `is_incident` helper function. + + In `setUp`, we will create a mock message which should qualify as an incident. Each + test case will then mutate this instance to make it **not** qualify, in various ways. + + Notice that we patch the #incidents channel id globally for this class. + """ + + def setUp(self) -> None: + """Prepare a mock message which should qualify as an incident.""" + self.incident = MagicMock( + discord.Message, + channel=MagicMock(discord.TextChannel, id=123), + content="this is an incident", + author=MagicMock(discord.User, bot=False), + pinned=False, + ) + + def test_is_incident_true(self): + """Message qualifies as an incident if unchanged.""" + self.assertTrue(incidents.is_incident(self.incident)) + + def check_false(self): + """Assert that `self.incident` does **not** qualify as an incident.""" + self.assertFalse(incidents.is_incident(self.incident)) + + def test_is_incident_false_channel(self): + """Message doesn't qualify if sent outside of #incidents.""" + self.incident.channel = MagicMock(discord.TextChannel, id=456) + self.check_false() + + def test_is_incident_false_content(self): + """Message doesn't qualify if content begins with hash symbol.""" + self.incident.content = "# this is a comment message" + self.check_false() + + def test_is_incident_false_author(self): + """Message doesn't qualify if author is a bot.""" + self.incident.author = MagicMock(discord.User, bot=True) + self.check_false() + + def test_is_incident_false_pinned(self): + """Message doesn't qualify if it is pinned.""" + self.incident.pinned = True + self.check_false() + + +class TestOwnReactions(unittest.TestCase): + """Assertions for the `own_reactions` function.""" + + def test_own_reactions(self): + """Only bot's own emoji are extracted from the input incident.""" + reactions = ( + MagicMock(discord.Reaction, emoji="A", me=True), + MagicMock(discord.Reaction, emoji="B", me=True), + MagicMock(discord.Reaction, emoji="C", me=False), + ) + message = MagicMock(discord.Message, reactions=reactions) + self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) + + +@patch("bot.cogs.moderation.incidents.ALLOWED_EMOJI", {"A", "B"}) +class TestHasSignals(unittest.TestCase): + """ + Assertions for the `has_signals` function. + + We patch `ALLOWED_EMOJI` globally. Each test function then patches `own_reactions` + as appropriate. + """ + + def test_has_signals_true(self): + """True when `own_reactions` returns all emoji in `ALLOWED_EMOJI`.""" + message = MagicMock(discord.Message) + own_reactions = MagicMock(return_value={"A", "B"}) + + with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): + self.assertTrue(incidents.has_signals(message)) + + def test_has_signals_false(self): + """False when `own_reactions` does not return all emoji in `ALLOWED_EMOJI`.""" + message = MagicMock(discord.Message) + own_reactions = MagicMock(return_value={"A", "C"}) + + with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): + self.assertFalse(incidents.has_signals(message)) + + +class Signal(enum.Enum): + A = "A" + B = "B" + + +@patch("bot.cogs.moderation.incidents.Signal", Signal) +class TestAddSignals(unittest.IsolatedAsyncioTestCase): + """ + Assertions for the `add_signals` coroutine. + + These are all fairly similar and could go into a single test function, but I found the + patching & sub-testing fairly awkward in that case and decided to split them up + to avoid unnecessary syntax noise. + """ + + def setUp(self): + """Prepare a mock incident message for tests to use.""" + self.incident = MagicMock(discord.Message, add_reaction=AsyncMock()) + + @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value=set())) + async def test_add_signals_missing(self): + """All emoji are added when none are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_has_calls([call("A"), call("B")]) + + @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A"})) + async def test_add_signals_partial(self): + """Only missing emoji are added when some are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_has_calls([call("B")]) + + @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"})) + async def test_add_signals_present(self): + """No emoji are added when all are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_not_called() -- cgit v1.2.3 From 9dbfe7da4cbc4d1820507e25ce56929b7fb55327 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 13 Jun 2020 08:26:19 +0300 Subject: Jam Tests: Update `Context` to `Guild` for tests too --- tests/bot/cogs/test_jams.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 54fe0b5f2..17b86601f 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -47,23 +47,23 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_category_dont_exist(self): """Should create code jam category.""" self.utils_mock.get.return_value = None - await self.cog.get_category(self.ctx) - self.ctx.guild.create_category_channel.assert_awaited_once() - category_overwrites = self.ctx.guild.create_category_channel.call_args[1]["overwrites"] + await self.cog.get_category(self.guild) + self.guild.create_category_channel.assert_awaited_once() + category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] - self.assertFalse(category_overwrites[self.ctx.guild.default_role].read_messages) - self.assertTrue(category_overwrites[self.ctx.guild.me].read_messages) + self.assertFalse(category_overwrites[self.guild.default_role].read_messages) + self.assertTrue(category_overwrites[self.guild.me].read_messages) async def test_category_channel_exist(self): """Should not try to create category channel.""" - await self.cog.get_category(self.ctx) - self.ctx.guild.create_category_channel.assert_not_awaited() + await self.cog.get_category(self.guild) + self.guild.create_category_channel.assert_not_awaited() async def test_channel_overwrites(self): """Should have correct permission overwrites for users and roles.""" leader = MockMember() members = [leader] + [MockMember() for _ in range(4)] - overwrites = self.cog.get_overwrites(members, self.ctx) + overwrites = self.cog.get_overwrites(members, self.guild) # Leader permission overwrites self.assertTrue(overwrites[leader].manage_messages) @@ -77,10 +77,10 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(overwrites[member].connect) # Everyone and verified role overwrite - self.assertFalse(overwrites[self.ctx.guild.default_role].read_messages) - self.assertFalse(overwrites[self.ctx.guild.default_role].connect) - self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].read_messages) - self.assertFalse(overwrites[self.ctx.guild.get_role(Roles.verified)].connect) + self.assertFalse(overwrites[self.guild.default_role].read_messages) + self.assertFalse(overwrites[self.guild.default_role].connect) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) async def test_team_channels_creation(self): """Should create new voice and text channel for team.""" @@ -90,18 +90,18 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.cog.get_overwrites = MagicMock() self.cog.get_category = AsyncMock() self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") - actual = await self.cog.create_channels(self.ctx, "my-team", members) + actual = await self.cog.create_channels(self.guild, "my-team", members) self.assertEqual("foobar-channel", actual) - self.cog.get_overwrites.assert_called_once_with(members, self.ctx) - self.cog.get_category.assert_awaited_once_with(self.ctx) + self.cog.get_overwrites.assert_called_once_with(members, self.guild) + self.cog.get_category.assert_awaited_once_with(self.guild) - self.ctx.guild.create_text_channel.assert_awaited_once_with( + self.guild.create_text_channel.assert_awaited_once_with( "my-team", overwrites=self.cog.get_overwrites.return_value, category=self.cog.get_category.return_value ) - self.ctx.guild.create_voice_channel.assert_awaited_once_with( + self.guild.create_voice_channel.assert_awaited_once_with( "My Team", overwrites=self.cog.get_overwrites.return_value, category=self.cog.get_category.return_value @@ -111,11 +111,11 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should add team leader role to leader and jam role to every team member.""" leader_role = MockRole(name="Team Leader") jam_role = MockRole(name="Jammer") - self.ctx.guild.get_role.side_effect = [leader_role, jam_role] + self.guild.get_role.side_effect = [leader_role, jam_role] leader = MockMember() members = [leader] + [MockMember() for _ in range(4)] - await self.cog.add_roles(self.ctx, members) + await self.cog.add_roles(self.guild, members) leader.add_roles.assert_any_await(leader_role) for member in members: -- cgit v1.2.3 From 2489b144b5bf131ec8b1b42e2ae1dd249cce4d3f Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 13 Jun 2020 08:35:35 +0300 Subject: Jam Tests: Simplify and make tests more secure --- tests/bot/cogs/test_jams.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 17b86601f..2d2eebabf 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -42,7 +42,8 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): member = MockMember() await self.cog.createteam(*self.default_args, (member for _ in range(5))) self.ctx.send.assert_awaited_once() - self.utils_mock.get.assert_not_called() + self.cog.create_channels.assert_now_awaited() + self.cog.add_roles.assert_not_awaited() async def test_category_dont_exist(self): """Should create code jam category.""" @@ -125,12 +126,9 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should call `ctx.send` when everything goes right.""" members = [MockMember() for _ in range(5)] await self.cog.createteam(self.cog, self.ctx, "foo", members) + self.cog.create_channel.assert_awaited_once() + self.cog.add_roles.assert_awaited_once() self.ctx.send.assert_awaited_once() - sent_string = self.ctx.send.call_args[0][0] - - self.assertIn(str(self.ctx.guild.create_text_channel.return_value.mention), sent_string) - self.assertIn(members[0].mention, sent_string) - self.assertIn(" ".join(member.mention for member in members[1:]), sent_string) class CodeJamSetup(unittest.TestCase): -- cgit v1.2.3 From 95ae613173bb87719155a95494fe448a45a2d6bc Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 13 Jun 2020 08:39:35 +0300 Subject: Jam Tests: Fix wrong function name and convert them to mocks --- tests/bot/cogs/test_jams.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 2d2eebabf..a66658134 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -124,9 +124,11 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_result_sending(self): """Should call `ctx.send` when everything goes right.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() members = [MockMember() for _ in range(5)] await self.cog.createteam(self.cog, self.ctx, "foo", members) - self.cog.create_channel.assert_awaited_once() + self.cog.create_channels.assert_awaited_once() self.cog.add_roles.assert_awaited_once() self.ctx.send.assert_awaited_once() -- cgit v1.2.3 From ef67747e59892d1307246bcad4d32e245098ff58 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 13 Jun 2020 08:44:56 +0300 Subject: Jam Tests: Fix `test_duplicate_member_provided` assertions --- tests/bot/cogs/test_jams.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index a66658134..2f2cb4695 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -39,10 +39,12 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_duplicate_members_provided(self): """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() member = MockMember() await self.cog.createteam(*self.default_args, (member for _ in range(5))) self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_now_awaited() + self.cog.create_channels.assert_not_awaited() self.cog.add_roles.assert_not_awaited() async def test_category_dont_exist(self): -- cgit v1.2.3 From e9724dad79e7dab3bb801f50770bb06cf8461019 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 15:12:38 +0200 Subject: Incidents tests: use our own helper mocks No reason to build own MagicMocks as we already have helpers that more accurately mimic the mocked behaviour. --- tests/bot/cogs/moderation/test_incidents.py | 30 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 4c1f9bc07..d7cc84734 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -1,10 +1,9 @@ import enum import unittest -from unittest.mock import AsyncMock, MagicMock, call, patch - -import discord +from unittest.mock import MagicMock, call, patch from bot.cogs.moderation import incidents +from tests.helpers import MockMessage, MockReaction, MockTextChannel, MockUser @patch("bot.constants.Channels.incidents", 123) @@ -20,11 +19,10 @@ class TestIsIncident(unittest.TestCase): def setUp(self) -> None: """Prepare a mock message which should qualify as an incident.""" - self.incident = MagicMock( - discord.Message, - channel=MagicMock(discord.TextChannel, id=123), + self.incident = MockMessage( + channel=MockTextChannel(id=123), content="this is an incident", - author=MagicMock(discord.User, bot=False), + author=MockUser(bot=False), pinned=False, ) @@ -38,7 +36,7 @@ class TestIsIncident(unittest.TestCase): def test_is_incident_false_channel(self): """Message doesn't qualify if sent outside of #incidents.""" - self.incident.channel = MagicMock(discord.TextChannel, id=456) + self.incident.channel = MockTextChannel(id=456) self.check_false() def test_is_incident_false_content(self): @@ -48,7 +46,7 @@ class TestIsIncident(unittest.TestCase): def test_is_incident_false_author(self): """Message doesn't qualify if author is a bot.""" - self.incident.author = MagicMock(discord.User, bot=True) + self.incident.author = MockUser(bot=True) self.check_false() def test_is_incident_false_pinned(self): @@ -63,11 +61,11 @@ class TestOwnReactions(unittest.TestCase): def test_own_reactions(self): """Only bot's own emoji are extracted from the input incident.""" reactions = ( - MagicMock(discord.Reaction, emoji="A", me=True), - MagicMock(discord.Reaction, emoji="B", me=True), - MagicMock(discord.Reaction, emoji="C", me=False), + MockReaction(emoji="A", me=True), + MockReaction(emoji="B", me=True), + MockReaction(emoji="C", me=False), ) - message = MagicMock(discord.Message, reactions=reactions) + message = MockMessage(reactions=reactions) self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) @@ -82,7 +80,7 @@ class TestHasSignals(unittest.TestCase): def test_has_signals_true(self): """True when `own_reactions` returns all emoji in `ALLOWED_EMOJI`.""" - message = MagicMock(discord.Message) + message = MockMessage() own_reactions = MagicMock(return_value={"A", "B"}) with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): @@ -90,7 +88,7 @@ class TestHasSignals(unittest.TestCase): def test_has_signals_false(self): """False when `own_reactions` does not return all emoji in `ALLOWED_EMOJI`.""" - message = MagicMock(discord.Message) + message = MockMessage() own_reactions = MagicMock(return_value={"A", "C"}) with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): @@ -114,7 +112,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): def setUp(self): """Prepare a mock incident message for tests to use.""" - self.incident = MagicMock(discord.Message, add_reaction=AsyncMock()) + self.incident = MockMessage() @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value=set())) async def test_add_signals_missing(self): -- cgit v1.2.3 From 00a44226cb659319b9df5f568b0f67f9a0ed3360 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 15:51:34 +0200 Subject: Incidents tests: improve mock `Signal` name & move def Let's make it clear that this is our own mock. We also move the definition to the top of the module. --- tests/bot/cogs/moderation/test_incidents.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index d7cc84734..a349c1cb7 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -6,6 +6,11 @@ from bot.cogs.moderation import incidents from tests.helpers import MockMessage, MockReaction, MockTextChannel, MockUser +class MockSignal(enum.Enum): + A = "A" + B = "B" + + @patch("bot.constants.Channels.incidents", 123) class TestIsIncident(unittest.TestCase): """ @@ -95,12 +100,7 @@ class TestHasSignals(unittest.TestCase): self.assertFalse(incidents.has_signals(message)) -class Signal(enum.Enum): - A = "A" - B = "B" - - -@patch("bot.cogs.moderation.incidents.Signal", Signal) +@patch("bot.cogs.moderation.incidents.Signal", MockSignal) class TestAddSignals(unittest.IsolatedAsyncioTestCase): """ Assertions for the `add_signals` coroutine. -- cgit v1.2.3 From c66b4a618503352803f73e9272a1d27b6e0a4d52 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 17:24:31 +0200 Subject: Incidents tests: set up base class for `Incidents` For cleanliness, I've decided to make a separate class for each method. Since most tests will want to have an `Incident` instance ready, they can inherit the `setUp` from `TestIncidents`, which does not make any assertions on its own. --- tests/bot/cogs/moderation/test_incidents.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index a349c1cb7..d52932e0a 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -2,8 +2,8 @@ import enum import unittest from unittest.mock import MagicMock, call, patch -from bot.cogs.moderation import incidents -from tests.helpers import MockMessage, MockReaction, MockTextChannel, MockUser +from bot.cogs.moderation import Incidents, incidents +from tests.helpers import MockBot, MockMessage, MockReaction, MockTextChannel, MockUser class MockSignal(enum.Enum): @@ -131,3 +131,22 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): """No emoji are added when all are present.""" await incidents.add_signals(self.incident) self.incident.add_reaction.assert_not_called() + + +class TestIncidents(unittest.IsolatedAsyncioTestCase): + """ + Tests for bound methods of the `Incidents` cog. + + Use this as a base class for `Incidents` tests - it will prepare a fresh instance + for each test function, but not make any assertions on its own. Tests can mutate + the instance as they wish. + """ + + def setUp(self): + """ + Prepare a fresh `Incidents` instance for each test. + + Note that this will not schedule `crawl_incidents` in the background, as everything + is being mocked. The `crawl_task` attribute will end up being None. + """ + self.cog_instance = Incidents(MockBot()) -- cgit v1.2.3 From 3c2d227cd067466668e3089f63a6548736edf8ab Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 17:56:31 +0200 Subject: Incidents tests: write tests for `archive` --- tests/bot/cogs/moderation/test_incidents.py | 65 ++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index d52932e0a..7500235cf 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -1,9 +1,12 @@ import enum import unittest -from unittest.mock import MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, call, patch + +import aiohttp +import discord from bot.cogs.moderation import Incidents, incidents -from tests.helpers import MockBot, MockMessage, MockReaction, MockTextChannel, MockUser +from tests.helpers import MockAsyncWebhook, MockBot, MockMessage, MockReaction, MockTextChannel, MockUser class MockSignal(enum.Enum): @@ -150,3 +153,61 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): is being mocked. The `crawl_task` attribute will end up being None. """ self.cog_instance = Incidents(MockBot()) + + +class TestArchive(TestIncidents): + """Tests for the `Incidents.archive` coroutine.""" + + async def test_archive_webhook_not_found(self): + """ + Method recovers and returns False when the webhook is not found. + + Implicitly, this also tests that the error is handled internally and doesn't + propagate out of the method, which is just as important. + """ + mock_404 = discord.NotFound( + response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response + message="Webhook not found", + ) + + self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) + self.assertFalse(await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock())) + + async def test_archive_relays_incident(self): + """ + If webhook is found, method relays `incident` properly. + + This test will assert the following: + * The fetched webhook's `send` method is fed the correct arguments + * The message returned by `send` will have `outcome` reaction added + * Finally, the `archive` method returns True + + Assertions are made specifically in this order. + """ + webhook_message = MockMessage() # The message that will be returned by the webhook's `send` method + webhook = MockAsyncWebhook(send=AsyncMock(return_value=webhook_message)) + + self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) # Patch in our webhook + + # Now we'll pas our own `incident` to `archive` and capture the return value + incident = MockMessage( + clean_content="pingless message", + content="pingful message", + author=MockUser(name="author_name", avatar_url="author_avatar"), + id=123, + ) + archive_return = await self.cog_instance.archive(incident, outcome=MagicMock(value="A")) + + # Check that the webhook was dispatched correctly + webhook.send.assert_called_once_with( + content="pingless message", + username="author_name", + avatar_url="author_avatar", + wait=True, + ) + + # Now check that the correct emoji was added to the relayed message + webhook_message.add_reaction.assert_called_once_with("A") + + # Finally check that the method returned True + self.assertTrue(archive_return) -- cgit v1.2.3 From 39dc3cd229888acac2782237db4b9389c0788478 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 21:52:20 +0200 Subject: Incidents tests: move `mock_404` into module namespace This will be useful for others tests as well. --- tests/bot/cogs/moderation/test_incidents.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 7500235cf..e51bda114 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -14,6 +14,12 @@ class MockSignal(enum.Enum): B = "B" +mock_404 = discord.NotFound( + response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response + message="Not found", +) + + @patch("bot.constants.Channels.incidents", 123) class TestIsIncident(unittest.TestCase): """ @@ -165,11 +171,6 @@ class TestArchive(TestIncidents): Implicitly, this also tests that the error is handled internally and doesn't propagate out of the method, which is just as important. """ - mock_404 = discord.NotFound( - response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response - message="Webhook not found", - ) - self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) self.assertFalse(await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock())) -- cgit v1.2.3 From 8ed5cc7ef5e38885a8e439602b59e56449d3633c Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 21:52:34 +0200 Subject: Incidents tests: write tests for `resolve_message` --- tests/bot/cogs/moderation/test_incidents.py | 56 +++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index e51bda114..b3beec3ab 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -212,3 +212,59 @@ class TestArchive(TestIncidents): # Finally check that the method returned True self.assertTrue(archive_return) + + +class TestResolveMessage(TestIncidents): + """Tests for the `Incidents.resolve_message` coroutine.""" + + async def test_resolve_message_pass_message_id(self): + """Method will call `_get_message` with the passed `message_id`.""" + await self.cog_instance.resolve_message(123) + self.cog_instance.bot._connection._get_message.assert_called_once_with(123) + + async def test_resolve_message_in_cache(self): + """ + No API call is made if the queried message exists in the cache. + + We mock the `_get_message` return value regardless of input. Whether it finds the message + internally is considered d.py's responsibility, not ours. + """ + cached_message = MockMessage(id=123) + self.cog_instance.bot._connection._get_message = MagicMock(return_value=cached_message) + + return_value = await self.cog_instance.resolve_message(123) + + self.assertIs(return_value, cached_message) + self.cog_instance.bot.get_channel.assert_not_called() # The `fetch_message` line was never hit + + async def test_resolve_message_not_in_cache(self): + """ + The message is retrieved from the API if it isn't cached. + + This is desired behaviour for messages which exist, but were sent before the bot's + current session. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + # API returns our message + uncached_message = MockMessage() + fetch_message = AsyncMock(return_value=uncached_message) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + retrieved_message = await self.cog_instance.resolve_message(123) + self.assertIs(retrieved_message, uncached_message) + + async def test_resolve_message_doesnt_exist(self): + """ + If the API returns a 404, the function handles it gracefully and returns None. + + This is an edge-case happening with racing events - event A will relay the message + to the archive and delete the original. Once event B acquires the `event_lock`, + it will not find the message in the cache, and will ask the API. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + fetch_message = AsyncMock(side_effect=mock_404) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + self.assertIsNone(await self.cog_instance.resolve_message(123)) -- cgit v1.2.3 From bbedcb377c4c31973f43f076c3f62646f25733b3 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 22:15:38 +0200 Subject: Incidents tests: test non-404 error response --- tests/bot/cogs/moderation/test_incidents.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index b3beec3ab..cbeb3342c 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -1,4 +1,5 @@ import enum +import logging import unittest from unittest.mock import AsyncMock, MagicMock, call, patch @@ -268,3 +269,22 @@ class TestResolveMessage(TestIncidents): self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) self.assertIsNone(await self.cog_instance.resolve_message(123)) + + async def test_resolve_message_fetch_fails(self): + """ + Non-404 errors are handled, logged & None is returned. + + In contrast with a 404, this should make an error-level log. We assert that at least + one such log was made - we do not make any assertions about the log's message. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + arbitrary_error = discord.HTTPException( + response=MagicMock(aiohttp.ClientResponse), + message="Arbitrary error", + ) + fetch_message = AsyncMock(side_effect=arbitrary_error) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + with self.assertLogs(logger=incidents.log, level=logging.ERROR): + self.assertIsNone(await self.cog_instance.resolve_message(123)) -- cgit v1.2.3 From 14b7fee42ddf6a2cc75526506ef2028bdc742c9a Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 22:42:58 +0200 Subject: Incidents tests: write tests for `on_message` --- tests/bot/cogs/moderation/test_incidents.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index cbeb3342c..0eb13df70 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -288,3 +288,30 @@ class TestResolveMessage(TestIncidents): with self.assertLogs(logger=incidents.log, level=logging.ERROR): self.assertIsNone(await self.cog_instance.resolve_message(123)) + + +class TestOnMessage(TestIncidents): + """ + Tests for the `Incidents.on_message` listener. + + Notice the decorators mocking the `is_incident` return value. The `is_incidents` + function is tested in `TestIsIncident` - here we do not worry about it. + """ + + @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) + async def test_on_message_incident(self): + """Messages qualifying as incidents are passed to `add_signals`.""" + incident = MockMessage() + + with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: + await self.cog_instance.on_message(incident) + + mock_add_signals.assert_called_once_with(incident) + + @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)) + async def test_on_message_non_incident(self): + """Messages not qualifying as incidents are ignored.""" + with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: + await self.cog_instance.on_message(MockMessage()) + + mock_add_signals.assert_not_called() -- cgit v1.2.3 From 9d35846a67c2bf9ed9e935f8b5e3500ae4b49327 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 13 Jun 2020 23:24:14 +0200 Subject: Incidents tests: write tests for `make_confirmation_task` --- tests/bot/cogs/moderation/test_incidents.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 0eb13df70..c093afc8a 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -215,6 +215,41 @@ class TestArchive(TestIncidents): self.assertTrue(archive_return) +class TestMakeConfirmationTask(TestIncidents): + """ + Tests for the `Incidents.make_confirmation_task` method. + + Writing tests for this method is difficult, as it mostly just delegates the provided + information elsewhere. There is very little internal logic. Whether our approach + works conceptually is difficult to prove using unit tests. + """ + + def test_make_confirmation_task_check(self): + """ + The internal check will recognize the passed incident. + + This is a little tricky - we first pass a message with a specific `id` in, and then + retrieve the built check from the `call_args` of the `wait_for` method. This relies + on the check being passed as a kwarg. + + Once the check is retrieved, we assert that it gives True for our incident's `id`, + and False for any other. + + If this function begins to fail, first check that `created_check` is being retrieved + correctly. It should be the function that is built locally in the tested method. + """ + self.cog_instance.make_confirmation_task(MockMessage(id=123)) + + self.cog_instance.bot.wait_for.assert_called_once() + created_check = self.cog_instance.bot.wait_for.call_args.kwargs["check"] + + # The `message_id` matches the `id` of our incident + self.assertTrue(created_check(payload=MagicMock(message_id=123))) + + # This `message_id` does not match + self.assertFalse(created_check(payload=MagicMock(message_id=0))) + + class TestResolveMessage(TestIncidents): """Tests for the `Incidents.resolve_message` coroutine.""" -- cgit v1.2.3 From 81e50cb2c970fc5c203e135434f897b6a3f7e52a Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 14 Jun 2020 23:34:28 -0700 Subject: Sync tests: test listeners ignore events from other guilds --- tests/bot/cogs/sync/test_cog.py | 64 ++++++++++++++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 14fd909c4..d7d60e961 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -131,6 +131,12 @@ class SyncCogListenerTests(SyncCogTestCase): super().setUp() self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) + self.guild_id_patcher = mock.patch("bot.cogs.sync.cog.constants.Guild.id", 5) + self.guild_id = self.guild_id_patcher.start() + + def tearDown(self): + self.guild_id_patcher.stop() + async def test_sync_cog_on_guild_role_create(self): """A POST request should be sent with the new role's data.""" self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) @@ -142,20 +148,32 @@ class SyncCogListenerTests(SyncCogTestCase): "permissions": 8, "position": 23, } - role = helpers.MockRole(**role_data) + role = helpers.MockRole(**role_data, guild=self.guild_id) await self.cog.on_guild_role_create(role) self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + async def test_sync_cog_on_guild_role_create_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=0) + await self.cog.on_guild_role_create(role) + self.bot.api_client.post.assert_not_awaited() + async def test_sync_cog_on_guild_role_delete(self): """A DELETE request should be sent.""" self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) - role = helpers.MockRole(id=99) + role = helpers.MockRole(id=99, guild=self.guild_id) await self.cog.on_guild_role_delete(role) self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=0) + await self.cog.on_guild_role_delete(role) + self.bot.api_client.delete.assert_not_awaited() + async def test_sync_cog_on_guild_role_update(self): """A PUT request should be sent if the colour, name, permissions, or position changes.""" self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) @@ -180,8 +198,8 @@ class SyncCogListenerTests(SyncCogTestCase): after_role_data = role_data.copy() after_role_data[attribute] = 876 - before_role = helpers.MockRole(**role_data) - after_role = helpers.MockRole(**after_role_data) + before_role = helpers.MockRole(**role_data, guild=self.guild_id) + after_role = helpers.MockRole(**after_role_data, guild=self.guild_id) await self.cog.on_guild_role_update(before_role, after_role) @@ -193,11 +211,17 @@ class SyncCogListenerTests(SyncCogTestCase): else: self.bot.api_client.put.assert_not_called() + async def test_sync_cog_on_guild_role_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=0) + await self.cog.on_guild_role_update(role, role) + self.bot.api_client.put.assert_not_awaited() + async def test_sync_cog_on_member_remove(self): - """Member should patched to set in_guild as False.""" + """Member should be patched to set in_guild as False.""" self.assertTrue(self.cog.on_member_remove.__cog_listener__) - member = helpers.MockMember() + member = helpers.MockMember(guild=self.guild_id) await self.cog.on_member_remove(member) self.cog.patch_user.assert_called_once_with( @@ -205,14 +229,20 @@ class SyncCogListenerTests(SyncCogTestCase): updated_information={"in_guild": False} ) + async def test_sync_cog_on_member_remove_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=0) + await self.cog.on_member_remove(member) + self.cog.patch_user.assert_not_awaited() + async def test_sync_cog_on_member_update_roles(self): """Members should be patched if their roles have changed.""" self.assertTrue(self.cog.on_member_update.__cog_listener__) # Roles are intentionally unsorted. before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] - before_member = helpers.MockMember(roles=before_roles) - after_member = helpers.MockMember(roles=before_roles[1:]) + before_member = helpers.MockMember(roles=before_roles, guild=self.guild_id) + after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild_id) await self.cog.on_member_update(before_member, after_member) @@ -233,13 +263,19 @@ class SyncCogListenerTests(SyncCogTestCase): with self.subTest(attribute=attribute): self.cog.patch_user.reset_mock() - before_member = helpers.MockMember(**{attribute: old_value}) - after_member = helpers.MockMember(**{attribute: new_value}) + before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild_id) + after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild_id) await self.cog.on_member_update(before_member, after_member) self.cog.patch_user.assert_not_called() + async def test_sync_cog_on_member_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=0) + await self.cog.on_member_update(member, member) + self.cog.patch_user.assert_not_awaited() + async def test_sync_cog_on_user_update(self): """A user should be patched only if the name, discriminator, or avatar changes.""" self.assertTrue(self.cog.on_user_update.__cog_listener__) @@ -290,6 +326,7 @@ class SyncCogListenerTests(SyncCogTestCase): member = helpers.MockMember( discriminator="1234", roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + guild=self.guild_id, ) data = { @@ -334,6 +371,13 @@ class SyncCogListenerTests(SyncCogTestCase): self.bot.api_client.post.assert_not_called() + async def test_sync_cog_on_member_join_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=0) + await self.cog.on_member_join(member) + self.bot.api_client.post.assert_not_awaited() + self.bot.api_client.put.assert_not_awaited() + class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): """Tests for the commands in the Sync cog.""" -- cgit v1.2.3 From 4d6acdf32a323de8b88fed464358d70faf35c9d1 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 14 Jun 2020 23:47:40 -0700 Subject: Sync: ignore 404s in on_user_update 404s probably mean the user is from another guild. --- bot/cogs/sync/cog.py | 14 ++++++++------ tests/bot/cogs/sync/test_cog.py | 17 ++++++++++------- 2 files changed, 18 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index 97ea31ba5..578cccfc9 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -34,14 +34,15 @@ class Sync(Cog): for syncer in (self.role_syncer, self.user_syncer): await syncer.sync(guild) - async def patch_user(self, user_id: int, updated_information: Dict[str, Any]) -> None: + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: """Send a PATCH request to partially update a user in the database.""" try: - await self.bot.api_client.patch(f"bot/users/{user_id}", json=updated_information) + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) except ResponseCodeError as e: if e.response.status != 404: raise - log.warning("Unable to update user, got 404. Assuming race condition from join event.") + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") @Cog.listener() async def on_guild_role_create(self, role: Role) -> None: @@ -137,7 +138,7 @@ class Sync(Cog): if member.guild != constants.Guild.id: return - await self.patch_user(member.id, updated_information={"in_guild": False}) + await self.patch_user(member.id, json={"in_guild": False}) @Cog.listener() async def on_member_update(self, before: Member, after: Member) -> None: @@ -147,7 +148,7 @@ class Sync(Cog): if before.roles != after.roles: updated_information = {"roles": sorted(role.id for role in after.roles)} - await self.patch_user(after.id, updated_information=updated_information) + await self.patch_user(after.id, json=updated_information) @Cog.listener() async def on_user_update(self, before: User, after: User) -> None: @@ -158,7 +159,8 @@ class Sync(Cog): "name": after.name, "discriminator": int(after.discriminator), } - await self.patch_user(after.id, updated_information=updated_information) + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) @commands.group(name='sync') @commands.has_permissions(administrator=True) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index d7d60e961..e5be14391 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -226,7 +226,7 @@ class SyncCogListenerTests(SyncCogTestCase): self.cog.patch_user.assert_called_once_with( member.id, - updated_information={"in_guild": False} + json={"in_guild": False} ) async def test_sync_cog_on_member_remove_ignores_guilds(self): @@ -247,7 +247,7 @@ class SyncCogListenerTests(SyncCogTestCase): await self.cog.on_member_update(before_member, after_member) data = {"roles": sorted(role.id for role in after_member.roles)} - self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + self.cog.patch_user.assert_called_once_with(after_member.id, json=data) async def test_sync_cog_on_member_update_other(self): """Members should not be patched if other attributes have changed.""" @@ -308,12 +308,15 @@ class SyncCogListenerTests(SyncCogTestCase): # Don't care if *all* keys are present; only the changed one is required call_args = self.cog.patch_user.call_args - self.assertEqual(call_args[0][0], after_user.id) - self.assertIn("updated_information", call_args[1]) + self.assertEqual(call_args.args[0], after_user.id) + self.assertIn("json", call_args.kwargs) - updated_information = call_args[1]["updated_information"] - self.assertIn(api_field, updated_information) - self.assertEqual(updated_information[api_field], api_value) + self.assertIn("ignore_404", call_args.kwargs) + self.assertTrue(call_args.kwargs["ignore_404"]) + + json = call_args.kwargs["json"] + self.assertIn(api_field, json) + self.assertEqual(json[api_field], api_value) else: self.cog.patch_user.assert_not_called() -- cgit v1.2.3 From c7373fa1143a2d2f2d784a59d40bcb40ee765bfb Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 15 Jun 2020 10:26:23 -0700 Subject: Token remover: ignore DMs It's a private channel so there's no risk of a token "leaking". Furthermore, messages cannot be deleted in DMs. --- bot/cogs/token_remover.py | 3 +++ tests/bot/cogs/test_token_remover.py | 10 ++++++++++ 2 files changed, 13 insertions(+) (limited to 'tests') diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index d55e079e9..493479df9 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -63,6 +63,9 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ + if not msg.guild: + return # Ignore DMs; can't delete messages in there anyway. + found_token = self.find_token_in_message(msg) if found_token: await self.take_action(msg, found_token) diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index a10124d2d..22c31d7b1 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -121,6 +121,16 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): find_token_in_message.assert_called_once_with(self.msg) take_action.assert_not_awaited() + @autospec(TokenRemover, "find_token_in_message") + async def test_on_message_ignores_dms(self, find_token_in_message): + """Shouldn't parse a message if it is a DM.""" + cog = TokenRemover(self.bot) + self.msg.guild = None + + await cog.on_message(self.msg) + + find_token_in_message.assert_not_called() + @autospec("bot.cogs.token_remover", "TOKEN_RE") def test_find_token_ignores_bot_messages(self, token_re): """The token finder should ignore messages authored by bots.""" -- cgit v1.2.3 From 2fa7429327e787a65803c16609da21463723bfeb Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 15 Jun 2020 10:38:46 -0700 Subject: Token remover: move bot check to on_message It just makes more sense to me to filter out messages at an earlier stage. --- bot/cogs/token_remover.py | 8 +++----- tests/bot/cogs/test_token_remover.py | 23 +++++++---------------- 2 files changed, 10 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 493479df9..1f7517501 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -63,8 +63,9 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ - if not msg.guild: - return # Ignore DMs; can't delete messages in there anyway. + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return found_token = self.find_token_in_message(msg) if found_token: @@ -115,9 +116,6 @@ class TokenRemover(Cog): @classmethod def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: """Return a seemingly valid token found in `msg` or `None` if no token is found.""" - if msg.author.bot: - return - # Use finditer rather than search to guard against method calls prematurely returning the # token check (e.g. `message.channel.send` also matches our token pattern) for match in TOKEN_RE.finditer(msg.content): diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 22c31d7b1..98ea9f823 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -122,24 +122,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): take_action.assert_not_awaited() @autospec(TokenRemover, "find_token_in_message") - async def test_on_message_ignores_dms(self, find_token_in_message): - """Shouldn't parse a message if it is a DM.""" + async def test_on_message_ignores_dms_bots(self, find_token_in_message): + """Shouldn't parse a message if it is a DM or authored by a bot.""" cog = TokenRemover(self.bot) - self.msg.guild = None + dm_msg = MockMessage(guild=None) + bot_msg = MockMessage(author=MagicMock(bot=True)) - await cog.on_message(self.msg) - - find_token_in_message.assert_not_called() - - @autospec("bot.cogs.token_remover", "TOKEN_RE") - def test_find_token_ignores_bot_messages(self, token_re): - """The token finder should ignore messages authored by bots.""" - self.msg.author.bot = True - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_not_called() + for msg in (dm_msg, bot_msg): + await cog.on_message(msg) + find_token_in_message.assert_not_called() @autospec("bot.cogs.token_remover", "TOKEN_RE") def test_find_token_no_matches(self, token_re): -- cgit v1.2.3 From 3aecf14419c87e533d47fe082abeb54ca9edb73c Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 15 Jun 2020 10:49:18 -0700 Subject: Token remover: exit early if message already deleted --- bot/cogs/token_remover.py | 10 ++++++++-- tests/bot/cogs/test_token_remover.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 1f7517501..ef979f222 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -4,7 +4,7 @@ import logging import re import typing as t -from discord import Colour, Message +from discord import Colour, Message, NotFound from discord.ext.commands import Cog from bot import utils @@ -83,7 +83,13 @@ class TokenRemover(Cog): async def take_action(self, msg: Message, found_token: Token) -> None: """Remove the `msg` containing the `found_token` and send a mod log message.""" self.mod_log.ignore(Event.message_delete, msg.id) - await msg.delete() + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") + return + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) log_message = self.format_log_message(msg, found_token) diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 98ea9f823..3349caa73 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -3,7 +3,7 @@ from re import Match from unittest import mock from unittest.mock import MagicMock -from discord import Colour +from discord import Colour, NotFound from bot import constants from bot.cogs import token_remover @@ -282,6 +282,19 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): channel_id=constants.Channels.mod_alerts ) + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + async def test_take_action_delete_failure(self, mod_log_property): + """Shouldn't send any messages if the token message can't be deleted.""" + cog = TokenRemover(self.bot) + mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) + self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) + + token = mock.create_autospec(Token, spec_set=True, instance=True) + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_not_awaited() + class TokenRemoverExtensionTests(unittest.TestCase): """Tests for the token_remover extension.""" -- cgit v1.2.3 From ebc0eae42c1da67f61f040a67bc1b70e53a6f97e Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 17 Jun 2020 16:46:43 -0700 Subject: Sync: fix guild ID check Need to compare the IDs against each other rather than the Guild object against the ID. --- bot/cogs/sync/cog.py | 12 ++++++------ tests/bot/cogs/sync/test_cog.py | 35 +++++++++++++++++++---------------- 2 files changed, 25 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index 578cccfc9..5ace957e7 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -47,7 +47,7 @@ class Sync(Cog): @Cog.listener() async def on_guild_role_create(self, role: Role) -> None: """Adds newly create role to the database table over the API.""" - if role.guild != constants.Guild.id: + if role.guild.id != constants.Guild.id: return await self.bot.api_client.post( @@ -64,7 +64,7 @@ class Sync(Cog): @Cog.listener() async def on_guild_role_delete(self, role: Role) -> None: """Deletes role from the database when it's deleted from the guild.""" - if role.guild != constants.Guild.id: + if role.guild.id != constants.Guild.id: return await self.bot.api_client.delete(f'bot/roles/{role.id}') @@ -72,7 +72,7 @@ class Sync(Cog): @Cog.listener() async def on_guild_role_update(self, before: Role, after: Role) -> None: """Syncs role with the database if any of the stored attributes were updated.""" - if after.guild != constants.Guild.id: + if after.guild.id != constants.Guild.id: return was_updated = ( @@ -103,7 +103,7 @@ class Sync(Cog): previously left), it will update the user's information. If the user is not yet known by the database, the user is added. """ - if member.guild != constants.Guild.id: + if member.guild.id != constants.Guild.id: return packed = { @@ -135,7 +135,7 @@ class Sync(Cog): @Cog.listener() async def on_member_remove(self, member: Member) -> None: """Set the in_guild field to False when a member leaves the guild.""" - if member.guild != constants.Guild.id: + if member.guild.id != constants.Guild.id: return await self.patch_user(member.id, json={"in_guild": False}) @@ -143,7 +143,7 @@ class Sync(Cog): @Cog.listener() async def on_member_update(self, before: Member, after: Member) -> None: """Update the roles of the member in the database if a change is detected.""" - if after.guild != constants.Guild.id: + if after.guild.id != constants.Guild.id: return if before.roles != after.roles: diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index e5be14391..120bc991d 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -134,6 +134,9 @@ class SyncCogListenerTests(SyncCogTestCase): self.guild_id_patcher = mock.patch("bot.cogs.sync.cog.constants.Guild.id", 5) self.guild_id = self.guild_id_patcher.start() + self.guild = helpers.MockGuild(id=self.guild_id) + self.other_guild = helpers.MockGuild(id=0) + def tearDown(self): self.guild_id_patcher.stop() @@ -148,14 +151,14 @@ class SyncCogListenerTests(SyncCogTestCase): "permissions": 8, "position": 23, } - role = helpers.MockRole(**role_data, guild=self.guild_id) + role = helpers.MockRole(**role_data, guild=self.guild) await self.cog.on_guild_role_create(role) self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) async def test_sync_cog_on_guild_role_create_ignores_guilds(self): """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=0) + role = helpers.MockRole(guild=self.other_guild) await self.cog.on_guild_role_create(role) self.bot.api_client.post.assert_not_awaited() @@ -163,14 +166,14 @@ class SyncCogListenerTests(SyncCogTestCase): """A DELETE request should be sent.""" self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) - role = helpers.MockRole(id=99, guild=self.guild_id) + role = helpers.MockRole(id=99, guild=self.guild) await self.cog.on_guild_role_delete(role) self.bot.api_client.delete.assert_called_once_with("bot/roles/99") async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=0) + role = helpers.MockRole(guild=self.other_guild) await self.cog.on_guild_role_delete(role) self.bot.api_client.delete.assert_not_awaited() @@ -198,8 +201,8 @@ class SyncCogListenerTests(SyncCogTestCase): after_role_data = role_data.copy() after_role_data[attribute] = 876 - before_role = helpers.MockRole(**role_data, guild=self.guild_id) - after_role = helpers.MockRole(**after_role_data, guild=self.guild_id) + before_role = helpers.MockRole(**role_data, guild=self.guild) + after_role = helpers.MockRole(**after_role_data, guild=self.guild) await self.cog.on_guild_role_update(before_role, after_role) @@ -213,7 +216,7 @@ class SyncCogListenerTests(SyncCogTestCase): async def test_sync_cog_on_guild_role_update_ignores_guilds(self): """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=0) + role = helpers.MockRole(guild=self.other_guild) await self.cog.on_guild_role_update(role, role) self.bot.api_client.put.assert_not_awaited() @@ -221,7 +224,7 @@ class SyncCogListenerTests(SyncCogTestCase): """Member should be patched to set in_guild as False.""" self.assertTrue(self.cog.on_member_remove.__cog_listener__) - member = helpers.MockMember(guild=self.guild_id) + member = helpers.MockMember(guild=self.guild) await self.cog.on_member_remove(member) self.cog.patch_user.assert_called_once_with( @@ -231,7 +234,7 @@ class SyncCogListenerTests(SyncCogTestCase): async def test_sync_cog_on_member_remove_ignores_guilds(self): """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=0) + member = helpers.MockMember(guild=self.other_guild) await self.cog.on_member_remove(member) self.cog.patch_user.assert_not_awaited() @@ -241,8 +244,8 @@ class SyncCogListenerTests(SyncCogTestCase): # Roles are intentionally unsorted. before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] - before_member = helpers.MockMember(roles=before_roles, guild=self.guild_id) - after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild_id) + before_member = helpers.MockMember(roles=before_roles, guild=self.guild) + after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) await self.cog.on_member_update(before_member, after_member) @@ -263,8 +266,8 @@ class SyncCogListenerTests(SyncCogTestCase): with self.subTest(attribute=attribute): self.cog.patch_user.reset_mock() - before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild_id) - after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild_id) + before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) + after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) await self.cog.on_member_update(before_member, after_member) @@ -272,7 +275,7 @@ class SyncCogListenerTests(SyncCogTestCase): async def test_sync_cog_on_member_update_ignores_guilds(self): """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=0) + member = helpers.MockMember(guild=self.other_guild) await self.cog.on_member_update(member, member) self.cog.patch_user.assert_not_awaited() @@ -329,7 +332,7 @@ class SyncCogListenerTests(SyncCogTestCase): member = helpers.MockMember( discriminator="1234", roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], - guild=self.guild_id, + guild=self.guild, ) data = { @@ -376,7 +379,7 @@ class SyncCogListenerTests(SyncCogTestCase): async def test_sync_cog_on_member_join_ignores_guilds(self): """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=0) + member = helpers.MockMember(guild=self.other_guild) await self.cog.on_member_join(member) self.bot.api_client.post.assert_not_awaited() self.bot.api_client.put.assert_not_awaited() -- cgit v1.2.3 From 40e00ff17465fc5a5fe6b46487bfea37655cd7b9 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Thu, 18 Jun 2020 19:33:59 +0200 Subject: Incidents tests: write tests for `process_event` This also breaks the helpers import statement into a vertical list, as the amount of imports has grown too much. I still believe that this is a preferred alternative to accessing the helpers via module namespace, as we use them a lot, and the added visual noise would be annoying to read - their names are already descriptive enough. --- tests/bot/cogs/moderation/test_incidents.py | 102 +++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index c093afc8a..6158d5d20 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -1,3 +1,4 @@ +import asyncio import enum import logging import unittest @@ -7,7 +8,16 @@ import aiohttp import discord from bot.cogs.moderation import Incidents, incidents -from tests.helpers import MockAsyncWebhook, MockBot, MockMessage, MockReaction, MockTextChannel, MockUser +from tests.helpers import ( + MockAsyncWebhook, + MockBot, + MockMember, + MockMessage, + MockReaction, + MockRole, + MockTextChannel, + MockUser, +) class MockSignal(enum.Enum): @@ -250,6 +260,96 @@ class TestMakeConfirmationTask(TestIncidents): self.assertFalse(created_check(payload=MagicMock(message_id=0))) +@patch("bot.cogs.moderation.incidents.ALLOWED_ROLES", {1, 2}) +@patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", AsyncMock()) # Generic awaitable +class TestProcessEvent(TestIncidents): + """Tests for the `Incidents.process_event` coroutine.""" + + @patch("bot.cogs.moderation.incidents.ALLOWED_ROLES", {1, 2}) + async def test_process_event_bad_role(self): + """The reaction is removed when the author lacks all allowed roles.""" + incident = MockMessage() + member = MockMember(roles=[MockRole(id=0)]) # Must have role 1 or 2 + + await self.cog_instance.process_event("reaction", incident, member) + incident.remove_reaction.assert_called_once_with("reaction", member) + + async def test_process_event_bad_emoji(self): + """ + The reaction is removed when an invalid emoji is used. + + This requires that we pass in a `member` with valid roles, as we need the role check + to succeed. + """ + incident = MockMessage() + member = MockMember(roles=[MockRole(id=1)]) # Member has allowed role + + await self.cog_instance.process_event("invalid_signal", incident, member) + incident.remove_reaction.assert_called_once_with("invalid_signal", member) + + async def test_process_event_no_archive_on_investigating(self): + """Message is not archived on `Signal.INVESTIGATING`.""" + with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive: + await self.cog_instance.process_event( + reaction=incidents.Signal.INVESTIGATING.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]), + ) + + mocked_archive.assert_not_called() + + async def test_process_event_no_delete_if_archive_fails(self): + """ + Original message is not deleted when `Incidents.archive` returns False. + + This is the way of signaling that the relay failed, and we should not remove the original, + as that would result in losing the incident record. + """ + incident = MockMessage() + + with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=incident, + member=MockMember(roles=[MockRole(id=1)]) + ) + + incident.delete.assert_not_called() + + async def test_process_event_confirmation_task_is_awaited(self): + """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" + mock_task = AsyncMock() + + with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]) + ) + + mock_task.assert_awaited() + + async def test_process_event_confirmation_task_timeout_is_handled(self): + """ + Confirmation task `asyncio.TimeoutError` is handled gracefully. + + We have `make_confirmation_task` return a mock with a side effect, and then catch the + exception should it propagate out of `process_event`. This is so that we can then manually + fail the test with a more informative message than just the plain traceback. + """ + mock_task = AsyncMock(side_effect=asyncio.TimeoutError()) + + try: + with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]) + ) + except asyncio.TimeoutError: + self.fail("TimeoutError was not handled gracefully, and propagated out of `process_event`!") + + class TestResolveMessage(TestIncidents): """Tests for the `Incidents.resolve_message` coroutine.""" -- cgit v1.2.3 From ed4097629601704f0c65fc40cceb5fd6757d4779 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Fri, 19 Jun 2020 14:32:31 +0200 Subject: Incidents tests: add helper for mocking async for-loops See the docstring. This does not make the ambition to be powerful enough to be included in `tests.helpers`, and is only intended for local purposes. --- tests/bot/cogs/moderation/test_incidents.py | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 6158d5d20..7fa8847ef 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -1,6 +1,7 @@ import asyncio import enum import logging +import typing as t import unittest from unittest.mock import AsyncMock, MagicMock, call, patch @@ -20,6 +21,42 @@ from tests.helpers import ( ) +class MockAsyncIterable: + """ + Helper for mocking asynchronous for loops. + + It does not appear that the `unittest` library currently provides anything that would + allow us to simply mock an async iterator, such as `discord.TextChannel.history`. + + We therefore write our own helper to wrap a regular synchronous iterable, and feed + its values via `__anext__` rather than `__next__`. + + This class was written for the purposes of testing the `Incidents` cog - it may not + be generic enough to be placed in the `tests.helpers` module. + """ + + def __init__(self, messages: t.Iterable): + """Take a sync iterable to be wrapped.""" + self.iter_messages = iter(messages) + + def __aiter__(self): + """Return `self` as we provide the `__anext__` method.""" + return self + + async def __anext__(self): + """ + Feed the next item, or raise `StopAsyncIteration`. + + Since we're wrapping a sync iterator, it will communicate that it has been depleted + by raising a `StopIteration`. The `async for` construct does not expect it, and we + therefore need to substitute it for the appropriate exception type. + """ + try: + return next(self.iter_messages) + except StopIteration: + raise StopAsyncIteration + + class MockSignal(enum.Enum): A = "A" B = "B" -- cgit v1.2.3 From d93ed5d801c08b7fb084427906e7ac484ac3563f Mon Sep 17 00:00:00 2001 From: kwzrd Date: Fri, 19 Jun 2020 14:37:44 +0200 Subject: Incidents tests: write tests for `crawl_incidents` --- tests/bot/cogs/moderation/test_incidents.py | 58 +++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 7fa8847ef..4e6dfd5f7 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -209,6 +209,64 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): self.cog_instance = Incidents(MockBot()) +@patch("asyncio.sleep", AsyncMock()) # Prevent the coro from sleeping to speed up the test +class TestCrawlIncidents(TestIncidents): + """ + Tests for the `Incidents.crawl_incidents` coroutine. + + Apart from `test_crawl_incidents_waits_until_cache_ready`, all tests in this class + will patch the return values of `is_incident` and `has_signal` and then observe + whether the `AsyncMock` for `add_signals` was awaited or not. + + The `add_signals` mock is added by each test separately to ensure it is clean (has not + been awaited by another test yet). The mock can be reset, but this appears to be the + cleaner way. + + For each test, we inject a mock channel with a history of 1 message only (see: `setUp`). + """ + + def setUp(self): + """For each test, ensure `bot.get_channel` returns a channel with 1 arbitrary message.""" + super().setUp() # First ensure we get `cog_instance` from parent + + incidents_history = MagicMock(return_value=MockAsyncIterable([MockMessage()])) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(history=incidents_history)) + + async def test_crawl_incidents_waits_until_cache_ready(self): + """ + The coroutine will await the `wait_until_guild_available` event. + + Since this task is schedule in the `__init__`, it is critical that it waits for the + cache to be ready, so that it can safely get the #incidents channel. + """ + await self.cog_instance.crawl_incidents() + self.cog_instance.bot.wait_until_guild_available.assert_awaited() + + @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)) # Message doesn't qualify + @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=False)) + async def test_crawl_incidents_noop_if_is_not_incident(self): + """Signals are not added for a non-incident message.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_not_awaited() + + @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies + @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=True)) # But already has signals + async def test_crawl_incidents_noop_if_message_already_has_signals(self): + """Signals are not added for messages which already have them.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_not_awaited() + + @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies + @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=False)) # And doesn't have signals + async def test_crawl_incidents_add_signals_called(self): + """Message has signals added as it does not have them yet and qualifies as an incident.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_awaited_once() + + class TestArchive(TestIncidents): """Tests for the `Incidents.archive` coroutine.""" -- cgit v1.2.3 From 9a58b45cad51c961ad34fa9de9aaa060446c54fd Mon Sep 17 00:00:00 2001 From: kwzrd Date: Fri, 19 Jun 2020 16:57:15 +0200 Subject: Incidents tests: write tests for `on_raw_reaction_add` --- tests/bot/cogs/moderation/test_incidents.py | 128 ++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 4e6dfd5f7..55b15ec9e 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -520,6 +520,134 @@ class TestResolveMessage(TestIncidents): self.assertIsNone(await self.cog_instance.resolve_message(123)) +@patch("bot.constants.Channels.incidents", 123) +class TestOnRawReactionAdd(TestIncidents): + """ + Tests for the `Incidents.on_raw_reaction_add` listener. + + Writing tests for this listener comes with additional complexity due to the listener + awaiting the `crawl_task` task. See `asyncSetUp` for further details, which attempts + to make unit testing this function possible. + """ + + def setUp(self): + """ + Prepare & assign `payload` attribute. + + This attribute represents an *ideal* payload which will not be rejected by the + listener. As each test will receive a fresh instance, it can be mutated to + observe how the listener's behaviour changes with different attributes on + the passed payload. + """ + super().setUp() # Ensure `cog_instance` is assigned + + self.payload = MagicMock( + discord.RawReactionActionEvent, + channel_id=123, # Patched at class level + message_id=456, + member=MockMember(bot=False), + emoji="reaction", + ) + + async def asyncSetUp(self): # noqa: N802 + """ + Prepare an empty task and assign it as `crawl_task`. + + It appears that the `unittest` framework does not provide anything for mocking + asyncio tasks. An `AsyncMock` instance can be called and then awaited, however, + it does not provide the `done` method or any other parts of the `asyncio.Task` + interface. + + Although we do not need to make any assertions about the task itself while + testing the listener, the code will still await it and call the `done` method, + and so we must inject something that will not fail on either action. + + Note that this is done in an `asyncSetUp`, which runs after `setUp`. + The justification is that creating an actual task requires the event + loop to be ready, which is not the case in the `setUp`. + """ + mock_task = asyncio.create_task(AsyncMock()()) # Mock async func, then a coro + self.cog_instance.crawl_task = mock_task + + async def test_on_raw_reaction_add_wrong_channel(self): + """ + Events outside of #incidents will be ignored. + + We check this by asserting that `resolve_message` was never queried. + """ + self.payload.channel_id = 0 + self.cog_instance.resolve_message = AsyncMock() + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.resolve_message.assert_not_called() + + async def test_on_raw_reaction_add_user_is_bot(self): + """ + Events dispatched by bot accounts will be ignored. + + We check this by asserting that `resolve_message` was never queried. + """ + self.payload.member = MockMember(bot=True) + self.cog_instance.resolve_message = AsyncMock() + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.resolve_message.assert_not_called() + + async def test_on_raw_reaction_add_message_doesnt_exist(self): + """ + Listener gracefully handles the case where `resolve_message` gives None. + + We check this by asserting that `process_event` was never called. + """ + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=None) + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.process_event.assert_not_called() + + async def test_on_raw_reaction_add_message_is_not_an_incident(self): + """ + The event won't be processed if the related message is not an incident. + + This is an edge-case that can happen if someone manually leaves a reaction + on a pinned message, or a comment. + + We check this by asserting that `process_event` was never called. + """ + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=MockMessage()) + + with patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)): + await self.cog_instance.on_raw_reaction_add(self.payload) + + self.cog_instance.process_event.assert_not_called() + + async def test_on_raw_reaction_add_valid_event_is_processed(self): + """ + If the reaction event is valid, it is passed to `process_event`. + + This is the case when everything goes right: + * The reaction was placed in #incidents, and not by a bot + * The message was found successfully + * The message qualifies as an incident + + Additionally, we check that all arguments were passed as expected. + """ + incident = MockMessage(id=1) + + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=incident) + + with patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)): + await self.cog_instance.on_raw_reaction_add(self.payload) + + self.cog_instance.process_event.assert_called_with( + "reaction", # Defined in `self.payload` + incident, + self.payload.member, + ) + + class TestOnMessage(TestIncidents): """ Tests for the `Incidents.on_message` listener. -- cgit v1.2.3 From 581573f2ece96a9ec666795431ff21068e949a63 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 20 Jun 2020 01:20:35 +0200 Subject: Write unit test for `sub_clyde` --- tests/bot/utils/test_messages.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/bot/utils/test_messages.py (limited to 'tests') diff --git a/tests/bot/utils/test_messages.py b/tests/bot/utils/test_messages.py new file mode 100644 index 000000000..9c22c9751 --- /dev/null +++ b/tests/bot/utils/test_messages.py @@ -0,0 +1,27 @@ +import unittest + +from bot.utils import messages + + +class TestMessages(unittest.TestCase): + """Tests for functions in the `bot.utils.messages` module.""" + + def test_sub_clyde(self): + """Uppercase E's and lowercase e's are substituted with their cyrillic counterparts.""" + sub_e = "\u0435" + sub_E = "\u0415" # noqa: N806: Uppercase E in variable name + + test_cases = ( + (None, None), + ("", ""), + ("clyde", f"clyd{sub_e}"), + ("CLYDE", f"CLYD{sub_E}"), + ("cLyDe", f"cLyD{sub_e}"), + ("BIGclyde", f"BIGclyd{sub_e}"), + ("small clydeus the unholy", f"small clyd{sub_e}us the unholy"), + ("BIGCLYDE, babyclyde", f"BIGCLYD{sub_E}, babyclyd{sub_e}"), + ) + + for username_in, username_out in test_cases: + with self.subTest(input=username_in, expected_output=username_out): + self.assertEqual(messages.sub_clyde(username_in), username_out) -- cgit v1.2.3 From 98b8947ab7865e33f18da8e2a62b26405676e8e4 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 20 Jun 2020 13:13:45 +0200 Subject: Incidents: try-except Signal creation Suggested by Mark during review. This follows the "ask for forgiveness rather than permission" paradigm, ends up being less code to read, and may be seen as more logical / safer. The `ALLOWED_EMOJI` set was renamed to `ALL_SIGNALS` as this now better communicates the set's purpose. Tests adjusted as appropriate. Co-authored-by: MarkKoz --- bot/cogs/moderation/incidents.py | 18 +++++++++++------- tests/bot/cogs/moderation/test_incidents.py | 8 ++++---- 2 files changed, 15 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 089a5bc9f..41a98bcb7 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -33,9 +33,11 @@ class Signal(Enum): INVESTIGATING = Emojis.incident_investigating -# Reactions from roles not listed here, or using emoji not listed here, will be removed +# Reactions from roles not listed here will be removed ALLOWED_ROLES: t.Set[int] = {Roles.moderators, Roles.admins, Roles.owners} -ALLOWED_EMOJI: t.Set[str] = {signal.value for signal in Signal} + +# Message must have all of these emoji to pass the `has_signals` check +ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} def is_incident(message: discord.Message) -> bool: @@ -56,7 +58,7 @@ def own_reactions(message: discord.Message) -> t.Set[str]: def has_signals(message: discord.Message) -> bool: """True if `message` already has all `Signal` reactions, False otherwise.""" - return ALLOWED_EMOJI.issubset(own_reactions(message)) + return ALL_SIGNALS.issubset(own_reactions(message)) async def add_signals(incident: discord.Message) -> None: @@ -96,7 +98,9 @@ class Incidents(Cog): * See: `on_message` On reaction: - * Remove reaction if not permitted (`ALLOWED_EMOJI`, `ALLOWED_ROLES`) + * Remove reaction if not permitted + * User does not have any of the roles in `ALLOWED_ROLES` + * Used emoji is not a `Signal` member * If `Signal.ACTIONED` or `Signal.NOT_ACTIONED` were chosen, attempt to relay the incident message to #incidents-archive * If relay successful, delete original message @@ -217,13 +221,13 @@ class Incidents(Cog): await incident.remove_reaction(reaction, member) return - if reaction not in ALLOWED_EMOJI: + try: + signal = Signal(reaction) + except ValueError: log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") await incident.remove_reaction(reaction, member) return - # If we reach this point, we know that `emoji` is a `Signal` member - signal = Signal(reaction) log.trace(f"Received signal: {signal}") if signal not in (Signal.ACTIONED, Signal.NOT_ACTIONED): diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 55b15ec9e..862736785 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -131,17 +131,17 @@ class TestOwnReactions(unittest.TestCase): self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) -@patch("bot.cogs.moderation.incidents.ALLOWED_EMOJI", {"A", "B"}) +@patch("bot.cogs.moderation.incidents.ALL_SIGNALS", {"A", "B"}) class TestHasSignals(unittest.TestCase): """ Assertions for the `has_signals` function. - We patch `ALLOWED_EMOJI` globally. Each test function then patches `own_reactions` + We patch `ALL_SIGNALS` globally. Each test function then patches `own_reactions` as appropriate. """ def test_has_signals_true(self): - """True when `own_reactions` returns all emoji in `ALLOWED_EMOJI`.""" + """True when `own_reactions` returns all emoji in `ALL_SIGNALS`.""" message = MockMessage() own_reactions = MagicMock(return_value={"A", "B"}) @@ -149,7 +149,7 @@ class TestHasSignals(unittest.TestCase): self.assertTrue(incidents.has_signals(message)) def test_has_signals_false(self): - """False when `own_reactions` does not return all emoji in `ALLOWED_EMOJI`.""" + """False when `own_reactions` does not return all emoji in `ALL_SIGNALS`.""" message = MockMessage() own_reactions = MagicMock(return_value={"A", "C"}) -- cgit v1.2.3 From 20dbd177f227511b9c3cb678ab45a67558cd3d7f Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 20 Jun 2020 13:15:43 +0200 Subject: Incidents tests: remove unnecessary patch This is already being patched at class-level. --- tests/bot/cogs/moderation/test_incidents.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 862736785..9f0553216 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -360,7 +360,6 @@ class TestMakeConfirmationTask(TestIncidents): class TestProcessEvent(TestIncidents): """Tests for the `Incidents.process_event` coroutine.""" - @patch("bot.cogs.moderation.incidents.ALLOWED_ROLES", {1, 2}) async def test_process_event_bad_role(self): """The reaction is removed when the author lacks all allowed roles.""" incident = MockMessage() -- cgit v1.2.3 From f240a970c6b97d201959d25a79a8babafed1c2b1 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sat, 20 Jun 2020 14:15:34 +0200 Subject: Incidents tests: assert webhook username is de-clyded See: a8b4e394d9da57287cd9497cd9bb0a97fa467e84 --- tests/bot/cogs/moderation/test_incidents.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 9f0553216..2fc9180cf 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -319,6 +319,25 @@ class TestArchive(TestIncidents): # Finally check that the method returned True self.assertTrue(archive_return) + async def test_archive_clyde_username(self): + """ + The archive webhook username is cleansed using `sub_clyde`. + + Discord will reject any webhook with "clyde" in the username field, as it impersonates + the official Clyde bot. Since we do not control what the username will be (the incident + author name is used), we must ensure the name is cleansed, otherwise the relay may fail. + + This test assumes the username is passed as a kwarg. If this test fails, please review + whether the passed argument is being retrieved correctly. + """ + webhook = MockAsyncWebhook() + self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) + + message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) + await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal)) + + self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) + class TestMakeConfirmationTask(TestIncidents): """ -- cgit v1.2.3 From 6fa8caed037b247a7c194f58a4635de7dae21fd2 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sun, 21 Jun 2020 13:51:17 +0200 Subject: Incidents: implement `make_username` helper The justification is to incorporate the `actioned_by` name into the username in some way, and so the logical thing to do is to abstract this process into a helper so that it can easily be adjusted in the future. For now, I've chosen to separate the names by a pipe. Discord webhook username cannot exceed 80 characters in length, and so we cap it at this length by default. This is seen as more of an edge-case, but it should be accounted for, as we're not joining two names. The `max_length` param is configurable primarily for testing purposes, it probably should never be passed explicitly. This commit also provides two tests for the function. --- bot/cogs/moderation/incidents.py | 24 ++++++++++++++++++++++++ tests/bot/cogs/moderation/test_incidents.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 040f2c0c8..2cce9b6fe 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -41,6 +41,30 @@ ALLOWED_ROLES: t.Set[int] = {Roles.moderators, Roles.admins, Roles.owners} ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} +def make_username(reported_by: discord.Member, actioned_by: discord.Member, max_length: int = 80) -> str: + """ + Create a webhook-friendly username from the names of `reported_by` and `actioned_by`. + + If the resulting username length exceeds `max_length`, it will be capped at `max_length - 3` + and have 3 dots appended to the end. The default value is 80, which corresponds to the limit + Discord imposes on webhook username length. + + If the value of `max_length` is < 3, ValueError is raised. + """ + if max_length < 3: + raise ValueError(f"Maximum length cannot be less than 3: {max_length=}") + + username = f"{reported_by.name} | {actioned_by.name}" + log.trace(f"Generated webhook username: {username} (length: {len(username)})") + + if len(username) > max_length: + stop = max_length - 3 + username = f"{username[:stop]}..." + log.trace(f"Username capped at {max_length=}: {username}") + + return username + + def is_incident(message: discord.Message) -> bool: """True if `message` qualifies as an incident, False otherwise.""" conditions = ( diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 2fc9180cf..5700a5a35 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -68,6 +68,35 @@ mock_404 = discord.NotFound( ) +class TestMakeUsername(unittest.TestCase): + """Collection of tests for the `make_username` helper function.""" + + def test_make_username_raises(self): + """Raises `ValueError` on `max_length` < 3.""" + with self.assertRaises(ValueError): + incidents.make_username(MockMember(), MockMember(), max_length=2) + + def test_make_username_never_exceed_limit(self): + """ + The return string length is always less than or equal to `max_length`. + + For this test we pass `max_length=10` for convenience. The name of the first + user (`reported_by`) is always 1 character in length, but we generate names + for the `actioned_by` user starting at length 1 and up to length 20. + + Finally, we assert that the output length never exceeded 10 in total. + """ + user_a = MockMember(name="A") + + max_length = 10 + test_cases = (MockMember(name="B" * n) for n in range(1, 20)) + + for user_b in test_cases: + with self.subTest(user_a=user_a, user_b=user_b, max_length=max_length): + generated_username = incidents.make_username(user_a, user_b, max_length) + self.assertLessEqual(len(generated_username), max_length) + + @patch("bot.constants.Channels.incidents", 123) class TestIsIncident(unittest.TestCase): """ -- cgit v1.2.3 From a8d179d9b04f54b20c5e870bcfa85c78c42c8dca Mon Sep 17 00:00:00 2001 From: kwzrd Date: Sun, 21 Jun 2020 14:21:18 +0200 Subject: Incidents: append `actioned_by` to webhook username Incident author and the moderator who actioned report are now passed through `make_username` to create the webhook username. Tests adjusted as appropriate. --- bot/cogs/moderation/incidents.py | 9 +++++---- tests/bot/cogs/moderation/test_incidents.py | 23 +++++++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 2cce9b6fe..72cc4b26c 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -172,13 +172,14 @@ class Incidents(Cog): log.debug("Crawl task finished!") - async def archive(self, incident: discord.Message, outcome: Signal) -> bool: + async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: """ Relay `incident` to the #incidents-archive channel. The following pieces of information are relayed: * Incident message content (clean, pingless) - * Incident author name (as webhook author) + * Incident author name (as webhook username) + * Name of user who actioned the incident (appended to webhook username) * Incident author avatar (as webhook avatar) * Resolution signal (`outcome`) @@ -194,7 +195,7 @@ class Incidents(Cog): # Now relay the incident message: discord.Message = await webhook.send( content=incident.clean_content, # Clean content will prevent mentions from pinging - username=sub_clyde(incident.author.name), + username=sub_clyde(make_username(incident.author, actioned_by)), avatar_url=incident.author.avatar_url, wait=True, # This makes the method return the sent Message object ) @@ -259,7 +260,7 @@ class Incidents(Cog): log.debug("Reaction was valid, but no action is currently defined for it") return - relay_successful = await self.archive(incident, signal) + relay_successful = await self.archive(incident, signal, actioned_by=member) if not relay_successful: log.trace("Original message will not be deleted as we failed to relay it to the archive") return diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 5700a5a35..a811868e5 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -307,7 +307,9 @@ class TestArchive(TestIncidents): propagate out of the method, which is just as important. """ self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) - self.assertFalse(await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock())) + + result = await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) + self.assertFalse(result) async def test_archive_relays_incident(self): """ @@ -332,12 +334,18 @@ class TestArchive(TestIncidents): author=MockUser(name="author_name", avatar_url="author_avatar"), id=123, ) - archive_return = await self.cog_instance.archive(incident, outcome=MagicMock(value="A")) + + with patch("bot.cogs.moderation.incidents.make_username", MagicMock(return_value="generated_username")): + archive_return = await self.cog_instance.archive( + incident=incident, + outcome=MagicMock(value="A"), + actioned_by=MockMember(name="moderator"), + ) # Check that the webhook was dispatched correctly webhook.send.assert_called_once_with( content="pingless message", - username="author_name", + username="generated_username", avatar_url="author_avatar", wait=True, ) @@ -354,7 +362,8 @@ class TestArchive(TestIncidents): Discord will reject any webhook with "clyde" in the username field, as it impersonates the official Clyde bot. Since we do not control what the username will be (the incident - author name is used), we must ensure the name is cleansed, otherwise the relay may fail. + author name, and actioning moderator names are used), we must ensure the name is cleansed, + otherwise the relay may fail. This test assumes the username is passed as a kwarg. If this test fails, please review whether the passed argument is being retrieved correctly. @@ -362,9 +371,11 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) - await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal)) + # The `make_username` helper will return a string with "clyde" in it + with patch("bot.cogs.moderation.incidents.make_username", MagicMock(return_value="clyde the great")): + await self.cog_instance.archive(MockMessage(), MagicMock(incidents.Signal), MockMember()) + # Assert that the "clyde" was never passed to `send` self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) -- cgit v1.2.3 From 195e0f9407d2a8b7ac5b3028b4f10c1b73af0a4f Mon Sep 17 00:00:00 2001 From: Kyle Stanley Date: Fri, 26 Jun 2020 02:08:48 -0400 Subject: Update LinePaginator.add_line() tests --- tests/bot/test_pagination.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/test_pagination.py b/tests/bot/test_pagination.py index f2e2c27ce..74896f010 100644 --- a/tests/bot/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -18,18 +18,18 @@ class LinePaginatorTests(TestCase): self.assertEqual(len(self.paginator._pages), 0) def test_add_line_works_on_long_lines(self): - """`add_line` should scale long lines up to `scale_to_size`.""" - self.paginator.add_line('x' * self.paginator.scale_to_size) - self.assertEqual(len(self.paginator._pages), 1) + """After additional lines after `max_size` is exceeded should go on the next page.""" + self.paginator.add_line('x' * self.paginator.max_size) + self.assertEqual(len(self.paginator._pages), 0) # Any additional lines should start a new page after `max_size` is exceeded. self.paginator.add_line('x') - self.assertEqual(len(self.paginator._pages), 2) + self.assertEqual(len(self.paginator._pages), 1) def test_add_line_continuation(self): """When `scale_to_size` is exceeded, remaining words should be split onto the next page.""" self.paginator.add_line('zyz ' * (self.paginator.scale_to_size//4 + 1)) - self.assertEqual(len(self.paginator._pages), 2) + self.assertEqual(len(self.paginator._pages), 1) def test_add_line_no_continuation(self): """If adding a new line to an existing page would exceed `max_size`, it should start a new -- cgit v1.2.3 From 77ce4c88695ca748059a7076de88d5b42b37d5f5 Mon Sep 17 00:00:00 2001 From: Kyle Stanley Date: Fri, 26 Jun 2020 03:22:30 -0400 Subject: In LinePaginator, truncate words that exceed scale_to_size --- bot/pagination.py | 11 ++++++----- tests/bot/test_pagination.py | 12 +++++------- 2 files changed, 11 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/bot/pagination.py b/bot/pagination.py index 746ec3696..cd602c715 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -88,11 +88,9 @@ class LinePaginator(Paginator): if len(line) > (max_chars := self.max_size - len(self.prefix) - 2): if len(line) > self.scale_to_size: line, remaining_words = self._split_remaining_words(line, max_chars) - # If line still exceeds scale_to_size, we were unable to split into a second - # page without truncating. if len(line) > self.scale_to_size: - raise RuntimeError(f'Line exceeds maximum scale_to_size {self.scale_to_size}' - ' and could not be split.') + log.debug("Could not continue to next page, truncating line.") + line = line[:self.scale_to_size] if self.max_lines is not None and self._linecount >= self.max_lines: log.debug("max_lines exceeded, creating new page.") @@ -144,11 +142,14 @@ class LinePaginator(Paginator): reduced_words.append(word) reduced_char_count += len(word) + 1 else: + # If reduced_words is empty, we were unable to split the words across pages + if not reduced_words: + return line, None is_full = True remaining_words.append(word) else: remaining_words.append(word) - + return ( " ".join(reduced_words), continuation_header + " ".join(remaining_words) if remaining_words else None diff --git a/tests/bot/test_pagination.py b/tests/bot/test_pagination.py index 74896f010..ce880d457 100644 --- a/tests/bot/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -39,13 +39,11 @@ class LinePaginatorTests(TestCase): self.paginator.add_line('z') self.assertEqual(len(self.paginator._pages), 1) - def test_add_line_raises_on_very_long_words(self): - """`add_line` should raise if a single long word is added that exceeds `scale_to_size`. - - Note: truncation is also a potential option, but this should not occur from normal usage. - """ - with self.assertRaises(RuntimeError): - self.paginator.add_line('x' * (self.paginator.scale_to_size + 1)) + def test_add_line_truncates_very_long_words(self): + """`add_line` should truncate if a single long word exceeds `scale_to_size`.""" + self.paginator.add_line('x' * (self.paginator.scale_to_size + 1)) + # Note: item at index 1 is the truncated line, index 0 is prefix + self.assertEqual(self.paginator._current_page[1], 'x' * self.paginator.scale_to_size) class ImagePaginatorTests(TestCase): -- cgit v1.2.3 From be4a61fb70c485262d36ca2aabf992f3118abcff Mon Sep 17 00:00:00 2001 From: kwzrd Date: Tue, 30 Jun 2020 23:09:00 +0200 Subject: Incidents: revert latest 2 commits Decision was made to use embeds to archive incidents instead of webhooking the raw message. As such, we're reverting the branch to a state from which the adjustments will be easier to make. Reverted commits: * a8d179d9b04f54b20c5e870bcfa85c78c42c8dca * 6fa8caed037b247a7c194f58a4635de7dae21fd2 --- bot/cogs/moderation/incidents.py | 33 +++--------------- tests/bot/cogs/moderation/test_incidents.py | 52 ++++------------------------- 2 files changed, 10 insertions(+), 75 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 72cc4b26c..040f2c0c8 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -41,30 +41,6 @@ ALLOWED_ROLES: t.Set[int] = {Roles.moderators, Roles.admins, Roles.owners} ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} -def make_username(reported_by: discord.Member, actioned_by: discord.Member, max_length: int = 80) -> str: - """ - Create a webhook-friendly username from the names of `reported_by` and `actioned_by`. - - If the resulting username length exceeds `max_length`, it will be capped at `max_length - 3` - and have 3 dots appended to the end. The default value is 80, which corresponds to the limit - Discord imposes on webhook username length. - - If the value of `max_length` is < 3, ValueError is raised. - """ - if max_length < 3: - raise ValueError(f"Maximum length cannot be less than 3: {max_length=}") - - username = f"{reported_by.name} | {actioned_by.name}" - log.trace(f"Generated webhook username: {username} (length: {len(username)})") - - if len(username) > max_length: - stop = max_length - 3 - username = f"{username[:stop]}..." - log.trace(f"Username capped at {max_length=}: {username}") - - return username - - def is_incident(message: discord.Message) -> bool: """True if `message` qualifies as an incident, False otherwise.""" conditions = ( @@ -172,14 +148,13 @@ class Incidents(Cog): log.debug("Crawl task finished!") - async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: + async def archive(self, incident: discord.Message, outcome: Signal) -> bool: """ Relay `incident` to the #incidents-archive channel. The following pieces of information are relayed: * Incident message content (clean, pingless) - * Incident author name (as webhook username) - * Name of user who actioned the incident (appended to webhook username) + * Incident author name (as webhook author) * Incident author avatar (as webhook avatar) * Resolution signal (`outcome`) @@ -195,7 +170,7 @@ class Incidents(Cog): # Now relay the incident message: discord.Message = await webhook.send( content=incident.clean_content, # Clean content will prevent mentions from pinging - username=sub_clyde(make_username(incident.author, actioned_by)), + username=sub_clyde(incident.author.name), avatar_url=incident.author.avatar_url, wait=True, # This makes the method return the sent Message object ) @@ -260,7 +235,7 @@ class Incidents(Cog): log.debug("Reaction was valid, but no action is currently defined for it") return - relay_successful = await self.archive(incident, signal, actioned_by=member) + relay_successful = await self.archive(incident, signal) if not relay_successful: log.trace("Original message will not be deleted as we failed to relay it to the archive") return diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index a811868e5..2fc9180cf 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -68,35 +68,6 @@ mock_404 = discord.NotFound( ) -class TestMakeUsername(unittest.TestCase): - """Collection of tests for the `make_username` helper function.""" - - def test_make_username_raises(self): - """Raises `ValueError` on `max_length` < 3.""" - with self.assertRaises(ValueError): - incidents.make_username(MockMember(), MockMember(), max_length=2) - - def test_make_username_never_exceed_limit(self): - """ - The return string length is always less than or equal to `max_length`. - - For this test we pass `max_length=10` for convenience. The name of the first - user (`reported_by`) is always 1 character in length, but we generate names - for the `actioned_by` user starting at length 1 and up to length 20. - - Finally, we assert that the output length never exceeded 10 in total. - """ - user_a = MockMember(name="A") - - max_length = 10 - test_cases = (MockMember(name="B" * n) for n in range(1, 20)) - - for user_b in test_cases: - with self.subTest(user_a=user_a, user_b=user_b, max_length=max_length): - generated_username = incidents.make_username(user_a, user_b, max_length) - self.assertLessEqual(len(generated_username), max_length) - - @patch("bot.constants.Channels.incidents", 123) class TestIsIncident(unittest.TestCase): """ @@ -307,9 +278,7 @@ class TestArchive(TestIncidents): propagate out of the method, which is just as important. """ self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) - - result = await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) - self.assertFalse(result) + self.assertFalse(await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock())) async def test_archive_relays_incident(self): """ @@ -334,18 +303,12 @@ class TestArchive(TestIncidents): author=MockUser(name="author_name", avatar_url="author_avatar"), id=123, ) - - with patch("bot.cogs.moderation.incidents.make_username", MagicMock(return_value="generated_username")): - archive_return = await self.cog_instance.archive( - incident=incident, - outcome=MagicMock(value="A"), - actioned_by=MockMember(name="moderator"), - ) + archive_return = await self.cog_instance.archive(incident, outcome=MagicMock(value="A")) # Check that the webhook was dispatched correctly webhook.send.assert_called_once_with( content="pingless message", - username="generated_username", + username="author_name", avatar_url="author_avatar", wait=True, ) @@ -362,8 +325,7 @@ class TestArchive(TestIncidents): Discord will reject any webhook with "clyde" in the username field, as it impersonates the official Clyde bot. Since we do not control what the username will be (the incident - author name, and actioning moderator names are used), we must ensure the name is cleansed, - otherwise the relay may fail. + author name is used), we must ensure the name is cleansed, otherwise the relay may fail. This test assumes the username is passed as a kwarg. If this test fails, please review whether the passed argument is being retrieved correctly. @@ -371,11 +333,9 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - # The `make_username` helper will return a string with "clyde" in it - with patch("bot.cogs.moderation.incidents.make_username", MagicMock(return_value="clyde the great")): - await self.cog_instance.archive(MockMessage(), MagicMock(incidents.Signal), MockMember()) + message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) + await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal)) - # Assert that the "clyde" was never passed to `send` self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) -- cgit v1.2.3 From 968251660768297383401576902a71f8ac9edada Mon Sep 17 00:00:00 2001 From: kwzrd Date: Tue, 30 Jun 2020 23:15:02 +0200 Subject: Incidents: pass `actioned_by` to `archive` This is an important piece of information that shall be relayed. --- bot/cogs/moderation/incidents.py | 4 ++-- tests/bot/cogs/moderation/test_incidents.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 040f2c0c8..580a258fe 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -148,7 +148,7 @@ class Incidents(Cog): log.debug("Crawl task finished!") - async def archive(self, incident: discord.Message, outcome: Signal) -> bool: + async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: """ Relay `incident` to the #incidents-archive channel. @@ -235,7 +235,7 @@ class Incidents(Cog): log.debug("Reaction was valid, but no action is currently defined for it") return - relay_successful = await self.archive(incident, signal) + relay_successful = await self.archive(incident, signal, actioned_by=member) if not relay_successful: log.trace("Original message will not be deleted as we failed to relay it to the archive") return diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 2fc9180cf..c2e32fe6b 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -278,7 +278,9 @@ class TestArchive(TestIncidents): propagate out of the method, which is just as important. """ self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) - self.assertFalse(await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock())) + self.assertFalse( + await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) + ) async def test_archive_relays_incident(self): """ @@ -303,7 +305,7 @@ class TestArchive(TestIncidents): author=MockUser(name="author_name", avatar_url="author_avatar"), id=123, ) - archive_return = await self.cog_instance.archive(incident, outcome=MagicMock(value="A")) + archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) # Check that the webhook was dispatched correctly webhook.send.assert_called_once_with( @@ -334,7 +336,7 @@ class TestArchive(TestIncidents): self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) - await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal)) + await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) -- cgit v1.2.3 From dd74105d4a4433bb9e9e6fa57960a4956c0f1231 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Tue, 30 Jun 2020 23:42:32 +0200 Subject: Incidents: implement `make_embed` helper & tests See `make_embed` docstring for further information. The tests are fairly loose and should be easily adjustable in the future should changes be made. --- bot/cogs/moderation/incidents.py | 32 ++++++++++++++++++++++++++++- tests/bot/cogs/moderation/test_incidents.py | 26 +++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 580a258fe..ca591fc6e 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -1,13 +1,14 @@ import asyncio import logging import typing as t +from datetime import datetime from enum import Enum import discord from discord.ext.commands import Cog from bot.bot import Bot -from bot.constants import Channels, Emojis, Roles, Webhooks +from bot.constants import Channels, Colours, Emojis, Roles, Webhooks from bot.utils.messages import sub_clyde log = logging.getLogger(__name__) @@ -41,6 +42,35 @@ ALLOWED_ROLES: t.Set[int] = {Roles.moderators, Roles.admins, Roles.owners} ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} +def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> discord.Embed: + """ + Create an embed representation of `incident` for the #incidents-archive channel. + + The name & discriminator of `actioned_by` and `outcome` will be presented in the + embed footer. Additionally, the embed is coloured based on `outcome`. + + The author of `incident` is not shown in the embed. It is assumed that this piece + of information will be relayed in other ways, e.g. webhook username. + + As mentions in embeds do not ping, we do not need to use `incident.clean_content`. + """ + if outcome is Signal.ACTIONED: + colour = Colours.soft_green + footer = f"Actioned by {actioned_by}" + else: + colour = Colours.soft_red + footer = f"Rejected by {actioned_by}" + + embed = discord.Embed( + description=incident.content, + timestamp=datetime.utcnow(), + colour=colour, + ) + embed.set_footer(text=footer, icon_url=actioned_by.avatar_url) + + return embed + + def is_incident(message: discord.Message) -> bool: """True if `message` qualifies as an incident, False otherwise.""" conditions = ( diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index c2e32fe6b..4731a786d 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -9,6 +9,7 @@ import aiohttp import discord from bot.cogs.moderation import Incidents, incidents +from bot.constants import Colours from tests.helpers import ( MockAsyncWebhook, MockBot, @@ -68,6 +69,31 @@ mock_404 = discord.NotFound( ) +class TestMakeEmbed(unittest.TestCase): + """Collection of tests for the `make_embed` helper function.""" + + def test_make_embed_actioned(self): + """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" + embed = incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + + self.assertEqual(embed.colour.value, Colours.soft_green) + self.assertIn("Actioned", embed.footer.text) + + def test_make_embed_not_actioned(self): + """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" + embed = incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + + self.assertEqual(embed.colour.value, Colours.soft_red) + self.assertIn("Rejected", embed.footer.text) + + def test_make_embed_content(self): + """Incident content appears as embed description.""" + incident = MockMessage(content="this is an incident") + embed = incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertEqual(incident.content, embed.description) + + @patch("bot.constants.Channels.incidents", 123) class TestIsIncident(unittest.TestCase): """ -- cgit v1.2.3 From 744aed585162cb0547e61a538734f116459ab510 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Wed, 1 Jul 2020 16:52:58 +0200 Subject: Incidents: relay incidents as embeds rather than raw content This applies the previously defined `make_embed` function. As the `archive` function is now simpler, I decided to reduce the amount of whitespace ~ it's a lot more compact now. Tests are adjusted as appropriate. --- bot/cogs/moderation/incidents.py | 24 ++++++++-------------- tests/bot/cogs/moderation/test_incidents.py | 32 ++++++++++------------------- 2 files changed, 19 insertions(+), 37 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index ca591fc6e..3a1a3d84e 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -180,38 +180,30 @@ class Incidents(Cog): async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: """ - Relay `incident` to the #incidents-archive channel. + Relay an embed representation of `incident` to the #incidents-archive channel. The following pieces of information are relayed: - * Incident message content (clean, pingless) + * Incident message content (as embed description) * Incident author name (as webhook author) * Incident author avatar (as webhook avatar) - * Resolution signal (`outcome`) + * Resolution signal `outcome` (as embed colour & footer) + * Moderator `actioned_by` (name & discriminator shown in footer) Return True if the relay finishes successfully. If anything goes wrong, meaning not all information was relayed, return False. This signals that the original message is not safe to be deleted, as we will lose some information. """ - log.debug(f"Archiving incident: {incident.id} with outcome: {outcome}") + log.debug(f"Archiving incident: {incident.id} (outcome: {outcome}, actioned by: {actioned_by})") try: - # First we try to grab the webhook - webhook: discord.Webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive) - - # Now relay the incident - message: discord.Message = await webhook.send( - content=incident.clean_content, # Clean content will prevent mentions from pinging + webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive) + await webhook.send( + embed=make_embed(incident, outcome, actioned_by), username=sub_clyde(incident.author.name), avatar_url=incident.author.avatar_url, - wait=True, # This makes the method return the sent Message object ) - - # Finally add the `outcome` emoji - await message.add_reaction(outcome.value) - except Exception: log.exception(f"Failed to archive incident {incident.id} to #incidents-archive") return False - else: log.trace("Message archived successfully!") return True diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 4731a786d..70dfe6b5f 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -312,39 +312,29 @@ class TestArchive(TestIncidents): """ If webhook is found, method relays `incident` properly. - This test will assert the following: - * The fetched webhook's `send` method is fed the correct arguments - * The message returned by `send` will have `outcome` reaction added - * Finally, the `archive` method returns True - - Assertions are made specifically in this order. + This test will assert that the fetched webhook's `send` method is fed the correct arguments, + and that the `archive` method returns True. """ - webhook_message = MockMessage() # The message that will be returned by the webhook's `send` method - webhook = MockAsyncWebhook(send=AsyncMock(return_value=webhook_message)) - + webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) # Patch in our webhook - # Now we'll pas our own `incident` to `archive` and capture the return value + # Define our own `incident` for archivation incident = MockMessage( - clean_content="pingless message", - content="pingful message", + content="this is an incident", author=MockUser(name="author_name", avatar_url="author_avatar"), id=123, ) - archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) + built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this - # Check that the webhook was dispatched correctly + with patch("bot.cogs.moderation.incidents.make_embed", MagicMock(return_value=built_embed)): + archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) + + # Now we check that the webhook was given the correct args, and that `archive` returned True webhook.send.assert_called_once_with( - content="pingless message", + embed=built_embed, username="author_name", avatar_url="author_avatar", - wait=True, ) - - # Now check that the correct emoji was added to the relayed message - webhook_message.add_reaction.assert_called_once_with("A") - - # Finally check that the method returned True self.assertTrue(archive_return) async def test_archive_clyde_username(self): -- cgit v1.2.3 From 83544ca0f91dd7bc8510e4fc7a64bc73712ddaf8 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Fri, 3 Jul 2020 10:47:47 +0200 Subject: Incidents: archive incident attachments There is no handling of file types as explained in the `archive` docstring. Testing indicates that relaying incidents with e.g. a text file attachment is simply a noop in the Discord GUI. If there is at least one attachment, we always only relay the one at index 0, as it is believed the user-sent messages can only contain one attachment at maximum. This also adds an extra test asserting the behaviour when an incident with an attachment is archived. The existing test for `archive` is adjusted to assume no attachments. Joe helped me conceive & test this. Co-authored-by: Joseph Banks --- bot/cogs/moderation/incidents.py | 21 +++++++++++++++++++- tests/bot/cogs/moderation/test_incidents.py | 30 +++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 8970c2c5c..1a12c8bbd 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -186,22 +186,41 @@ class Incidents(Cog): The following pieces of information are relayed: * Incident message content (as embed description) + * Incident attachment (if image, shown in archive embed) * Incident author name (as webhook author) * Incident author avatar (as webhook avatar) * Resolution signal `outcome` (as embed colour & footer) * Moderator `actioned_by` (name & discriminator shown in footer) + If `incident` contains an attachment, we try to add it to the archive embed. There is + no handing of extensions / file types - we simply dispatch the attachment file with the + webhook, and try to display it in the embed. Testing indicates that if the attachment + cannot be displayed (e.g. a text file), it's invisible in the embed, with no error. + Return True if the relay finishes successfully. If anything goes wrong, meaning not all information was relayed, return False. This signals that the original message is not safe to be deleted, as we will lose some information. """ log.debug(f"Archiving incident: {incident.id} (outcome: {outcome}, actioned by: {actioned_by})") + embed = make_embed(incident, outcome, actioned_by) + + # If the incident had an attachment, we will try to relay it + if incident.attachments: + attachment = incident.attachments[0] # User-sent messages can only contain one attachment + log.debug(f"Attempting to archive incident attachment: {attachment.filename}") + + attachment_file = await attachment.to_file() # The file will be sent with the webhook + embed.set_image(url=f"attachment://{attachment.filename}") # Embed displays the attached file + else: + attachment_file = None + try: webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive) await webhook.send( - embed=make_embed(incident, outcome, actioned_by), + embed=embed, username=sub_clyde(incident.author.name), avatar_url=incident.author.avatar_url, + file=attachment_file, ) except Exception: log.exception(f"Failed to archive incident {incident.id} to #incidents-archive") diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 70dfe6b5f..f8d479cef 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -323,6 +323,7 @@ class TestArchive(TestIncidents): content="this is an incident", author=MockUser(name="author_name", avatar_url="author_avatar"), id=123, + attachments=[], # This incident has no attachments ) built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this @@ -334,9 +335,38 @@ class TestArchive(TestIncidents): embed=built_embed, username="author_name", avatar_url="author_avatar", + file=None, ) self.assertTrue(archive_return) + async def test_archive_relays_incident_with_attachments(self): + """ + Incident attachments are relayed and displayed in the embed. + + This test asserts the two things that need to happen in order to relay the attachment. + The embed returned by `make_embed` must have the `set_image` method called with the + attachment's filename, and the file must be passed to the webhook's send method. + """ + attachment_file = MagicMock(discord.File) + attachment = MagicMock( + discord.Attachment, + filename="abc.png", + to_file=AsyncMock(return_value=attachment_file), + ) + incident = MockMessage( + attachments=[attachment], + ) + built_embed = MagicMock(discord.Embed) + + with patch("bot.cogs.moderation.incidents.make_embed", MagicMock(return_value=built_embed)): + await self.cog_instance.archive(incident, incidents.Signal.ACTIONED, actioned_by=MockMember()) + + built_embed.set_image.assert_called_once_with(url="attachment://abc.png") + + send_kwargs = self.cog_instance.bot.fetch_webhook.return_value.send.call_args.kwargs + self.assertIn("file", send_kwargs) + self.assertIs(send_kwargs["file"], attachment_file) + async def test_archive_clyde_username(self): """ The archive webhook username is cleansed using `sub_clyde`. -- cgit v1.2.3 From 40719793f9c0d8a2c5761d3730b5920a146709c3 Mon Sep 17 00:00:00 2001 From: Den4200 Date: Mon, 6 Jul 2020 04:08:27 +0000 Subject: Add tests for cog_check and get_slowmode --- tests/bot/cogs/test_slowmode.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/bot/cogs/test_slowmode.py (limited to 'tests') diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py new file mode 100644 index 000000000..fb9f3c9ad --- /dev/null +++ b/tests/bot/cogs/test_slowmode.py @@ -0,0 +1,37 @@ +import unittest +from unittest import mock + +from bot.cogs.slowmode import Slowmode +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SlowmodeTests(unittest.IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.bot = MockBot() + self.cog = Slowmode(self.bot) + self.text_channel = MockTextChannel() + self.ctx = MockContext(channel=self.text_channel) + + async def test_get_slowmode_no_channel(self) -> None: + """Get slowmode without a given channel""" + self.text_channel.mention = '#python-general' + self.text_channel.slowmode_delay = 5 + + await self.cog.get_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") + + async def test_get_slowmode_with_channel(self) -> None: + """Get slowmode without a given channel""" + self.text_channel.mention = '#python-language' + self.text_channel.slowmode_delay = 2 + + await self.cog.get_slowmode(self.cog, self.ctx, self.text_channel) + self.ctx.send.assert_called_once_with("The slowmode delay for #python-language is 2 seconds.") + + @mock.patch("bot.cogs.slowmode.with_role_check") + @mock.patch("bot.cogs.slowmode.MODERATION_ROLES", new=(1, 2, 3)) + def test_cog_check(self, role_check): + """Role check is called with `MODERATION_ROLES`""" + self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) -- cgit v1.2.3 From e760b4312a5264fe9442cb1d53c9e357dbeb2b81 Mon Sep 17 00:00:00 2001 From: Den4200 Date: Mon, 6 Jul 2020 04:55:42 +0000 Subject: Add tests for reset_slowmode --- tests/bot/cogs/test_slowmode.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py index fb9f3c9ad..a2e5ad346 100644 --- a/tests/bot/cogs/test_slowmode.py +++ b/tests/bot/cogs/test_slowmode.py @@ -2,6 +2,7 @@ import unittest from unittest import mock from bot.cogs.slowmode import Slowmode +from bot.constants import Emojis from tests.helpers import MockBot, MockContext, MockTextChannel @@ -14,7 +15,7 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): self.ctx = MockContext(channel=self.text_channel) async def test_get_slowmode_no_channel(self) -> None: - """Get slowmode without a given channel""" + """Get slowmode without a given channel.""" self.text_channel.mention = '#python-general' self.text_channel.slowmode_delay = 5 @@ -22,12 +23,30 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") async def test_get_slowmode_with_channel(self) -> None: - """Get slowmode without a given channel""" + """Get slowmode with a given channel.""" self.text_channel.mention = '#python-language' self.text_channel.slowmode_delay = 2 await self.cog.get_slowmode(self.cog, self.ctx, self.text_channel) - self.ctx.send.assert_called_once_with("The slowmode delay for #python-language is 2 seconds.") + self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') + + async def test_reset_slowmode_no_channel(self) -> None: + """Reset slowmode without a given channel.""" + self.text_channel.mention = '#careers' + + await self.cog.reset_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' + ) + + async def test_reset_slowmode_with_channel(self) -> None: + """Reset slowmode with a given channel.""" + self.text_channel.mention = '#meta' + + await self.cog.reset_slowmode(self.cog, self.ctx, self.text_channel) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' + ) @mock.patch("bot.cogs.slowmode.with_role_check") @mock.patch("bot.cogs.slowmode.MODERATION_ROLES", new=(1, 2, 3)) -- cgit v1.2.3 From 8613659cb191bedca925dc798c89623b49c9a90a Mon Sep 17 00:00:00 2001 From: Den4200 Date: Mon, 6 Jul 2020 05:45:04 +0000 Subject: Add tests for set_slowmode --- tests/bot/cogs/test_slowmode.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py index a2e5ad346..5262ce34a 100644 --- a/tests/bot/cogs/test_slowmode.py +++ b/tests/bot/cogs/test_slowmode.py @@ -1,6 +1,8 @@ import unittest from unittest import mock +from dateutil.relativedelta import relativedelta + from bot.cogs.slowmode import Slowmode from bot.constants import Emojis from tests.helpers import MockBot, MockContext, MockTextChannel @@ -30,6 +32,24 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): await self.cog.get_slowmode(self.cog, self.ctx, self.text_channel) self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') + async def test_set_slowmode_no_channel(self) -> None: + """Set slowmode without a given channel.""" + self.text_channel.mention = '#careers' + + await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=3)) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #careers is now 3 seconds.' + ) + + async def test_set_slowmode_with_channel(self) -> None: + """Set slowmode with a given channel.""" + self.text_channel.mention = '#meta' + + await self.cog.set_slowmode(self.cog, self.ctx, self.text_channel, relativedelta(seconds=4)) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #meta is now 4 seconds.' + ) + async def test_reset_slowmode_no_channel(self) -> None: """Reset slowmode without a given channel.""" self.text_channel.mention = '#careers' -- cgit v1.2.3 From 4935ed5ae632f5887bcff23ac67c781eab8527e9 Mon Sep 17 00:00:00 2001 From: Den4200 Date: Mon, 6 Jul 2020 06:05:32 +0000 Subject: Use local text_channel instead of instance attribute --- tests/bot/cogs/test_slowmode.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py index 5262ce34a..663c9fd43 100644 --- a/tests/bot/cogs/test_slowmode.py +++ b/tests/bot/cogs/test_slowmode.py @@ -13,28 +13,25 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot() self.cog = Slowmode(self.bot) - self.text_channel = MockTextChannel() - self.ctx = MockContext(channel=self.text_channel) + self.ctx = MockContext() async def test_get_slowmode_no_channel(self) -> None: """Get slowmode without a given channel.""" - self.text_channel.mention = '#python-general' - self.text_channel.slowmode_delay = 5 + self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) await self.cog.get_slowmode(self.cog, self.ctx, None) self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") async def test_get_slowmode_with_channel(self) -> None: """Get slowmode with a given channel.""" - self.text_channel.mention = '#python-language' - self.text_channel.slowmode_delay = 2 + text_channel = MockTextChannel(name='python-language', slowmode_delay=2) - await self.cog.get_slowmode(self.cog, self.ctx, self.text_channel) + await self.cog.get_slowmode(self.cog, self.ctx, text_channel) self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') async def test_set_slowmode_no_channel(self) -> None: """Set slowmode without a given channel.""" - self.text_channel.mention = '#careers' + self.ctx.channel = MockTextChannel(name='careers') await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=3)) self.ctx.send.assert_called_once_with( @@ -43,16 +40,16 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): async def test_set_slowmode_with_channel(self) -> None: """Set slowmode with a given channel.""" - self.text_channel.mention = '#meta' + text_channel = MockTextChannel(name='meta') - await self.cog.set_slowmode(self.cog, self.ctx, self.text_channel, relativedelta(seconds=4)) + await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=4)) self.ctx.send.assert_called_once_with( f'{Emojis.check_mark} The slowmode delay for #meta is now 4 seconds.' ) async def test_reset_slowmode_no_channel(self) -> None: """Reset slowmode without a given channel.""" - self.text_channel.mention = '#careers' + self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) await self.cog.reset_slowmode(self.cog, self.ctx, None) self.ctx.send.assert_called_once_with( @@ -61,9 +58,9 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): async def test_reset_slowmode_with_channel(self) -> None: """Reset slowmode with a given channel.""" - self.text_channel.mention = '#meta' + text_channel = MockTextChannel(name='meta', slowmode_delay=1) - await self.cog.reset_slowmode(self.cog, self.ctx, self.text_channel) + await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) self.ctx.send.assert_called_once_with( f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' ) -- cgit v1.2.3 From 77a2e514dd2e200e23ccf45760677c2e7c40b9ff Mon Sep 17 00:00:00 2001 From: Den4200 Date: Mon, 6 Jul 2020 06:11:00 +0000 Subject: Add multiple test cases for set_slowmode tests --- tests/bot/cogs/test_slowmode.py | 44 +++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py index 663c9fd43..e9835b8bd 100644 --- a/tests/bot/cogs/test_slowmode.py +++ b/tests/bot/cogs/test_slowmode.py @@ -31,22 +31,46 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): async def test_set_slowmode_no_channel(self) -> None: """Set slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='careers') - - await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=3)) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #careers is now 3 seconds.' + test_cases = ( + ('helpers', 23, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), + ('mods', 76526, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), + ('admins', 97, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') ) + for channel_name, seconds, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + result_msg=result_msg + ): + self.ctx.channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + async def test_set_slowmode_with_channel(self) -> None: """Set slowmode with a given channel.""" - text_channel = MockTextChannel(name='meta') - - await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=4)) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #meta is now 4 seconds.' + test_cases = ( + ('bot-commands', 12, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), + ('mod-spam', 21, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), + ('admin-spam', 4323598, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') ) + for channel_name, seconds, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + result_msg=result_msg + ): + text_channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + async def test_reset_slowmode_no_channel(self) -> None: """Reset slowmode without a given channel.""" self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) -- cgit v1.2.3 From 2d170b8af92c77bedea4d77fbdeedc515d3f2c59 Mon Sep 17 00:00:00 2001 From: Den4200 Date: Mon, 6 Jul 2020 17:08:24 +0000 Subject: Improve set_slowmode tests by checking whether the channel was edited --- tests/bot/cogs/test_slowmode.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py index e9835b8bd..65b1534cb 100644 --- a/tests/bot/cogs/test_slowmode.py +++ b/tests/bot/cogs/test_slowmode.py @@ -32,20 +32,27 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): async def test_set_slowmode_no_channel(self) -> None: """Set slowmode without a given channel.""" test_cases = ( - ('helpers', 23, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), - ('mods', 76526, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), - ('admins', 97, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') + ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), + ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), + ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') ) - for channel_name, seconds, result_msg in test_cases: + for channel_name, seconds, edited, result_msg in test_cases: with self.subTest( channel_mention=channel_name, seconds=seconds, + edited=edited, result_msg=result_msg ): self.ctx.channel = MockTextChannel(name=channel_name) await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) + + if edited: + self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + self.ctx.channel.edit.assert_not_called() + self.ctx.send.assert_called_once_with(result_msg) self.ctx.reset_mock() @@ -53,20 +60,27 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): async def test_set_slowmode_with_channel(self) -> None: """Set slowmode with a given channel.""" test_cases = ( - ('bot-commands', 12, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), - ('mod-spam', 21, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), - ('admin-spam', 4323598, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') + ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), + ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), + ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') ) - for channel_name, seconds, result_msg in test_cases: + for channel_name, seconds, edited, result_msg in test_cases: with self.subTest( channel_mention=channel_name, seconds=seconds, + edited=edited, result_msg=result_msg ): text_channel = MockTextChannel(name=channel_name) await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) + + if edited: + text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + text_channel.edit.assert_not_called() + self.ctx.send.assert_called_once_with(result_msg) self.ctx.reset_mock() -- cgit v1.2.3 From cdeb41bfd283cb6cb1285993737e8e3abd5aea9f Mon Sep 17 00:00:00 2001 From: Den4200 Date: Mon, 6 Jul 2020 17:30:44 +0000 Subject: Fix imports in slowmode tests --- tests/bot/cogs/test_slowmode.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py index 65b1534cb..f442814c8 100644 --- a/tests/bot/cogs/test_slowmode.py +++ b/tests/bot/cogs/test_slowmode.py @@ -3,7 +3,7 @@ from unittest import mock from dateutil.relativedelta import relativedelta -from bot.cogs.slowmode import Slowmode +from bot.cogs.moderation.slowmode import Slowmode from bot.constants import Emojis from tests.helpers import MockBot, MockContext, MockTextChannel @@ -103,8 +103,8 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' ) - @mock.patch("bot.cogs.slowmode.with_role_check") - @mock.patch("bot.cogs.slowmode.MODERATION_ROLES", new=(1, 2, 3)) + @mock.patch("bot.cogs.moderation.slowmode.with_role_check") + @mock.patch("bot.cogs.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) def test_cog_check(self, role_check): """Role check is called with `MODERATION_ROLES`""" self.cog.cog_check(self.ctx) -- cgit v1.2.3 From ddb1f556ace346a97b8639f278fae8915078e78d Mon Sep 17 00:00:00 2001 From: kwzrd Date: Thu, 9 Jul 2020 12:11:16 +0200 Subject: Incidents tests: improve in-line comment wording Co-authored-by: MarkKoz --- tests/bot/cogs/moderation/test_incidents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index f8d479cef..789a37cd4 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -318,7 +318,7 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) # Patch in our webhook - # Define our own `incident` for archivation + # Define our own `incident` to be archived incident = MockMessage( content="this is an incident", author=MockUser(name="author_name", avatar_url="author_avatar"), -- cgit v1.2.3 From df1730ef5d51223fe1d5a2cfe8c027e5177ae9c7 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 12 Jul 2020 16:30:03 +0200 Subject: Fix DuckPond tests now that send_webhook is gone. Some of the tests were failing because they were expecting send_webhook to be a method of the DuckPond cog, other tests simply were no longer applicable, and have been removed. https://github.com/python-discord/bot/issues/667 --- tests/bot/cogs/test_duck_pond.py | 51 ++++++++++------------------------------ 1 file changed, 12 insertions(+), 39 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index a8c0107c6..cfe10aebf 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -129,38 +129,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): ): self.assertEqual(expected_return, actual_return) - def test_send_webhook_correctly_passes_on_arguments(self): - """The `send_webhook` method should pass the arguments to the webhook correctly.""" - self.cog.webhook = helpers.MockAsyncWebhook() - - content = "fake content" - username = "fake username" - avatar_url = "fake avatar_url" - embed = "fake embed" - - asyncio.run(self.cog.send_webhook(content, username, avatar_url, embed)) - - self.cog.webhook.send.assert_called_once_with( - content=content, - username=username, - avatar_url=avatar_url, - embed=embed - ) - - def test_send_webhook_logs_when_sending_message_fails(self): - """The `send_webhook` method should catch a `discord.HTTPException` and log accordingly.""" - self.cog.webhook = helpers.MockAsyncWebhook() - self.cog.webhook.send.side_effect = discord.HTTPException(response=MagicMock(), message="Something failed.") - - log = logging.getLogger('bot.cogs.duck_pond') - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - asyncio.run(self.cog.send_webhook()) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - def _get_reaction( self, emoji: typing.Union[str, helpers.MockEmoji], @@ -280,16 +248,20 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): async def test_relay_message_correctly_relays_content_and_attachments(self): """The `relay_message` method should correctly relay message content and attachments.""" - send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" + send_webhook_path = f"{MODULE_PATH}.send_webhook" send_attachments_path = f"{MODULE_PATH}.send_attachments" + author = MagicMock( + display_name="x", + avatar_url="https://" + ) self.cog.webhook = helpers.MockAsyncWebhook() test_values = ( - (helpers.MockMessage(clean_content="", attachments=[]), False, False), - (helpers.MockMessage(clean_content="message", attachments=[]), True, False), - (helpers.MockMessage(clean_content="", attachments=["attachment"]), False, True), - (helpers.MockMessage(clean_content="message", attachments=["attachment"]), True, True), + (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), + (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), + (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), + (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), ) for message, expect_webhook_call, expect_attachment_call in test_values: @@ -314,14 +286,14 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): for side_effect in side_effects: # pragma: no cover send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook: + with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: with self.subTest(side_effect=type(side_effect).__name__): with self.assertNotLogs(logger=log, level=logging.ERROR): await self.cog.relay_message(message) self.assertEqual(send_webhook.call_count, 2) - @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) + @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): """The `relay_message` method should handle irretrievable attachments.""" @@ -337,6 +309,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): await self.cog.relay_message(message) send_webhook.assert_called_once_with( + webhook=self.cog.webhook, content=message.clean_content, username=message.author.display_name, avatar_url=message.author.avatar_url -- cgit v1.2.3 From c4e9060a76a901c7d2e6035e6ca19d51770a4ab3 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Mon, 13 Jul 2020 15:04:40 +0200 Subject: Incidents: add `download_file` helper & tests Co-authored-by: MarkKoz --- bot/cogs/moderation/incidents.py | 13 +++++++++++++ tests/bot/cogs/moderation/test_incidents.py | 20 ++++++++++++++++++++ 2 files changed, 33 insertions(+) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index be46c8202..65b0e458e 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -42,6 +42,19 @@ ALLOWED_ROLES: t.Set[int] = set(Guild.moderation_roles) ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} +async def download_file(attachment: discord.Attachment) -> t.Optional[discord.File]: + """ + Download & return `attachment` file. + + If the download fails, the reason is logged and None will be returned. + """ + log.debug(f"Attempting to download attachment: {attachment.filename}") + try: + return await attachment.to_file() + except Exception: + log.exception("Failed to download attachment") + + def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> discord.Embed: """ Create an embed representation of `incident` for the #incidents-archive channel. diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 789a37cd4..273916199 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -12,6 +12,7 @@ from bot.cogs.moderation import Incidents, incidents from bot.constants import Colours from tests.helpers import ( MockAsyncWebhook, + MockAttachment, MockBot, MockMember, MockMessage, @@ -69,6 +70,25 @@ mock_404 = discord.NotFound( ) +class TestDownloadFile(unittest.IsolatedAsyncioTestCase): + """Collection of tests for the `download_file` helper function.""" + + async def test_download_file_success(self): + """If `to_file` succeeds, function returns the acquired `discord.File`.""" + file = MagicMock(discord.File, filename="bigbadlemon.jpg") + attachment = MockAttachment(to_file=AsyncMock(return_value=file)) + + acquired_file = await incidents.download_file(attachment) + self.assertIs(file, acquired_file) + + async def test_download_file_fail(self): + """If `to_file` fails, function handles the exception & returns None.""" + attachment = MockAttachment(to_file=AsyncMock(side_effect=mock_404)) + + acquired_file = await incidents.download_file(attachment) + self.assertIsNone(acquired_file) + + class TestMakeEmbed(unittest.TestCase): """Collection of tests for the `make_embed` helper function.""" -- cgit v1.2.3 From f1b1d0cb723abbbf7d4b49ac4b42fe0b7f266692 Mon Sep 17 00:00:00 2001 From: Slushie Date: Mon, 13 Jul 2020 16:09:08 +0100 Subject: edit snekbox tests to work with filtering --- tests/bot/cogs/test_snekbox.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index cf9adbee0..98dee7a1b 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -233,6 +233,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('[No output]', None)) + mocked_filter_cog = MagicMock() + mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + self.bot.get_cog.return_value = mocked_filter_cog + await self.cog.send_eval(ctx, 'MyAwesomeCode') ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```' @@ -254,6 +258,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) + mocked_filter_cog = MagicMock() + mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + self.bot.get_cog.return_value = mocked_filter_cog + await self.cog.send_eval(ctx, 'MyAwesomeCode') ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :yay!: Return code 0.' @@ -275,6 +283,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.get_status_emoji = MagicMock(return_value=':nope!:') self.cog.format_output = AsyncMock() # This function isn't called + mocked_filter_cog = MagicMock() + mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + self.bot.get_cog.return_value = mocked_filter_cog + await self.cog.send_eval(ctx, 'MyAwesomeCode') ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```' -- cgit v1.2.3 From 6e48d666b31d13d801c394b527ce545b039b478f Mon Sep 17 00:00:00 2001 From: kwzrd Date: Tue, 14 Jul 2020 17:28:53 +0200 Subject: Incidents: link `proxy_url` if attachment fails to download Suggested by Mark during review. If the download fails, we fallback on showing an informative message, which will link the attachment cdn link. The attachment-handling logic was moved from the `archive` coroutine into `make_embed`, which now also returns the file, if available. In the end, this appears to be the smoothest approach. Co-authored-by: MarkKoz --- bot/cogs/moderation/incidents.py | 36 +++++++++----- tests/bot/cogs/moderation/test_incidents.py | 73 ++++++++++++++--------------- 2 files changed, 59 insertions(+), 50 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 65b0e458e..018538040 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -41,6 +41,10 @@ ALLOWED_ROLES: t.Set[int] = set(Guild.moderation_roles) # Message must have all of these emoji to pass the `has_signals` check ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} +# An embed coupled with an optional file to be dispatched +# If the file is not None, the embed attempts to show it in its body +FileEmbed = t.Tuple[discord.Embed, t.Optional[discord.File]] + async def download_file(attachment: discord.Attachment) -> t.Optional[discord.File]: """ @@ -55,7 +59,7 @@ async def download_file(attachment: discord.Attachment) -> t.Optional[discord.Fi log.exception("Failed to download attachment") -def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> discord.Embed: +async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> FileEmbed: """ Create an embed representation of `incident` for the #incidents-archive channel. @@ -66,6 +70,11 @@ def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord. of information will be relayed in other ways, e.g. webhook username. As mentions in embeds do not ping, we do not need to use `incident.clean_content`. + + If `incident` contains attachments, the first attachment will be downloaded and + returned alongside the embed. The embed attempts to display the attachment. + Should the download fail, we fallback on linking the `proxy_url`, which should + remain functional for some time after the original message is deleted. """ log.trace(f"Creating embed for {incident.id=}") @@ -83,7 +92,18 @@ def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord. ) embed.set_footer(text=footer, icon_url=actioned_by.avatar_url) - return embed + if incident.attachments: + attachment = incident.attachments[0] # User-sent messages can only contain one attachment + file = await download_file(attachment) + + if file is not None: + embed.set_image(url=f"attachment://{attachment.filename}") # Embed displays the attached file + else: + embed.set_author(name="[Failed to relay attachment]", url=attachment.proxy_url) # Embed links the file + else: + file = None + + return embed, file def is_incident(message: discord.Message) -> bool: @@ -215,17 +235,7 @@ class Incidents(Cog): message is not safe to be deleted, as we will lose some information. """ log.debug(f"Archiving incident: {incident.id} (outcome: {outcome}, actioned by: {actioned_by})") - embed = make_embed(incident, outcome, actioned_by) - - # If the incident had an attachment, we will try to relay it - if incident.attachments: - attachment = incident.attachments[0] # User-sent messages can only contain one attachment - log.debug(f"Attempting to archive incident attachment: {attachment.filename}") - - attachment_file = await attachment.to_file() # The file will be sent with the webhook - embed.set_image(url=f"attachment://{attachment.filename}") # Embed displays the attached file - else: - attachment_file = None + embed, attachment_file = await make_embed(incident, outcome, actioned_by) try: webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive) diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 273916199..9b6054f55 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -89,30 +89,58 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): self.assertIsNone(acquired_file) -class TestMakeEmbed(unittest.TestCase): +class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): """Collection of tests for the `make_embed` helper function.""" - def test_make_embed_actioned(self): + async def test_make_embed_actioned(self): """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" - embed = incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) self.assertEqual(embed.colour.value, Colours.soft_green) self.assertIn("Actioned", embed.footer.text) - def test_make_embed_not_actioned(self): + async def test_make_embed_not_actioned(self): """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" - embed = incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) self.assertEqual(embed.colour.value, Colours.soft_red) self.assertIn("Rejected", embed.footer.text) - def test_make_embed_content(self): + async def test_make_embed_content(self): """Incident content appears as embed description.""" incident = MockMessage(content="this is an incident") - embed = incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) self.assertEqual(incident.content, embed.description) + async def test_make_embed_with_attachment_succeeds(self): + """Incident's attachment is downloaded and displayed in the embed's image field.""" + file = MagicMock(discord.File, filename="bigbadjoe.jpg") + attachment = MockAttachment(filename="bigbadjoe.jpg") + incident = MockMessage(content="this is an incident", attachments=[attachment]) + + # Patch `download_file` to return our `file` + with patch("bot.cogs.moderation.incidents.download_file", AsyncMock(return_value=file)): + embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertIs(file, returned_file) + self.assertEqual("attachment://bigbadjoe.jpg", embed.image.url) + + async def test_make_embed_with_attachment_fails(self): + """Incident's attachment fails to download, proxy url is linked instead.""" + attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") + incident = MockMessage(content="this is an incident", attachments=[attachment]) + + # Patch `download_file` to return None as if the download failed + with patch("bot.cogs.moderation.incidents.download_file", AsyncMock(return_value=None)): + embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertIsNone(returned_file) + + # The author name field is simply expected to have something in it, we do not assert the message + self.assertGreater(len(embed.author.name), 0) + self.assertEqual(embed.author.url, "discord.com/bigbadjoe.jpg") # However, it should link the exact url + @patch("bot.constants.Channels.incidents", 123) class TestIsIncident(unittest.TestCase): @@ -343,11 +371,10 @@ class TestArchive(TestIncidents): content="this is an incident", author=MockUser(name="author_name", avatar_url="author_avatar"), id=123, - attachments=[], # This incident has no attachments ) built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this - with patch("bot.cogs.moderation.incidents.make_embed", MagicMock(return_value=built_embed)): + with patch("bot.cogs.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) # Now we check that the webhook was given the correct args, and that `archive` returned True @@ -359,34 +386,6 @@ class TestArchive(TestIncidents): ) self.assertTrue(archive_return) - async def test_archive_relays_incident_with_attachments(self): - """ - Incident attachments are relayed and displayed in the embed. - - This test asserts the two things that need to happen in order to relay the attachment. - The embed returned by `make_embed` must have the `set_image` method called with the - attachment's filename, and the file must be passed to the webhook's send method. - """ - attachment_file = MagicMock(discord.File) - attachment = MagicMock( - discord.Attachment, - filename="abc.png", - to_file=AsyncMock(return_value=attachment_file), - ) - incident = MockMessage( - attachments=[attachment], - ) - built_embed = MagicMock(discord.Embed) - - with patch("bot.cogs.moderation.incidents.make_embed", MagicMock(return_value=built_embed)): - await self.cog_instance.archive(incident, incidents.Signal.ACTIONED, actioned_by=MockMember()) - - built_embed.set_image.assert_called_once_with(url="attachment://abc.png") - - send_kwargs = self.cog_instance.bot.fetch_webhook.return_value.send.call_args.kwargs - self.assertIn("file", send_kwargs) - self.assertIs(send_kwargs["file"], attachment_file) - async def test_archive_clyde_username(self): """ The archive webhook username is cleansed using `sub_clyde`. -- cgit v1.2.3 From c115dcfb72e4d4a86b66bb84a72984705a2afcd4 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 15 Jul 2020 02:45:31 +0200 Subject: Change tests to work with the new file layout. 326beebe9b097731a39ecc9868e5e1f2bd762aae --- tests/bot/utils/test_init.py | 74 ---------------------------------------- tests/bot/utils/test_services.py | 74 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 74 deletions(-) delete mode 100644 tests/bot/utils/test_init.py create mode 100644 tests/bot/utils/test_services.py (limited to 'tests') diff --git a/tests/bot/utils/test_init.py b/tests/bot/utils/test_init.py deleted file mode 100644 index f3a8f5939..000000000 --- a/tests/bot/utils/test_init.py +++ /dev/null @@ -1,74 +0,0 @@ -import logging -import unittest -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -from aiohttp import ClientConnectorError - -from bot.utils import FAILED_REQUEST_ATTEMPTS, send_to_paste_service - - -class PasteTests(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: - self.http_session = MagicMock() - - @patch("bot.utils.URLs.paste_service", "https://paste_service.com/{key}") - async def test_url_and_sent_contents(self): - """Correct url was used and post was called with expected data.""" - response = MagicMock( - json=AsyncMock(return_value={"key": ""}) - ) - self.http_session.post().__aenter__.return_value = response - self.http_session.post.reset_mock() - await send_to_paste_service(self.http_session, "Content") - self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") - - @patch("bot.utils.URLs.paste_service", "https://paste_service.com/{key}") - async def test_paste_returns_correct_url_on_success(self): - """Url with specified extension is returned on successful requests.""" - key = "paste_key" - test_cases = ( - (f"https://paste_service.com/{key}.txt", "txt"), - (f"https://paste_service.com/{key}.py", "py"), - (f"https://paste_service.com/{key}", ""), - ) - response = MagicMock( - json=AsyncMock(return_value={"key": key}) - ) - self.http_session.post().__aenter__.return_value = response - - for expected_output, extension in test_cases: - with self.subTest(msg=f"Send contents with extension {repr(extension)}"): - self.assertEqual( - await send_to_paste_service(self.http_session, "", extension=extension), - expected_output - ) - - async def test_request_repeated_on_json_errors(self): - """Json with error message and invalid json are handled as errors and requests repeated.""" - test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) - self.http_session.post().__aenter__.return_value = response = MagicMock() - self.http_session.post.reset_mock() - - for error_json in test_cases: - with self.subTest(error_json=error_json): - response.json = AsyncMock(return_value=error_json) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) - - self.http_session.post.reset_mock() - - async def test_request_repeated_on_connection_errors(self): - """Requests are repeated in the case of connection errors.""" - self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) - - async def test_general_error_handled_and_request_repeated(self): - """All `Exception`s are handled, logged and request repeated.""" - self.http_session.post = MagicMock(side_effect=Exception) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertLogs("bot.utils", logging.ERROR) - self.assertIsNone(result) diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py new file mode 100644 index 000000000..5e0855704 --- /dev/null +++ b/tests/bot/utils/test_services.py @@ -0,0 +1,74 @@ +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +from aiohttp import ClientConnectorError + +from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service + + +class PasteTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.http_session = MagicMock() + + @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") + async def test_url_and_sent_contents(self): + """Correct url was used and post was called with expected data.""" + response = MagicMock( + json=AsyncMock(return_value={"key": ""}) + ) + self.http_session.post().__aenter__.return_value = response + self.http_session.post.reset_mock() + await send_to_paste_service(self.http_session, "Content") + self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") + + @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") + async def test_paste_returns_correct_url_on_success(self): + """Url with specified extension is returned on successful requests.""" + key = "paste_key" + test_cases = ( + (f"https://paste_service.com/{key}.txt", "txt"), + (f"https://paste_service.com/{key}.py", "py"), + (f"https://paste_service.com/{key}", ""), + ) + response = MagicMock( + json=AsyncMock(return_value={"key": key}) + ) + self.http_session.post().__aenter__.return_value = response + + for expected_output, extension in test_cases: + with self.subTest(msg=f"Send contents with extension {repr(extension)}"): + self.assertEqual( + await send_to_paste_service(self.http_session, "", extension=extension), + expected_output + ) + + async def test_request_repeated_on_json_errors(self): + """Json with error message and invalid json are handled as errors and requests repeated.""" + test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) + self.http_session.post().__aenter__.return_value = response = MagicMock() + self.http_session.post.reset_mock() + + for error_json in test_cases: + with self.subTest(error_json=error_json): + response.json = AsyncMock(return_value=error_json) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + self.http_session.post.reset_mock() + + async def test_request_repeated_on_connection_errors(self): + """Requests are repeated in the case of connection errors.""" + self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + async def test_general_error_handled_and_request_repeated(self): + """All `Exception`s are handled, logged and request repeated.""" + self.http_session.post = MagicMock(side_effect=Exception) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertLogs("bot.utils", logging.ERROR) + self.assertIsNone(result) -- cgit v1.2.3 From 6f5fb205bcc3f9b468ef585f83e123e5b19d7340 Mon Sep 17 00:00:00 2001 From: kwzrd Date: Thu, 16 Jul 2020 17:03:02 +0200 Subject: Incidents: reduce log level of 404 exception Co-authored-by: MarkKoz --- bot/cogs/moderation/incidents.py | 2 ++ tests/bot/cogs/moderation/test_incidents.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 018538040..2d5f26f20 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -55,6 +55,8 @@ async def download_file(attachment: discord.Attachment) -> t.Optional[discord.Fi log.debug(f"Attempting to download attachment: {attachment.filename}") try: return await attachment.to_file() + except discord.NotFound as not_found: + log.debug(f"Failed to download attachment: {not_found}") except Exception: log.exception("Failed to download attachment") diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 9b6054f55..435a1cd51 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -81,13 +81,23 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): acquired_file = await incidents.download_file(attachment) self.assertIs(file, acquired_file) - async def test_download_file_fail(self): - """If `to_file` fails, function handles the exception & returns None.""" + async def test_download_file_404(self): + """If `to_file` encounters a 404, function handles the exception & returns None.""" attachment = MockAttachment(to_file=AsyncMock(side_effect=mock_404)) acquired_file = await incidents.download_file(attachment) self.assertIsNone(acquired_file) + async def test_download_file_fail(self): + """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" + arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") + attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) + + with self.assertLogs(logger=incidents.log, level=logging.ERROR): + acquired_file = await incidents.download_file(attachment) + + self.assertIsNone(acquired_file) + class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): """Collection of tests for the `make_embed` helper function.""" -- cgit v1.2.3 From 1c569f2f38fe18d6210deec001046cf9ee68ea53 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sat, 18 Jul 2020 16:54:01 +0200 Subject: Remove AntiMalWare constants, use cache data. Also updates the tests for this cog. --- bot/bot.py | 2 +- bot/cogs/antimalware.py | 24 ++++++++++++++---------- bot/constants.py | 6 ------ config-default.yml | 29 ----------------------------- tests/bot/cogs/test_antimalware.py | 24 +++++++++++++++--------- 5 files changed, 30 insertions(+), 55 deletions(-) (limited to 'tests') diff --git a/bot/bot.py b/bot/bot.py index 6c02e72a7..962c8dd93 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -34,6 +34,7 @@ class Bot(commands.Bot): self.redis_ready = asyncio.Event() self.redis_closed = False self.api_client = api.APIClient(loop=self.loop) + self.allow_deny_list_cache = {} self._connector = None self._resolver = None @@ -52,7 +53,6 @@ class Bot(commands.Bot): async def _cache_allow_deny_list_data(self) -> None: """Cache all the data in the AllowDenyList on the site.""" full_cache = await self.api_client.get('bot/allow_deny_lists') - self.allow_deny_list_cache = {} for item in full_cache: type_ = item.get("type") diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index ea257442e..38ff1133d 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -6,7 +6,7 @@ from discord import Embed, Message, NotFound from discord.ext.commands import Cog from bot.bot import Bot -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs +from bot.constants import Channels, STAFF_ROLES, URLs log = logging.getLogger(__name__) @@ -27,7 +27,7 @@ TXT_EMBED_DESCRIPTION = ( DISALLOWED_EMBED_DESCRIPTION = ( "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " - f"We currently allow the following file types: **{', '.join(AntiMalwareConfig.whitelist)}**.\n\n" + "We currently allow the following file types: **{joined_whitelist}**.\n\n" "Feel free to ask in {meta_channel_mention} if you think this is a mistake." ) @@ -38,6 +38,16 @@ class AntiMalware(Cog): def __init__(self, bot: Bot): self.bot = bot + def _get_whitelisted_file_formats(self) -> list: + """Get the file formats currently on the whitelist.""" + return [item.get('content') for item in self.bot.allow_deny_list_cache['file_format.True']] + + def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: + """Get an iterable containing all the disallowed extensions of attachments.""" + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) + return extensions_blocked + @Cog.listener() async def on_message(self, message: Message) -> None: """Identify messages with prohibited attachments.""" @@ -51,7 +61,7 @@ class AntiMalware(Cog): return embed = Embed() - extensions_blocked = self.get_disallowed_extensions(message) + extensions_blocked = self._get_disallowed_extensions(message) blocked_extensions_str = ', '.join(extensions_blocked) if ".py" in extensions_blocked: # Short-circuit on *.py files to provide a pastebin link @@ -63,6 +73,7 @@ class AntiMalware(Cog): elif extensions_blocked: meta_channel = self.bot.get_channel(Channels.meta) embed.description = DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=', '.join(self._get_whitelisted_file_formats()), blocked_extensions_str=blocked_extensions_str, meta_channel_mention=meta_channel.mention, ) @@ -81,13 +92,6 @@ class AntiMalware(Cog): except NotFound: log.info(f"Tried to delete message `{message.id}`, but message could not be found.") - @classmethod - def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]: - """Get an iterable containing all the disallowed extensions of attachments.""" - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} - extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) - return extensions_blocked - def setup(bot: Bot) -> None: """Load the AntiMalware cog.""" diff --git a/bot/constants.py b/bot/constants.py index f5245ca50..857e6c4f0 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -527,12 +527,6 @@ class AntiSpam(metaclass=YAMLGetter): rules: Dict[str, Dict[str, int]] -class AntiMalware(metaclass=YAMLGetter): - section = "anti_malware" - - whitelist: list - - class BigBrother(metaclass=YAMLGetter): section = 'big_brother' diff --git a/config-default.yml b/config-default.yml index 81c8c40d5..503cc2b52 100644 --- a/config-default.yml +++ b/config-default.yml @@ -386,35 +386,6 @@ anti_spam: max: 3 -anti_malware: - whitelist: - - '.3gp' - - '.3g2' - - '.avi' - - '.bmp' - - '.gif' - - '.h264' - - '.jpg' - - '.jpeg' - - '.m4v' - - '.mkv' - - '.mov' - - '.mp4' - - '.mpeg' - - '.mpg' - - '.png' - - '.tiff' - - '.wmv' - - '.svg' - - '.psd' # Photoshop - - '.ai' # Illustrator - - '.aep' # After Effects - - '.xcf' # GIMP - - '.mp3' - - '.wav' - - '.ogg' - - reddit: subreddits: - 'r/Python' diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index f219fc1ba..1e010d2ce 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,28 +1,33 @@ import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock from discord import NotFound from bot.cogs import antimalware -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from bot.constants import Channels, STAFF_ROLES from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole -MODULE = "bot.cogs.antimalware" - -@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Test the AntiMalware cog.""" def setUp(self): """Sets up fresh objects for each test.""" self.bot = MockBot() + self.bot.allow_deny_list_cache = { + "file_format.True": [ + {"content": ".first"}, + {"content": ".second"}, + {"content": ".third"} + ] + } self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() + self.whitelist = [".first", ".second", ".third"] async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") + attachment = MockAttachment(filename="python.first") self.message.attachments = [attachment] await self.cog.on_message(self.message) @@ -93,7 +98,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) - async def test_other_disallowed_extention_embed_description(self): + async def test_other_disallowed_extension_embed_description(self): """Test the description for a non .py/.txt disallowed extension.""" attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] @@ -109,6 +114,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + joined_whitelist=", ".join(self.whitelist), blocked_extensions_str=".disallowed", meta_channel_mention=meta_channel.mention ) @@ -135,7 +141,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """The return value should include all non-whitelisted extensions.""" test_values = ( ([], []), - (AntiMalwareConfig.whitelist, []), + (self.whitelist, []), ([".first"], []), ([".first", ".disallowed"], [".disallowed"]), ([".disallowed"], [".disallowed"]), @@ -145,7 +151,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): for extensions, expected_disallowed_extensions in test_values: with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog.get_disallowed_extensions(self.message) + disallowed_extensions = self.cog._get_disallowed_extensions(self.message) self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) -- cgit v1.2.3 From e93fdaf57d1d35394b466a6bd1c84712e29415d7 Mon Sep 17 00:00:00 2001 From: wookie184 Date: Mon, 20 Jul 2020 16:05:32 +0100 Subject: Edited tests to reflect changes (removed py formatting) --- tests/bot/cogs/test_snekbox.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 98dee7a1b..343e37db9 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -239,7 +239,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.send_eval(ctx, 'MyAwesomeCode') ctx.send.assert_called_once_with( - '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```' + '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```' ) self.cog.post_eval.assert_called_once_with('MyAwesomeCode') self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) @@ -265,7 +265,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.send_eval(ctx, 'MyAwesomeCode') ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :yay!: Return code 0.' - '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com' + '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) self.cog.post_eval.assert_called_once_with('MyAwesomeCode') self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) @@ -289,7 +289,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.send_eval(ctx, 'MyAwesomeCode') ctx.send.assert_called_once_with( - '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```' + '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) self.cog.post_eval.assert_called_once_with('MyAwesomeCode') self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) -- cgit v1.2.3 From 63c7827d9d9025c7505747904237b37eb46464df Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 22 Jul 2020 11:42:31 -0700 Subject: Jam Tests: fix utils patch stop needs to be called on the patcher, not the mock. Furthermore, using addCleanup is safer than tearDown because the latter may not be called if an exception is raised in setUp. --- tests/bot/cogs/test_jams.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 2f2cb4695..28eb1ab53 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -16,11 +16,12 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.guild = MockGuild([self.admin_role]) self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) self.cog = CodeJams(self.bot) - self.utils_mock = patch("bot.cogs.jams.utils").start() - self.default_args = [self.cog, self.ctx, "foo"] - def tearDown(self): - self.utils_mock.stop() + utils_patcher = patch("bot.cogs.jams.utils") + self.utils_mock = utils_patcher.start() + self.addCleanup(utils_patcher.stop) + + self.default_args = [self.cog, self.ctx, "foo"] async def test_too_small_amount_of_team_members_passed(self): """Should `ctx.send` and exit early when too small amount of members.""" -- cgit v1.2.3 From f7e177357e7a47d9a43b492aac7703961af72c19 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 22 Jul 2020 11:43:58 -0700 Subject: Jam Tests: re-arrange tests to follow definition order in the cog --- tests/bot/cogs/test_jams.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 28eb1ab53..e0018e006 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -48,6 +48,16 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.cog.create_channels.assert_not_awaited() self.cog.add_roles.assert_not_awaited() + async def test_result_sending(self): + """Should call `ctx.send` when everything goes right.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + members = [MockMember() for _ in range(5)] + await self.cog.createteam(self.cog, self.ctx, "foo", members) + self.cog.create_channels.assert_awaited_once() + self.cog.add_roles.assert_awaited_once() + self.ctx.send.assert_awaited_once() + async def test_category_dont_exist(self): """Should create code jam category.""" self.utils_mock.get.return_value = None @@ -125,16 +135,6 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): for member in members: member.add_roles.assert_any_await(jam_role) - async def test_result_sending(self): - """Should call `ctx.send` when everything goes right.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - members = [MockMember() for _ in range(5)] - await self.cog.createteam(self.cog, self.ctx, "foo", members) - self.cog.create_channels.assert_awaited_once() - self.cog.add_roles.assert_awaited_once() - self.ctx.send.assert_awaited_once() - class CodeJamSetup(unittest.TestCase): """Test for `setup` function of `CodeJam` cog.""" -- cgit v1.2.3 From b1d0f36356ecf4eee729bf276c8b0ed10653ad54 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 22 Jul 2020 11:47:49 -0700 Subject: Jam Tests: remove default_args attribute Kind of redundant since it's only used by two tests. --- tests/bot/cogs/test_jams.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index e0018e006..0fce2a67c 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -21,8 +21,6 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.utils_mock = utils_patcher.start() self.addCleanup(utils_patcher.stop) - self.default_args = [self.cog, self.ctx, "foo"] - async def test_too_small_amount_of_team_members_passed(self): """Should `ctx.send` and exit early when too small amount of members.""" for case in (1, 2): @@ -32,7 +30,8 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.ctx.reset_mock() self.utils_mock.reset_mock() - await self.cog.createteam(*self.default_args, (MockMember() for _ in range(case))) + members = (MockMember() for _ in range(case)) + await self.cog.createteam(self.cog, self.ctx, "foo", members) self.ctx.send.assert_awaited_once() self.cog.create_channels.assert_not_awaited() @@ -43,7 +42,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.cog.create_channels = AsyncMock() self.cog.add_roles = AsyncMock() member = MockMember() - await self.cog.createteam(*self.default_args, (member for _ in range(5))) + await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) self.ctx.send.assert_awaited_once() self.cog.create_channels.assert_not_awaited() self.cog.add_roles.assert_not_awaited() -- cgit v1.2.3 From 44cd1d989d491d692d48324228ccc9593a545cd2 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 22 Jul 2020 11:51:08 -0700 Subject: Jam Tests: space out lines for readability --- tests/bot/cogs/test_jams.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 0fce2a67c..81fbcb798 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -41,8 +41,10 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" self.cog.create_channels = AsyncMock() self.cog.add_roles = AsyncMock() + member = MockMember() await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) + self.ctx.send.assert_awaited_once() self.cog.create_channels.assert_not_awaited() self.cog.add_roles.assert_not_awaited() @@ -51,8 +53,10 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Should call `ctx.send` when everything goes right.""" self.cog.create_channels = AsyncMock() self.cog.add_roles = AsyncMock() + members = [MockMember() for _ in range(5)] await self.cog.createteam(self.cog, self.ctx, "foo", members) + self.cog.create_channels.assert_awaited_once() self.cog.add_roles.assert_awaited_once() self.ctx.send.assert_awaited_once() @@ -60,7 +64,9 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_category_dont_exist(self): """Should create code jam category.""" self.utils_mock.get.return_value = None + await self.cog.get_category(self.guild) + self.guild.create_category_channel.assert_awaited_once() category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] -- cgit v1.2.3 From 12168766a153d9d1bd134ff64f74997eef8ff7b0 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 22 Jul 2020 16:11:32 -0700 Subject: Jam tests: fix category test --- tests/bot/cogs/test_jams.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 81fbcb798..54a096703 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -1,11 +1,22 @@ import unittest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, create_autospec -from bot.cogs.jams import CodeJams, setup +from discord import CategoryChannel + +from bot.cogs import jams from bot.constants import Roles from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel +def get_mock_category(channel_count: int, name: str) -> CategoryChannel: + """Return a mocked code jam category.""" + category = create_autospec(CategoryChannel, spec_set=True, instance=True) + category.name = name + category.channels = [MockTextChannel() for _ in range(channel_count)] + + return category + + class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): """Tests for `createteam` command.""" @@ -15,11 +26,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.command_user = MockMember([self.admin_role]) self.guild = MockGuild([self.admin_role]) self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) - self.cog = CodeJams(self.bot) - - utils_patcher = patch("bot.cogs.jams.utils") - self.utils_mock = utils_patcher.start() - self.addCleanup(utils_patcher.stop) + self.cog = jams.CodeJams(self.bot) async def test_too_small_amount_of_team_members_passed(self): """Should `ctx.send` and exit early when too small amount of members.""" @@ -29,7 +36,6 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.cog.add_roles = AsyncMock() self.ctx.reset_mock() - self.utils_mock.reset_mock() members = (MockMember() for _ in range(case)) await self.cog.createteam(self.cog, self.ctx, "foo", members) @@ -63,8 +69,6 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_category_dont_exist(self): """Should create code jam category.""" - self.utils_mock.get.return_value = None - await self.cog.get_category(self.guild) self.guild.create_category_channel.assert_awaited_once() @@ -75,8 +79,15 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_category_channel_exist(self): """Should not try to create category channel.""" - await self.cog.get_category(self.guild) - self.guild.create_category_channel.assert_not_awaited() + expected_category = get_mock_category(48, jams.CATEGORY_NAME) + self.guild.categories = [ + get_mock_category(48, "other"), + expected_category, + get_mock_category(6, jams.CATEGORY_NAME), + ] + + actual_category = await self.cog.get_category(self.guild) + self.assertEqual(expected_category, actual_category) async def test_channel_overwrites(self): """Should have correct permission overwrites for users and roles.""" @@ -103,7 +114,6 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_team_channels_creation(self): """Should create new voice and text channel for team.""" - self.utils_mock.get.return_value = "foo" members = [MockMember() for _ in range(5)] self.cog.get_overwrites = MagicMock() @@ -147,5 +157,5 @@ class CodeJamSetup(unittest.TestCase): def test_setup(self): """Should call `bot.add_cog`.""" bot = MockBot() - setup(bot) + jams.setup(bot) bot.add_cog.assert_called_once() -- cgit v1.2.3 From 92d3f88eb5c2348f3e4cb53a22a833bed61c6fb7 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 22 Jul 2020 16:21:38 -0700 Subject: Jam tests: add subtests to non-existent category test The test has to account for not only the name not matching, but also a lack of available spaces for new channels. --- tests/bot/cogs/test_jams.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index 54a096703..e6b2ac588 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -67,15 +67,26 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.cog.add_roles.assert_awaited_once() self.ctx.send.assert_awaited_once() - async def test_category_dont_exist(self): - """Should create code jam category.""" - await self.cog.get_category(self.guild) + async def test_category_doesnt_exist(self): + """Should create a new code jam category.""" + subtests = ( + [], + [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], + [get_mock_category(48, "other")], + ) + + for categories in subtests: + self.guild.reset_mock() + self.guild.categories = categories + + with self.subTest(categories=categories): + await self.cog.get_category(self.guild) - self.guild.create_category_channel.assert_awaited_once() - category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] + self.guild.create_category_channel.assert_awaited_once() + category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] - self.assertFalse(category_overwrites[self.guild.default_role].read_messages) - self.assertTrue(category_overwrites[self.guild.me].read_messages) + self.assertFalse(category_overwrites[self.guild.default_role].read_messages) + self.assertTrue(category_overwrites[self.guild.me].read_messages) async def test_category_channel_exist(self): """Should not try to create category channel.""" -- cgit v1.2.3 From ddba3f5fcfbda0f72baa3f15055c8a92e94c6d88 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 22 Jul 2020 16:27:02 -0700 Subject: Jam tests: assert equality of new category --- tests/bot/cogs/test_jams.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index e6b2ac588..a76a8a051 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -80,13 +80,14 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): self.guild.categories = categories with self.subTest(categories=categories): - await self.cog.get_category(self.guild) + actual_category = await self.cog.get_category(self.guild) self.guild.create_category_channel.assert_awaited_once() category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] self.assertFalse(category_overwrites[self.guild.default_role].read_messages) self.assertTrue(category_overwrites[self.guild.me].read_messages) + self.assertEqual(self.guild.create_category_channel.return_value, actual_category) async def test_category_channel_exist(self): """Should not try to create category channel.""" -- cgit v1.2.3 From 8e3c05210f057ab76d135afbe12035847c9029f4 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 22 Jul 2020 16:54:53 -0700 Subject: Jam tests: use the MAX_CHANNELS constant more It's clearer to write MAX_CHANNELS - 2 than a literal 48. --- tests/bot/cogs/test_jams.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py index a76a8a051..b4ad8535f 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/cogs/test_jams.py @@ -72,7 +72,7 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): subtests = ( [], [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], - [get_mock_category(48, "other")], + [get_mock_category(jams.MAX_CHANNELS - 2, "other")], ) for categories in subtests: @@ -91,11 +91,11 @@ class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): async def test_category_channel_exist(self): """Should not try to create category channel.""" - expected_category = get_mock_category(48, jams.CATEGORY_NAME) + expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) self.guild.categories = [ - get_mock_category(48, "other"), + get_mock_category(jams.MAX_CHANNELS - 2, "other"), expected_category, - get_mock_category(6, jams.CATEGORY_NAME), + get_mock_category(0, jams.CATEGORY_NAME), ] actual_category = await self.cog.get_category(self.guild) -- cgit v1.2.3 From 3d5faa421756fadb42590db92e8fee64578390d4 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Mon, 27 Jul 2020 10:26:10 +0200 Subject: Rename AllowDenyList to FilterLists --- bot/__main__.py | 2 +- bot/bot.py | 14 +-- bot/cogs/allow_deny_lists.py | 218 ------------------------------------- bot/cogs/antimalware.py | 2 +- bot/cogs/filter_lists.py | 218 +++++++++++++++++++++++++++++++++++++ bot/cogs/filtering.py | 16 +-- bot/converters.py | 10 +- tests/bot/cogs/test_antimalware.py | 2 +- 8 files changed, 241 insertions(+), 241 deletions(-) delete mode 100644 bot/cogs/allow_deny_lists.py create mode 100644 bot/cogs/filter_lists.py (limited to 'tests') diff --git a/bot/__main__.py b/bot/__main__.py index 932aa705c..c2271cd16 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -53,7 +53,7 @@ bot.load_extension("bot.cogs.verification") # Feature cogs bot.load_extension("bot.cogs.alias") -bot.load_extension("bot.cogs.allow_deny_lists") +bot.load_extension("bot.cogs.filter_lists") bot.load_extension("bot.cogs.defcon") bot.load_extension("bot.cogs.dm_relay") bot.load_extension("bot.cogs.duck_pond") diff --git a/bot/bot.py b/bot/bot.py index d834c151b..3dfb4e948 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -34,7 +34,7 @@ class Bot(commands.Bot): self.redis_ready = asyncio.Event() self.redis_closed = False self.api_client = api.APIClient(loop=self.loop) - self.allow_deny_list_cache = {} + self.filter_list_cache = {} self._connector = None self._resolver = None @@ -50,9 +50,9 @@ class Bot(commands.Bot): self.stats = AsyncStatsClient(self.loop, statsd_url, 8125, prefix="bot") - async def _cache_allow_deny_list_data(self) -> None: - """Cache all the data in the AllowDenyList on the site.""" - full_cache = await self.api_client.get('bot/allow_deny_lists') + async def _cache_filter_list_data(self) -> None: + """Cache all the data in the FilterList on the site.""" + full_cache = await self.api_client.get('bot/filter-lists') for item in full_cache: type_ = item.get("type") @@ -64,7 +64,7 @@ class Bot(commands.Bot): "created_at": item.get("created_at"), "updated_at": item.get("updated_at"), } - self.allow_deny_list_cache.setdefault(f"{type_}.{allowed}", []).append(metadata) + self.filter_list_cache.setdefault(f"{type_}.{allowed}", []).append(metadata) async def _create_redis_session(self) -> None: """ @@ -176,8 +176,8 @@ class Bot(commands.Bot): self.http_session = aiohttp.ClientSession(connector=self._connector) self.api_client.recreate(force=True, connector=self._connector) - # Build the AllowDenyList cache - self.loop.create_task(self._cache_allow_deny_list_data()) + # Build the FilterList cache + self.loop.create_task(self._cache_filter_list_data()) async def on_guild_available(self, guild: discord.Guild) -> None: """ diff --git a/bot/cogs/allow_deny_lists.py b/bot/cogs/allow_deny_lists.py deleted file mode 100644 index e28e32bd6..000000000 --- a/bot/cogs/allow_deny_lists.py +++ /dev/null @@ -1,218 +0,0 @@ -import logging -from typing import Optional - -from discord import Colour, Embed -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.converters import ValidAllowDenyListType, ValidDiscordServerInvite -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - - -class AllowDenyLists(Cog): - """Commands for blacklisting and whitelisting things.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - async def _add_data( - self, - ctx: Context, - allowed: bool, - list_type: ValidAllowDenyListType, - content: str, - comment: Optional[str] = None, - ) -> None: - """Add an item to an allow or denylist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we gotta validate it. - if list_type == "GUILD_INVITE": - log.trace(f"{content} is a guild invite, attempting to validate.") - validator = ValidDiscordServerInvite() - guild_data = await validator.convert(ctx, content) - - # If we make it this far without raising a BadArgument, the invite is - # valid. Let's convert the content to an ID. - log.trace(f"{content} validated as server invite. Converting to ID.") - content = guild_data.get("id") - - # Unless the user has specified another comment, let's - # use the server name as the comment so that the list - # of guild IDs will be more easily readable when we - # display it. - if not comment: - comment = guild_data.get("name") - - # Try to add the item to the database - log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") - payload = { - 'allowed': allowed, - 'type': list_type, - 'content': content, - 'comment': comment, - } - - try: - item = await self.bot.api_client.post( - "bot/allow_deny_lists", - json=payload - ) - except ResponseCodeError as e: - if e.status == 500: - await ctx.message.add_reaction("❌") - log.debug( - f"{ctx.author} tried to add data to a {allow_type}, but the API returned 500, " - "probably because the request violated the UniqueConstraint." - ) - raise BadArgument( - f"Unable to add the item to the {allow_type}. " - "The item probably already exists. Keep in mind that a " - "blacklist and a whitelist for the same item cannot co-exist, " - "and we do not permit any duplicates." - ) - raise - - # Insert the item into the cache - type_ = item.get("type") - allowed = item.get("allowed") - metadata = { - "content": item.get("content"), - "comment": item.get("comment"), - "id": item.get("id"), - "created_at": item.get("created_at"), - "updated_at": item.get("updated_at"), - } - self.bot.allow_deny_list_cache.setdefault(f"{type_}.{allowed}", []).append(metadata) - await ctx.message.add_reaction("✅") - - async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidAllowDenyListType, content: str) -> None: - """Remove an item from an allow or denylist.""" - item = None - allow_type = "whitelist" if allowed else "blacklist" - id_converter = IDConverter() - - # If this is a server invite, we need to convert it. - if list_type == "GUILD_INVITE" and not id_converter._get_id_match(content): - log.trace(f"{content} is a guild invite, attempting to validate.") - validator = ValidDiscordServerInvite() - guild_data = await validator.convert(ctx, content) - - # If we make it this far without raising a BadArgument, the invite is - # valid. Let's convert the content to an ID. - log.trace(f"{content} validated as server invite. Converting to ID.") - content = guild_data.get("id") - - # Find the content and delete it. - log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") - for allow_list in self.bot.allow_deny_list_cache.get(f"{list_type}.{allowed}", []): - if content == allow_list.get("content"): - item = allow_list - break - - if item is not None: - await self.bot.api_client.delete( - f"bot/allow_deny_lists/{item.get('id')}" - ) - self.bot.allow_deny_list_cache[f"{list_type}.{allowed}"].remove(item) - await ctx.message.add_reaction("✅") - - async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidAllowDenyListType) -> None: - """Paginate and display all items in an allow or denylist.""" - allow_type = "whitelist" if allowed else "blacklist" - result = self.bot.allow_deny_list_cache.get(f"{list_type}.{allowed}", []) - - # Build a list of lines we want to show in the paginator - lines = [] - for item in result: - line = f"• `{item.get('content')}`" - - if item.get("comment"): - line += f" - {item.get('comment')}" - - lines.append(line) - lines = sorted(lines) - - # Build the embed - list_type_plural = list_type.lower().replace("_", " ").title() + "s" - embed = Embed( - title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", - colour=Colour.blue() - ) - log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") - - if result: - await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - - @group(aliases=("allowlist", "allow", "al", "wl")) - async def whitelist(self, ctx: Context) -> None: - """Group for whitelisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @group(aliases=("denylist", "deny", "bl", "dl")) - async def blacklist(self, ctx: Context) -> None: - """Group for blacklisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @whitelist.command(name="add", aliases=("a", "set")) - async def allow_add( - self, - ctx: Context, - list_type: ValidAllowDenyListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified allowlist.""" - await self._add_data(ctx, True, list_type, content, comment) - - @blacklist.command(name="add", aliases=("a", "set")) - async def deny_add( - self, - ctx: Context, - list_type: ValidAllowDenyListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified denylist.""" - await self._add_data(ctx, False, list_type, content, comment) - - @whitelist.command(name="remove", aliases=("delete", "rm",)) - async def allow_delete(self, ctx: Context, list_type: ValidAllowDenyListType, content: str) -> None: - """Remove an item from the specified allowlist.""" - await self._delete_data(ctx, True, list_type, content) - - @blacklist.command(name="remove", aliases=("delete", "rm",)) - async def deny_delete(self, ctx: Context, list_type: ValidAllowDenyListType, content: str) -> None: - """Remove an item from the specified denylist.""" - await self._delete_data(ctx, False, list_type, content) - - @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def allow_get(self, ctx: Context, list_type: ValidAllowDenyListType) -> None: - """Get the contents of a specified allowlist.""" - await self._list_all_data(ctx, True, list_type) - - @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def deny_get(self, ctx: Context, list_type: ValidAllowDenyListType) -> None: - """Get the contents of a specified denylist.""" - await self._list_all_data(ctx, False, list_type) - - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the AllowDenyLists cog.""" - bot.add_cog(AllowDenyLists(bot)) diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 5b56f937f..9a100b3fc 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -40,7 +40,7 @@ class AntiMalware(Cog): def _get_whitelisted_file_formats(self) -> list: """Get the file formats currently on the whitelist.""" - return [item['content'] for item in self.bot.allow_deny_list_cache['file_format.True']] + return [item['content'] for item in self.bot.filter_list_cache['file_format.True']] def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: """Get an iterable containing all the disallowed extensions of attachments.""" diff --git a/bot/cogs/filter_lists.py b/bot/cogs/filter_lists.py new file mode 100644 index 000000000..d1db9830e --- /dev/null +++ b/bot/cogs/filter_lists.py @@ -0,0 +1,218 @@ +import logging +from typing import Optional + +from discord import Colour, Embed +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.converters import ValidDiscordServerInvite, ValidFilterListType +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + + +class FilterLists(Cog): + """Commands for blacklisting and whitelisting things.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + async def _add_data( + self, + ctx: Context, + allowed: bool, + list_type: ValidFilterListType, + content: str, + comment: Optional[str] = None, + ) -> None: + """Add an item to a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we gotta validate it. + if list_type == "GUILD_INVITE": + log.trace(f"{content} is a guild invite, attempting to validate.") + validator = ValidDiscordServerInvite() + guild_data = await validator.convert(ctx, content) + + # If we make it this far without raising a BadArgument, the invite is + # valid. Let's convert the content to an ID. + log.trace(f"{content} validated as server invite. Converting to ID.") + content = guild_data.get("id") + + # Unless the user has specified another comment, let's + # use the server name as the comment so that the list + # of guild IDs will be more easily readable when we + # display it. + if not comment: + comment = guild_data.get("name") + + # Try to add the item to the database + log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") + payload = { + 'allowed': allowed, + 'type': list_type, + 'content': content, + 'comment': comment, + } + + try: + item = await self.bot.api_client.post( + "bot/filter-lists", + json=payload + ) + except ResponseCodeError as e: + if e.status == 500: + await ctx.message.add_reaction("❌") + log.debug( + f"{ctx.author} tried to add data to a {allow_type}, but the API returned 500, " + "probably because the request violated the UniqueConstraint." + ) + raise BadArgument( + f"Unable to add the item to the {allow_type}. " + "The item probably already exists. Keep in mind that a " + "blacklist and a whitelist for the same item cannot co-exist, " + "and we do not permit any duplicates." + ) + raise + + # Insert the item into the cache + type_ = item.get("type") + allowed = item.get("allowed") + metadata = { + "content": item.get("content"), + "comment": item.get("comment"), + "id": item.get("id"), + "created_at": item.get("created_at"), + "updated_at": item.get("updated_at"), + } + self.bot.filter_list_cache.setdefault(f"{type_}.{allowed}", []).append(metadata) + await ctx.message.add_reaction("✅") + + async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from a filterlist.""" + item = None + allow_type = "whitelist" if allowed else "blacklist" + id_converter = IDConverter() + + # If this is a server invite, we need to convert it. + if list_type == "GUILD_INVITE" and not id_converter._get_id_match(content): + log.trace(f"{content} is a guild invite, attempting to validate.") + validator = ValidDiscordServerInvite() + guild_data = await validator.convert(ctx, content) + + # If we make it this far without raising a BadArgument, the invite is + # valid. Let's convert the content to an ID. + log.trace(f"{content} validated as server invite. Converting to ID.") + content = guild_data.get("id") + + # Find the content and delete it. + log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") + for allow_list in self.bot.filter_list_cache.get(f"{list_type}.{allowed}", []): + if content == allow_list.get("content"): + item = allow_list + break + + if item is not None: + await self.bot.api_client.delete( + f"bot/filter-lists/{item.get('id')}" + ) + self.bot.filter_list_cache[f"{list_type}.{allowed}"].remove(item) + await ctx.message.add_reaction("✅") + + async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: + """Paginate and display all items in a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + result = self.bot.filter_list_cache.get(f"{list_type}.{allowed}", []) + + # Build a list of lines we want to show in the paginator + lines = [] + for item in result: + line = f"• `{item.get('content')}`" + + if item.get("comment"): + line += f" - {item.get('comment')}" + + lines.append(line) + lines = sorted(lines) + + # Build the embed + list_type_plural = list_type.lower().replace("_", " ").title() + "s" + embed = Embed( + title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", + colour=Colour.blue() + ) + log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") + + if result: + await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + + @group(aliases=("allowlist", "allow", "al", "wl")) + async def whitelist(self, ctx: Context) -> None: + """Group for whitelisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @group(aliases=("denylist", "deny", "bl", "dl")) + async def blacklist(self, ctx: Context) -> None: + """Group for blacklisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @whitelist.command(name="add", aliases=("a", "set")) + async def allow_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified allowlist.""" + await self._add_data(ctx, True, list_type, content, comment) + + @blacklist.command(name="add", aliases=("a", "set")) + async def deny_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified denylist.""" + await self._add_data(ctx, False, list_type, content, comment) + + @whitelist.command(name="remove", aliases=("delete", "rm",)) + async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified allowlist.""" + await self._delete_data(ctx, True, list_type, content) + + @blacklist.command(name="remove", aliases=("delete", "rm",)) + async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified denylist.""" + await self._delete_data(ctx, False, list_type, content) + + @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified allowlist.""" + await self._list_all_data(ctx, True, list_type) + + @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified denylist.""" + await self._list_all_data(ctx, False, list_type) + + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the FilterLists cog.""" + bot.add_cog(FilterLists(bot)) diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 8897cbaf9..652af5ff5 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -99,9 +99,9 @@ class Filtering(Cog): self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) - def _get_allowlist_items(self, list_type: str, *, allowed: bool, compiled: Optional[bool] = False) -> list: - """Fetch items from the allow_deny_list_cache.""" - items = self.bot.allow_deny_list_cache.get(f"{list_type.upper()}.{allowed}", []) + def _get_filterlist_items(self, list_type: str, *, allowed: bool, compiled: Optional[bool] = False) -> list: + """Fetch items from the filter_list_cache.""" + items = self.bot.filter_list_cache.get(f"{list_type.upper()}.{allowed}", []) if compiled: return [re.compile(fr'{item["content"]}', flags=re.IGNORECASE) for item in items] @@ -143,7 +143,7 @@ class Filtering(Cog): def get_name_matches(self, name: str) -> List[re.Match]: """Check bad words from passed string (name). Return list of matches.""" matches = [] - watchlist_patterns = self._get_allowlist_items('word_watchlist', allowed=False, compiled=True) + watchlist_patterns = self._get_filterlist_items('word_watchlist', allowed=False, compiled=True) for pattern in watchlist_patterns: if match := pattern.search(name): matches.append(match) @@ -408,7 +408,7 @@ class Filtering(Cog): if URL_RE.search(text): return False - watchlist_patterns = self._get_allowlist_items('word_watchlist', allowed=False, compiled=True) + watchlist_patterns = self._get_filterlist_items('word_watchlist', allowed=False, compiled=True) for pattern in watchlist_patterns: match = pattern.search(text) if match: @@ -420,7 +420,7 @@ class Filtering(Cog): return False text = text.lower() - domain_blacklist = self._get_allowlist_items("domain_name", allowed=False) + domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) for url in domain_blacklist: if url.lower() in text: @@ -468,8 +468,8 @@ class Filtering(Cog): return True guild_id = guild.get("id") - guild_invite_whitelist = self._get_allowlist_items("guild_invite", allowed=True) - guild_invite_blacklist = self._get_allowlist_items("guild_invite", allowed=False) + guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) + guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) # Is this invite allowed? guild_partnered_or_verified = ( diff --git a/bot/converters.py b/bot/converters.py index 41cd3f3e5..158bf1a16 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -72,18 +72,18 @@ class ValidDiscordServerInvite(Converter): raise BadArgument("This does not appear to be a valid Discord server invite.") -class ValidAllowDenyListType(Converter): +class ValidFilterListType(Converter): """ - A converter that checks whether the given string is a valid AllowDenyList type. + A converter that checks whether the given string is a valid FilterList type. - Raises `BadArgument` if the argument is not a valid AllowDenyList type, and simply + Raises `BadArgument` if the argument is not a valid FilterList type, and simply passes through the given argument otherwise. """ async def convert(self, ctx: Context, list_type: str) -> str: - """Checks whether the given string is a valid AllowDenyList type.""" + """Checks whether the given string is a valid FilterList type.""" try: - valid_types = await ctx.bot.api_client.get('bot/allow_deny_lists/get_types') + valid_types = await ctx.bot.api_client.get('bot/filter-lists/get-types') except ResponseCodeError: raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.") diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 1e010d2ce..664fa8f19 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -14,7 +14,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Sets up fresh objects for each test.""" self.bot = MockBot() - self.bot.allow_deny_list_cache = { + self.bot.filter_list_cache = { "file_format.True": [ {"content": ".first"}, {"content": ".second"}, -- cgit v1.2.3 From e0837f4f6dd7c5c2d6fc0811dccfaf1ecae768ba Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 29 Jul 2020 20:14:52 +0200 Subject: Restructure bot.filter_list_cache. This is an optimization designed to eliminate all the list comprehensions we were doing inside antimalware and filtering. The cache is now structured so that the content is the key and the metadata is the value. --- bot/bot.py | 8 ++++---- bot/cogs/antimalware.py | 2 +- bot/cogs/filter_lists.py | 18 +++++++++--------- bot/cogs/filtering.py | 3 +-- tests/bot/cogs/test_antimalware.py | 10 +++++----- 5 files changed, 20 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/bot/bot.py b/bot/bot.py index 5deb986ec..4492feaa9 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -35,7 +35,7 @@ class Bot(commands.Bot): self.redis_ready = asyncio.Event() self.redis_closed = False self.api_client = api.APIClient(loop=self.loop) - self.filter_list_cache = defaultdict(list) + self.filter_list_cache = defaultdict(dict) self._connector = None self._resolver = None @@ -169,14 +169,14 @@ class Bot(commands.Bot): """Add an item to the bots filter_list_cache.""" type_ = item["type"] allowed = item["allowed"] - metadata = { + content = item["content"] + + self.filter_list_cache[f"{type_}.{allowed}"][content] = { "id": item["id"], - "content": item["content"], "comment": item["comment"], "created_at": item["created_at"], "updated_at": item["updated_at"], } - self.filter_list_cache[f"{type_}.{allowed}"].append(metadata) async def login(self, *args, **kwargs) -> None: """Re-create the connector and set up sessions before logging into Discord.""" diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 9a100b3fc..c76bd2c60 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -40,7 +40,7 @@ class AntiMalware(Cog): def _get_whitelisted_file_formats(self) -> list: """Get the file formats currently on the whitelist.""" - return [item['content'] for item in self.bot.filter_list_cache['file_format.True']] + return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: """Get an iterable containing all the disallowed extensions of attachments.""" diff --git a/bot/cogs/filter_lists.py b/bot/cogs/filter_lists.py index a93de2de9..3331be014 100644 --- a/bot/cogs/filter_lists.py +++ b/bot/cogs/filter_lists.py @@ -88,16 +88,16 @@ class FilterLists(Cog): # Find the content and delete it. log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") - for allow_list in self.bot.filter_list_cache[f"{list_type}.{allowed}"]: - if content == allow_list.get("content"): - item = allow_list + for allow_list, metadata in self.bot.filter_list_cache[f"{list_type}.{allowed}"].items(): + if content == allow_list: + item = metadata break if item is not None: await self.bot.api_client.delete( - f"bot/filter-lists/{item.get('id')}" + f"bot/filter-lists/{item['id']}" ) - self.bot.filter_list_cache[f"{list_type}.{allowed}"].remove(item) + del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] await ctx.message.add_reaction("✅") async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: @@ -107,11 +107,11 @@ class FilterLists(Cog): # Build a list of lines we want to show in the paginator lines = [] - for item in result: - line = f"• `{item.get('content')}`" + for content, metadata in result.items(): + line = f"• `{content}`" - if item.get("comment"): - line += f" - {item.get('comment')}" + if metadata.get("comment"): + line += f" - {metadata.get('comment')}" lines.append(line) lines = sorted(lines) diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 7787d396d..0951cb740 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -101,8 +101,7 @@ class Filtering(Cog): def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: """Fetch items from the filter_list_cache.""" - items = self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"] - return [item["content"] for item in items] + return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() @staticmethod def _expand_spoilers(text: str) -> str: diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 664fa8f19..82eadf226 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -15,11 +15,11 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Sets up fresh objects for each test.""" self.bot = MockBot() self.bot.filter_list_cache = { - "file_format.True": [ - {"content": ".first"}, - {"content": ".second"}, - {"content": ".third"} - ] + "file_format.True": { + ".first": {}, + ".second": {}, + ".third": {}, + } } self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() -- cgit v1.2.3 From 0cfc918c6d68764c380f1188f3bc5508e6b27030 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 29 Jul 2020 20:24:06 +0200 Subject: Fix broken antimalware tests. --- tests/bot/cogs/test_antimalware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 82eadf226..ecb7abf00 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -15,7 +15,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Sets up fresh objects for each test.""" self.bot = MockBot() self.bot.filter_list_cache = { - "file_format.True": { + "FILE_FORMAT.True": { ".first": {}, ".second": {}, ".third": {}, -- cgit v1.2.3 From 0fca2445e2979d6e4bebf6a974c974a5ddd14fbe Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 14 Jun 2020 17:29:43 -0700 Subject: Move extensions into sub-directories --- bot/cogs/alias.py | 2 +- bot/cogs/antimalware.py | 98 ---- bot/cogs/antispam.py | 288 ----------- bot/cogs/backend/__init__.py | 0 bot/cogs/backend/config_verifier.py | 40 ++ bot/cogs/backend/error_handler.py | 287 +++++++++++ bot/cogs/backend/logging.py | 42 ++ bot/cogs/backend/sync/__init__.py | 7 + bot/cogs/backend/sync/cog.py | 180 +++++++ bot/cogs/backend/sync/syncers.py | 347 +++++++++++++ bot/cogs/bot.py | 385 --------------- bot/cogs/clean.py | 272 ---------- bot/cogs/config_verifier.py | 40 -- bot/cogs/defcon.py | 258 ---------- bot/cogs/doc.py | 511 ------------------- bot/cogs/error_handler.py | 287 ----------- bot/cogs/eval.py | 202 -------- bot/cogs/extensions.py | 236 --------- bot/cogs/filter_lists.py | 273 ---------- bot/cogs/filtering.py | 575 ---------------------- bot/cogs/filters/__init__.py | 0 bot/cogs/filters/antimalware.py | 98 ++++ bot/cogs/filters/antispam.py | 288 +++++++++++ bot/cogs/filters/filter_lists.py | 273 ++++++++++ bot/cogs/filters/filtering.py | 575 ++++++++++++++++++++++ bot/cogs/filters/security.py | 31 ++ bot/cogs/filters/token_remover.py | 182 +++++++ bot/cogs/filters/webhook_remover.py | 84 ++++ bot/cogs/help.py | 375 -------------- bot/cogs/info/__init__.py | 0 bot/cogs/info/doc.py | 511 +++++++++++++++++++ bot/cogs/info/help.py | 375 ++++++++++++++ bot/cogs/info/information.py | 422 ++++++++++++++++ bot/cogs/info/python_news.py | 232 +++++++++ bot/cogs/info/reddit.py | 304 ++++++++++++ bot/cogs/info/site.py | 146 ++++++ bot/cogs/info/source.py | 141 ++++++ bot/cogs/info/stats.py | 129 +++++ bot/cogs/info/tags.py | 277 +++++++++++ bot/cogs/info/wolfram.py | 280 +++++++++++ bot/cogs/information.py | 422 ---------------- bot/cogs/jams.py | 150 ------ bot/cogs/logging.py | 42 -- bot/cogs/moderation/__init__.py | 6 +- bot/cogs/moderation/defcon.py | 258 ++++++++++ bot/cogs/moderation/infraction/__init__.py | 0 bot/cogs/moderation/infraction/infractions.py | 370 ++++++++++++++ bot/cogs/moderation/infraction/management.py | 305 ++++++++++++ bot/cogs/moderation/infraction/scheduler.py | 463 +++++++++++++++++ bot/cogs/moderation/infraction/superstarify.py | 239 +++++++++ bot/cogs/moderation/infraction/utils.py | 201 ++++++++ bot/cogs/moderation/infractions.py | 370 -------------- bot/cogs/moderation/management.py | 305 ------------ bot/cogs/moderation/scheduler.py | 463 ----------------- bot/cogs/moderation/superstarify.py | 239 --------- bot/cogs/moderation/utils.py | 201 -------- bot/cogs/moderation/verification.py | 191 +++++++ bot/cogs/moderation/watchchannels/__init__.py | 9 + bot/cogs/moderation/watchchannels/bigbrother.py | 165 +++++++ bot/cogs/moderation/watchchannels/talentpool.py | 264 ++++++++++ bot/cogs/moderation/watchchannels/watchchannel.py | 348 +++++++++++++ bot/cogs/python_news.py | 232 --------- bot/cogs/reddit.py | 304 ------------ bot/cogs/reminders.py | 427 ---------------- bot/cogs/security.py | 31 -- bot/cogs/site.py | 146 ------ bot/cogs/snekbox.py | 349 ------------- bot/cogs/source.py | 141 ------ bot/cogs/stats.py | 129 ----- bot/cogs/sync/__init__.py | 7 - bot/cogs/sync/cog.py | 180 ------- bot/cogs/sync/syncers.py | 347 ------------- bot/cogs/tags.py | 277 ----------- bot/cogs/token_remover.py | 182 ------- bot/cogs/utils.py | 265 ---------- bot/cogs/utils/__init__.py | 0 bot/cogs/utils/bot.py | 385 +++++++++++++++ bot/cogs/utils/clean.py | 272 ++++++++++ bot/cogs/utils/eval.py | 202 ++++++++ bot/cogs/utils/extensions.py | 236 +++++++++ bot/cogs/utils/jams.py | 150 ++++++ bot/cogs/utils/reminders.py | 427 ++++++++++++++++ bot/cogs/utils/snekbox.py | 349 +++++++++++++ bot/cogs/utils/utils.py | 265 ++++++++++ bot/cogs/verification.py | 191 ------- bot/cogs/watchchannels/__init__.py | 9 - bot/cogs/watchchannels/bigbrother.py | 165 ------- bot/cogs/watchchannels/talentpool.py | 264 ---------- bot/cogs/watchchannels/watchchannel.py | 348 ------------- bot/cogs/webhook_remover.py | 84 ---- bot/cogs/wolfram.py | 280 ----------- tests/bot/cogs/moderation/test_infractions.py | 2 +- tests/bot/cogs/sync/test_base.py | 2 +- tests/bot/cogs/sync/test_cog.py | 4 +- tests/bot/cogs/sync/test_roles.py | 2 +- tests/bot/cogs/sync/test_users.py | 2 +- tests/bot/cogs/test_antimalware.py | 2 +- tests/bot/cogs/test_antispam.py | 2 +- tests/bot/cogs/test_information.py | 2 +- tests/bot/cogs/test_security.py | 2 +- tests/bot/cogs/test_snekbox.py | 4 +- tests/bot/cogs/test_token_remover.py | 4 +- 102 files changed, 10368 insertions(+), 10368 deletions(-) delete mode 100644 bot/cogs/antimalware.py delete mode 100644 bot/cogs/antispam.py create mode 100644 bot/cogs/backend/__init__.py create mode 100644 bot/cogs/backend/config_verifier.py create mode 100644 bot/cogs/backend/error_handler.py create mode 100644 bot/cogs/backend/logging.py create mode 100644 bot/cogs/backend/sync/__init__.py create mode 100644 bot/cogs/backend/sync/cog.py create mode 100644 bot/cogs/backend/sync/syncers.py delete mode 100644 bot/cogs/bot.py delete mode 100644 bot/cogs/clean.py delete mode 100644 bot/cogs/config_verifier.py delete mode 100644 bot/cogs/defcon.py delete mode 100644 bot/cogs/doc.py delete mode 100644 bot/cogs/error_handler.py delete mode 100644 bot/cogs/eval.py delete mode 100644 bot/cogs/extensions.py delete mode 100644 bot/cogs/filter_lists.py delete mode 100644 bot/cogs/filtering.py create mode 100644 bot/cogs/filters/__init__.py create mode 100644 bot/cogs/filters/antimalware.py create mode 100644 bot/cogs/filters/antispam.py create mode 100644 bot/cogs/filters/filter_lists.py create mode 100644 bot/cogs/filters/filtering.py create mode 100644 bot/cogs/filters/security.py create mode 100644 bot/cogs/filters/token_remover.py create mode 100644 bot/cogs/filters/webhook_remover.py delete mode 100644 bot/cogs/help.py create mode 100644 bot/cogs/info/__init__.py create mode 100644 bot/cogs/info/doc.py create mode 100644 bot/cogs/info/help.py create mode 100644 bot/cogs/info/information.py create mode 100644 bot/cogs/info/python_news.py create mode 100644 bot/cogs/info/reddit.py create mode 100644 bot/cogs/info/site.py create mode 100644 bot/cogs/info/source.py create mode 100644 bot/cogs/info/stats.py create mode 100644 bot/cogs/info/tags.py create mode 100644 bot/cogs/info/wolfram.py delete mode 100644 bot/cogs/information.py delete mode 100644 bot/cogs/jams.py delete mode 100644 bot/cogs/logging.py create mode 100644 bot/cogs/moderation/defcon.py create mode 100644 bot/cogs/moderation/infraction/__init__.py create mode 100644 bot/cogs/moderation/infraction/infractions.py create mode 100644 bot/cogs/moderation/infraction/management.py create mode 100644 bot/cogs/moderation/infraction/scheduler.py create mode 100644 bot/cogs/moderation/infraction/superstarify.py create mode 100644 bot/cogs/moderation/infraction/utils.py delete mode 100644 bot/cogs/moderation/infractions.py delete mode 100644 bot/cogs/moderation/management.py delete mode 100644 bot/cogs/moderation/scheduler.py delete mode 100644 bot/cogs/moderation/superstarify.py delete mode 100644 bot/cogs/moderation/utils.py create mode 100644 bot/cogs/moderation/verification.py create mode 100644 bot/cogs/moderation/watchchannels/__init__.py create mode 100644 bot/cogs/moderation/watchchannels/bigbrother.py create mode 100644 bot/cogs/moderation/watchchannels/talentpool.py create mode 100644 bot/cogs/moderation/watchchannels/watchchannel.py delete mode 100644 bot/cogs/python_news.py delete mode 100644 bot/cogs/reddit.py delete mode 100644 bot/cogs/reminders.py delete mode 100644 bot/cogs/security.py delete mode 100644 bot/cogs/site.py delete mode 100644 bot/cogs/snekbox.py delete mode 100644 bot/cogs/source.py delete mode 100644 bot/cogs/stats.py delete mode 100644 bot/cogs/sync/__init__.py delete mode 100644 bot/cogs/sync/cog.py delete mode 100644 bot/cogs/sync/syncers.py delete mode 100644 bot/cogs/tags.py delete mode 100644 bot/cogs/token_remover.py delete mode 100644 bot/cogs/utils.py create mode 100644 bot/cogs/utils/__init__.py create mode 100644 bot/cogs/utils/bot.py create mode 100644 bot/cogs/utils/clean.py create mode 100644 bot/cogs/utils/eval.py create mode 100644 bot/cogs/utils/extensions.py create mode 100644 bot/cogs/utils/jams.py create mode 100644 bot/cogs/utils/reminders.py create mode 100644 bot/cogs/utils/snekbox.py create mode 100644 bot/cogs/utils/utils.py delete mode 100644 bot/cogs/verification.py delete mode 100644 bot/cogs/watchchannels/__init__.py delete mode 100644 bot/cogs/watchchannels/bigbrother.py delete mode 100644 bot/cogs/watchchannels/talentpool.py delete mode 100644 bot/cogs/watchchannels/watchchannel.py delete mode 100644 bot/cogs/webhook_remover.py delete mode 100644 bot/cogs/wolfram.py (limited to 'tests') diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 55c7efe65..3c5a35c24 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -8,7 +8,7 @@ from discord.ext.commands import ( ) from bot.bot import Bot -from bot.cogs.extensions import Extension +from bot.cogs.utils.extensions import Extension from bot.converters import FetchedMember, TagNameConverter from bot.pagination import LinePaginator diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py deleted file mode 100644 index c76bd2c60..000000000 --- a/bot/cogs/antimalware.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging -import typing as t -from os.path import splitext - -from discord import Embed, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, STAFF_ROLES, URLs - -log = logging.getLogger(__name__) - -PY_EMBED_DESCRIPTION = ( - "It looks like you tried to attach a Python file - " - f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" -) - -TXT_EMBED_DESCRIPTION = ( - "**Uh-oh!** It looks like your message got zapped by our spam filter. " - "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" - "• If you attempted to send a message longer than 2000 characters, try shortening your message " - "to fit within the character limit or use a pasting service (see below) \n\n" - "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " - "{cmd_channel_mention} for more information) or use a pasting service like: " - f"\n\n{URLs.site_schema}{URLs.site_paste}" -) - -DISALLOWED_EMBED_DESCRIPTION = ( - "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " - "We currently allow the following file types: **{joined_whitelist}**.\n\n" - "Feel free to ask in {meta_channel_mention} if you think this is a mistake." -) - - -class AntiMalware(Cog): - """Delete messages which contain attachments with non-whitelisted file extensions.""" - - def __init__(self, bot: Bot): - self.bot = bot - - def _get_whitelisted_file_formats(self) -> list: - """Get the file formats currently on the whitelist.""" - return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() - - def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: - """Get an iterable containing all the disallowed extensions of attachments.""" - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} - extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) - return extensions_blocked - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Identify messages with prohibited attachments.""" - # Return when message don't have attachment and don't moderate DMs - if not message.attachments or not message.guild: - return - - # Check if user is staff, if is, return - # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance - if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): - return - - embed = Embed() - extensions_blocked = self._get_disallowed_extensions(message) - blocked_extensions_str = ', '.join(extensions_blocked) - if ".py" in extensions_blocked: - # Short-circuit on *.py files to provide a pastebin link - embed.description = PY_EMBED_DESCRIPTION - elif ".txt" in extensions_blocked: - # Work around Discord AutoConversion of messages longer than 2000 chars to .txt - cmd_channel = self.bot.get_channel(Channels.bot_commands) - embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) - elif extensions_blocked: - meta_channel = self.bot.get_channel(Channels.meta) - embed.description = DISALLOWED_EMBED_DESCRIPTION.format( - joined_whitelist=', '.join(self._get_whitelisted_file_formats()), - blocked_extensions_str=blocked_extensions_str, - meta_channel_mention=meta_channel.mention, - ) - - if embed.description: - log.info( - f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", - extra={"attachment_list": [attachment.filename for attachment in message.attachments]} - ) - - await message.channel.send(f"Hey {message.author.mention}!", embed=embed) - - # Delete the offending message: - try: - await message.delete() - except NotFound: - log.info(f"Tried to delete message `{message.id}`, but message could not be found.") - - -def setup(bot: Bot) -> None: - """Load the AntiMalware cog.""" - bot.add_cog(AntiMalware(bot)) diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py deleted file mode 100644 index 0bcca578d..000000000 --- a/bot/cogs/antispam.py +++ /dev/null @@ -1,288 +0,0 @@ -import asyncio -import logging -from collections.abc import Mapping -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from operator import itemgetter -from typing import Dict, Iterable, List, Set - -from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Cog - -from bot import rules -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import ( - AntiSpam as AntiSpamConfig, Channels, - Colours, DEBUG_MODE, Event, Filter, - Guild as GuildConfig, Icons, - STAFF_ROLES, -) -from bot.converters import Duration -from bot.utils.messages import send_attachments - - -log = logging.getLogger(__name__) - -RULE_FUNCTION_MAPPING = { - 'attachments': rules.apply_attachments, - 'burst': rules.apply_burst, - 'burst_shared': rules.apply_burst_shared, - 'chars': rules.apply_chars, - 'discord_emojis': rules.apply_discord_emojis, - 'duplicates': rules.apply_duplicates, - 'links': rules.apply_links, - 'mentions': rules.apply_mentions, - 'newlines': rules.apply_newlines, - 'role_mentions': rules.apply_role_mentions -} - - -@dataclass -class DeletionContext: - """Represents a Deletion Context for a single spam event.""" - - channel: TextChannel - members: Dict[int, Member] = field(default_factory=dict) - rules: Set[str] = field(default_factory=set) - messages: Dict[int, Message] = field(default_factory=dict) - attachments: List[List[str]] = field(default_factory=list) - - async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: - """Adds new rule violation events to the deletion context.""" - self.rules.add(rule_name) - - for member in members: - if member.id not in self.members: - self.members[member.id] = member - - for message in messages: - if message.id not in self.messages: - self.messages[message.id] = message - - # Re-upload attachments - destination = message.guild.get_channel(Channels.attachment_log) - urls = await send_attachments(message, destination, link_large=False) - self.attachments.append(urls) - - async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: - """Method that takes care of uploading the queue and posting modlog alert.""" - triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values()) - - mod_alert_message = ( - f"**Triggered by:** {triggered_by_users}\n" - f"**Channel:** {self.channel.mention}\n" - f"**Rules:** {', '.join(rule for rule in self.rules)}\n" - ) - - # For multiple messages or those with excessive newlines, use the logs API - if len(self.messages) > 1 or 'newlines' in self.rules: - url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) - mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" - else: - mod_alert_message += "Message:\n" - [message] = self.messages.values() - content = message.clean_content - remaining_chars = 2040 - len(mod_alert_message) - - if len(content) > remaining_chars: - content = content[:remaining_chars] + "..." - - mod_alert_message += f"{content}" - - *_, last_message = self.messages.values() - await modlog.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title="Spam detected!", - text=mod_alert_message, - thumbnail=last_message.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=AntiSpamConfig.ping_everyone - ) - - -class AntiSpam(Cog): - """Cog that controls our anti-spam measures.""" - - def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: - self.bot = bot - self.validation_errors = validation_errors - role_id = AntiSpamConfig.punishment['role_id'] - self.muted_role = Object(role_id) - self.expiration_date_converter = Duration() - - self.message_deletion_queue = dict() - - self.bot.loop.create_task(self.alert_on_validation_error()) - - @property - def mod_log(self) -> ModLog: - """Allows for easy access of the ModLog cog.""" - return self.bot.get_cog("ModLog") - - async def alert_on_validation_error(self) -> None: - """Unloads the cog and alerts admins if configuration validation failed.""" - await self.bot.wait_until_guild_available() - if self.validation_errors: - body = "**The following errors were encountered:**\n" - body += "\n".join(f"- {error}" for error in self.validation_errors.values()) - body += "\n\n**The cog has been unloaded.**" - - await self.mod_log.send_log_message( - title="Error: AntiSpam configuration validation failed!", - text=body, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Colour.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Applies the antispam rules to each received message.""" - if ( - not message.guild - or message.guild.id != GuildConfig.id - or message.author.bot - or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) - or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) - ): - return - - # Fetch the rule configuration with the highest rule interval. - max_interval_config = max( - AntiSpamConfig.rules.values(), - key=itemgetter('interval') - ) - max_interval = max_interval_config['interval'] - - # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. - earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) - relevant_messages = [ - msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) - if not msg.author.bot - ] - - for rule_name in AntiSpamConfig.rules: - rule_config = AntiSpamConfig.rules[rule_name] - rule_function = RULE_FUNCTION_MAPPING[rule_name] - - # Create a list of messages that were sent in the interval that the rule cares about. - latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) - messages_for_rule = [ - msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp - ] - result = await rule_function(message, messages_for_rule, rule_config) - - # If the rule returns `None`, that means the message didn't violate it. - # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` - # which contains the reason for why the message violated the rule and - # an iterable of all members that violated the rule. - if result is not None: - self.bot.stats.incr(f"mod_alerts.{rule_name}") - reason, members, relevant_messages = result - full_reason = f"`{rule_name}` rule: {reason}" - - # If there's no spam event going on for this channel, start a new Message Deletion Context - channel = message.channel - if channel.id not in self.message_deletion_queue: - log.trace(f"Creating queue for channel `{channel.id}`") - self.message_deletion_queue[message.channel.id] = DeletionContext(channel) - self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) - - # Add the relevant of this trigger to the Deletion Context - await self.message_deletion_queue[message.channel.id].add( - rule_name=rule_name, - members=members, - messages=relevant_messages - ) - - for member in members: - - # Fire it off as a background task to ensure - # that the sleep doesn't block further tasks - self.bot.loop.create_task( - self.punish(message, member, full_reason) - ) - - await self.maybe_delete_messages(channel, relevant_messages) - break - - async def punish(self, msg: Message, member: Member, reason: str) -> None: - """Punishes the given member for triggering an antispam rule.""" - if not any(role.id == self.muted_role.id for role in member.roles): - remove_role_after = AntiSpamConfig.punishment['remove_after'] - - # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes - context = await self.bot.get_context(msg) - context.author = self.bot.user - context.message.author = self.bot.user - - # Since we're going to invoke the tempmute command directly, we need to manually call the converter. - dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S") - await context.invoke( - self.bot.get_command('tempmute'), - member, - dt_remove_role_after, - reason=reason - ) - - async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: - """Cleans the messages if cleaning is configured.""" - if AntiSpamConfig.clean_offending: - # If we have more than one message, we can use bulk delete. - if len(messages) > 1: - message_ids = [message.id for message in messages] - self.mod_log.ignore(Event.message_delete, *message_ids) - await channel.delete_messages(messages) - - # Otherwise, the bulk delete endpoint will throw up. - # Delete the message directly instead. - else: - self.mod_log.ignore(Event.message_delete, messages[0].id) - try: - await messages[0].delete() - except NotFound: - log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") - - async def _process_deletion_context(self, context_id: int) -> None: - """Processes the Deletion Context queue.""" - log.trace("Sleeping before processing message deletion queue.") - await asyncio.sleep(10) - - if context_id not in self.message_deletion_queue: - log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") - return - - deletion_context = self.message_deletion_queue.pop(context_id) - await deletion_context.upload_messages(self.bot.user.id, self.mod_log) - - -def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: - """Validates the antispam configs.""" - validation_errors = {} - for name, config in rules_.items(): - if name not in RULE_FUNCTION_MAPPING: - log.error( - f"Unrecognized antispam rule `{name}`. " - f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" - ) - validation_errors[name] = f"`{name}` is not recognized as an antispam rule." - continue - for required_key in ('interval', 'max'): - if required_key not in config: - log.error( - f"`{required_key}` is required but was not " - f"set in rule `{name}`'s configuration." - ) - validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" - return validation_errors - - -def setup(bot: Bot) -> None: - """Validate the AntiSpam configs and load the AntiSpam cog.""" - validation_errors = validate_config() - bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/cogs/backend/__init__.py b/bot/cogs/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/backend/config_verifier.py b/bot/cogs/backend/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/cogs/backend/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): + """Verify config on startup.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + + async def verify_channels(self) -> None: + """ + Verify channels. + + If any channels in config aren't present in server, log them in a warning. + """ + await self.bot.wait_until_guild_available() + server = self.bot.get_guild(constants.Guild.id) + + server_channel_ids = {channel.id for channel in server.channels} + invalid_channels = [ + channel_name for channel_name, channel_id in constants.Channels + if channel_id not in server_channel_ids + ] + + if invalid_channels: + log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: + """Load the ConfigVerifier cog.""" + bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/backend/error_handler.py b/bot/cogs/backend/error_handler.py new file mode 100644 index 000000000..f9d4de638 --- /dev/null +++ b/bot/cogs/backend/error_handler.py @@ -0,0 +1,287 @@ +import contextlib +import logging +import typing as t + +from discord import Embed +from discord.ext.commands import Cog, Context, errors +from sentry_sdk import push_scope + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Colours +from bot.converters import TagNameConverter +from bot.utils.checks import InWhitelistCheckFailure + +log = logging.getLogger(__name__) + + +class ErrorHandler(Cog): + """Handles errors emitted from commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_error_embed(self, title: str, body: str) -> Embed: + """Return an embed that contains the exception.""" + return Embed( + title=title, + colour=Colours.soft_red, + description=body + ) + + @Cog.listener() + async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: + """ + Provide generic command error handling. + + Error handling is deferred to any local error handler, if present. This is done by + checking for the presence of a `handled` attribute on the error. + + Error handling emits a single error message in the invoking context `ctx` and a log message, + prioritised as follows: + + 1. If the name fails to match a command: + * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. + Otherwise if it matches a tag, the tag is invoked + * If CommandNotFound is raised when invoking the tag (determined by the presence of the + `invoked_from_error_handler` attribute), this error is treated as being unexpected + and therefore sends an error message + * Commands in the verification channel are ignored + 2. UserInputError: see `handle_user_input_error` + 3. CheckFailure: see `handle_check_failure` + 4. CommandOnCooldown: send an error message in the invoking context + 5. ResponseCodeError: see `handle_api_error` + 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` + """ + command = ctx.command + + if hasattr(e, "handled"): + log.trace(f"Command {command} had its error already handled locally; ignoring.") + return + + if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + if await self.try_silence(ctx): + return + if ctx.channel.id != Channels.verification: + # Try to look for a tag with the command's name + await self.try_get_tag(ctx) + return # Exit early to avoid logging. + elif isinstance(e, errors.UserInputError): + await self.handle_user_input_error(ctx, e) + elif isinstance(e, errors.CheckFailure): + await self.handle_check_failure(ctx, e) + elif isinstance(e, errors.CommandOnCooldown): + await ctx.send(e) + elif isinstance(e, errors.CommandInvokeError): + if isinstance(e.original, ResponseCodeError): + await self.handle_api_error(ctx, e.original) + else: + await self.handle_unexpected_error(ctx, e.original) + return # Exit early to avoid logging. + elif not isinstance(e, errors.DisabledCommand): + # ConversionError, MaxConcurrencyReached, ExtensionError + await self.handle_unexpected_error(ctx, e) + return # Exit early to avoid logging. + + log.debug( + f"Command {command} invoked by {ctx.message.author} with error " + f"{e.__class__.__name__}: {e}" + ) + + @staticmethod + def get_help_command(ctx: Context) -> t.Coroutine: + """Return a prepared `help` command invocation coroutine.""" + if ctx.command: + return ctx.send_help(ctx.command) + + return ctx.send_help() + + async def try_silence(self, ctx: Context) -> bool: + """ + Attempt to invoke the silence or unsilence command if invoke with matches a pattern. + + Respecting the checks if: + * invoked with `shh+` silence channel for amount of h's*2 with max of 15. + * invoked with `unshh+` unsilence channel + Return bool depending on success of command. + """ + command = ctx.invoked_with.lower() + silence_command = self.bot.get_command("silence") + ctx.invoked_from_error_handler = True + try: + if not await silence_command.can_run(ctx): + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + except errors.CommandError: + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + if command.startswith("shh"): + await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) + return True + elif command.startswith("unshh"): + await ctx.invoke(self.bot.get_command("unsilence")) + return True + return False + + async def try_get_tag(self, ctx: Context) -> None: + """ + Attempt to display a tag by interpreting the command name as a tag name. + + The invocation of tags get respects its checks. Any CommandErrors raised will be handled + by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to + the context to prevent infinite recursion in the case of a CommandNotFound exception. + """ + tags_get_command = self.bot.get_command("tags get") + ctx.invoked_from_error_handler = True + + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except errors.CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) + return + + try: + tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) + except errors.BadArgument: + log.debug( + f"{ctx.author} tried to use an invalid command " + f"and the fallback tag failed validation in TagNameConverter." + ) + else: + with contextlib.suppress(ResponseCodeError): + await ctx.invoke(tags_get_command, tag_name=tag_name) + # Return to not raise the exception + return + + async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: + """ + Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. + + * MissingRequiredArgument: send an error message with arg name and the help command + * TooManyArguments: send an error message and the help command + * BadArgument: send an error message and the help command + * BadUnionArgument: send an error message including the error produced by the last converter + * ArgumentParsingError: send an error message + * Other: send an error message and the help command + """ + prepared_help_command = self.get_help_command(ctx) + + if isinstance(e, errors.MissingRequiredArgument): + embed = self._get_error_embed("Missing required argument", e.param.name) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.missing_required_argument") + elif isinstance(e, errors.TooManyArguments): + embed = self._get_error_embed("Too many arguments", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.too_many_arguments") + elif isinstance(e, errors.BadArgument): + embed = self._get_error_embed("Bad argument", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.bad_argument") + elif isinstance(e, errors.BadUnionArgument): + embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") + await ctx.send(embed=embed) + self.bot.stats.incr("errors.bad_union_argument") + elif isinstance(e, errors.ArgumentParsingError): + embed = self._get_error_embed("Argument parsing error", str(e)) + await ctx.send(embed=embed) + self.bot.stats.incr("errors.argument_parsing_error") + else: + embed = self._get_error_embed( + "Input error", + "Something about your input seems off. Check the arguments and try again." + ) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.other_user_input_error") + + @staticmethod + async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: + """ + Send an error message in `ctx` for certain types of CheckFailure. + + The following types are handled: + + * BotMissingPermissions + * BotMissingRole + * BotMissingAnyRole + * NoPrivateMessage + * InWhitelistCheckFailure + """ + bot_missing_errors = ( + errors.BotMissingPermissions, + errors.BotMissingRole, + errors.BotMissingAnyRole + ) + + if isinstance(e, bot_missing_errors): + ctx.bot.stats.incr("errors.bot_permission_error") + await ctx.send( + "Sorry, it looks like I don't have the permissions or roles I need to do that." + ) + elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): + ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") + await ctx.send(e) + + @staticmethod + async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: + """Send an error message in `ctx` for ResponseCodeError and log it.""" + if e.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + log.debug(f"API responded with 404 for command {ctx.command}") + ctx.bot.stats.incr("errors.api_error_404") + elif e.status == 400: + content = await e.response.json() + log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + await ctx.send("According to the API, your request is malformed.") + ctx.bot.stats.incr("errors.api_error_400") + elif 500 <= e.status < 600: + await ctx.send("Sorry, there seems to be an internal issue with the API.") + log.warning(f"API responded with {e.status} for command {ctx.command}") + ctx.bot.stats.incr("errors.api_internal_server_error") + else: + await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") + log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") + ctx.bot.stats.incr(f"errors.api_error_{e.status}") + + @staticmethod + async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: + """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" + await ctx.send( + f"Sorry, an unexpected error occurred. Please let us know!\n\n" + f"```{e.__class__.__name__}: {e}```" + ) + + ctx.bot.stats.incr("errors.unexpected") + + with push_scope() as scope: + scope.user = { + "id": ctx.author.id, + "username": str(ctx.author) + } + + scope.set_tag("command", ctx.command.qualified_name) + scope.set_tag("message_id", ctx.message.id) + scope.set_tag("channel_id", ctx.channel.id) + + scope.set_extra("full_message", ctx.message.content) + + if ctx.guild is not None: + scope.set_extra( + "jump_to", + f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" + ) + + log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) + + +def setup(bot: Bot) -> None: + """Load the ErrorHandler cog.""" + bot.add_cog(ErrorHandler(bot)) diff --git a/bot/cogs/backend/logging.py b/bot/cogs/backend/logging.py new file mode 100644 index 000000000..94fa2b139 --- /dev/null +++ b/bot/cogs/backend/logging.py @@ -0,0 +1,42 @@ +import logging + +from discord import Embed +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, DEBUG_MODE + + +log = logging.getLogger(__name__) + + +class Logging(Cog): + """Debug logging module.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.bot.loop.create_task(self.startup_greeting()) + + async def startup_greeting(self) -> None: + """Announce our presence to the configured devlog channel.""" + await self.bot.wait_until_guild_available() + log.info("Bot connected!") + + embed = Embed(description="Connected!") + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=( + "https://raw.githubusercontent.com/" + "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" + ) + ) + + if not DEBUG_MODE: + await self.bot.get_channel(Channels.dev_log).send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the Logging cog.""" + bot.add_cog(Logging(bot)) diff --git a/bot/cogs/backend/sync/__init__.py b/bot/cogs/backend/sync/__init__.py new file mode 100644 index 000000000..fe7df4e9b --- /dev/null +++ b/bot/cogs/backend/sync/__init__.py @@ -0,0 +1,7 @@ +from bot.bot import Bot +from .cog import Sync + + +def setup(bot: Bot) -> None: + """Load the Sync cog.""" + bot.add_cog(Sync(bot)) diff --git a/bot/cogs/backend/sync/cog.py b/bot/cogs/backend/sync/cog.py new file mode 100644 index 000000000..274845a50 --- /dev/null +++ b/bot/cogs/backend/sync/cog.py @@ -0,0 +1,180 @@ +import logging +from typing import Any, Dict + +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from . import syncers + +log = logging.getLogger(__name__) + + +class Sync(Cog): + """Captures relevant events and sends them to the site.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.role_syncer = syncers.RoleSyncer(self.bot) + self.user_syncer = syncers.UserSyncer(self.bot) + + self.bot.loop.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Syncs the roles/users of the guild with the database.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: + """Send a PATCH request to partially update a user in the database.""" + try: + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) + except ResponseCodeError as e: + if e.response.status != 404: + raise + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") + + @Cog.listener() + async def on_guild_role_create(self, role: Role) -> None: + """Adds newly create role to the database table over the API.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.post( + 'bot/roles', + json={ + 'colour': role.colour.value, + 'id': role.id, + 'name': role.name, + 'permissions': role.permissions.value, + 'position': role.position, + } + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: Role) -> None: + """Deletes role from the database when it's deleted from the guild.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + @Cog.listener() + async def on_guild_role_update(self, before: Role, after: Role) -> None: + """Syncs role with the database if any of the stored attributes were updated.""" + if after.guild.id != constants.Guild.id: + return + + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: + await self.bot.api_client.put( + f'bot/roles/{after.id}', + json={ + 'colour': after.colour.value, + 'id': after.id, + 'name': after.name, + 'permissions': after.permissions.value, + 'position': after.position, + } + ) + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """ + Adds a new user or updates existing user to the database when a member joins the guild. + + If the joining member is a user that is already known to the database (i.e., a user that + previously left), it will update the user's information. If the user is not yet known by + the database, the user is added. + """ + if member.guild.id != constants.Guild.id: + return + + packed = { + 'discriminator': int(member.discriminator), + 'id': member.id, + 'in_guild': True, + 'name': member.name, + 'roles': sorted(role.id for role in member.roles) + } + + got_error = False + + try: + # First try an update of the user to set the `in_guild` field and other + # fields that may have changed since the last time we've seen them. + await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) + + except ResponseCodeError as e: + # If we didn't get 404, something else broke - propagate it up. + if e.response.status != 404: + raise + + got_error = True # yikes + + if got_error: + # If we got `404`, the user is new. Create them. + await self.bot.api_client.post('bot/users', json=packed) + + @Cog.listener() + async def on_member_remove(self, member: Member) -> None: + """Set the in_guild field to False when a member leaves the guild.""" + if member.guild.id != constants.Guild.id: + return + + await self.patch_user(member.id, json={"in_guild": False}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Update the roles of the member in the database if a change is detected.""" + if after.guild.id != constants.Guild.id: + return + + if before.roles != after.roles: + updated_information = {"roles": sorted(role.id for role in after.roles)} + await self.patch_user(after.id, json=updated_information) + + @Cog.listener() + async def on_user_update(self, before: User, after: User) -> None: + """Update the user information in the database if a relevant change is detected.""" + attrs = ("name", "discriminator") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): + updated_information = { + "name": after.name, + "discriminator": int(after.discriminator), + } + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) + + @commands.group(name='sync') + @commands.has_permissions(administrator=True) + async def sync_group(self, ctx: Context) -> None: + """Run synchronizations between the bot and site manually.""" + + @sync_group.command(name='roles') + @commands.has_permissions(administrator=True) + async def sync_roles_command(self, ctx: Context) -> None: + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) + + @sync_group.command(name='users') + @commands.has_permissions(administrator=True) + async def sync_users_command(self, ctx: Context) -> None: + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/backend/sync/syncers.py b/bot/cogs/backend/sync/syncers.py new file mode 100644 index 000000000..f7ba811bc --- /dev/null +++ b/bot/cogs/backend/sync/syncers.py @@ -0,0 +1,347 @@ +import abc +import asyncio +import logging +import typing as t +from collections import namedtuple +from functools import partial + +import discord +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot + +log = logging.getLogger(__name__) + +# These objects are declared as namedtuples because tuples are hashable, +# something that we make use of when diffing site roles against guild roles. +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) + + +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' + ) + + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + allowed_roles = [discord.Object(constants.Roles.core_developers)] + message = await channel.send( + f"{self._CORE_DEV_MENTION}{msg_content}", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS + ) + + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except asyncio.TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.info(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py deleted file mode 100644 index 79510739c..000000000 --- a/bot/cogs/bot.py +++ /dev/null @@ -1,385 +0,0 @@ -import ast -import logging -import re -import time -from typing import Optional, Tuple - -from discord import Embed, Message, RawMessageUpdateEvent, TextChannel -from discord.ext.commands import Cog, Context, command, group - -from bot.bot import Bot -from bot.cogs.token_remover import TokenRemover -from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs -from bot.decorators import with_role -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -RE_MARKDOWN = re.compile(r'([*_~`|>])') - - -class BotCog(Cog, name="Bot"): - """Bot information commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - # Stores allowed channels plus epoch time since last call. - self.channel_cooldowns = { - Channels.python_discussion: 0, - } - - # These channels will also work, but will not be subject to cooldown - self.channel_whitelist = ( - Channels.bot_commands, - ) - - # Stores improperly formatted Python codeblock message ids and the corresponding bot message - self.codeblock_message_ids = {} - - @group(invoke_without_command=True, name="bot", hidden=True) - @with_role(Roles.verified) - async def botinfo_group(self, ctx: Context) -> None: - """Bot informational commands.""" - await ctx.send_help(ctx.command) - - @botinfo_group.command(name='about', aliases=('info',), hidden=True) - @with_role(Roles.verified) - async def about_command(self, ctx: Context) -> None: - """Get information about the bot.""" - embed = Embed( - description="A utility bot designed just for the Python server! Try `!help` for more info.", - url="https://github.com/python-discord/bot" - ) - - embed.add_field(name="Total Users", value=str(len(self.bot.get_guild(Guild.id).members))) - embed.set_author( - name="Python Bot", - url="https://github.com/python-discord/bot", - icon_url=URLs.bot_avatar - ) - - await ctx.send(embed=embed) - - @command(name='echo', aliases=('print',)) - @with_role(*MODERATION_ROLES) - async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: - """Repeat the given message in either a specified channel or the current channel.""" - if channel is None: - await ctx.send(text) - else: - await channel.send(text) - - @command(name='embed') - @with_role(*MODERATION_ROLES) - async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: - """Send the input within an embed to either a specified channel or the current channel.""" - embed = Embed(description=text) - - if channel is None: - await ctx.send(embed=embed) - else: - await channel.send(embed=embed) - - def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: - """ - Strip msg in order to find Python code. - - Tries to strip out Python code out of msg and returns the stripped block or - None if the block is a valid Python codeblock. - """ - if msg.count("\n") >= 3: - # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. - if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: - log.trace( - "Someone wrote a message that was already a " - "valid Python syntax highlighted code block. No action taken." - ) - return None - - else: - # Stripping backticks from every line of the message. - log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") - content = "" - for line in msg.splitlines(keepends=True): - content += line.strip("`") - - content = content.strip() - - # Remove "Python" or "Py" from start of the message if it exists. - log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") - pycode = False - if content.lower().startswith("python"): - content = content[6:] - pycode = True - elif content.lower().startswith("py"): - content = content[2:] - pycode = True - - if pycode: - content = content.splitlines(keepends=True) - - # Check if there might be code in the first line, and preserve it. - first_line = content[0] - if " " in content[0]: - first_space = first_line.index(" ") - content[0] = first_line[first_space:] - content = "".join(content) - - # If there's no code we can just get rid of the first line. - else: - content = "".join(content[1:]) - - # Strip it again to remove any leading whitespace. This is neccessary - # if the first line of the message looked like ```python - old = content.strip() - - # Strips REPL code out of the message if there is any. - content, repl_code = self.repl_stripping(old) - if old != content: - return (content, old), repl_code - - # Try to apply indentation fixes to the code. - content = self.fix_indentation(content) - - # Check if the code contains backticks, if it does ignore the message. - if "`" in content: - log.trace("Detected ` inside the code, won't reply") - return None - else: - log.trace(f"Returning message.\n\n{content}\n\n") - return (content,), repl_code - - def fix_indentation(self, msg: str) -> str: - """Attempts to fix badly indented code.""" - def unindent(code: str, skip_spaces: int = 0) -> str: - """Unindents all code down to the number of spaces given in skip_spaces.""" - final = "" - current = code[0] - leading_spaces = 0 - - # Get numbers of spaces before code in the first line. - while current == " ": - current = code[leading_spaces + 1] - leading_spaces += 1 - leading_spaces -= skip_spaces - - # If there are any, remove that number of spaces from every line. - if leading_spaces > 0: - for line in code.splitlines(keepends=True): - line = line[leading_spaces:] - final += line - return final - else: - return code - - # Apply fix for "all lines are overindented" case. - msg = unindent(msg) - - # If the first line does not end with a colon, we can be - # certain the next line will be on the same indentation level. - # - # If it does end with a colon, we will need to indent all successive - # lines one additional level. - first_line = msg.splitlines()[0] - code = "".join(msg.splitlines(keepends=True)[1:]) - if not first_line.endswith(":"): - msg = f"{first_line}\n{unindent(code)}" - else: - msg = f"{first_line}\n{unindent(code, 4)}" - return msg - - def repl_stripping(self, msg: str) -> Tuple[str, bool]: - """ - Strip msg in order to extract Python code out of REPL output. - - Tries to strip out REPL Python code out of msg and returns the stripped msg. - - Returns True for the boolean if REPL code was found in the input msg. - """ - final = "" - for line in msg.splitlines(keepends=True): - if line.startswith(">>>") or line.startswith("..."): - final += line[4:] - log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") - if not final: - log.trace(f"Found no REPL code in \n\n{msg}\n\n") - return msg, False - else: - log.trace(f"Found REPL code in \n\n{msg}\n\n") - return final.rstrip(), True - - def has_bad_ticks(self, msg: Message) -> bool: - """Check to see if msg contains ticks that aren't '`'.""" - not_backticks = [ - "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", - "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", - "\u3003\u3003\u3003" - ] - - return msg.content[:3] in not_backticks - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """ - Detect poorly formatted Python code in new messages. - - If poorly formatted code is detected, send the user a helpful message explaining how to do - properly formatted Python syntax highlighting codeblocks. - """ - is_help_channel = ( - getattr(msg.channel, "category", None) - and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) - ) - parse_codeblock = ( - ( - is_help_channel - or msg.channel.id in self.channel_cooldowns - or msg.channel.id in self.channel_whitelist - ) - and not msg.author.bot - and len(msg.content.splitlines()) > 3 - and not TokenRemover.find_token_in_message(msg) - ) - - if parse_codeblock: # no token in the msg - on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 - if not on_cooldown or DEBUG_MODE: - try: - if self.has_bad_ticks(msg): - ticks = msg.content[:3] - content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) - if content is None: - return - - content, repl_code = content - - if len(content) == 2: - content = content[1] - else: - content = content[0] - - space_left = 204 - if len(content) >= space_left: - current_length = 0 - lines_walked = 0 - for line in content.splitlines(keepends=True): - if current_length + len(line) > space_left or lines_walked == 10: - break - current_length += len(line) - lines_walked += 1 - content = content[:current_length] + "#..." - content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) - howto = ( - "It looks like you are trying to paste code into this channel.\n\n" - "You seem to be using the wrong symbols to indicate where the codeblock should start. " - f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" - "**Here is an example of how it should look:**\n" - f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" - "**This will result in the following:**\n" - f"```python\n{content}\n```" - ) - - else: - howto = "" - content = self.codeblock_stripping(msg.content, False) - if content is None: - return - - content, repl_code = content - # Attempts to parse the message into an AST node. - # Invalid Python code will raise a SyntaxError. - tree = ast.parse(content[0]) - - # Multiple lines of single words could be interpreted as expressions. - # This check is to avoid all nodes being parsed as expressions. - # (e.g. words over multiple lines) - if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: - # Shorten the code to 10 lines and/or 204 characters. - space_left = 204 - if content and repl_code: - content = content[1] - else: - content = content[0] - - if len(content) >= space_left: - current_length = 0 - lines_walked = 0 - for line in content.splitlines(keepends=True): - if current_length + len(line) > space_left or lines_walked == 10: - break - current_length += len(line) - lines_walked += 1 - content = content[:current_length] + "#..." - - content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) - howto += ( - "It looks like you're trying to paste code into this channel.\n\n" - "Discord has support for Markdown, which allows you to post code with full " - "syntax highlighting. Please use these whenever you paste code, as this " - "helps improve the legibility and makes it easier for us to help you.\n\n" - f"**To do this, use the following method:**\n" - f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" - "**This will result in the following:**\n" - f"```python\n{content}\n```" - ) - - log.debug(f"{msg.author} posted something that needed to be put inside python code " - "blocks. Sending the user some instructions.") - else: - log.trace("The code consists only of expressions, not sending instructions") - - if howto != "": - # Increase amount of codeblock correction in stats - self.bot.stats.incr("codeblock_corrections") - howto_embed = Embed(description=howto) - bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) - self.codeblock_message_ids[msg.id] = bot_message.id - - self.bot.loop.create_task( - wait_for_deletion(bot_message, user_ids=(msg.author.id,), client=self.bot) - ) - else: - return - - if msg.channel.id not in self.channel_whitelist: - self.channel_cooldowns[msg.channel.id] = time.time() - - except SyntaxError: - log.trace( - f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " - "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " - f"The message that was posted was:\n\n{msg.content}\n\n" - ) - - @Cog.listener() - async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: - """Check to see if an edited message (previously called out) still contains poorly formatted code.""" - if ( - # Checks to see if the message was called out by the bot - payload.message_id not in self.codeblock_message_ids - # Makes sure that there is content in the message - or payload.data.get("content") is None - # Makes sure there's a channel id in the message payload - or payload.data.get("channel_id") is None - ): - return - - # Retrieve channel and message objects for use later - channel = self.bot.get_channel(int(payload.data.get("channel_id"))) - user_message = await channel.fetch_message(payload.message_id) - - # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None - has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) - - # If the message is fixed, delete the bot message and the entry from the id dictionary - if has_fixed_codeblock is None: - bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) - await bot_message.delete() - del self.codeblock_message_ids[payload.message_id] - log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") - - -def setup(bot: Bot) -> None: - """Load the Bot cog.""" - bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py deleted file mode 100644 index f436e531a..000000000 --- a/bot/cogs/clean.py +++ /dev/null @@ -1,272 +0,0 @@ -import logging -import random -import re -from typing import Iterable, Optional - -from discord import Colour, Embed, Message, TextChannel, User -from discord.ext import commands -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import ( - Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES -) -from bot.decorators import with_role - -log = logging.getLogger(__name__) - - -class Clean(Cog): - """ - A cog that allows messages to be deleted in bulk, while applying various filters. - - You can delete messages sent by a specific user, messages sent by bots, all messages, or messages that match a - specific regular expression. - - The deleted messages are saved and uploaded to the database via an API endpoint, and a URL is returned which can be - used to view the messages in the Discord dark theme style. - """ - - def __init__(self, bot: Bot): - self.bot = bot - self.cleaning = False - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def _clean_messages( - self, - amount: int, - ctx: Context, - channels: Iterable[TextChannel], - bots_only: bool = False, - user: User = None, - regex: Optional[str] = None, - until_message: Optional[Message] = None, - ) -> None: - """A helper function that does the actual message cleaning.""" - def predicate_bots_only(message: Message) -> bool: - """Return True if the message was sent by a bot.""" - return message.author.bot - - def predicate_specific_user(message: Message) -> bool: - """Return True if the message was sent by the user provided in the _clean_messages call.""" - return message.author == user - - def predicate_regex(message: Message) -> bool: - """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" - content = [message.content] - - # Add the content for all embed attributes - for embed in message.embeds: - content.append(embed.title) - content.append(embed.description) - content.append(embed.footer.text) - content.append(embed.author.name) - for field in embed.fields: - content.append(field.name) - content.append(field.value) - - # Get rid of empty attributes and turn it into a string - content = [attr for attr in content if attr] - content = "\n".join(content) - - # Now let's see if there's a regex match - if not content: - return False - else: - return bool(re.search(regex.lower(), content.lower())) - - # Is this an acceptable amount of messages to clean? - if amount > CleanMessages.message_limit: - embed = Embed( - color=Colour(Colours.soft_red), - title=random.choice(NEGATIVE_REPLIES), - description=f"You cannot clean more than {CleanMessages.message_limit} messages." - ) - await ctx.send(embed=embed) - return - - # Are we already performing a clean? - if self.cleaning: - embed = Embed( - color=Colour(Colours.soft_red), - title=random.choice(NEGATIVE_REPLIES), - description="Please wait for the currently ongoing clean operation to complete." - ) - await ctx.send(embed=embed) - return - - # Set up the correct predicate - if bots_only: - predicate = predicate_bots_only # Delete messages from bots - elif user: - predicate = predicate_specific_user # Delete messages from specific user - elif regex: - predicate = predicate_regex # Delete messages that match regex - else: - predicate = None # Delete all messages - - # Default to using the invoking context's channel - if not channels: - channels = [ctx.channel] - - # Delete the invocation first - self.mod_log.ignore(Event.message_delete, ctx.message.id) - await ctx.message.delete() - - messages = [] - message_ids = [] - self.cleaning = True - - # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. - for channel in channels: - async for message in channel.history(limit=amount): - - # If at any point the cancel command is invoked, we should stop. - if not self.cleaning: - return - - # If we are looking for specific message. - if until_message: - - # we could use ID's here however in case if the message we are looking for gets deleted, - # we won't have a way to figure that out thus checking for datetime should be more reliable - if message.created_at < until_message.created_at: - # means we have found the message until which we were supposed to be deleting. - break - - # Since we will be using `delete_messages` method of a TextChannel and we need message objects to - # use it as well as to send logs we will start appending messages here instead adding them from - # purge. - messages.append(message) - - # If the message passes predicate, let's save it. - if predicate is None or predicate(message): - message_ids.append(message.id) - - self.cleaning = False - - # Now let's delete the actual messages with purge. - self.mod_log.ignore(Event.message_delete, *message_ids) - for channel in channels: - if until_message: - for i in range(0, len(messages), 100): - # while purge automatically handles the amount of messages - # delete_messages only allows for up to 100 messages at once - # thus we need to paginate the amount to always be <= 100 - await channel.delete_messages(messages[i:i + 100]) - else: - messages += await channel.purge(limit=amount, check=predicate) - - # Reverse the list to restore chronological order - if messages: - messages = reversed(messages) - log_url = await self.mod_log.upload_log(messages, ctx.author.id) - else: - # Can't build an embed, nothing to clean! - embed = Embed( - color=Colour(Colours.soft_red), - description="No matching messages could be found." - ) - await ctx.send(embed=embed, delete_after=10) - return - - # Build the embed and send it - target_channels = ", ".join(channel.mention for channel in channels) - - message = ( - f"**{len(message_ids)}** messages deleted in {target_channels} by **{ctx.author.name}**\n\n" - f"A log of the deleted messages can be found [here]({log_url})." - ) - - await self.mod_log.send_log_message( - icon_url=Icons.message_bulk_delete, - colour=Colour(Colours.soft_red), - title="Bulk message delete", - text=message, - channel_id=Channels.mod_log, - ) - - @group(invoke_without_command=True, name="clean", aliases=["purge"]) - @with_role(*MODERATION_ROLES) - async def clean_group(self, ctx: Context) -> None: - """Commands for cleaning messages in channels.""" - await ctx.send_help(ctx.command) - - @clean_group.command(name="user", aliases=["users"]) - @with_role(*MODERATION_ROLES) - async def clean_user( - self, - ctx: Context, - user: User, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, user=user, channels=channels) - - @clean_group.command(name="all", aliases=["everything"]) - @with_role(*MODERATION_ROLES) - async def clean_all( - self, - ctx: Context, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, channels=channels) - - @clean_group.command(name="bots", aliases=["bot"]) - @with_role(*MODERATION_ROLES) - async def clean_bots( - self, - ctx: Context, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, bots_only=True, channels=channels) - - @clean_group.command(name="regex", aliases=["word", "expression"]) - @with_role(*MODERATION_ROLES) - async def clean_regex( - self, - ctx: Context, - regex: str, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, regex=regex, channels=channels) - - @clean_group.command(name="message", aliases=["messages"]) - @with_role(*MODERATION_ROLES) - async def clean_message(self, ctx: Context, message: Message) -> None: - """Delete all messages until certain message, stop cleaning after hitting the `message`.""" - await self._clean_messages( - CleanMessages.message_limit, - ctx, - channels=[message.channel], - until_message=message - ) - - @clean_group.command(name="stop", aliases=["cancel", "abort"]) - @with_role(*MODERATION_ROLES) - async def clean_cancel(self, ctx: Context) -> None: - """If there is an ongoing cleaning process, attempt to immediately cancel it.""" - self.cleaning = False - - embed = Embed( - color=Colour.blurple(), - description="Clean interrupted." - ) - await ctx.send(embed=embed, delete_after=10) - - -def setup(bot: Bot) -> None: - """Load the Clean cog.""" - bot.add_cog(Clean(bot)) diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py deleted file mode 100644 index d72c6c22e..000000000 --- a/bot/cogs/config_verifier.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging - -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot - - -log = logging.getLogger(__name__) - - -class ConfigVerifier(Cog): - """Verify config on startup.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) - - async def verify_channels(self) -> None: - """ - Verify channels. - - If any channels in config aren't present in server, log them in a warning. - """ - await self.bot.wait_until_guild_available() - server = self.bot.get_guild(constants.Guild.id) - - server_channel_ids = {channel.id for channel in server.channels} - invalid_channels = [ - channel_name for channel_name, channel_id in constants.Channels - if channel_id not in server_channel_ids - ] - - if invalid_channels: - log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") - - -def setup(bot: Bot) -> None: - """Load the ConfigVerifier cog.""" - bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py deleted file mode 100644 index 4c0ad5914..000000000 --- a/bot/cogs/defcon.py +++ /dev/null @@ -1,258 +0,0 @@ -from __future__ import annotations - -import logging -from collections import namedtuple -from datetime import datetime, timedelta -from enum import Enum - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles -from bot.decorators import with_role - -log = logging.getLogger(__name__) - -REJECTION_MESSAGE = """ -Hi, {user} - Thanks for your interest in our server! - -Due to a current (or detected) cyberattack on our community, we've limited access to the server for new accounts. Since -your account is relatively new, we're unable to provide access to the server at this time. - -Even so, thanks for joining! We're very excited at the possibility of having you here, and we hope that this situation -will be resolved soon. In the meantime, please feel free to peruse the resources on our site at -, and have a nice day! -""" - -BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism" - - -class Action(Enum): - """Defcon Action.""" - - ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template']) - - ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n") - DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "") - UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n") - - -class Defcon(Cog): - """Time-sensitive server defense mechanisms.""" - - days = None # type: timedelta - enabled = False # type: bool - - def __init__(self, bot: Bot): - self.bot = bot - self.channel = None - self.days = timedelta(days=0) - - self.bot.loop.create_task(self.sync_settings()) - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def sync_settings(self) -> None: - """On cog load, try to synchronize DEFCON settings to the API.""" - await self.bot.wait_until_guild_available() - self.channel = await self.bot.fetch_channel(Channels.defcon) - - try: - response = await self.bot.api_client.get('bot/bot-settings/defcon') - data = response['data'] - - except Exception: # Yikes! - log.exception("Unable to get DEFCON settings!") - await self.bot.get_channel(Channels.dev_log).send( - f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" - ) - - else: - if data["enabled"]: - self.enabled = True - self.days = timedelta(days=data["days"]) - log.info(f"DEFCON enabled: {self.days.days} days") - - else: - self.enabled = False - self.days = timedelta(days=0) - log.info("DEFCON disabled") - - await self.update_channel_topic() - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """If DEFCON is enabled, check newly joining users to see if they meet the account age threshold.""" - if self.enabled and self.days.days > 0: - now = datetime.utcnow() - - if now - member.created_at < self.days: - log.info(f"Rejecting user {member}: Account is too new and DEFCON is enabled") - - message_sent = False - - try: - await member.send(REJECTION_MESSAGE.format(user=member.mention)) - - message_sent = True - except Exception: - log.exception(f"Unable to send rejection message to user: {member}") - - await member.kick(reason="DEFCON active, user is too new") - self.bot.stats.incr("defcon.leaves") - - message = ( - f"{member} (`{member.id}`) was denied entry because their account is too new." - ) - - if not message_sent: - message = f"{message}\n\nUnable to send rejection message via DM; they probably have DMs disabled." - - await self.mod_log.send_log_message( - Icons.defcon_denied, Colours.soft_red, "Entry denied", - message, member.avatar_url_as(static_format="png") - ) - - @group(name='defcon', aliases=('dc',), invoke_without_command=True) - @with_role(Roles.admins, Roles.owners) - async def defcon_group(self, ctx: Context) -> None: - """Check the DEFCON status or run a subcommand.""" - await ctx.send_help(ctx.command) - - async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: - """Providing a structured way to do an defcon action.""" - try: - response = await self.bot.api_client.get('bot/bot-settings/defcon') - data = response['data'] - - if "enable_date" in data and action is Action.DISABLED: - enabled = datetime.fromisoformat(data["enable_date"]) - - delta = datetime.now() - enabled - - self.bot.stats.timing("defcon.enabled", delta) - except Exception: - pass - - error = None - try: - await self.bot.api_client.put( - 'bot/bot-settings/defcon', - json={ - 'name': 'defcon', - 'data': { - # TODO: retrieve old days count - 'days': days, - 'enabled': action is not Action.DISABLED, - 'enable_date': datetime.now().isoformat() - } - } - ) - except Exception as err: - log.exception("Unable to update DEFCON settings.") - error = err - finally: - await ctx.send(self.build_defcon_msg(action, error)) - await self.send_defcon_log(action, ctx.author, error) - - self.bot.stats.gauge("defcon.threshold", days) - - @defcon_group.command(name='enable', aliases=('on', 'e')) - @with_role(Roles.admins, Roles.owners) - async def enable_command(self, ctx: Context) -> None: - """ - Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! - - Currently, this just adds an account age requirement. Use !defcon days to set how old an account must be, - in days. - """ - self.enabled = True - await self._defcon_action(ctx, days=0, action=Action.ENABLED) - await self.update_channel_topic() - - @defcon_group.command(name='disable', aliases=('off', 'd')) - @with_role(Roles.admins, Roles.owners) - async def disable_command(self, ctx: Context) -> None: - """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" - self.enabled = False - await self._defcon_action(ctx, days=0, action=Action.DISABLED) - await self.update_channel_topic() - - @defcon_group.command(name='status', aliases=('s',)) - @with_role(Roles.admins, Roles.owners) - async def status_command(self, ctx: Context) -> None: - """Check the current status of DEFCON mode.""" - embed = Embed( - colour=Colour.blurple(), title="DEFCON Status", - description=f"**Enabled:** {self.enabled}\n" - f"**Days:** {self.days.days}" - ) - - await ctx.send(embed=embed) - - @defcon_group.command(name='days') - @with_role(Roles.admins, Roles.owners) - async def days_command(self, ctx: Context, days: int) -> None: - """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" - self.days = timedelta(days=days) - self.enabled = True - await self._defcon_action(ctx, days=days, action=Action.UPDATED) - await self.update_channel_topic() - - async def update_channel_topic(self) -> None: - """Update the #defcon channel topic with the current DEFCON status.""" - if self.enabled: - day_str = "days" if self.days.days > 1 else "day" - new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Enabled, Threshold: {self.days.days} {day_str})" - else: - new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Disabled)" - - self.mod_log.ignore(Event.guild_channel_update, Channels.defcon) - await self.channel.edit(topic=new_topic) - - def build_defcon_msg(self, action: Action, e: Exception = None) -> str: - """Build in-channel response string for DEFCON action.""" - if action is Action.ENABLED: - msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n" - elif action is Action.DISABLED: - msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n" - elif action is Action.UPDATED: - msg = ( - f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} " - f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n" - ) - - if e: - msg += ( - "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" - f"```py\n{e}\n```" - ) - - return msg - - async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None: - """Send log message for DEFCON action.""" - info = action.value - log_msg: str = ( - f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n" - f"{info.template.format(days=self.days.days)}" - ) - status_msg = f"DEFCON {action.name.lower()}" - - if e: - log_msg += ( - "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" - f"```py\n{e}\n```" - ) - - await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg) - - -def setup(bot: Bot) -> None: - """Load the Defcon cog.""" - bot.add_cog(Defcon(bot)) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py deleted file mode 100644 index 204cffb37..000000000 --- a/bot/cogs/doc.py +++ /dev/null @@ -1,511 +0,0 @@ -import asyncio -import functools -import logging -import re -import textwrap -from collections import OrderedDict -from contextlib import suppress -from types import SimpleNamespace -from typing import Any, Callable, Optional, Tuple - -import discord -from bs4 import BeautifulSoup -from bs4.element import PageElement, Tag -from discord.errors import NotFound -from discord.ext import commands -from markdownify import MarkdownConverter -from requests import ConnectTimeout, ConnectionError, HTTPError -from sphinx.ext import intersphinx -from urllib3.exceptions import ProtocolError - -from bot.bot import Bot -from bot.constants import MODERATION_ROLES, RedirectOutput -from bot.converters import ValidPythonIdentifier, ValidURL -from bot.decorators import with_role -from bot.pagination import LinePaginator - - -log = logging.getLogger(__name__) -logging.getLogger('urllib3').setLevel(logging.WARNING) - -# Since Intersphinx is intended to be used with Sphinx, -# we need to mock its configuration. -SPHINX_MOCK_APP = SimpleNamespace( - config=SimpleNamespace( - intersphinx_timeout=3, - tls_verify=True, - user_agent="python3:python-discord/bot:1.0.0" - ) -) - -NO_OVERRIDE_GROUPS = ( - "2to3fixer", - "token", - "label", - "pdbcommand", - "term", -) -NO_OVERRIDE_PACKAGES = ( - "python", -) - -SEARCH_END_TAG_ATTRS = ( - "data", - "function", - "class", - "exception", - "seealso", - "section", - "rubric", - "sphinxsidebar", -) -UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") -WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") - -FAILED_REQUEST_RETRY_AMOUNT = 3 -NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay - - -def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: - """ - LRU cache implementation for coroutines. - - Once the cache exceeds the maximum size, keys are deleted in FIFO order. - - An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. - """ - # Assign the cache to the function itself so we can clear it from outside. - async_cache.cache = OrderedDict() - - def decorator(function: Callable) -> Callable: - """Define the async_cache decorator.""" - @functools.wraps(function) - async def wrapper(*args) -> Any: - """Decorator wrapper for the caching logic.""" - key = ':'.join(args[arg_offset:]) - - value = async_cache.cache.get(key) - if value is None: - if len(async_cache.cache) > max_size: - async_cache.cache.popitem(last=False) - - async_cache.cache[key] = await function(*args) - return async_cache.cache[key] - return wrapper - return decorator - - -class DocMarkdownConverter(MarkdownConverter): - """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" - - def convert_code(self, el: PageElement, text: str) -> str: - """Undo `markdownify`s underscore escaping.""" - return f"`{text}`".replace('\\', '') - - def convert_pre(self, el: PageElement, text: str) -> str: - """Wrap any codeblocks in `py` for syntax highlighting.""" - code = ''.join(el.strings) - return f"```py\n{code}```" - - -def markdownify(html: str) -> DocMarkdownConverter: - """Create a DocMarkdownConverter object from the input html.""" - return DocMarkdownConverter(bullets='•').convert(html) - - -class InventoryURL(commands.Converter): - """ - Represents an Intersphinx inventory URL. - - This converter checks whether intersphinx accepts the given inventory URL, and raises - `BadArgument` if that is not the case. - - Otherwise, it simply passes through the given URL. - """ - - @staticmethod - async def convert(ctx: commands.Context, url: str) -> str: - """Convert url to Intersphinx inventory URL.""" - try: - intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) - except AttributeError: - raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") - except ConnectionError: - if url.startswith('https'): - raise commands.BadArgument( - f"Cannot establish a connection to `{url}`. Does it support HTTPS?" - ) - raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") - except ValueError: - raise commands.BadArgument( - f"Failed to read Intersphinx inventory from URL `{url}`. " - "Are you sure that it's a valid inventory file?" - ) - return url - - -class Doc(commands.Cog): - """A set of commands for querying & displaying documentation.""" - - def __init__(self, bot: Bot): - self.base_urls = {} - self.bot = bot - self.inventories = {} - self.renamed_symbols = set() - - self.bot.loop.create_task(self.init_refresh_inventory()) - - async def init_refresh_inventory(self) -> None: - """Refresh documentation inventory on cog initialization.""" - await self.bot.wait_until_guild_available() - await self.refresh_inventory() - - async def update_single( - self, package_name: str, base_url: str, inventory_url: str - ) -> None: - """ - Rebuild the inventory for a single package. - - Where: - * `package_name` is the package name to use, appears in the log - * `base_url` is the root documentation URL for the specified package, used to build - absolute paths that link to specific symbols - * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running - `intersphinx.fetch_inventory` in an executor on the bot's event loop - """ - self.base_urls[package_name] = base_url - - package = await self._fetch_inventory(inventory_url) - if not package: - return None - - for group, value in package.items(): - for symbol, (package_name, _version, relative_doc_url, _) in value.items(): - absolute_doc_url = base_url + relative_doc_url - - if symbol in self.inventories: - group_name = group.split(":")[1] - symbol_base_url = self.inventories[symbol].split("/", 3)[2] - if ( - group_name in NO_OVERRIDE_GROUPS - or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) - ): - - symbol = f"{group_name}.{symbol}" - # If renamed `symbol` already exists, add library name in front to differentiate between them. - if symbol in self.renamed_symbols: - # Split `package_name` because of packages like Pillow that have spaces in them. - symbol = f"{package_name.split()[0]}.{symbol}" - - self.inventories[symbol] = absolute_doc_url - self.renamed_symbols.add(symbol) - continue - - self.inventories[symbol] = absolute_doc_url - - log.trace(f"Fetched inventory for {package_name}.") - - async def refresh_inventory(self) -> None: - """Refresh internal documentation inventory.""" - log.debug("Refreshing documentation inventory...") - - # Clear the old base URLS and inventories to ensure - # that we start from a fresh local dataset. - # Also, reset the cache used for fetching documentation. - self.base_urls.clear() - self.inventories.clear() - self.renamed_symbols.clear() - async_cache.cache = OrderedDict() - - # Run all coroutines concurrently - since each of them performs a HTTP - # request, this speeds up fetching the inventory data heavily. - coros = [ - self.update_single( - package["package"], package["base_url"], package["inventory_url"] - ) for package in await self.bot.api_client.get('bot/documentation-links') - ] - await asyncio.gather(*coros) - - async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: - """ - Given a Python symbol, return its signature and description. - - The first tuple element is the signature of the given symbol as a markup-free string, and - the second tuple element is the description of the given symbol with HTML markup included. - - If the given symbol is a module, returns a tuple `(None, str)` - else if the symbol could not be found, returns `None`. - """ - url = self.inventories.get(symbol) - if url is None: - return None - - async with self.bot.http_session.get(url) as response: - html = await response.text(encoding='utf-8') - - # Find the signature header and parse the relevant parts. - symbol_id = url.split('#')[-1] - soup = BeautifulSoup(html, 'lxml') - symbol_heading = soup.find(id=symbol_id) - search_html = str(soup) - - if symbol_heading is None: - return None - - if symbol_id == f"module-{symbol}": - # Get page content from the module headerlink to the - # first tag that has its class in `SEARCH_END_TAG_ATTRS` - start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) - if start_tag is None: - return [], "" - - end_tag = start_tag.find_next(self._match_end_tag) - if end_tag is None: - return [], "" - - description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) - description_end_index = search_html.find(str(end_tag)) - description = search_html[description_start_index:description_end_index] - signatures = None - - else: - signatures = [] - description = str(symbol_heading.find_next_sibling("dd")) - description_pos = search_html.find(description) - # Get text of up to 3 signatures, remove unwanted symbols - for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): - signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) - if signature and search_html.find(str(element)) < description_pos: - signatures.append(signature) - - return signatures, description.replace('¶', '') - - @async_cache(arg_offset=1) - async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: - """ - Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. - - If the symbol is known, an Embed with documentation about it is returned. - """ - scraped_html = await self.get_symbol_html(symbol) - if scraped_html is None: - return None - - signatures = scraped_html[0] - permalink = self.inventories[symbol] - description = markdownify(scraped_html[1]) - - # Truncate the description of the embed to the last occurrence - # of a double newline (interpreted as a paragraph) before index 1000. - if len(description) > 1000: - shortened = description[:1000] - description_cutoff = shortened.rfind('\n\n', 100) - if description_cutoff == -1: - # Search the shortened version for cutoff points in decreasing desirability, - # cutoff at 1000 if none are found. - for string in (". ", ", ", ",", " "): - description_cutoff = shortened.rfind(string) - if description_cutoff != -1: - break - else: - description_cutoff = 1000 - description = description[:description_cutoff] - - # If there is an incomplete code block, cut it out - if description.count("```") % 2: - codeblock_start = description.rfind('```py') - description = description[:codeblock_start].rstrip() - description += f"... [read more]({permalink})" - - description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) - if signatures is None: - # If symbol is a module, don't show signature. - embed_description = description - - elif not signatures: - # It's some "meta-page", for example: - # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views - embed_description = "This appears to be a generic page not tied to a specific symbol." - - else: - embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) - embed_description += f"\n{description}" - - embed = discord.Embed( - title=f'`{symbol}`', - url=permalink, - description=embed_description - ) - # Show all symbols with the same name that were renamed in the footer. - embed.set_footer( - text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) - ) - return embed - - @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) - async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: - """Lookup documentation for Python symbols.""" - await ctx.invoke(self.get_command, symbol) - - @docs_group.command(name='get', aliases=('g',)) - async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: - """ - Return a documentation embed for a given symbol. - - If no symbol is given, return a list of all available inventories. - - Examples: - !docs - !docs aiohttp - !docs aiohttp.ClientSession - !docs get aiohttp.ClientSession - """ - if symbol is None: - inventory_embed = discord.Embed( - title=f"All inventories (`{len(self.base_urls)}` total)", - colour=discord.Colour.blue() - ) - - lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) - if self.base_urls: - await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) - - else: - inventory_embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=inventory_embed) - - else: - # Fetching documentation for a symbol (at least for the first time, since - # caching is used) takes quite some time, so let's send typing to indicate - # that we got the command, but are still working on it. - async with ctx.typing(): - doc_embed = await self.get_symbol_embed(symbol) - - if doc_embed is None: - error_embed = discord.Embed( - description=f"Sorry, I could not find any documentation for `{symbol}`.", - colour=discord.Colour.red() - ) - error_message = await ctx.send(embed=error_embed) - with suppress(NotFound): - await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) - await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) - else: - await ctx.send(embed=doc_embed) - - @docs_group.command(name='set', aliases=('s',)) - @with_role(*MODERATION_ROLES) - async def set_command( - self, ctx: commands.Context, package_name: ValidPythonIdentifier, - base_url: ValidURL, inventory_url: InventoryURL - ) -> None: - """ - Adds a new documentation metadata object to the site's database. - - The database will update the object, should an existing item with the specified `package_name` already exist. - - Example: - !docs set \ - python \ - https://docs.python.org/3/ \ - https://docs.python.org/3/objects.inv - """ - body = { - 'package': package_name, - 'base_url': base_url, - 'inventory_url': inventory_url - } - await self.bot.api_client.post('bot/documentation-links', json=body) - - log.info( - f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" - f"Package name: {package_name}\n" - f"Base url: {base_url}\n" - f"Inventory URL: {inventory_url}" - ) - - # Rebuilding the inventory can take some time, so lets send out a - # typing event to show that the Bot is still working. - async with ctx.typing(): - await self.refresh_inventory() - await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") - - @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) - @with_role(*MODERATION_ROLES) - async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: - """ - Removes the specified package from the database. - - Examples: - !docs delete aiohttp - """ - await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') - - async with ctx.typing(): - # Rebuild the inventory to ensure that everything - # that was from this package is properly deleted. - await self.refresh_inventory() - await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") - - @docs_group.command(name="refresh", aliases=("rfsh", "r")) - @with_role(*MODERATION_ROLES) - async def refresh_command(self, ctx: commands.Context) -> None: - """Refresh inventories and send differences to channel.""" - old_inventories = set(self.base_urls) - with ctx.typing(): - await self.refresh_inventory() - # Get differences of added and removed inventories - added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) - if added: - added = f"+ {added}" - - removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) - if removed: - removed = f"- {removed}" - - embed = discord.Embed( - title="Inventories refreshed", - description=f"```diff\n{added}\n{removed}```" if added or removed else "" - ) - await ctx.send(embed=embed) - - async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: - """Get and return inventory from `inventory_url`. If fetching fails, return None.""" - fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) - for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): - try: - package = await self.bot.loop.run_in_executor(None, fetch_func) - except ConnectTimeout: - log.error( - f"Fetching of inventory {inventory_url} timed out," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" - ) - except ProtocolError: - log.error( - f"Connection lost while fetching inventory {inventory_url}," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" - ) - except HTTPError as e: - log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") - return None - except ConnectionError: - log.error(f"Couldn't establish connection to inventory {inventory_url}.") - return None - else: - return package - log.error(f"Fetching of inventory {inventory_url} failed.") - return None - - @staticmethod - def _match_end_tag(tag: Tag) -> bool: - """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" - for attr in SEARCH_END_TAG_ATTRS: - if attr in tag.get("class", ()): - return True - - return tag.name == "table" - - -def setup(bot: Bot) -> None: - """Load the Doc cog.""" - bot.add_cog(Doc(bot)) diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py deleted file mode 100644 index f9d4de638..000000000 --- a/bot/cogs/error_handler.py +++ /dev/null @@ -1,287 +0,0 @@ -import contextlib -import logging -import typing as t - -from discord import Embed -from discord.ext.commands import Cog, Context, errors -from sentry_sdk import push_scope - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, Colours -from bot.converters import TagNameConverter -from bot.utils.checks import InWhitelistCheckFailure - -log = logging.getLogger(__name__) - - -class ErrorHandler(Cog): - """Handles errors emitted from commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - def _get_error_embed(self, title: str, body: str) -> Embed: - """Return an embed that contains the exception.""" - return Embed( - title=title, - colour=Colours.soft_red, - description=body - ) - - @Cog.listener() - async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: - """ - Provide generic command error handling. - - Error handling is deferred to any local error handler, if present. This is done by - checking for the presence of a `handled` attribute on the error. - - Error handling emits a single error message in the invoking context `ctx` and a log message, - prioritised as follows: - - 1. If the name fails to match a command: - * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. - Otherwise if it matches a tag, the tag is invoked - * If CommandNotFound is raised when invoking the tag (determined by the presence of the - `invoked_from_error_handler` attribute), this error is treated as being unexpected - and therefore sends an error message - * Commands in the verification channel are ignored - 2. UserInputError: see `handle_user_input_error` - 3. CheckFailure: see `handle_check_failure` - 4. CommandOnCooldown: send an error message in the invoking context - 5. ResponseCodeError: see `handle_api_error` - 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` - """ - command = ctx.command - - if hasattr(e, "handled"): - log.trace(f"Command {command} had its error already handled locally; ignoring.") - return - - if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): - if await self.try_silence(ctx): - return - if ctx.channel.id != Channels.verification: - # Try to look for a tag with the command's name - await self.try_get_tag(ctx) - return # Exit early to avoid logging. - elif isinstance(e, errors.UserInputError): - await self.handle_user_input_error(ctx, e) - elif isinstance(e, errors.CheckFailure): - await self.handle_check_failure(ctx, e) - elif isinstance(e, errors.CommandOnCooldown): - await ctx.send(e) - elif isinstance(e, errors.CommandInvokeError): - if isinstance(e.original, ResponseCodeError): - await self.handle_api_error(ctx, e.original) - else: - await self.handle_unexpected_error(ctx, e.original) - return # Exit early to avoid logging. - elif not isinstance(e, errors.DisabledCommand): - # ConversionError, MaxConcurrencyReached, ExtensionError - await self.handle_unexpected_error(ctx, e) - return # Exit early to avoid logging. - - log.debug( - f"Command {command} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" - ) - - @staticmethod - def get_help_command(ctx: Context) -> t.Coroutine: - """Return a prepared `help` command invocation coroutine.""" - if ctx.command: - return ctx.send_help(ctx.command) - - return ctx.send_help() - - async def try_silence(self, ctx: Context) -> bool: - """ - Attempt to invoke the silence or unsilence command if invoke with matches a pattern. - - Respecting the checks if: - * invoked with `shh+` silence channel for amount of h's*2 with max of 15. - * invoked with `unshh+` unsilence channel - Return bool depending on success of command. - """ - command = ctx.invoked_with.lower() - silence_command = self.bot.get_command("silence") - ctx.invoked_from_error_handler = True - try: - if not await silence_command.can_run(ctx): - log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") - return False - except errors.CommandError: - log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") - return False - if command.startswith("shh"): - await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) - return True - elif command.startswith("unshh"): - await ctx.invoke(self.bot.get_command("unsilence")) - return True - return False - - async def try_get_tag(self, ctx: Context) -> None: - """ - Attempt to display a tag by interpreting the command name as a tag name. - - The invocation of tags get respects its checks. Any CommandErrors raised will be handled - by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to - the context to prevent infinite recursion in the case of a CommandNotFound exception. - """ - tags_get_command = self.bot.get_command("tags get") - ctx.invoked_from_error_handler = True - - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." - try: - if not await tags_get_command.can_run(ctx): - log.debug(log_msg) - return - except errors.CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - - try: - tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) - except errors.BadArgument: - log.debug( - f"{ctx.author} tried to use an invalid command " - f"and the fallback tag failed validation in TagNameConverter." - ) - else: - with contextlib.suppress(ResponseCodeError): - await ctx.invoke(tags_get_command, tag_name=tag_name) - # Return to not raise the exception - return - - async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: - """ - Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. - - * MissingRequiredArgument: send an error message with arg name and the help command - * TooManyArguments: send an error message and the help command - * BadArgument: send an error message and the help command - * BadUnionArgument: send an error message including the error produced by the last converter - * ArgumentParsingError: send an error message - * Other: send an error message and the help command - """ - prepared_help_command = self.get_help_command(ctx) - - if isinstance(e, errors.MissingRequiredArgument): - embed = self._get_error_embed("Missing required argument", e.param.name) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.missing_required_argument") - elif isinstance(e, errors.TooManyArguments): - embed = self._get_error_embed("Too many arguments", str(e)) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.too_many_arguments") - elif isinstance(e, errors.BadArgument): - embed = self._get_error_embed("Bad argument", str(e)) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.bad_argument") - elif isinstance(e, errors.BadUnionArgument): - embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") - await ctx.send(embed=embed) - self.bot.stats.incr("errors.bad_union_argument") - elif isinstance(e, errors.ArgumentParsingError): - embed = self._get_error_embed("Argument parsing error", str(e)) - await ctx.send(embed=embed) - self.bot.stats.incr("errors.argument_parsing_error") - else: - embed = self._get_error_embed( - "Input error", - "Something about your input seems off. Check the arguments and try again." - ) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.other_user_input_error") - - @staticmethod - async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: - """ - Send an error message in `ctx` for certain types of CheckFailure. - - The following types are handled: - - * BotMissingPermissions - * BotMissingRole - * BotMissingAnyRole - * NoPrivateMessage - * InWhitelistCheckFailure - """ - bot_missing_errors = ( - errors.BotMissingPermissions, - errors.BotMissingRole, - errors.BotMissingAnyRole - ) - - if isinstance(e, bot_missing_errors): - ctx.bot.stats.incr("errors.bot_permission_error") - await ctx.send( - "Sorry, it looks like I don't have the permissions or roles I need to do that." - ) - elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): - ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") - await ctx.send(e) - - @staticmethod - async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: - """Send an error message in `ctx` for ResponseCodeError and log it.""" - if e.status == 404: - await ctx.send("There does not seem to be anything matching your query.") - log.debug(f"API responded with 404 for command {ctx.command}") - ctx.bot.stats.incr("errors.api_error_404") - elif e.status == 400: - content = await e.response.json() - log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) - await ctx.send("According to the API, your request is malformed.") - ctx.bot.stats.incr("errors.api_error_400") - elif 500 <= e.status < 600: - await ctx.send("Sorry, there seems to be an internal issue with the API.") - log.warning(f"API responded with {e.status} for command {ctx.command}") - ctx.bot.stats.incr("errors.api_internal_server_error") - else: - await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") - log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") - ctx.bot.stats.incr(f"errors.api_error_{e.status}") - - @staticmethod - async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: - """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" - await ctx.send( - f"Sorry, an unexpected error occurred. Please let us know!\n\n" - f"```{e.__class__.__name__}: {e}```" - ) - - ctx.bot.stats.incr("errors.unexpected") - - with push_scope() as scope: - scope.user = { - "id": ctx.author.id, - "username": str(ctx.author) - } - - scope.set_tag("command", ctx.command.qualified_name) - scope.set_tag("message_id", ctx.message.id) - scope.set_tag("channel_id", ctx.channel.id) - - scope.set_extra("full_message", ctx.message.content) - - if ctx.guild is not None: - scope.set_extra( - "jump_to", - f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" - ) - - log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) - - -def setup(bot: Bot) -> None: - """Load the ErrorHandler cog.""" - bot.add_cog(ErrorHandler(bot)) diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py deleted file mode 100644 index eb8bfb1cf..000000000 --- a/bot/cogs/eval.py +++ /dev/null @@ -1,202 +0,0 @@ -import contextlib -import inspect -import logging -import pprint -import re -import textwrap -import traceback -from io import StringIO -from typing import Any, Optional, Tuple - -import discord -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import Roles -from bot.decorators import with_role -from bot.interpreter import Interpreter - -log = logging.getLogger(__name__) - - -class CodeEval(Cog): - """Owner and admin feature that evaluates code and returns the result to the channel.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.env = {} - self.ln = 0 - self.stdout = StringIO() - - self.interpreter = Interpreter(bot) - - def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: - """Format the eval output into a string & attempt to format it into an Embed.""" - self._ = out - - res = "" - - # Erase temp input we made - if inp.startswith("_ = "): - inp = inp[4:] - - # Get all non-empty lines - lines = [line for line in inp.split("\n") if line.strip()] - if len(lines) != 1: - lines += [""] - - # Create the input dialog - for i, line in enumerate(lines): - if i == 0: - # Start dialog - start = f"In [{self.ln}]: " - - else: - # Indent the 3 dots correctly; - # Normally, it's something like - # In [X]: - # ...: - # - # But if it's - # In [XX]: - # ...: - # - # You can see it doesn't look right. - # This code simply indents the dots - # far enough to align them. - # we first `str()` the line number - # then we get the length - # and use `str.rjust()` - # to indent it. - start = "...: ".rjust(len(str(self.ln)) + 7) - - if i == len(lines) - 2: - if line.startswith("return"): - line = line[6:].strip() - - # Combine everything - res += (start + line + "\n") - - self.stdout.seek(0) - text = self.stdout.read() - self.stdout.close() - self.stdout = StringIO() - - if text: - res += (text + "\n") - - if out is None: - # No output, return the input statement - return (res, None) - - res += f"Out[{self.ln}]: " - - if isinstance(out, discord.Embed): - # We made an embed? Send that as embed - res += "" - res = (res, out) - - else: - if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): - # Leave out the traceback message - out = "\n" + "\n".join(out.split("\n")[1:]) - - if isinstance(out, str): - pretty = out - else: - pretty = pprint.pformat(out, compact=True, width=60) - - if pretty != str(out): - # We're using the pretty version, start on the next line - res += "\n" - - if pretty.count("\n") > 20: - # Text too long, shorten - li = pretty.split("\n") - - pretty = ("\n".join(li[:3]) # First 3 lines - + "\n ...\n" # Ellipsis to indicate removed lines - + "\n".join(li[-3:])) # last 3 lines - - # Add the output - res += pretty - res = (res, None) - - return res # Return (text, embed) - - async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: - """Eval the input code string & send an embed to the invoking context.""" - self.ln += 1 - - if code.startswith("exit"): - self.ln = 0 - self.env = {} - return await ctx.send("```Reset history!```") - - env = { - "message": ctx.message, - "author": ctx.message.author, - "channel": ctx.channel, - "guild": ctx.guild, - "ctx": ctx, - "self": self, - "bot": self.bot, - "inspect": inspect, - "discord": discord, - "contextlib": contextlib - } - - self.env.update(env) - - # Ignore this code, it works - code_ = """ -async def func(): # (None,) -> Any - try: - with contextlib.redirect_stdout(self.stdout): -{0} - if '_' in locals(): - if inspect.isawaitable(_): - _ = await _ - return _ - finally: - self.env.update(locals()) -""".format(textwrap.indent(code, ' ')) - - try: - exec(code_, self.env) # noqa: B102,S102 - func = self.env['func'] - res = await func() - - except Exception: - res = traceback.format_exc() - - out, embed = self._format(code, res) - await ctx.send(f"```py\n{out}```", embed=embed) - - @group(name='internal', aliases=('int',)) - @with_role(Roles.owners, Roles.admins) - async def internal_group(self, ctx: Context) -> None: - """Internal commands. Top secret!""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @internal_group.command(name='eval', aliases=('e',)) - @with_role(Roles.admins, Roles.owners) - async def eval(self, ctx: Context, *, code: str) -> None: - """Run eval in a REPL-like format.""" - code = code.strip("`") - if re.match('py(thon)?\n', code): - code = "\n".join(code.split("\n")[1:]) - - if not re.search( # Check if it's an expression - r"^(return|import|for|while|def|class|" - r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( - code.split("\n")) == 1: - code = "_ = " + code - - await self._eval(ctx, code) - - -def setup(bot: Bot) -> None: - """Load the CodeEval cog.""" - bot.add_cog(CodeEval(bot)) diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py deleted file mode 100644 index 365f198ff..000000000 --- a/bot/cogs/extensions.py +++ /dev/null @@ -1,236 +0,0 @@ -import functools -import logging -import typing as t -from enum import Enum -from pkgutil import iter_modules - -from discord import Colour, Embed -from discord.ext import commands -from discord.ext.commands import Context, group - -from bot.bot import Bot -from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - -UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"} -EXTENSIONS = frozenset( - ext.name - for ext in iter_modules(("bot/cogs",), "bot.cogs.") - if ext.name[-1] != "_" -) - - -class Action(Enum): - """Represents an action to perform on an extension.""" - - # Need to be partial otherwise they are considered to be function definitions. - LOAD = functools.partial(Bot.load_extension) - UNLOAD = functools.partial(Bot.unload_extension) - RELOAD = functools.partial(Bot.reload_extension) - - -class Extension(commands.Converter): - """ - Fully qualify the name of an extension and ensure it exists. - - The * and ** values bypass this when used with the reload command. - """ - - async def convert(self, ctx: Context, argument: str) -> str: - """Fully qualify the name of an extension and ensure it exists.""" - # Special values to reload all extensions - if argument == "*" or argument == "**": - return argument - - argument = argument.lower() - - if "." not in argument: - argument = f"bot.cogs.{argument}" - - if argument in EXTENSIONS: - return argument - else: - raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") - - -class Extensions(commands.Cog): - """Extension management commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) - async def extensions_group(self, ctx: Context) -> None: - """Load, unload, reload, and list loaded extensions.""" - await ctx.send_help(ctx.command) - - @extensions_group.command(name="load", aliases=("l",)) - async def load_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Load extensions given their fully qualified or unqualified names. - - If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - if "*" in extensions or "**" in extensions: - extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) - - msg = self.batch_manage(Action.LOAD, *extensions) - await ctx.send(msg) - - @extensions_group.command(name="unload", aliases=("ul",)) - async def unload_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Unload currently loaded extensions given their fully qualified or unqualified names. - - If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) - - if blacklisted: - msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" - else: - if "*" in extensions or "**" in extensions: - extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST - - msg = self.batch_manage(Action.UNLOAD, *extensions) - - await ctx.send(msg) - - @extensions_group.command(name="reload", aliases=("r",)) - async def reload_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Reload extensions given their fully qualified or unqualified names. - - If an extension fails to be reloaded, it will be rolled-back to the prior working state. - - If '\*' is given as the name, all currently loaded extensions will be reloaded. - If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - if "**" in extensions: - extensions = EXTENSIONS - elif "*" in extensions: - extensions = set(self.bot.extensions.keys()) | set(extensions) - extensions.remove("*") - - msg = self.batch_manage(Action.RELOAD, *extensions) - - await ctx.send(msg) - - @extensions_group.command(name="list", aliases=("all",)) - async def list_command(self, ctx: Context) -> None: - """ - Get a list of all extensions, including their loaded status. - - Grey indicates that the extension is unloaded. - Green indicates that the extension is currently loaded. - """ - embed = Embed() - lines = [] - - embed.colour = Colour.blurple() - embed.set_author( - name="Extensions List", - url=URLs.github_bot_repo, - icon_url=URLs.bot_avatar - ) - - for ext in sorted(list(EXTENSIONS)): - if ext in self.bot.extensions: - status = Emojis.status_online - else: - status = Emojis.status_offline - - ext = ext.rsplit(".", 1)[1] - lines.append(f"{status} {ext}") - - log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") - await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) - - def batch_manage(self, action: Action, *extensions: str) -> str: - """ - Apply an action to multiple extensions and return a message with the results. - - If only one extension is given, it is deferred to `manage()`. - """ - if len(extensions) == 1: - msg, _ = self.manage(action, extensions[0]) - return msg - - verb = action.name.lower() - failures = {} - - for extension in extensions: - _, error = self.manage(action, extension) - if error: - failures[extension] = error - - emoji = ":x:" if failures else ":ok_hand:" - msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." - - if failures: - failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) - msg += f"\nFailures:```{failures}```" - - log.debug(f"Batch {verb}ed extensions.") - - return msg - - def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: - """Apply an action to an extension and return the status message and any error message.""" - verb = action.name.lower() - error_msg = None - - try: - action.value(self.bot, ext) - except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): - if action is Action.RELOAD: - # When reloading, just load the extension if it was not loaded. - return self.manage(Action.LOAD, ext) - - msg = f":x: Extension `{ext}` is already {verb}ed." - log.debug(msg[4:]) - except Exception as e: - if hasattr(e, "original"): - e = e.original - - log.exception(f"Extension '{ext}' failed to {verb}.") - - error_msg = f"{e.__class__.__name__}: {e}" - msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" - else: - msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." - log.debug(msg[10:]) - - return msg, error_msg - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators and core developers to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Handle BadArgument errors locally to prevent the help command from showing.""" - if isinstance(error, commands.BadArgument): - await ctx.send(str(error)) - error.handled = True - - -def setup(bot: Bot) -> None: - """Load the Extensions cog.""" - bot.add_cog(Extensions(bot)) diff --git a/bot/cogs/filter_lists.py b/bot/cogs/filter_lists.py deleted file mode 100644 index c15adc461..000000000 --- a/bot/cogs/filter_lists.py +++ /dev/null @@ -1,273 +0,0 @@ -import logging -from typing import Optional - -from discord import Colour, Embed -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.converters import ValidDiscordServerInvite, ValidFilterListType -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - - -class FilterLists(Cog): - """Commands for blacklisting and whitelisting things.""" - - methods_with_filterlist_types = [ - "allow_add", - "allow_delete", - "allow_get", - "deny_add", - "deny_delete", - "deny_get", - ] - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.bot.loop.create_task(self._amend_docstrings()) - - async def _amend_docstrings(self) -> None: - """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" - await self.bot.wait_until_guild_available() - - # Add valid filterlist types to the docstrings - valid_types = await ValidFilterListType.get_valid_types(self.bot) - valid_types = [f"`{type_.lower()}`" for type_ in valid_types] - - for method_name in self.methods_with_filterlist_types: - command = getattr(self, method_name) - command.help = ( - f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." - ) - - async def _add_data( - self, - ctx: Context, - allowed: bool, - list_type: ValidFilterListType, - content: str, - comment: Optional[str] = None, - ) -> None: - """Add an item to a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we gotta validate it. - if list_type == "GUILD_INVITE": - guild_data = await self._validate_guild_invite(ctx, content) - content = guild_data.get("id") - - # Unless the user has specified another comment, let's - # use the server name as the comment so that the list - # of guild IDs will be more easily readable when we - # display it. - if not comment: - comment = guild_data.get("name") - - # If it's a file format, let's make sure it has a leading dot. - elif list_type == "FILE_FORMAT" and not content.startswith("."): - content = f".{content}" - - # Try to add the item to the database - log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") - payload = { - "allowed": allowed, - "type": list_type, - "content": content, - "comment": comment, - } - - try: - item = await self.bot.api_client.post( - "bot/filter-lists", - json=payload - ) - except ResponseCodeError as e: - if e.status == 400: - await ctx.message.add_reaction("❌") - log.debug( - f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " - "probably because the request violated the UniqueConstraint." - ) - raise BadArgument( - f"Unable to add the item to the {allow_type}. " - "The item probably already exists. Keep in mind that a " - "blacklist and a whitelist for the same item cannot co-exist, " - "and we do not permit any duplicates." - ) - raise - - # Insert the item into the cache - self.bot.insert_item_into_filter_list_cache(item) - await ctx.message.add_reaction("✅") - - async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we need to convert it. - if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): - guild_data = await self._validate_guild_invite(ctx, content) - content = guild_data.get("id") - - # If it's a file format, let's make sure it has a leading dot. - elif list_type == "FILE_FORMAT" and not content.startswith("."): - content = f".{content}" - - # Find the content and delete it. - log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") - item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) - - if item is not None: - try: - await self.bot.api_client.delete( - f"bot/filter-lists/{item['id']}" - ) - del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] - await ctx.message.add_reaction("✅") - except ResponseCodeError as e: - log.debug( - f"{ctx.author} tried to delete an item with the id {item['id']}, but " - f"the API raised an unexpected error: {e}" - ) - await ctx.message.add_reaction("❌") - else: - await ctx.message.add_reaction("❌") - - async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: - """Paginate and display all items in a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] - - # Build a list of lines we want to show in the paginator - lines = [] - for content, metadata in result.items(): - line = f"• `{content}`" - - if comment := metadata.get("comment"): - line += f" - {comment}" - - lines.append(line) - lines = sorted(lines) - - # Build the embed - list_type_plural = list_type.lower().replace("_", " ").title() + "s" - embed = Embed( - title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", - colour=Colour.blue() - ) - log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") - - if result: - await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - await ctx.message.add_reaction("❌") - - async def _sync_data(self, ctx: Context) -> None: - """Syncs the filterlists with the API.""" - try: - log.trace("Attempting to sync FilterList cache with data from the API.") - await self.bot.cache_filter_list_data() - await ctx.message.add_reaction("✅") - except ResponseCodeError as e: - log.debug( - f"{ctx.author} tried to sync FilterList cache data but " - f"the API raised an unexpected error: {e}" - ) - await ctx.message.add_reaction("❌") - - @staticmethod - async def _validate_guild_invite(ctx: Context, invite: str) -> dict: - """ - Validates a guild invite, and returns the guild info as a dict. - - Will raise a BadArgument if the guild invite is invalid. - """ - log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") - validator = ValidDiscordServerInvite() - guild_data = await validator.convert(ctx, invite) - - # If we make it this far without raising a BadArgument, the invite is - # valid. Let's return a dict of guild information. - log.trace(f"{invite} validated as server invite. Converting to ID.") - return guild_data - - @group(aliases=("allowlist", "allow", "al", "wl")) - async def whitelist(self, ctx: Context) -> None: - """Group for whitelisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @group(aliases=("denylist", "deny", "bl", "dl")) - async def blacklist(self, ctx: Context) -> None: - """Group for blacklisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @whitelist.command(name="add", aliases=("a", "set")) - async def allow_add( - self, - ctx: Context, - list_type: ValidFilterListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified allowlist.""" - await self._add_data(ctx, True, list_type, content, comment) - - @blacklist.command(name="add", aliases=("a", "set")) - async def deny_add( - self, - ctx: Context, - list_type: ValidFilterListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified denylist.""" - await self._add_data(ctx, False, list_type, content, comment) - - @whitelist.command(name="remove", aliases=("delete", "rm",)) - async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from the specified allowlist.""" - await self._delete_data(ctx, True, list_type, content) - - @blacklist.command(name="remove", aliases=("delete", "rm",)) - async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from the specified denylist.""" - await self._delete_data(ctx, False, list_type, content) - - @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: - """Get the contents of a specified allowlist.""" - await self._list_all_data(ctx, True, list_type) - - @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: - """Get the contents of a specified denylist.""" - await self._list_all_data(ctx, False, list_type) - - @whitelist.command(name="sync", aliases=("s",)) - async def allow_sync(self, ctx: Context) -> None: - """Syncs both allowlists and denylists with the API.""" - await self._sync_data(ctx) - - @blacklist.command(name="sync", aliases=("s",)) - async def deny_sync(self, ctx: Context) -> None: - """Syncs both allowlists and denylists with the API.""" - await self._sync_data(ctx) - - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the FilterLists cog.""" - bot.add_cog(FilterLists(bot)) diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py deleted file mode 100644 index 93cc1c655..000000000 --- a/bot/cogs/filtering.py +++ /dev/null @@ -1,575 +0,0 @@ -import asyncio -import logging -import re -from datetime import datetime, timedelta -from typing import List, Mapping, Optional, Tuple, Union - -import dateutil -import discord.errors -from dateutil.relativedelta import relativedelta -from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel -from discord.ext.commands import Cog -from discord.utils import escape_markdown - -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import ( - Channels, Colours, - Filter, Icons, URLs -) -from bot.utils.redis_cache import RedisCache -from bot.utils.regex import INVITE_RE -from bot.utils.scheduling import Scheduler - -log = logging.getLogger(__name__) - -# Regular expressions -SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) -URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) -ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") - -# Other constants. -DAYS_BETWEEN_ALERTS = 3 -OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) - - -class Filtering(Cog): - """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" - - # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent - name_alerts = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - self.name_lock = asyncio.Lock() - - staff_mistake_str = "If you believe this was a mistake, please let staff know!" - self.filters = { - "filter_zalgo": { - "enabled": Filter.filter_zalgo, - "function": self._has_zalgo, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_zalgo, - "notification_msg": ( - "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " - f"{staff_mistake_str}" - ), - "schedule_deletion": False - }, - "filter_invites": { - "enabled": Filter.filter_invites, - "function": self._has_invites, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_invites, - "notification_msg": ( - f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" - r"Our server rules can be found here: " - ), - "schedule_deletion": False - }, - "filter_domains": { - "enabled": Filter.filter_domains, - "function": self._has_urls, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_domains, - "notification_msg": ( - f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" - ), - "schedule_deletion": False - }, - "watch_regex": { - "enabled": Filter.watch_regex, - "function": self._has_watch_regex_match, - "type": "watchlist", - "content_only": True, - "schedule_deletion": True - }, - "watch_rich_embeds": { - "enabled": Filter.watch_rich_embeds, - "function": self._has_rich_embed, - "type": "watchlist", - "content_only": False, - "schedule_deletion": False - } - } - - self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: - """Fetch items from the filter_list_cache.""" - return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() - - @staticmethod - def _expand_spoilers(text: str) -> str: - """Return a string containing all interpretations of a spoilered message.""" - split_text = SPOILER_RE.split(text) - return ''.join( - split_text[0::2] + split_text[1::2] + split_text - ) - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Invoke message filter for new messages.""" - await self._filter_message(msg) - - # Ignore webhook messages. - if msg.webhook_id is None: - await self.check_bad_words_in_name(msg.author) - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """ - Invoke message filter for message edits. - - If there have been multiple edits, calculate the time delta from the previous edit. - """ - if not before.edited_at: - delta = relativedelta(after.edited_at, before.created_at).microseconds - else: - delta = relativedelta(after.edited_at, before.edited_at).microseconds - await self._filter_message(after, delta) - - def get_name_matches(self, name: str) -> List[re.Match]: - """Check bad words from passed string (name). Return list of matches.""" - matches = [] - watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) - for pattern in watchlist_patterns: - if match := re.search(pattern, name, flags=re.IGNORECASE): - matches.append(match) - return matches - - async def check_send_alert(self, member: Member) -> bool: - """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" - if last_alert := await self.name_alerts.get(member.id): - last_alert = datetime.utcfromtimestamp(last_alert) - if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: - log.trace(f"Last alert was too recent for {member}'s nickname.") - return False - - return True - - async def check_bad_words_in_name(self, member: Member) -> None: - """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" - # Use lock to avoid race conditions - async with self.name_lock: - # Check whether the users display name contains any words in our blacklist - matches = self.get_name_matches(member.display_name) - - if not matches or not await self.check_send_alert(member): - return - - log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") - - log_string = ( - f"**User:** {member.mention} (`{member.id}`)\n" - f"**Display Name:** {member.display_name}\n" - f"**Bad Matches:** {', '.join(match.group() for match in matches)}" - ) - - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colours.soft_red, - title="Username filtering alert", - text=log_string, - channel_id=Channels.mod_alerts, - thumbnail=member.avatar_url - ) - - # Update time when alert sent - await self.name_alerts.set(member.id, datetime.utcnow().timestamp()) - - async def filter_eval(self, result: str, msg: Message) -> bool: - """ - Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. - - Also requires the original message, to check whether to filter and for mod logs. - Returns whether a filter was triggered or not. - """ - filter_triggered = False - # Should we filter this message? - if self._check_filter(msg): - for filter_name, _filter in self.filters.items(): - # Is this specific filter enabled in the config? - # We also do not need to worry about filters that take the full message, - # since all we have is an arbitrary string. - if _filter["enabled"] and _filter["content_only"]: - match = await _filter["function"](result) - - if match: - # If this is a filter (not a watchlist), we set the variable so we know - # that it has been triggered - if _filter["type"] == "filter": - filter_triggered = True - - # We do not have to check against DM channels since !eval cannot be used there. - channel_str = f"in {msg.channel.mention}" - - message_content, additional_embeds, additional_embeds_msg = self._add_stats( - filter_name, match, result - ) - - message = ( - f"The {filter_name} {_filter['type']} was triggered " - f"by **{msg.author}** " - f"(`{msg.author.id}`) {channel_str} using !eval with " - f"[the following message]({msg.jump_url}):\n\n" - f"{message_content}" - ) - - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title=f"{_filter['type'].title()} triggered!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone, - additional_embeds=additional_embeds, - additional_embeds_msg=additional_embeds_msg - ) - - break # We don't want multiple filters to trigger - - return filter_triggered - - async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: - """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" - # Should we filter this message? - if self._check_filter(msg): - for filter_name, _filter in self.filters.items(): - # Is this specific filter enabled in the config? - if _filter["enabled"]: - # Double trigger check for the embeds filter - if filter_name == "watch_rich_embeds": - # If the edit delta is less than 0.001 seconds, then we're probably dealing - # with a double filter trigger. - if delta is not None and delta < 100: - continue - - # Does the filter only need the message content or the full message? - if _filter["content_only"]: - match = await _filter["function"](msg.content) - else: - match = await _filter["function"](msg) - - if match: - is_private = msg.channel.type is discord.ChannelType.private - - # If this is a filter (not a watchlist) and not in a DM, delete the message. - if _filter["type"] == "filter" and not is_private: - try: - # Embeds (can?) trigger both the `on_message` and `on_message_edit` - # event handlers, triggering filtering twice for the same message. - # - # If `on_message`-triggered filtering already deleted the message - # then `on_message_edit`-triggered filtering will raise exception - # since the message no longer exists. - # - # In addition, to avoid sending two notifications to the user, the - # logs, and mod_alert, we return if the message no longer exists. - await msg.delete() - except discord.errors.NotFound: - return - - # Notify the user if the filter specifies - if _filter["user_notification"]: - await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) - - # If the message is classed as offensive, we store it in the site db and - # it will be deleted it after one week. - if _filter["schedule_deletion"] and not is_private: - delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() - data = { - 'id': msg.id, - 'channel_id': msg.channel.id, - 'delete_date': delete_date - } - - await self.bot.api_client.post('bot/offensive-messages', json=data) - self.schedule_msg_delete(data) - log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") - - if is_private: - channel_str = "via DM" - else: - channel_str = f"in {msg.channel.mention}" - - message_content, additional_embeds, additional_embeds_msg = self._add_stats( - filter_name, match, msg.content - ) - - message = ( - f"The {filter_name} {_filter['type']} was triggered " - f"by **{msg.author}** " - f"(`{msg.author.id}`) {channel_str} with [the " - f"following message]({msg.jump_url}):\n\n" - f"{message_content}" - ) - - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title=f"{_filter['type'].title()} triggered!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone if not is_private else False, - additional_embeds=additional_embeds, - additional_embeds_msg=additional_embeds_msg - ) - - break # We don't want multiple filters to trigger - - def _add_stats(self, name: str, match: Union[re.Match, dict, bool, List[discord.Embed]], content: str) -> Tuple[ - str, Optional[List[discord.Embed]], Optional[str] - ]: - """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" - # Word and match stats for watch_regex - if name == "watch_regex": - surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] - message_content = ( - f"**Match:** '{match[0]}'\n" - f"**Location:** '...{escape_markdown(surroundings)}...'\n" - f"\n**Original Message:**\n{escape_markdown(content)}" - ) - else: # Use original content - message_content = content - - additional_embeds = None - additional_embeds_msg = None - - self.bot.stats.incr(f"filters.{name}") - - # The function returns True for invalid invites. - # They have no data so additional embeds can't be created for them. - if name == "filter_invites" and match is not True: - additional_embeds = [] - for _, data in match.items(): - embed = discord.Embed(description=( - f"**Members:**\n{data['members']}\n" - f"**Active:**\n{data['active']}" - )) - embed.set_author(name=data["name"]) - embed.set_thumbnail(url=data["icon"]) - embed.set_footer(text=f"Guild ID: {data['id']}") - additional_embeds.append(embed) - additional_embeds_msg = "For the following guild(s):" - - elif name == "watch_rich_embeds": - additional_embeds = match - additional_embeds_msg = "With the following embed(s):" - - return message_content, additional_embeds, additional_embeds_msg - - @staticmethod - def _check_filter(msg: Message) -> bool: - """Check whitelists to see if we should filter this message.""" - role_whitelisted = False - - if type(msg.author) is Member: # Only Member has roles, not User. - for role in msg.author.roles: - if role.id in Filter.role_whitelist: - role_whitelisted = True - - return ( - msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist - and not role_whitelisted # Role not in whitelist - and not msg.author.bot # Author not a bot - ) - - async def _has_watch_regex_match(self, text: str) -> Union[bool, re.Match]: - """ - Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. - - `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is - matched as-is. Spoilers are expanded, if any, and URLs are ignored. - """ - if SPOILER_RE.search(text): - text = self._expand_spoilers(text) - - # Make sure it's not a URL - if URL_RE.search(text): - return False - - watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) - for pattern in watchlist_patterns: - match = re.search(pattern, text, flags=re.IGNORECASE) - if match: - return match - - async def _has_urls(self, text: str) -> bool: - """Returns True if the text contains one of the blacklisted URLs from the config file.""" - if not URL_RE.search(text): - return False - - text = text.lower() - domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) - - for url in domain_blacklist: - if url.lower() in text: - return True - - return False - - @staticmethod - async def _has_zalgo(text: str) -> bool: - """ - Returns True if the text contains zalgo characters. - - Zalgo range is \u0300 – \u036F and \u0489. - """ - return bool(ZALGO_RE.search(text)) - - async def _has_invites(self, text: str) -> Union[dict, bool]: - """ - Checks if there's any invites in the text content that aren't in the guild whitelist. - - If any are detected, a dictionary of invite data is returned, with a key per invite. - If none are detected, False is returned. - - Attempts to catch some of common ways to try to cheat the system. - """ - # Remove backslashes to prevent escape character aroundfuckery like - # discord\.gg/gdudes-pony-farm - text = text.replace("\\", "") - - invites = INVITE_RE.findall(text) - invite_data = dict() - for invite in invites: - if invite in invite_data: - continue - - response = await self.bot.http_session.get( - f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} - ) - response = await response.json() - guild = response.get("guild") - if guild is None: - # Lack of a "guild" key in the JSON response indicates either an group DM invite, an - # expired invite, or an invalid invite. The API does not currently differentiate - # between invalid and expired invites - return True - - guild_id = guild.get("id") - guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) - guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) - - # Is this invite allowed? - guild_partnered_or_verified = ( - 'PARTNERED' in guild.get("features", []) - or 'VERIFIED' in guild.get("features", []) - ) - invite_not_allowed = ( - guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted. - or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted. - and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered. - ) - - if invite_not_allowed: - guild_icon_hash = guild["icon"] - guild_icon = ( - "https://cdn.discordapp.com/icons/" - f"{guild_id}/{guild_icon_hash}.png?size=512" - ) - - invite_data[invite] = { - "name": guild["name"], - "id": guild['id'], - "icon": guild_icon, - "members": response["approximate_member_count"], - "active": response["approximate_presence_count"] - } - - return invite_data if invite_data else False - - @staticmethod - async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: - """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" - if msg.embeds: - for embed in msg.embeds: - if embed.type == "rich": - urls = URL_RE.findall(msg.content) - if not embed.url or embed.url not in urls: - # If `embed.url` does not exist or if `embed.url` is not part of the content - # of the message, it's unlikely to be an auto-generated embed by Discord. - return msg.embeds - else: - log.trace( - "Found a rich embed sent by a regular user account, " - "but it was likely just an automatic URL embed." - ) - return False - return False - - async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: - """ - Notify filtered_member about a moderation action with the reason str. - - First attempts to DM the user, fall back to in-channel notification if user has DMs disabled - """ - try: - await filtered_member.send(reason) - except discord.errors.Forbidden: - await channel.send(f"{filtered_member.mention} {reason}") - - def schedule_msg_delete(self, msg: dict) -> None: - """Delete an offensive message once its deletion date is reached.""" - delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) - - async def reschedule_offensive_msg_deletion(self) -> None: - """Get all the pending message deletion from the API and reschedule them.""" - await self.bot.wait_until_ready() - response = await self.bot.api_client.get('bot/offensive-messages',) - - now = datetime.utcnow() - - for msg in response: - delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - - if delete_at < now: - await self.delete_offensive_msg(msg) - else: - self.schedule_msg_delete(msg) - - async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: - """Delete an offensive message, and then delete it from the db.""" - try: - channel = self.bot.get_channel(msg['channel_id']) - if channel: - msg_obj = await channel.fetch_message(msg['id']) - await msg_obj.delete() - except NotFound: - log.info( - f"Tried to delete message {msg['id']}, but the message can't be found " - f"(it has been probably already deleted)." - ) - except HTTPException as e: - log.warning(f"Failed to delete message {msg['id']}: status {e.status}") - - await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') - log.info(f"Deleted the offensive message with id {msg['id']}.") - - -def setup(bot: Bot) -> None: - """Load the Filtering cog.""" - bot.add_cog(Filtering(bot)) diff --git a/bot/cogs/filters/__init__.py b/bot/cogs/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/filters/antimalware.py b/bot/cogs/filters/antimalware.py new file mode 100644 index 000000000..c76bd2c60 --- /dev/null +++ b/bot/cogs/filters/antimalware.py @@ -0,0 +1,98 @@ +import logging +import typing as t +from os.path import splitext + +from discord import Embed, Message, NotFound +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, STAFF_ROLES, URLs + +log = logging.getLogger(__name__) + +PY_EMBED_DESCRIPTION = ( + "It looks like you tried to attach a Python file - " + f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" +) + +TXT_EMBED_DESCRIPTION = ( + "**Uh-oh!** It looks like your message got zapped by our spam filter. " + "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" + "• If you attempted to send a message longer than 2000 characters, try shortening your message " + "to fit within the character limit or use a pasting service (see below) \n\n" + "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " + "{cmd_channel_mention} for more information) or use a pasting service like: " + f"\n\n{URLs.site_schema}{URLs.site_paste}" +) + +DISALLOWED_EMBED_DESCRIPTION = ( + "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " + "We currently allow the following file types: **{joined_whitelist}**.\n\n" + "Feel free to ask in {meta_channel_mention} if you think this is a mistake." +) + + +class AntiMalware(Cog): + """Delete messages which contain attachments with non-whitelisted file extensions.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_whitelisted_file_formats(self) -> list: + """Get the file formats currently on the whitelist.""" + return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() + + def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: + """Get an iterable containing all the disallowed extensions of attachments.""" + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) + return extensions_blocked + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Identify messages with prohibited attachments.""" + # Return when message don't have attachment and don't moderate DMs + if not message.attachments or not message.guild: + return + + # Check if user is staff, if is, return + # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance + if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): + return + + embed = Embed() + extensions_blocked = self._get_disallowed_extensions(message) + blocked_extensions_str = ', '.join(extensions_blocked) + if ".py" in extensions_blocked: + # Short-circuit on *.py files to provide a pastebin link + embed.description = PY_EMBED_DESCRIPTION + elif ".txt" in extensions_blocked: + # Work around Discord AutoConversion of messages longer than 2000 chars to .txt + cmd_channel = self.bot.get_channel(Channels.bot_commands) + embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) + elif extensions_blocked: + meta_channel = self.bot.get_channel(Channels.meta) + embed.description = DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=', '.join(self._get_whitelisted_file_formats()), + blocked_extensions_str=blocked_extensions_str, + meta_channel_mention=meta_channel.mention, + ) + + if embed.description: + log.info( + f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", + extra={"attachment_list": [attachment.filename for attachment in message.attachments]} + ) + + await message.channel.send(f"Hey {message.author.mention}!", embed=embed) + + # Delete the offending message: + try: + await message.delete() + except NotFound: + log.info(f"Tried to delete message `{message.id}`, but message could not be found.") + + +def setup(bot: Bot) -> None: + """Load the AntiMalware cog.""" + bot.add_cog(AntiMalware(bot)) diff --git a/bot/cogs/filters/antispam.py b/bot/cogs/filters/antispam.py new file mode 100644 index 000000000..0bcca578d --- /dev/null +++ b/bot/cogs/filters/antispam.py @@ -0,0 +1,288 @@ +import asyncio +import logging +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from operator import itemgetter +from typing import Dict, Iterable, List, Set + +from discord import Colour, Member, Message, NotFound, Object, TextChannel +from discord.ext.commands import Cog + +from bot import rules +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import ( + AntiSpam as AntiSpamConfig, Channels, + Colours, DEBUG_MODE, Event, Filter, + Guild as GuildConfig, Icons, + STAFF_ROLES, +) +from bot.converters import Duration +from bot.utils.messages import send_attachments + + +log = logging.getLogger(__name__) + +RULE_FUNCTION_MAPPING = { + 'attachments': rules.apply_attachments, + 'burst': rules.apply_burst, + 'burst_shared': rules.apply_burst_shared, + 'chars': rules.apply_chars, + 'discord_emojis': rules.apply_discord_emojis, + 'duplicates': rules.apply_duplicates, + 'links': rules.apply_links, + 'mentions': rules.apply_mentions, + 'newlines': rules.apply_newlines, + 'role_mentions': rules.apply_role_mentions +} + + +@dataclass +class DeletionContext: + """Represents a Deletion Context for a single spam event.""" + + channel: TextChannel + members: Dict[int, Member] = field(default_factory=dict) + rules: Set[str] = field(default_factory=set) + messages: Dict[int, Message] = field(default_factory=dict) + attachments: List[List[str]] = field(default_factory=list) + + async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: + """Adds new rule violation events to the deletion context.""" + self.rules.add(rule_name) + + for member in members: + if member.id not in self.members: + self.members[member.id] = member + + for message in messages: + if message.id not in self.messages: + self.messages[message.id] = message + + # Re-upload attachments + destination = message.guild.get_channel(Channels.attachment_log) + urls = await send_attachments(message, destination, link_large=False) + self.attachments.append(urls) + + async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: + """Method that takes care of uploading the queue and posting modlog alert.""" + triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values()) + + mod_alert_message = ( + f"**Triggered by:** {triggered_by_users}\n" + f"**Channel:** {self.channel.mention}\n" + f"**Rules:** {', '.join(rule for rule in self.rules)}\n" + ) + + # For multiple messages or those with excessive newlines, use the logs API + if len(self.messages) > 1 or 'newlines' in self.rules: + url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) + mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" + else: + mod_alert_message += "Message:\n" + [message] = self.messages.values() + content = message.clean_content + remaining_chars = 2040 - len(mod_alert_message) + + if len(content) > remaining_chars: + content = content[:remaining_chars] + "..." + + mod_alert_message += f"{content}" + + *_, last_message = self.messages.values() + await modlog.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title="Spam detected!", + text=mod_alert_message, + thumbnail=last_message.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=AntiSpamConfig.ping_everyone + ) + + +class AntiSpam(Cog): + """Cog that controls our anti-spam measures.""" + + def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: + self.bot = bot + self.validation_errors = validation_errors + role_id = AntiSpamConfig.punishment['role_id'] + self.muted_role = Object(role_id) + self.expiration_date_converter = Duration() + + self.message_deletion_queue = dict() + + self.bot.loop.create_task(self.alert_on_validation_error()) + + @property + def mod_log(self) -> ModLog: + """Allows for easy access of the ModLog cog.""" + return self.bot.get_cog("ModLog") + + async def alert_on_validation_error(self) -> None: + """Unloads the cog and alerts admins if configuration validation failed.""" + await self.bot.wait_until_guild_available() + if self.validation_errors: + body = "**The following errors were encountered:**\n" + body += "\n".join(f"- {error}" for error in self.validation_errors.values()) + body += "\n\n**The cog has been unloaded.**" + + await self.mod_log.send_log_message( + title="Error: AntiSpam configuration validation failed!", + text=body, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Colour.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Applies the antispam rules to each received message.""" + if ( + not message.guild + or message.guild.id != GuildConfig.id + or message.author.bot + or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) + or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) + ): + return + + # Fetch the rule configuration with the highest rule interval. + max_interval_config = max( + AntiSpamConfig.rules.values(), + key=itemgetter('interval') + ) + max_interval = max_interval_config['interval'] + + # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. + earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) + relevant_messages = [ + msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) + if not msg.author.bot + ] + + for rule_name in AntiSpamConfig.rules: + rule_config = AntiSpamConfig.rules[rule_name] + rule_function = RULE_FUNCTION_MAPPING[rule_name] + + # Create a list of messages that were sent in the interval that the rule cares about. + latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) + messages_for_rule = [ + msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp + ] + result = await rule_function(message, messages_for_rule, rule_config) + + # If the rule returns `None`, that means the message didn't violate it. + # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` + # which contains the reason for why the message violated the rule and + # an iterable of all members that violated the rule. + if result is not None: + self.bot.stats.incr(f"mod_alerts.{rule_name}") + reason, members, relevant_messages = result + full_reason = f"`{rule_name}` rule: {reason}" + + # If there's no spam event going on for this channel, start a new Message Deletion Context + channel = message.channel + if channel.id not in self.message_deletion_queue: + log.trace(f"Creating queue for channel `{channel.id}`") + self.message_deletion_queue[message.channel.id] = DeletionContext(channel) + self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) + + # Add the relevant of this trigger to the Deletion Context + await self.message_deletion_queue[message.channel.id].add( + rule_name=rule_name, + members=members, + messages=relevant_messages + ) + + for member in members: + + # Fire it off as a background task to ensure + # that the sleep doesn't block further tasks + self.bot.loop.create_task( + self.punish(message, member, full_reason) + ) + + await self.maybe_delete_messages(channel, relevant_messages) + break + + async def punish(self, msg: Message, member: Member, reason: str) -> None: + """Punishes the given member for triggering an antispam rule.""" + if not any(role.id == self.muted_role.id for role in member.roles): + remove_role_after = AntiSpamConfig.punishment['remove_after'] + + # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes + context = await self.bot.get_context(msg) + context.author = self.bot.user + context.message.author = self.bot.user + + # Since we're going to invoke the tempmute command directly, we need to manually call the converter. + dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S") + await context.invoke( + self.bot.get_command('tempmute'), + member, + dt_remove_role_after, + reason=reason + ) + + async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: + """Cleans the messages if cleaning is configured.""" + if AntiSpamConfig.clean_offending: + # If we have more than one message, we can use bulk delete. + if len(messages) > 1: + message_ids = [message.id for message in messages] + self.mod_log.ignore(Event.message_delete, *message_ids) + await channel.delete_messages(messages) + + # Otherwise, the bulk delete endpoint will throw up. + # Delete the message directly instead. + else: + self.mod_log.ignore(Event.message_delete, messages[0].id) + try: + await messages[0].delete() + except NotFound: + log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") + + async def _process_deletion_context(self, context_id: int) -> None: + """Processes the Deletion Context queue.""" + log.trace("Sleeping before processing message deletion queue.") + await asyncio.sleep(10) + + if context_id not in self.message_deletion_queue: + log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") + return + + deletion_context = self.message_deletion_queue.pop(context_id) + await deletion_context.upload_messages(self.bot.user.id, self.mod_log) + + +def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: + """Validates the antispam configs.""" + validation_errors = {} + for name, config in rules_.items(): + if name not in RULE_FUNCTION_MAPPING: + log.error( + f"Unrecognized antispam rule `{name}`. " + f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" + ) + validation_errors[name] = f"`{name}` is not recognized as an antispam rule." + continue + for required_key in ('interval', 'max'): + if required_key not in config: + log.error( + f"`{required_key}` is required but was not " + f"set in rule `{name}`'s configuration." + ) + validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" + return validation_errors + + +def setup(bot: Bot) -> None: + """Validate the AntiSpam configs and load the AntiSpam cog.""" + validation_errors = validate_config() + bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/cogs/filters/filter_lists.py b/bot/cogs/filters/filter_lists.py new file mode 100644 index 000000000..c15adc461 --- /dev/null +++ b/bot/cogs/filters/filter_lists.py @@ -0,0 +1,273 @@ +import logging +from typing import Optional + +from discord import Colour, Embed +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.converters import ValidDiscordServerInvite, ValidFilterListType +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + + +class FilterLists(Cog): + """Commands for blacklisting and whitelisting things.""" + + methods_with_filterlist_types = [ + "allow_add", + "allow_delete", + "allow_get", + "deny_add", + "deny_delete", + "deny_get", + ] + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.bot.loop.create_task(self._amend_docstrings()) + + async def _amend_docstrings(self) -> None: + """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" + await self.bot.wait_until_guild_available() + + # Add valid filterlist types to the docstrings + valid_types = await ValidFilterListType.get_valid_types(self.bot) + valid_types = [f"`{type_.lower()}`" for type_ in valid_types] + + for method_name in self.methods_with_filterlist_types: + command = getattr(self, method_name) + command.help = ( + f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." + ) + + async def _add_data( + self, + ctx: Context, + allowed: bool, + list_type: ValidFilterListType, + content: str, + comment: Optional[str] = None, + ) -> None: + """Add an item to a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we gotta validate it. + if list_type == "GUILD_INVITE": + guild_data = await self._validate_guild_invite(ctx, content) + content = guild_data.get("id") + + # Unless the user has specified another comment, let's + # use the server name as the comment so that the list + # of guild IDs will be more easily readable when we + # display it. + if not comment: + comment = guild_data.get("name") + + # If it's a file format, let's make sure it has a leading dot. + elif list_type == "FILE_FORMAT" and not content.startswith("."): + content = f".{content}" + + # Try to add the item to the database + log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") + payload = { + "allowed": allowed, + "type": list_type, + "content": content, + "comment": comment, + } + + try: + item = await self.bot.api_client.post( + "bot/filter-lists", + json=payload + ) + except ResponseCodeError as e: + if e.status == 400: + await ctx.message.add_reaction("❌") + log.debug( + f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " + "probably because the request violated the UniqueConstraint." + ) + raise BadArgument( + f"Unable to add the item to the {allow_type}. " + "The item probably already exists. Keep in mind that a " + "blacklist and a whitelist for the same item cannot co-exist, " + "and we do not permit any duplicates." + ) + raise + + # Insert the item into the cache + self.bot.insert_item_into_filter_list_cache(item) + await ctx.message.add_reaction("✅") + + async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we need to convert it. + if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): + guild_data = await self._validate_guild_invite(ctx, content) + content = guild_data.get("id") + + # If it's a file format, let's make sure it has a leading dot. + elif list_type == "FILE_FORMAT" and not content.startswith("."): + content = f".{content}" + + # Find the content and delete it. + log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") + item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) + + if item is not None: + try: + await self.bot.api_client.delete( + f"bot/filter-lists/{item['id']}" + ) + del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] + await ctx.message.add_reaction("✅") + except ResponseCodeError as e: + log.debug( + f"{ctx.author} tried to delete an item with the id {item['id']}, but " + f"the API raised an unexpected error: {e}" + ) + await ctx.message.add_reaction("❌") + else: + await ctx.message.add_reaction("❌") + + async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: + """Paginate and display all items in a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] + + # Build a list of lines we want to show in the paginator + lines = [] + for content, metadata in result.items(): + line = f"• `{content}`" + + if comment := metadata.get("comment"): + line += f" - {comment}" + + lines.append(line) + lines = sorted(lines) + + # Build the embed + list_type_plural = list_type.lower().replace("_", " ").title() + "s" + embed = Embed( + title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", + colour=Colour.blue() + ) + log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") + + if result: + await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + await ctx.message.add_reaction("❌") + + async def _sync_data(self, ctx: Context) -> None: + """Syncs the filterlists with the API.""" + try: + log.trace("Attempting to sync FilterList cache with data from the API.") + await self.bot.cache_filter_list_data() + await ctx.message.add_reaction("✅") + except ResponseCodeError as e: + log.debug( + f"{ctx.author} tried to sync FilterList cache data but " + f"the API raised an unexpected error: {e}" + ) + await ctx.message.add_reaction("❌") + + @staticmethod + async def _validate_guild_invite(ctx: Context, invite: str) -> dict: + """ + Validates a guild invite, and returns the guild info as a dict. + + Will raise a BadArgument if the guild invite is invalid. + """ + log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") + validator = ValidDiscordServerInvite() + guild_data = await validator.convert(ctx, invite) + + # If we make it this far without raising a BadArgument, the invite is + # valid. Let's return a dict of guild information. + log.trace(f"{invite} validated as server invite. Converting to ID.") + return guild_data + + @group(aliases=("allowlist", "allow", "al", "wl")) + async def whitelist(self, ctx: Context) -> None: + """Group for whitelisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @group(aliases=("denylist", "deny", "bl", "dl")) + async def blacklist(self, ctx: Context) -> None: + """Group for blacklisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @whitelist.command(name="add", aliases=("a", "set")) + async def allow_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified allowlist.""" + await self._add_data(ctx, True, list_type, content, comment) + + @blacklist.command(name="add", aliases=("a", "set")) + async def deny_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified denylist.""" + await self._add_data(ctx, False, list_type, content, comment) + + @whitelist.command(name="remove", aliases=("delete", "rm",)) + async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified allowlist.""" + await self._delete_data(ctx, True, list_type, content) + + @blacklist.command(name="remove", aliases=("delete", "rm",)) + async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified denylist.""" + await self._delete_data(ctx, False, list_type, content) + + @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified allowlist.""" + await self._list_all_data(ctx, True, list_type) + + @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified denylist.""" + await self._list_all_data(ctx, False, list_type) + + @whitelist.command(name="sync", aliases=("s",)) + async def allow_sync(self, ctx: Context) -> None: + """Syncs both allowlists and denylists with the API.""" + await self._sync_data(ctx) + + @blacklist.command(name="sync", aliases=("s",)) + async def deny_sync(self, ctx: Context) -> None: + """Syncs both allowlists and denylists with the API.""" + await self._sync_data(ctx) + + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the FilterLists cog.""" + bot.add_cog(FilterLists(bot)) diff --git a/bot/cogs/filters/filtering.py b/bot/cogs/filters/filtering.py new file mode 100644 index 000000000..93cc1c655 --- /dev/null +++ b/bot/cogs/filters/filtering.py @@ -0,0 +1,575 @@ +import asyncio +import logging +import re +from datetime import datetime, timedelta +from typing import List, Mapping, Optional, Tuple, Union + +import dateutil +import discord.errors +from dateutil.relativedelta import relativedelta +from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel +from discord.ext.commands import Cog +from discord.utils import escape_markdown + +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import ( + Channels, Colours, + Filter, Icons, URLs +) +from bot.utils.redis_cache import RedisCache +from bot.utils.regex import INVITE_RE +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + +# Regular expressions +SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) +URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) +ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") + +# Other constants. +DAYS_BETWEEN_ALERTS = 3 +OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) + + +class Filtering(Cog): + """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" + + # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent + name_alerts = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + self.name_lock = asyncio.Lock() + + staff_mistake_str = "If you believe this was a mistake, please let staff know!" + self.filters = { + "filter_zalgo": { + "enabled": Filter.filter_zalgo, + "function": self._has_zalgo, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_zalgo, + "notification_msg": ( + "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " + f"{staff_mistake_str}" + ), + "schedule_deletion": False + }, + "filter_invites": { + "enabled": Filter.filter_invites, + "function": self._has_invites, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_invites, + "notification_msg": ( + f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" + r"Our server rules can be found here: " + ), + "schedule_deletion": False + }, + "filter_domains": { + "enabled": Filter.filter_domains, + "function": self._has_urls, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_domains, + "notification_msg": ( + f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" + ), + "schedule_deletion": False + }, + "watch_regex": { + "enabled": Filter.watch_regex, + "function": self._has_watch_regex_match, + "type": "watchlist", + "content_only": True, + "schedule_deletion": True + }, + "watch_rich_embeds": { + "enabled": Filter.watch_rich_embeds, + "function": self._has_rich_embed, + "type": "watchlist", + "content_only": False, + "schedule_deletion": False + } + } + + self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: + """Fetch items from the filter_list_cache.""" + return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() + + @staticmethod + def _expand_spoilers(text: str) -> str: + """Return a string containing all interpretations of a spoilered message.""" + split_text = SPOILER_RE.split(text) + return ''.join( + split_text[0::2] + split_text[1::2] + split_text + ) + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Invoke message filter for new messages.""" + await self._filter_message(msg) + + # Ignore webhook messages. + if msg.webhook_id is None: + await self.check_bad_words_in_name(msg.author) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Invoke message filter for message edits. + + If there have been multiple edits, calculate the time delta from the previous edit. + """ + if not before.edited_at: + delta = relativedelta(after.edited_at, before.created_at).microseconds + else: + delta = relativedelta(after.edited_at, before.edited_at).microseconds + await self._filter_message(after, delta) + + def get_name_matches(self, name: str) -> List[re.Match]: + """Check bad words from passed string (name). Return list of matches.""" + matches = [] + watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) + for pattern in watchlist_patterns: + if match := re.search(pattern, name, flags=re.IGNORECASE): + matches.append(match) + return matches + + async def check_send_alert(self, member: Member) -> bool: + """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" + if last_alert := await self.name_alerts.get(member.id): + last_alert = datetime.utcfromtimestamp(last_alert) + if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: + log.trace(f"Last alert was too recent for {member}'s nickname.") + return False + + return True + + async def check_bad_words_in_name(self, member: Member) -> None: + """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" + # Use lock to avoid race conditions + async with self.name_lock: + # Check whether the users display name contains any words in our blacklist + matches = self.get_name_matches(member.display_name) + + if not matches or not await self.check_send_alert(member): + return + + log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") + + log_string = ( + f"**User:** {member.mention} (`{member.id}`)\n" + f"**Display Name:** {member.display_name}\n" + f"**Bad Matches:** {', '.join(match.group() for match in matches)}" + ) + + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colours.soft_red, + title="Username filtering alert", + text=log_string, + channel_id=Channels.mod_alerts, + thumbnail=member.avatar_url + ) + + # Update time when alert sent + await self.name_alerts.set(member.id, datetime.utcnow().timestamp()) + + async def filter_eval(self, result: str, msg: Message) -> bool: + """ + Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. + + Also requires the original message, to check whether to filter and for mod logs. + Returns whether a filter was triggered or not. + """ + filter_triggered = False + # Should we filter this message? + if self._check_filter(msg): + for filter_name, _filter in self.filters.items(): + # Is this specific filter enabled in the config? + # We also do not need to worry about filters that take the full message, + # since all we have is an arbitrary string. + if _filter["enabled"] and _filter["content_only"]: + match = await _filter["function"](result) + + if match: + # If this is a filter (not a watchlist), we set the variable so we know + # that it has been triggered + if _filter["type"] == "filter": + filter_triggered = True + + # We do not have to check against DM channels since !eval cannot be used there. + channel_str = f"in {msg.channel.mention}" + + message_content, additional_embeds, additional_embeds_msg = self._add_stats( + filter_name, match, result + ) + + message = ( + f"The {filter_name} {_filter['type']} was triggered " + f"by **{msg.author}** " + f"(`{msg.author.id}`) {channel_str} using !eval with " + f"[the following message]({msg.jump_url}):\n\n" + f"{message_content}" + ) + + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title=f"{_filter['type'].title()} triggered!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=Filter.ping_everyone, + additional_embeds=additional_embeds, + additional_embeds_msg=additional_embeds_msg + ) + + break # We don't want multiple filters to trigger + + return filter_triggered + + async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: + """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" + # Should we filter this message? + if self._check_filter(msg): + for filter_name, _filter in self.filters.items(): + # Is this specific filter enabled in the config? + if _filter["enabled"]: + # Double trigger check for the embeds filter + if filter_name == "watch_rich_embeds": + # If the edit delta is less than 0.001 seconds, then we're probably dealing + # with a double filter trigger. + if delta is not None and delta < 100: + continue + + # Does the filter only need the message content or the full message? + if _filter["content_only"]: + match = await _filter["function"](msg.content) + else: + match = await _filter["function"](msg) + + if match: + is_private = msg.channel.type is discord.ChannelType.private + + # If this is a filter (not a watchlist) and not in a DM, delete the message. + if _filter["type"] == "filter" and not is_private: + try: + # Embeds (can?) trigger both the `on_message` and `on_message_edit` + # event handlers, triggering filtering twice for the same message. + # + # If `on_message`-triggered filtering already deleted the message + # then `on_message_edit`-triggered filtering will raise exception + # since the message no longer exists. + # + # In addition, to avoid sending two notifications to the user, the + # logs, and mod_alert, we return if the message no longer exists. + await msg.delete() + except discord.errors.NotFound: + return + + # Notify the user if the filter specifies + if _filter["user_notification"]: + await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) + + # If the message is classed as offensive, we store it in the site db and + # it will be deleted it after one week. + if _filter["schedule_deletion"] and not is_private: + delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() + data = { + 'id': msg.id, + 'channel_id': msg.channel.id, + 'delete_date': delete_date + } + + await self.bot.api_client.post('bot/offensive-messages', json=data) + self.schedule_msg_delete(data) + log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") + + if is_private: + channel_str = "via DM" + else: + channel_str = f"in {msg.channel.mention}" + + message_content, additional_embeds, additional_embeds_msg = self._add_stats( + filter_name, match, msg.content + ) + + message = ( + f"The {filter_name} {_filter['type']} was triggered " + f"by **{msg.author}** " + f"(`{msg.author.id}`) {channel_str} with [the " + f"following message]({msg.jump_url}):\n\n" + f"{message_content}" + ) + + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title=f"{_filter['type'].title()} triggered!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=Filter.ping_everyone if not is_private else False, + additional_embeds=additional_embeds, + additional_embeds_msg=additional_embeds_msg + ) + + break # We don't want multiple filters to trigger + + def _add_stats(self, name: str, match: Union[re.Match, dict, bool, List[discord.Embed]], content: str) -> Tuple[ + str, Optional[List[discord.Embed]], Optional[str] + ]: + """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" + # Word and match stats for watch_regex + if name == "watch_regex": + surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] + message_content = ( + f"**Match:** '{match[0]}'\n" + f"**Location:** '...{escape_markdown(surroundings)}...'\n" + f"\n**Original Message:**\n{escape_markdown(content)}" + ) + else: # Use original content + message_content = content + + additional_embeds = None + additional_embeds_msg = None + + self.bot.stats.incr(f"filters.{name}") + + # The function returns True for invalid invites. + # They have no data so additional embeds can't be created for them. + if name == "filter_invites" and match is not True: + additional_embeds = [] + for _, data in match.items(): + embed = discord.Embed(description=( + f"**Members:**\n{data['members']}\n" + f"**Active:**\n{data['active']}" + )) + embed.set_author(name=data["name"]) + embed.set_thumbnail(url=data["icon"]) + embed.set_footer(text=f"Guild ID: {data['id']}") + additional_embeds.append(embed) + additional_embeds_msg = "For the following guild(s):" + + elif name == "watch_rich_embeds": + additional_embeds = match + additional_embeds_msg = "With the following embed(s):" + + return message_content, additional_embeds, additional_embeds_msg + + @staticmethod + def _check_filter(msg: Message) -> bool: + """Check whitelists to see if we should filter this message.""" + role_whitelisted = False + + if type(msg.author) is Member: # Only Member has roles, not User. + for role in msg.author.roles: + if role.id in Filter.role_whitelist: + role_whitelisted = True + + return ( + msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist + and not role_whitelisted # Role not in whitelist + and not msg.author.bot # Author not a bot + ) + + async def _has_watch_regex_match(self, text: str) -> Union[bool, re.Match]: + """ + Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. + + `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is + matched as-is. Spoilers are expanded, if any, and URLs are ignored. + """ + if SPOILER_RE.search(text): + text = self._expand_spoilers(text) + + # Make sure it's not a URL + if URL_RE.search(text): + return False + + watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) + for pattern in watchlist_patterns: + match = re.search(pattern, text, flags=re.IGNORECASE) + if match: + return match + + async def _has_urls(self, text: str) -> bool: + """Returns True if the text contains one of the blacklisted URLs from the config file.""" + if not URL_RE.search(text): + return False + + text = text.lower() + domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) + + for url in domain_blacklist: + if url.lower() in text: + return True + + return False + + @staticmethod + async def _has_zalgo(text: str) -> bool: + """ + Returns True if the text contains zalgo characters. + + Zalgo range is \u0300 – \u036F and \u0489. + """ + return bool(ZALGO_RE.search(text)) + + async def _has_invites(self, text: str) -> Union[dict, bool]: + """ + Checks if there's any invites in the text content that aren't in the guild whitelist. + + If any are detected, a dictionary of invite data is returned, with a key per invite. + If none are detected, False is returned. + + Attempts to catch some of common ways to try to cheat the system. + """ + # Remove backslashes to prevent escape character aroundfuckery like + # discord\.gg/gdudes-pony-farm + text = text.replace("\\", "") + + invites = INVITE_RE.findall(text) + invite_data = dict() + for invite in invites: + if invite in invite_data: + continue + + response = await self.bot.http_session.get( + f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} + ) + response = await response.json() + guild = response.get("guild") + if guild is None: + # Lack of a "guild" key in the JSON response indicates either an group DM invite, an + # expired invite, or an invalid invite. The API does not currently differentiate + # between invalid and expired invites + return True + + guild_id = guild.get("id") + guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) + guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) + + # Is this invite allowed? + guild_partnered_or_verified = ( + 'PARTNERED' in guild.get("features", []) + or 'VERIFIED' in guild.get("features", []) + ) + invite_not_allowed = ( + guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted. + or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted. + and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered. + ) + + if invite_not_allowed: + guild_icon_hash = guild["icon"] + guild_icon = ( + "https://cdn.discordapp.com/icons/" + f"{guild_id}/{guild_icon_hash}.png?size=512" + ) + + invite_data[invite] = { + "name": guild["name"], + "id": guild['id'], + "icon": guild_icon, + "members": response["approximate_member_count"], + "active": response["approximate_presence_count"] + } + + return invite_data if invite_data else False + + @staticmethod + async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: + """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" + if msg.embeds: + for embed in msg.embeds: + if embed.type == "rich": + urls = URL_RE.findall(msg.content) + if not embed.url or embed.url not in urls: + # If `embed.url` does not exist or if `embed.url` is not part of the content + # of the message, it's unlikely to be an auto-generated embed by Discord. + return msg.embeds + else: + log.trace( + "Found a rich embed sent by a regular user account, " + "but it was likely just an automatic URL embed." + ) + return False + return False + + async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: + """ + Notify filtered_member about a moderation action with the reason str. + + First attempts to DM the user, fall back to in-channel notification if user has DMs disabled + """ + try: + await filtered_member.send(reason) + except discord.errors.Forbidden: + await channel.send(f"{filtered_member.mention} {reason}") + + def schedule_msg_delete(self, msg: dict) -> None: + """Delete an offensive message once its deletion date is reached.""" + delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) + self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) + + async def reschedule_offensive_msg_deletion(self) -> None: + """Get all the pending message deletion from the API and reschedule them.""" + await self.bot.wait_until_ready() + response = await self.bot.api_client.get('bot/offensive-messages',) + + now = datetime.utcnow() + + for msg in response: + delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) + + if delete_at < now: + await self.delete_offensive_msg(msg) + else: + self.schedule_msg_delete(msg) + + async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: + """Delete an offensive message, and then delete it from the db.""" + try: + channel = self.bot.get_channel(msg['channel_id']) + if channel: + msg_obj = await channel.fetch_message(msg['id']) + await msg_obj.delete() + except NotFound: + log.info( + f"Tried to delete message {msg['id']}, but the message can't be found " + f"(it has been probably already deleted)." + ) + except HTTPException as e: + log.warning(f"Failed to delete message {msg['id']}: status {e.status}") + + await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') + log.info(f"Deleted the offensive message with id {msg['id']}.") + + +def setup(bot: Bot) -> None: + """Load the Filtering cog.""" + bot.add_cog(Filtering(bot)) diff --git a/bot/cogs/filters/security.py b/bot/cogs/filters/security.py new file mode 100644 index 000000000..c680c5e27 --- /dev/null +++ b/bot/cogs/filters/security.py @@ -0,0 +1,31 @@ +import logging + +from discord.ext.commands import Cog, Context, NoPrivateMessage + +from bot.bot import Bot + +log = logging.getLogger(__name__) + + +class Security(Cog): + """Security-related helpers.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.bot.check(self.check_not_bot) # Global commands check - no bots can run any commands at all + self.bot.check(self.check_on_guild) # Global commands check - commands can't be run in a DM + + def check_not_bot(self, ctx: Context) -> bool: + """Check if the context is a bot user.""" + return not ctx.author.bot + + def check_on_guild(self, ctx: Context) -> bool: + """Check if the context is in a guild.""" + if ctx.guild is None: + raise NoPrivateMessage("This command cannot be used in private messages.") + return True + + +def setup(bot: Bot) -> None: + """Load the Security cog.""" + bot.add_cog(Security(bot)) diff --git a/bot/cogs/filters/token_remover.py b/bot/cogs/filters/token_remover.py new file mode 100644 index 000000000..ef979f222 --- /dev/null +++ b/bot/cogs/filters/token_remover.py @@ -0,0 +1,182 @@ +import base64 +import binascii +import logging +import re +import typing as t + +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog + +from bot import utils +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import Channels, Colours, Event, Icons + +log = logging.getLogger(__name__) + +LOG_MESSAGE = ( + "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, " + "token was `{user_id}.{timestamp}.{hmac}`" +) +DELETION_MESSAGE_TEMPLATE = ( + "Hey {mention}! I noticed you posted a seemingly valid Discord API " + "token in your message and have removed your message. " + "This means that your token has been **compromised**. " + "Please change your token **immediately** at: " + "\n\n" + "Feel free to re-post it with the token removed. " + "If you believe this was a mistake, please let us know!" +) +DISCORD_EPOCH = 1_420_070_400 +TOKEN_EPOCH = 1_293_840_000 + +# Three parts delimited by dots: user ID, creation timestamp, HMAC. +# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string. +# Each part only matches base64 URL-safe characters. +# Padding has never been observed, but the padding character '=' is matched just in case. +TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII) + + +class Token(t.NamedTuple): + """A Discord Bot token.""" + + user_id: str + timestamp: str + hmac: str + + +class TokenRemover(Cog): + """Scans messages for potential discord.py bot tokens and removes them.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """ + Check each message for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return + + found_token = self.find_token_in_message(msg) + if found_token: + await self.take_action(msg, found_token) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + await self.on_message(after) + + async def take_action(self, msg: Message, found_token: Token) -> None: + """Remove the `msg` containing the `found_token` and send a mod log message.""" + self.mod_log.ignore(Event.message_delete, msg.id) + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") + return + + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + log_message = self.format_log_message(msg, found_token) + log.debug(log_message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=log_message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) + + self.bot.stats.incr("tokens.removed_tokens") + + @staticmethod + def format_log_message(msg: Message, token: Token) -> str: + """Return the log message to send for `token` being censored in `msg`.""" + return LOG_MESSAGE.format( + author=msg.author, + author_id=msg.author.id, + channel=msg.channel.mention, + user_id=token.user_id, + timestamp=token.timestamp, + hmac='x' * len(token.hmac), + ) + + @classmethod + def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: + """Return a seemingly valid token found in `msg` or `None` if no token is found.""" + # Use finditer rather than search to guard against method calls prematurely returning the + # token check (e.g. `message.channel.send` also matches our token pattern) + for match in TOKEN_RE.finditer(msg.content): + token = Token(*match.groups()) + if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp): + # Short-circuit on first match + return token + + # No matching substring + return + + @staticmethod + def is_valid_user_id(b64_content: str) -> bool: + """ + Check potential token to see if it contains a valid Discord user ID. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + b64_content = utils.pad_base64(b64_content) + + try: + decoded_bytes = base64.urlsafe_b64decode(b64_content) + string = decoded_bytes.decode('utf-8') + + # isdigit on its own would match a lot of other Unicode characters, hence the isascii. + return string.isascii() and string.isdigit() + except (binascii.Error, ValueError): + return False + + @staticmethod + def is_valid_timestamp(b64_content: str) -> bool: + """ + Return True if `b64_content` decodes to a valid timestamp. + + If the timestamp is greater than the Discord epoch, it's probably valid. + See: https://i.imgur.com/7WdehGn.png + """ + b64_content = utils.pad_base64(b64_content) + + try: + decoded_bytes = base64.urlsafe_b64decode(b64_content) + timestamp = int.from_bytes(decoded_bytes, byteorder="big") + except (binascii.Error, ValueError) as e: + log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") + return False + + # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound + # is not checked. + if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH: + return True + else: + log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch") + return False + + +def setup(bot: Bot) -> None: + """Load the TokenRemover cog.""" + bot.add_cog(TokenRemover(bot)) diff --git a/bot/cogs/filters/webhook_remover.py b/bot/cogs/filters/webhook_remover.py new file mode 100644 index 000000000..5812da87c --- /dev/null +++ b/bot/cogs/filters/webhook_remover.py @@ -0,0 +1,84 @@ +import logging +import re + +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.cogs.moderation.modlog import ModLog +from bot.constants import Channels, Colours, Event, Icons + +WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) + +ALERT_MESSAGE_TEMPLATE = ( + "{user}, looks like you posted a Discord webhook URL. Therefore, your " + "message has been removed. Your webhook may have been **compromised** so " + "please re-create the webhook **immediately**. If you believe this was " + "mistake, please let us know." +) + +log = logging.getLogger(__name__) + + +class WebhookRemover(Cog): + """Scan messages to detect Discord webhooks links.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get current instance of `ModLog`.""" + return self.bot.get_cog("ModLog") + + async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: + """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" + # Don't log this, due internal delete, not by user. Will make different entry. + self.mod_log.ignore(Event.message_delete, msg.id) + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") + return + + await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) + + message = ( + f"{msg.author} (`{msg.author.id}`) posted a Discord webhook URL " + f"to #{msg.channel}. Webhook URL was `{redacted_url}`" + ) + log.debug(message) + + # Send entry to moderation alerts. + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Discord webhook URL removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts + ) + + self.bot.stats.incr("tokens.removed_webhooks") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Check if a Discord webhook URL is in `message`.""" + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return + + matches = WEBHOOK_URL_RE.search(msg.content) + if matches: + await self.delete_and_respond(msg, matches[1] + "xxx") + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """Check if a Discord webhook URL is in the edited message `after`.""" + await self.on_message(after) + + +def setup(bot: Bot) -> None: + """Load `WebhookRemover` cog.""" + bot.add_cog(WebhookRemover(bot)) diff --git a/bot/cogs/help.py b/bot/cogs/help.py deleted file mode 100644 index 3d1d6fd10..000000000 --- a/bot/cogs/help.py +++ /dev/null @@ -1,375 +0,0 @@ -import itertools -import logging -from asyncio import TimeoutError -from collections import namedtuple -from contextlib import suppress -from typing import List, Union - -from discord import Colour, Embed, Member, Message, NotFound, Reaction, User -from discord.ext.commands import Bot, Cog, Command, Context, Group, HelpCommand -from fuzzywuzzy import fuzz, process -from fuzzywuzzy.utils import full_process - -from bot import constants -from bot.constants import Channels, Emojis, STAFF_ROLES -from bot.decorators import redirect_output -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -COMMANDS_PER_PAGE = 8 -DELETE_EMOJI = Emojis.trashcan -PREFIX = constants.Bot.prefix - -Category = namedtuple("Category", ["name", "description", "cogs"]) - - -async def help_cleanup(bot: Bot, author: Member, message: Message) -> None: - """ - Runs the cleanup for the help command. - - Adds the :trashcan: reaction that, when clicked, will delete the help message. - After a 300 second timeout, the reaction will be removed. - """ - def check(reaction: Reaction, user: User) -> bool: - """Checks the reaction is :trashcan:, the author is original author and messages are the same.""" - return str(reaction) == DELETE_EMOJI and user.id == author.id and reaction.message.id == message.id - - await message.add_reaction(DELETE_EMOJI) - - with suppress(NotFound): - try: - await bot.wait_for("reaction_add", check=check, timeout=300) - await message.delete() - except TimeoutError: - await message.remove_reaction(DELETE_EMOJI, bot.user) - - -class HelpQueryNotFound(ValueError): - """ - Raised when a HelpSession Query doesn't match a command or cog. - - Contains the custom attribute of ``possible_matches``. - - Instances of this object contain a dictionary of any command(s) that were close to matching the - query, where keys are the possible matched command names and values are the likeness match scores. - """ - - def __init__(self, arg: str, possible_matches: dict = None): - super().__init__(arg) - self.possible_matches = possible_matches - - -class CustomHelpCommand(HelpCommand): - """ - An interactive instance for the bot help command. - - Cogs can be grouped into custom categories. All cogs with the same category will be displayed - under a single category name in the help output. Custom categories are defined inside the cogs - as a class attribute named `category`. A description can also be specified with the attribute - `category_description`. If a description is not found in at least one cog, the default will be - the regular description (class docstring) of the first cog found in the category. - """ - - def __init__(self): - super().__init__(command_attrs={"help": "Shows help for bot commands"}) - - @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) - async def command_callback(self, ctx: Context, *, command: str = None) -> None: - """Attempts to match the provided query with a valid command or cog.""" - # the only reason we need to tamper with this is because d.py does not support "categories", - # so we need to deal with them ourselves. - - bot = ctx.bot - - if command is None: - # quick and easy, send bot help if command is none - mapping = self.get_bot_mapping() - await self.send_bot_help(mapping) - return - - cog_matches = [] - description = None - for cog in bot.cogs.values(): - if hasattr(cog, "category") and cog.category == command: - cog_matches.append(cog) - if hasattr(cog, "category_description"): - description = cog.category_description - - if cog_matches: - category = Category(name=command, description=description, cogs=cog_matches) - await self.send_category_help(category) - return - - # it's either a cog, group, command or subcommand; let the parent class deal with it - await super().command_callback(ctx, command=command) - - async def get_all_help_choices(self) -> set: - """ - Get all the possible options for getting help in the bot. - - This will only display commands the author has permission to run. - - These include: - - Category names - - Cog names - - Group command names (and aliases) - - Command names (and aliases) - - Subcommand names (with parent group and aliases for subcommand, but not including aliases for group) - - Options and choices are case sensitive. - """ - # first get all commands including subcommands and full command name aliases - choices = set() - for command in await self.filter_commands(self.context.bot.walk_commands()): - # the the command or group name - choices.add(str(command)) - - if isinstance(command, Command): - # all aliases if it's just a command - choices.update(command.aliases) - else: - # otherwise we need to add the parent name in - choices.update(f"{command.full_parent_name} {alias}" for alias in command.aliases) - - # all cog names - choices.update(self.context.bot.cogs) - - # all category names - choices.update(cog.category for cog in self.context.bot.cogs.values() if hasattr(cog, "category")) - return choices - - async def command_not_found(self, string: str) -> "HelpQueryNotFound": - """ - Handles when a query does not match a valid command, group, cog or category. - - Will return an instance of the `HelpQueryNotFound` exception with the error message and possible matches. - """ - choices = await self.get_all_help_choices() - - # Run fuzzywuzzy's processor beforehand, and avoid matching if processed string is empty - # This avoids fuzzywuzzy from raising a warning on inputs with only non-alphanumeric characters - if (processed := full_process(string)): - result = process.extractBests(processed, choices, scorer=fuzz.ratio, score_cutoff=60, processor=None) - else: - result = [] - - return HelpQueryNotFound(f'Query "{string}" not found.', dict(result)) - - async def subcommand_not_found(self, command: Command, string: str) -> "HelpQueryNotFound": - """ - Redirects the error to `command_not_found`. - - `command_not_found` deals with searching and getting best choices for both commands and subcommands. - """ - return await self.command_not_found(f"{command.qualified_name} {string}") - - async def send_error_message(self, error: HelpQueryNotFound) -> None: - """Send the error message to the channel.""" - embed = Embed(colour=Colour.red(), title=str(error)) - - if getattr(error, "possible_matches", None): - matches = "\n".join(f"`{match}`" for match in error.possible_matches) - embed.description = f"**Did you mean:**\n{matches}" - - await self.context.send(embed=embed) - - async def command_formatting(self, command: Command) -> Embed: - """ - Takes a command and turns it into an embed. - - It will add an author, command signature + help, aliases and a note if the user can't run the command. - """ - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - parent = command.full_parent_name - - name = str(command) if not parent else f"{parent} {command.name}" - command_details = f"**```{PREFIX}{name} {command.signature}```**\n" - - # show command aliases - aliases = ", ".join(f"`{alias}`" if not parent else f"`{parent} {alias}`" for alias in command.aliases) - if aliases: - command_details += f"**Can also use:** {aliases}\n\n" - - # check if the user is allowed to run this command - if not await command.can_run(self.context): - command_details += "***You cannot run this command.***\n\n" - - command_details += f"*{command.help or 'No details provided.'}*\n" - embed.description = command_details - - return embed - - async def send_command_help(self, command: Command) -> None: - """Send help for a single command.""" - embed = await self.command_formatting(command) - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - @staticmethod - def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: - """ - Formats the prefix, command name and signature, and short doc for an iterable of commands. - - return_as_list is helpful for passing these command details into the paginator as a list of command details. - """ - details = [] - for command in commands_: - signature = f" {command.signature}" if command.signature else "" - details.append( - f"\n**`{PREFIX}{command.qualified_name}{signature}`**\n*{command.short_doc or 'No details provided'}*" - ) - if return_as_list: - return details - else: - return "".join(details) - - async def send_group_help(self, group: Group) -> None: - """Sends help for a group command.""" - subcommands = group.commands - - if len(subcommands) == 0: - # no subcommands, just treat it like a regular command - await self.send_command_help(group) - return - - # remove commands that the user can't run and are hidden, and sort by name - commands_ = await self.filter_commands(subcommands, sort=True) - - embed = await self.command_formatting(group) - - command_details = self.get_commands_brief_details(commands_) - if command_details: - embed.description += f"\n**Subcommands:**\n{command_details}" - - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - async def send_cog_help(self, cog: Cog) -> None: - """Send help for a cog.""" - # sort commands by name, and remove any the user cant run or are hidden. - commands_ = await self.filter_commands(cog.get_commands(), sort=True) - - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - embed.description = f"**{cog.qualified_name}**\n*{cog.description}*" - - command_details = self.get_commands_brief_details(commands_) - if command_details: - embed.description += f"\n\n**Commands:**\n{command_details}" - - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - @staticmethod - def _category_key(command: Command) -> str: - """ - Returns a cog name of a given command for use as a key for `sorted` and `groupby`. - - A zero width space is used as a prefix for results with no cogs to force them last in ordering. - """ - if command.cog: - with suppress(AttributeError): - if command.cog.category: - return f"**{command.cog.category}**" - return f"**{command.cog_name}**" - else: - return "**\u200bNo Category:**" - - async def send_category_help(self, category: Category) -> None: - """ - Sends help for a bot category. - - This sends a brief help for all commands in all cogs registered to the category. - """ - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - all_commands = [] - for cog in category.cogs: - all_commands.extend(cog.get_commands()) - - filtered_commands = await self.filter_commands(all_commands, sort=True) - - command_detail_lines = self.get_commands_brief_details(filtered_commands, return_as_list=True) - description = f"**{category.name}**\n*{category.description}*" - - if command_detail_lines: - description += "\n\n**Commands:**" - - await LinePaginator.paginate( - command_detail_lines, - self.context, - embed, - prefix=description, - max_lines=COMMANDS_PER_PAGE, - max_size=2000, - ) - - async def send_bot_help(self, mapping: dict) -> None: - """Sends help for all bot commands and cogs.""" - bot = self.context.bot - - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - filter_commands = await self.filter_commands(bot.commands, sort=True, key=self._category_key) - - cog_or_category_pages = [] - - for cog_or_category, _commands in itertools.groupby(filter_commands, key=self._category_key): - sorted_commands = sorted(_commands, key=lambda c: c.name) - - if len(sorted_commands) == 0: - continue - - command_detail_lines = self.get_commands_brief_details(sorted_commands, return_as_list=True) - - # Split cogs or categories which have too many commands to fit in one page. - # The length of commands is included for later use when aggregating into pages for the paginator. - for index in range(0, len(sorted_commands), COMMANDS_PER_PAGE): - truncated_lines = command_detail_lines[index:index + COMMANDS_PER_PAGE] - joined_lines = "".join(truncated_lines) - cog_or_category_pages.append((f"**{cog_or_category}**{joined_lines}", len(truncated_lines))) - - pages = [] - counter = 0 - page = "" - for page_details, length in cog_or_category_pages: - counter += length - if counter > COMMANDS_PER_PAGE: - # force a new page on paginator even if it falls short of the max pages - # since we still want to group categories/cogs. - counter = length - pages.append(page) - page = f"{page_details}\n\n" - else: - page += f"{page_details}\n\n" - - if page: - # add any remaining command help that didn't get added in the last iteration above. - pages.append(page) - - await LinePaginator.paginate(pages, self.context, embed=embed, max_lines=1, max_size=2000) - - -class Help(Cog): - """Custom Embed Pagination Help feature.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.old_help_command = bot.help_command - bot.help_command = CustomHelpCommand() - bot.help_command.cog = self - - def cog_unload(self) -> None: - """Reset the help command when the cog is unloaded.""" - self.bot.help_command = self.old_help_command - - -def setup(bot: Bot) -> None: - """Load the Help cog.""" - bot.add_cog(Help(bot)) - log.info("Cog loaded: Help") diff --git a/bot/cogs/info/__init__.py b/bot/cogs/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/info/doc.py b/bot/cogs/info/doc.py new file mode 100644 index 000000000..204cffb37 --- /dev/null +++ b/bot/cogs/info/doc.py @@ -0,0 +1,511 @@ +import asyncio +import functools +import logging +import re +import textwrap +from collections import OrderedDict +from contextlib import suppress +from types import SimpleNamespace +from typing import Any, Callable, Optional, Tuple + +import discord +from bs4 import BeautifulSoup +from bs4.element import PageElement, Tag +from discord.errors import NotFound +from discord.ext import commands +from markdownify import MarkdownConverter +from requests import ConnectTimeout, ConnectionError, HTTPError +from sphinx.ext import intersphinx +from urllib3.exceptions import ProtocolError + +from bot.bot import Bot +from bot.constants import MODERATION_ROLES, RedirectOutput +from bot.converters import ValidPythonIdentifier, ValidURL +from bot.decorators import with_role +from bot.pagination import LinePaginator + + +log = logging.getLogger(__name__) +logging.getLogger('urllib3').setLevel(logging.WARNING) + +# Since Intersphinx is intended to be used with Sphinx, +# we need to mock its configuration. +SPHINX_MOCK_APP = SimpleNamespace( + config=SimpleNamespace( + intersphinx_timeout=3, + tls_verify=True, + user_agent="python3:python-discord/bot:1.0.0" + ) +) + +NO_OVERRIDE_GROUPS = ( + "2to3fixer", + "token", + "label", + "pdbcommand", + "term", +) +NO_OVERRIDE_PACKAGES = ( + "python", +) + +SEARCH_END_TAG_ATTRS = ( + "data", + "function", + "class", + "exception", + "seealso", + "section", + "rubric", + "sphinxsidebar", +) +UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") +WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") + +FAILED_REQUEST_RETRY_AMOUNT = 3 +NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay + + +def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: + """ + LRU cache implementation for coroutines. + + Once the cache exceeds the maximum size, keys are deleted in FIFO order. + + An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. + """ + # Assign the cache to the function itself so we can clear it from outside. + async_cache.cache = OrderedDict() + + def decorator(function: Callable) -> Callable: + """Define the async_cache decorator.""" + @functools.wraps(function) + async def wrapper(*args) -> Any: + """Decorator wrapper for the caching logic.""" + key = ':'.join(args[arg_offset:]) + + value = async_cache.cache.get(key) + if value is None: + if len(async_cache.cache) > max_size: + async_cache.cache.popitem(last=False) + + async_cache.cache[key] = await function(*args) + return async_cache.cache[key] + return wrapper + return decorator + + +class DocMarkdownConverter(MarkdownConverter): + """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" + + def convert_code(self, el: PageElement, text: str) -> str: + """Undo `markdownify`s underscore escaping.""" + return f"`{text}`".replace('\\', '') + + def convert_pre(self, el: PageElement, text: str) -> str: + """Wrap any codeblocks in `py` for syntax highlighting.""" + code = ''.join(el.strings) + return f"```py\n{code}```" + + +def markdownify(html: str) -> DocMarkdownConverter: + """Create a DocMarkdownConverter object from the input html.""" + return DocMarkdownConverter(bullets='•').convert(html) + + +class InventoryURL(commands.Converter): + """ + Represents an Intersphinx inventory URL. + + This converter checks whether intersphinx accepts the given inventory URL, and raises + `BadArgument` if that is not the case. + + Otherwise, it simply passes through the given URL. + """ + + @staticmethod + async def convert(ctx: commands.Context, url: str) -> str: + """Convert url to Intersphinx inventory URL.""" + try: + intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) + except AttributeError: + raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") + except ConnectionError: + if url.startswith('https'): + raise commands.BadArgument( + f"Cannot establish a connection to `{url}`. Does it support HTTPS?" + ) + raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") + except ValueError: + raise commands.BadArgument( + f"Failed to read Intersphinx inventory from URL `{url}`. " + "Are you sure that it's a valid inventory file?" + ) + return url + + +class Doc(commands.Cog): + """A set of commands for querying & displaying documentation.""" + + def __init__(self, bot: Bot): + self.base_urls = {} + self.bot = bot + self.inventories = {} + self.renamed_symbols = set() + + self.bot.loop.create_task(self.init_refresh_inventory()) + + async def init_refresh_inventory(self) -> None: + """Refresh documentation inventory on cog initialization.""" + await self.bot.wait_until_guild_available() + await self.refresh_inventory() + + async def update_single( + self, package_name: str, base_url: str, inventory_url: str + ) -> None: + """ + Rebuild the inventory for a single package. + + Where: + * `package_name` is the package name to use, appears in the log + * `base_url` is the root documentation URL for the specified package, used to build + absolute paths that link to specific symbols + * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running + `intersphinx.fetch_inventory` in an executor on the bot's event loop + """ + self.base_urls[package_name] = base_url + + package = await self._fetch_inventory(inventory_url) + if not package: + return None + + for group, value in package.items(): + for symbol, (package_name, _version, relative_doc_url, _) in value.items(): + absolute_doc_url = base_url + relative_doc_url + + if symbol in self.inventories: + group_name = group.split(":")[1] + symbol_base_url = self.inventories[symbol].split("/", 3)[2] + if ( + group_name in NO_OVERRIDE_GROUPS + or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) + ): + + symbol = f"{group_name}.{symbol}" + # If renamed `symbol` already exists, add library name in front to differentiate between them. + if symbol in self.renamed_symbols: + # Split `package_name` because of packages like Pillow that have spaces in them. + symbol = f"{package_name.split()[0]}.{symbol}" + + self.inventories[symbol] = absolute_doc_url + self.renamed_symbols.add(symbol) + continue + + self.inventories[symbol] = absolute_doc_url + + log.trace(f"Fetched inventory for {package_name}.") + + async def refresh_inventory(self) -> None: + """Refresh internal documentation inventory.""" + log.debug("Refreshing documentation inventory...") + + # Clear the old base URLS and inventories to ensure + # that we start from a fresh local dataset. + # Also, reset the cache used for fetching documentation. + self.base_urls.clear() + self.inventories.clear() + self.renamed_symbols.clear() + async_cache.cache = OrderedDict() + + # Run all coroutines concurrently - since each of them performs a HTTP + # request, this speeds up fetching the inventory data heavily. + coros = [ + self.update_single( + package["package"], package["base_url"], package["inventory_url"] + ) for package in await self.bot.api_client.get('bot/documentation-links') + ] + await asyncio.gather(*coros) + + async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: + """ + Given a Python symbol, return its signature and description. + + The first tuple element is the signature of the given symbol as a markup-free string, and + the second tuple element is the description of the given symbol with HTML markup included. + + If the given symbol is a module, returns a tuple `(None, str)` + else if the symbol could not be found, returns `None`. + """ + url = self.inventories.get(symbol) + if url is None: + return None + + async with self.bot.http_session.get(url) as response: + html = await response.text(encoding='utf-8') + + # Find the signature header and parse the relevant parts. + symbol_id = url.split('#')[-1] + soup = BeautifulSoup(html, 'lxml') + symbol_heading = soup.find(id=symbol_id) + search_html = str(soup) + + if symbol_heading is None: + return None + + if symbol_id == f"module-{symbol}": + # Get page content from the module headerlink to the + # first tag that has its class in `SEARCH_END_TAG_ATTRS` + start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) + if start_tag is None: + return [], "" + + end_tag = start_tag.find_next(self._match_end_tag) + if end_tag is None: + return [], "" + + description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) + description_end_index = search_html.find(str(end_tag)) + description = search_html[description_start_index:description_end_index] + signatures = None + + else: + signatures = [] + description = str(symbol_heading.find_next_sibling("dd")) + description_pos = search_html.find(description) + # Get text of up to 3 signatures, remove unwanted symbols + for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): + signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) + if signature and search_html.find(str(element)) < description_pos: + signatures.append(signature) + + return signatures, description.replace('¶', '') + + @async_cache(arg_offset=1) + async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: + """ + Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. + + If the symbol is known, an Embed with documentation about it is returned. + """ + scraped_html = await self.get_symbol_html(symbol) + if scraped_html is None: + return None + + signatures = scraped_html[0] + permalink = self.inventories[symbol] + description = markdownify(scraped_html[1]) + + # Truncate the description of the embed to the last occurrence + # of a double newline (interpreted as a paragraph) before index 1000. + if len(description) > 1000: + shortened = description[:1000] + description_cutoff = shortened.rfind('\n\n', 100) + if description_cutoff == -1: + # Search the shortened version for cutoff points in decreasing desirability, + # cutoff at 1000 if none are found. + for string in (". ", ", ", ",", " "): + description_cutoff = shortened.rfind(string) + if description_cutoff != -1: + break + else: + description_cutoff = 1000 + description = description[:description_cutoff] + + # If there is an incomplete code block, cut it out + if description.count("```") % 2: + codeblock_start = description.rfind('```py') + description = description[:codeblock_start].rstrip() + description += f"... [read more]({permalink})" + + description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) + if signatures is None: + # If symbol is a module, don't show signature. + embed_description = description + + elif not signatures: + # It's some "meta-page", for example: + # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views + embed_description = "This appears to be a generic page not tied to a specific symbol." + + else: + embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) + embed_description += f"\n{description}" + + embed = discord.Embed( + title=f'`{symbol}`', + url=permalink, + description=embed_description + ) + # Show all symbols with the same name that were renamed in the footer. + embed.set_footer( + text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) + ) + return embed + + @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) + async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: + """Lookup documentation for Python symbols.""" + await ctx.invoke(self.get_command, symbol) + + @docs_group.command(name='get', aliases=('g',)) + async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: + """ + Return a documentation embed for a given symbol. + + If no symbol is given, return a list of all available inventories. + + Examples: + !docs + !docs aiohttp + !docs aiohttp.ClientSession + !docs get aiohttp.ClientSession + """ + if symbol is None: + inventory_embed = discord.Embed( + title=f"All inventories (`{len(self.base_urls)}` total)", + colour=discord.Colour.blue() + ) + + lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) + if self.base_urls: + await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) + + else: + inventory_embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=inventory_embed) + + else: + # Fetching documentation for a symbol (at least for the first time, since + # caching is used) takes quite some time, so let's send typing to indicate + # that we got the command, but are still working on it. + async with ctx.typing(): + doc_embed = await self.get_symbol_embed(symbol) + + if doc_embed is None: + error_embed = discord.Embed( + description=f"Sorry, I could not find any documentation for `{symbol}`.", + colour=discord.Colour.red() + ) + error_message = await ctx.send(embed=error_embed) + with suppress(NotFound): + await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) + await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) + else: + await ctx.send(embed=doc_embed) + + @docs_group.command(name='set', aliases=('s',)) + @with_role(*MODERATION_ROLES) + async def set_command( + self, ctx: commands.Context, package_name: ValidPythonIdentifier, + base_url: ValidURL, inventory_url: InventoryURL + ) -> None: + """ + Adds a new documentation metadata object to the site's database. + + The database will update the object, should an existing item with the specified `package_name` already exist. + + Example: + !docs set \ + python \ + https://docs.python.org/3/ \ + https://docs.python.org/3/objects.inv + """ + body = { + 'package': package_name, + 'base_url': base_url, + 'inventory_url': inventory_url + } + await self.bot.api_client.post('bot/documentation-links', json=body) + + log.info( + f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" + f"Package name: {package_name}\n" + f"Base url: {base_url}\n" + f"Inventory URL: {inventory_url}" + ) + + # Rebuilding the inventory can take some time, so lets send out a + # typing event to show that the Bot is still working. + async with ctx.typing(): + await self.refresh_inventory() + await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") + + @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) + @with_role(*MODERATION_ROLES) + async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: + """ + Removes the specified package from the database. + + Examples: + !docs delete aiohttp + """ + await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') + + async with ctx.typing(): + # Rebuild the inventory to ensure that everything + # that was from this package is properly deleted. + await self.refresh_inventory() + await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") + + @docs_group.command(name="refresh", aliases=("rfsh", "r")) + @with_role(*MODERATION_ROLES) + async def refresh_command(self, ctx: commands.Context) -> None: + """Refresh inventories and send differences to channel.""" + old_inventories = set(self.base_urls) + with ctx.typing(): + await self.refresh_inventory() + # Get differences of added and removed inventories + added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) + if added: + added = f"+ {added}" + + removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) + if removed: + removed = f"- {removed}" + + embed = discord.Embed( + title="Inventories refreshed", + description=f"```diff\n{added}\n{removed}```" if added or removed else "" + ) + await ctx.send(embed=embed) + + async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: + """Get and return inventory from `inventory_url`. If fetching fails, return None.""" + fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) + for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): + try: + package = await self.bot.loop.run_in_executor(None, fetch_func) + except ConnectTimeout: + log.error( + f"Fetching of inventory {inventory_url} timed out," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except ProtocolError: + log.error( + f"Connection lost while fetching inventory {inventory_url}," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except HTTPError as e: + log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") + return None + except ConnectionError: + log.error(f"Couldn't establish connection to inventory {inventory_url}.") + return None + else: + return package + log.error(f"Fetching of inventory {inventory_url} failed.") + return None + + @staticmethod + def _match_end_tag(tag: Tag) -> bool: + """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" + for attr in SEARCH_END_TAG_ATTRS: + if attr in tag.get("class", ()): + return True + + return tag.name == "table" + + +def setup(bot: Bot) -> None: + """Load the Doc cog.""" + bot.add_cog(Doc(bot)) diff --git a/bot/cogs/info/help.py b/bot/cogs/info/help.py new file mode 100644 index 000000000..3d1d6fd10 --- /dev/null +++ b/bot/cogs/info/help.py @@ -0,0 +1,375 @@ +import itertools +import logging +from asyncio import TimeoutError +from collections import namedtuple +from contextlib import suppress +from typing import List, Union + +from discord import Colour, Embed, Member, Message, NotFound, Reaction, User +from discord.ext.commands import Bot, Cog, Command, Context, Group, HelpCommand +from fuzzywuzzy import fuzz, process +from fuzzywuzzy.utils import full_process + +from bot import constants +from bot.constants import Channels, Emojis, STAFF_ROLES +from bot.decorators import redirect_output +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + +COMMANDS_PER_PAGE = 8 +DELETE_EMOJI = Emojis.trashcan +PREFIX = constants.Bot.prefix + +Category = namedtuple("Category", ["name", "description", "cogs"]) + + +async def help_cleanup(bot: Bot, author: Member, message: Message) -> None: + """ + Runs the cleanup for the help command. + + Adds the :trashcan: reaction that, when clicked, will delete the help message. + After a 300 second timeout, the reaction will be removed. + """ + def check(reaction: Reaction, user: User) -> bool: + """Checks the reaction is :trashcan:, the author is original author and messages are the same.""" + return str(reaction) == DELETE_EMOJI and user.id == author.id and reaction.message.id == message.id + + await message.add_reaction(DELETE_EMOJI) + + with suppress(NotFound): + try: + await bot.wait_for("reaction_add", check=check, timeout=300) + await message.delete() + except TimeoutError: + await message.remove_reaction(DELETE_EMOJI, bot.user) + + +class HelpQueryNotFound(ValueError): + """ + Raised when a HelpSession Query doesn't match a command or cog. + + Contains the custom attribute of ``possible_matches``. + + Instances of this object contain a dictionary of any command(s) that were close to matching the + query, where keys are the possible matched command names and values are the likeness match scores. + """ + + def __init__(self, arg: str, possible_matches: dict = None): + super().__init__(arg) + self.possible_matches = possible_matches + + +class CustomHelpCommand(HelpCommand): + """ + An interactive instance for the bot help command. + + Cogs can be grouped into custom categories. All cogs with the same category will be displayed + under a single category name in the help output. Custom categories are defined inside the cogs + as a class attribute named `category`. A description can also be specified with the attribute + `category_description`. If a description is not found in at least one cog, the default will be + the regular description (class docstring) of the first cog found in the category. + """ + + def __init__(self): + super().__init__(command_attrs={"help": "Shows help for bot commands"}) + + @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) + async def command_callback(self, ctx: Context, *, command: str = None) -> None: + """Attempts to match the provided query with a valid command or cog.""" + # the only reason we need to tamper with this is because d.py does not support "categories", + # so we need to deal with them ourselves. + + bot = ctx.bot + + if command is None: + # quick and easy, send bot help if command is none + mapping = self.get_bot_mapping() + await self.send_bot_help(mapping) + return + + cog_matches = [] + description = None + for cog in bot.cogs.values(): + if hasattr(cog, "category") and cog.category == command: + cog_matches.append(cog) + if hasattr(cog, "category_description"): + description = cog.category_description + + if cog_matches: + category = Category(name=command, description=description, cogs=cog_matches) + await self.send_category_help(category) + return + + # it's either a cog, group, command or subcommand; let the parent class deal with it + await super().command_callback(ctx, command=command) + + async def get_all_help_choices(self) -> set: + """ + Get all the possible options for getting help in the bot. + + This will only display commands the author has permission to run. + + These include: + - Category names + - Cog names + - Group command names (and aliases) + - Command names (and aliases) + - Subcommand names (with parent group and aliases for subcommand, but not including aliases for group) + + Options and choices are case sensitive. + """ + # first get all commands including subcommands and full command name aliases + choices = set() + for command in await self.filter_commands(self.context.bot.walk_commands()): + # the the command or group name + choices.add(str(command)) + + if isinstance(command, Command): + # all aliases if it's just a command + choices.update(command.aliases) + else: + # otherwise we need to add the parent name in + choices.update(f"{command.full_parent_name} {alias}" for alias in command.aliases) + + # all cog names + choices.update(self.context.bot.cogs) + + # all category names + choices.update(cog.category for cog in self.context.bot.cogs.values() if hasattr(cog, "category")) + return choices + + async def command_not_found(self, string: str) -> "HelpQueryNotFound": + """ + Handles when a query does not match a valid command, group, cog or category. + + Will return an instance of the `HelpQueryNotFound` exception with the error message and possible matches. + """ + choices = await self.get_all_help_choices() + + # Run fuzzywuzzy's processor beforehand, and avoid matching if processed string is empty + # This avoids fuzzywuzzy from raising a warning on inputs with only non-alphanumeric characters + if (processed := full_process(string)): + result = process.extractBests(processed, choices, scorer=fuzz.ratio, score_cutoff=60, processor=None) + else: + result = [] + + return HelpQueryNotFound(f'Query "{string}" not found.', dict(result)) + + async def subcommand_not_found(self, command: Command, string: str) -> "HelpQueryNotFound": + """ + Redirects the error to `command_not_found`. + + `command_not_found` deals with searching and getting best choices for both commands and subcommands. + """ + return await self.command_not_found(f"{command.qualified_name} {string}") + + async def send_error_message(self, error: HelpQueryNotFound) -> None: + """Send the error message to the channel.""" + embed = Embed(colour=Colour.red(), title=str(error)) + + if getattr(error, "possible_matches", None): + matches = "\n".join(f"`{match}`" for match in error.possible_matches) + embed.description = f"**Did you mean:**\n{matches}" + + await self.context.send(embed=embed) + + async def command_formatting(self, command: Command) -> Embed: + """ + Takes a command and turns it into an embed. + + It will add an author, command signature + help, aliases and a note if the user can't run the command. + """ + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + parent = command.full_parent_name + + name = str(command) if not parent else f"{parent} {command.name}" + command_details = f"**```{PREFIX}{name} {command.signature}```**\n" + + # show command aliases + aliases = ", ".join(f"`{alias}`" if not parent else f"`{parent} {alias}`" for alias in command.aliases) + if aliases: + command_details += f"**Can also use:** {aliases}\n\n" + + # check if the user is allowed to run this command + if not await command.can_run(self.context): + command_details += "***You cannot run this command.***\n\n" + + command_details += f"*{command.help or 'No details provided.'}*\n" + embed.description = command_details + + return embed + + async def send_command_help(self, command: Command) -> None: + """Send help for a single command.""" + embed = await self.command_formatting(command) + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + @staticmethod + def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: + """ + Formats the prefix, command name and signature, and short doc for an iterable of commands. + + return_as_list is helpful for passing these command details into the paginator as a list of command details. + """ + details = [] + for command in commands_: + signature = f" {command.signature}" if command.signature else "" + details.append( + f"\n**`{PREFIX}{command.qualified_name}{signature}`**\n*{command.short_doc or 'No details provided'}*" + ) + if return_as_list: + return details + else: + return "".join(details) + + async def send_group_help(self, group: Group) -> None: + """Sends help for a group command.""" + subcommands = group.commands + + if len(subcommands) == 0: + # no subcommands, just treat it like a regular command + await self.send_command_help(group) + return + + # remove commands that the user can't run and are hidden, and sort by name + commands_ = await self.filter_commands(subcommands, sort=True) + + embed = await self.command_formatting(group) + + command_details = self.get_commands_brief_details(commands_) + if command_details: + embed.description += f"\n**Subcommands:**\n{command_details}" + + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + async def send_cog_help(self, cog: Cog) -> None: + """Send help for a cog.""" + # sort commands by name, and remove any the user cant run or are hidden. + commands_ = await self.filter_commands(cog.get_commands(), sort=True) + + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + embed.description = f"**{cog.qualified_name}**\n*{cog.description}*" + + command_details = self.get_commands_brief_details(commands_) + if command_details: + embed.description += f"\n\n**Commands:**\n{command_details}" + + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + @staticmethod + def _category_key(command: Command) -> str: + """ + Returns a cog name of a given command for use as a key for `sorted` and `groupby`. + + A zero width space is used as a prefix for results with no cogs to force them last in ordering. + """ + if command.cog: + with suppress(AttributeError): + if command.cog.category: + return f"**{command.cog.category}**" + return f"**{command.cog_name}**" + else: + return "**\u200bNo Category:**" + + async def send_category_help(self, category: Category) -> None: + """ + Sends help for a bot category. + + This sends a brief help for all commands in all cogs registered to the category. + """ + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + all_commands = [] + for cog in category.cogs: + all_commands.extend(cog.get_commands()) + + filtered_commands = await self.filter_commands(all_commands, sort=True) + + command_detail_lines = self.get_commands_brief_details(filtered_commands, return_as_list=True) + description = f"**{category.name}**\n*{category.description}*" + + if command_detail_lines: + description += "\n\n**Commands:**" + + await LinePaginator.paginate( + command_detail_lines, + self.context, + embed, + prefix=description, + max_lines=COMMANDS_PER_PAGE, + max_size=2000, + ) + + async def send_bot_help(self, mapping: dict) -> None: + """Sends help for all bot commands and cogs.""" + bot = self.context.bot + + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + filter_commands = await self.filter_commands(bot.commands, sort=True, key=self._category_key) + + cog_or_category_pages = [] + + for cog_or_category, _commands in itertools.groupby(filter_commands, key=self._category_key): + sorted_commands = sorted(_commands, key=lambda c: c.name) + + if len(sorted_commands) == 0: + continue + + command_detail_lines = self.get_commands_brief_details(sorted_commands, return_as_list=True) + + # Split cogs or categories which have too many commands to fit in one page. + # The length of commands is included for later use when aggregating into pages for the paginator. + for index in range(0, len(sorted_commands), COMMANDS_PER_PAGE): + truncated_lines = command_detail_lines[index:index + COMMANDS_PER_PAGE] + joined_lines = "".join(truncated_lines) + cog_or_category_pages.append((f"**{cog_or_category}**{joined_lines}", len(truncated_lines))) + + pages = [] + counter = 0 + page = "" + for page_details, length in cog_or_category_pages: + counter += length + if counter > COMMANDS_PER_PAGE: + # force a new page on paginator even if it falls short of the max pages + # since we still want to group categories/cogs. + counter = length + pages.append(page) + page = f"{page_details}\n\n" + else: + page += f"{page_details}\n\n" + + if page: + # add any remaining command help that didn't get added in the last iteration above. + pages.append(page) + + await LinePaginator.paginate(pages, self.context, embed=embed, max_lines=1, max_size=2000) + + +class Help(Cog): + """Custom Embed Pagination Help feature.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.old_help_command = bot.help_command + bot.help_command = CustomHelpCommand() + bot.help_command.cog = self + + def cog_unload(self) -> None: + """Reset the help command when the cog is unloaded.""" + self.bot.help_command = self.old_help_command + + +def setup(bot: Bot) -> None: + """Load the Help cog.""" + bot.add_cog(Help(bot)) + log.info("Cog loaded: Help") diff --git a/bot/cogs/info/information.py b/bot/cogs/info/information.py new file mode 100644 index 000000000..8982196d1 --- /dev/null +++ b/bot/cogs/info/information.py @@ -0,0 +1,422 @@ +import colorsys +import logging +import pprint +import textwrap +from collections import Counter, defaultdict +from string import Template +from typing import Any, Mapping, Optional, Union + +from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils +from discord.abc import GuildChannel +from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group +from discord.utils import escape_markdown + +from bot import constants +from bot.bot import Bot +from bot.decorators import in_whitelist, with_role +from bot.pagination import LinePaginator +from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + + +class Information(Cog): + """A cog with commands for generating embeds with server info, such as server stats and user info.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @staticmethod + def role_can_read(channel: GuildChannel, role: Role) -> bool: + """Return True if `role` can read messages in `channel`.""" + overwrites = channel.overwrites_for(role) + return overwrites.read_messages is True + + def get_staff_channel_count(self, guild: Guild) -> int: + """ + Get the number of channels that are staff-only. + + We need to know two things about a channel: + - Does the @everyone role have explicit read deny permissions? + - Do staff roles have explicit read allow permissions? + + If the answer to both of these questions is yes, it's a staff channel. + """ + channel_ids = set() + for channel in guild.channels: + if channel.type is ChannelType.category: + continue + + everyone_can_read = self.role_can_read(channel, guild.default_role) + + for role in constants.STAFF_ROLES: + role_can_read = self.role_can_read(channel, guild.get_role(role)) + if role_can_read and not everyone_can_read: + channel_ids.add(channel.id) + break + + return len(channel_ids) + + @staticmethod + def get_channel_type_counts(guild: Guild) -> str: + """Return the total amounts of the various types of channels in `guild`.""" + channel_counter = Counter(c.type for c in guild.channels) + channel_type_list = [] + for channel, count in channel_counter.items(): + channel_type = str(channel).title() + channel_type_list.append(f"{channel_type} channels: {count}") + + channel_type_list = sorted(channel_type_list) + return "\n".join(channel_type_list) + + @with_role(*constants.MODERATION_ROLES) + @command(name="roles") + async def roles_info(self, ctx: Context) -> None: + """Returns a list of all roles and their corresponding IDs.""" + # Sort the roles alphabetically and remove the @everyone role + roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) + + # Build a list + role_list = [] + for role in roles: + role_list.append(f"`{role.id}` - {role.mention}") + + # Build an embed + embed = Embed( + title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", + colour=Colour.blurple() + ) + + await LinePaginator.paginate(role_list, ctx, embed, empty=False) + + @with_role(*constants.MODERATION_ROLES) + @command(name="role") + async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: + """ + Return information on a role or list of roles. + + To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. + """ + parsed_roles = [] + failed_roles = [] + + for role_name in roles: + if isinstance(role_name, Role): + # Role conversion has already succeeded + parsed_roles.append(role_name) + continue + + role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) + + if not role: + failed_roles.append(role_name) + continue + + parsed_roles.append(role) + + if failed_roles: + await ctx.send(f":x: Could not retrieve the following roles: {', '.join(failed_roles)}") + + for role in parsed_roles: + h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) + + embed = Embed( + title=f"{role.name} info", + colour=role.colour, + ) + embed.add_field(name="ID", value=role.id, inline=True) + embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) + embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) + embed.add_field(name="Member count", value=len(role.members), inline=True) + embed.add_field(name="Position", value=role.position) + embed.add_field(name="Permission code", value=role.permissions.value, inline=True) + + await ctx.send(embed=embed) + + @command(name="server", aliases=["server_info", "guild", "guild_info"]) + async def server_info(self, ctx: Context) -> None: + """Returns an embed full of server information.""" + created = time_since(ctx.guild.created_at, precision="days") + features = ", ".join(ctx.guild.features) + region = ctx.guild.region + + roles = len(ctx.guild.roles) + member_count = ctx.guild.member_count + channel_counts = self.get_channel_type_counts(ctx.guild) + + # How many of each user status? + statuses = Counter(member.status for member in ctx.guild.members) + embed = Embed(colour=Colour.blurple()) + + # How many staff members and staff channels do we have? + staff_member_count = len(ctx.guild.get_role(constants.Roles.helpers).members) + staff_channel_count = self.get_staff_channel_count(ctx.guild) + + # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the + # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting + # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts + # after the dedent is made. + embed.description = Template( + textwrap.dedent(f""" + **Server information** + Created: {created} + Voice region: {region} + Features: {features} + + **Channel counts** + $channel_counts + Staff channels: {staff_channel_count} + + **Member counts** + Members: {member_count:,} + Staff members: {staff_member_count} + Roles: {roles} + + **Member statuses** + {constants.Emojis.status_online} {statuses[Status.online]:,} + {constants.Emojis.status_idle} {statuses[Status.idle]:,} + {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} + {constants.Emojis.status_offline} {statuses[Status.offline]:,} + """) + ).substitute({"channel_counts": channel_counts}) + embed.set_thumbnail(url=ctx.guild.icon_url) + + await ctx.send(embed=embed) + + @command(name="user", aliases=["user_info", "member", "member_info"]) + async def user_info(self, ctx: Context, user: Member = None) -> None: + """Returns info about a user.""" + if user is None: + user = ctx.author + + # Do a role check if this is being executed on someone other than the caller + elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): + await ctx.send("You may not use this command on users other than yourself.") + return + + # Non-staff may only do this in #bot-commands + if not with_role_check(ctx, *constants.STAFF_ROLES): + if not ctx.channel.id == constants.Channels.bot_commands: + raise InWhitelistCheckFailure(constants.Channels.bot_commands) + + embed = await self.create_user_embed(ctx, user) + + await ctx.send(embed=embed) + + async def create_user_embed(self, ctx: Context, user: Member) -> Embed: + """Creates an embed containing information on the `user`.""" + created = time_since(user.created_at, max_units=3) + + # Custom status + custom_status = '' + for activity in user.activities: + # Check activity.state for None value if user has a custom status set + # This guards against a custom status with an emoji but no text, which will cause + # escape_markdown to raise an exception + # This can be reworked after a move to d.py 1.3.0+, which adds a CustomActivity class + if activity.name == 'Custom Status' and activity.state: + state = escape_markdown(activity.state) + custom_status = f'Status: {state}\n' + + name = str(user) + if user.nick: + name = f"{user.nick} ({name})" + + joined = time_since(user.joined_at, max_units=3) + roles = ", ".join(role.mention for role in user.roles[1:]) + + description = [ + textwrap.dedent(f""" + **User Information** + Created: {created} + Profile: {user.mention} + ID: {user.id} + {custom_status} + **Member Information** + Joined: {joined} + Roles: {roles or None} + """).strip() + ] + + # Show more verbose output in moderation channels for infractions and nominations + if ctx.channel.id in constants.MODERATION_CHANNELS: + description.append(await self.expanded_user_infraction_counts(user)) + description.append(await self.user_nomination_counts(user)) + else: + description.append(await self.basic_user_infraction_counts(user)) + + # Let's build the embed now + embed = Embed( + title=name, + description="\n\n".join(description) + ) + + embed.set_thumbnail(url=user.avatar_url_as(static_format="png")) + embed.colour = user.top_role.colour if roles else Colour.blurple() + + return embed + + async def basic_user_infraction_counts(self, member: Member) -> str: + """Gets the total and active infraction counts for the given `member`.""" + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'hidden': 'False', + 'user__id': str(member.id) + } + ) + + total_infractions = len(infractions) + active_infractions = sum(infraction['active'] for infraction in infractions) + + infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}" + + return infraction_output + + async def expanded_user_infraction_counts(self, member: Member) -> str: + """ + Gets expanded infraction counts for the given `member`. + + The counts will be split by infraction type and the number of active infractions for each type will indicated + in the output as well. + """ + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'user__id': str(member.id) + } + ) + + infraction_output = ["**Infractions**"] + if not infractions: + infraction_output.append("This user has never received an infraction.") + else: + # Count infractions split by `type` and `active` status for this user + infraction_types = set() + infraction_counter = defaultdict(int) + for infraction in infractions: + infraction_type = infraction["type"] + infraction_active = 'active' if infraction["active"] else 'inactive' + + infraction_types.add(infraction_type) + infraction_counter[f"{infraction_active} {infraction_type}"] += 1 + + # Format the output of the infraction counts + for infraction_type in sorted(infraction_types): + active_count = infraction_counter[f"active {infraction_type}"] + total_count = active_count + infraction_counter[f"inactive {infraction_type}"] + + line = f"{infraction_type.capitalize()}s: {total_count}" + if active_count: + line += f" ({active_count} active)" + + infraction_output.append(line) + + return "\n".join(infraction_output) + + async def user_nomination_counts(self, member: Member) -> str: + """Gets the active and historical nomination counts for the given `member`.""" + nominations = await self.bot.api_client.get( + 'bot/nominations', + params={ + 'user__id': str(member.id) + } + ) + + output = ["**Nominations**"] + + if not nominations: + output.append("This user has never been nominated.") + else: + count = len(nominations) + is_currently_nominated = any(nomination["active"] for nomination in nominations) + nomination_noun = "nomination" if count == 1 else "nominations" + + if is_currently_nominated: + output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).") + else: + output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.") + + return "\n".join(output) + + def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str: + """Format a mapping to be readable to a human.""" + # sorting is technically superfluous but nice if you want to look for a specific field + fields = sorted(mapping.items(), key=lambda item: item[0]) + + if field_width is None: + field_width = len(max(mapping.keys(), key=len)) + + out = '' + + for key, val in fields: + if isinstance(val, dict): + # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries + inner_width = int(field_width * 1.6) + val = '\n' + self.format_fields(val, field_width=inner_width) + + elif isinstance(val, str): + # split up text since it might be long + text = textwrap.fill(val, width=100, replace_whitespace=False) + + # indent it, I guess you could do this with `wrap` and `join` but this is nicer + val = textwrap.indent(text, ' ' * (field_width + len(': '))) + + # the first line is already indented so we `str.lstrip` it + val = val.lstrip() + + if key == 'color': + # makes the base 10 representation of a hex number readable to humans + val = hex(val) + + out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width) + + # remove trailing whitespace + return out.rstrip() + + @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) + @group(invoke_without_command=True) + @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) + async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: + """Shows information about the raw API response.""" + # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling + # doing this extra request is also much easier than trying to convert everything back into a dictionary again + raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) + + paginator = Paginator() + + def add_content(title: str, content: str) -> None: + paginator.add_line(f'== {title} ==\n') + # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution. + # we hope it's not close to 2000 + paginator.add_line(content.replace('```', '`` `')) + paginator.close_page() + + if message.content: + add_content('Raw message', message.content) + + transformer = pprint.pformat if json else self.format_fields + for field_name in ('embeds', 'attachments'): + data = raw_data[field_name] + + if not data: + continue + + total = len(data) + for current, item in enumerate(data, start=1): + title = f'Raw {field_name} ({current}/{total})' + add_content(title, transformer(item)) + + for page in paginator.pages: + await ctx.send(page) + + @raw.command() + async def json(self, ctx: Context, message: Message) -> None: + """Shows information about the raw API response in a copy-pasteable Python format.""" + await ctx.invoke(self.raw, message=message, json=True) + + +def setup(bot: Bot) -> None: + """Load the Information cog.""" + bot.add_cog(Information(bot)) diff --git a/bot/cogs/info/python_news.py b/bot/cogs/info/python_news.py new file mode 100644 index 000000000..0ab5738a4 --- /dev/null +++ b/bot/cogs/info/python_news.py @@ -0,0 +1,232 @@ +import logging +import typing as t +from datetime import date, datetime + +import discord +import feedparser +from bs4 import BeautifulSoup +from discord.ext.commands import Cog +from discord.ext.tasks import loop + +from bot import constants +from bot.bot import Bot +from bot.utils.webhooks import send_webhook + +PEPS_RSS_URL = "https://www.python.org/dev/peps/peps.rss/" + +RECENT_THREADS_TEMPLATE = "https://mail.python.org/archives/list/{name}@python.org/recent-threads" +THREAD_TEMPLATE_URL = "https://mail.python.org/archives/api/list/{name}@python.org/thread/{id}/" +MAILMAN_PROFILE_URL = "https://mail.python.org/archives/users/{id}/" +THREAD_URL = "https://mail.python.org/archives/list/{list}@python.org/thread/{id}/" + +AVATAR_URL = "https://www.python.org/static/opengraph-icon-200x200.png" + +log = logging.getLogger(__name__) + + +class PythonNews(Cog): + """Post new PEPs and Python News to `#python-news`.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_names = {} + self.webhook: t.Optional[discord.Webhook] = None + + self.bot.loop.create_task(self.get_webhook_names()) + self.bot.loop.create_task(self.get_webhook_and_channel()) + + async def start_tasks(self) -> None: + """Start the tasks for fetching new PEPs and mailing list messages.""" + self.fetch_new_media.start() + + @loop(minutes=20) + async def fetch_new_media(self) -> None: + """Fetch new mailing list messages and then new PEPs.""" + await self.post_maillist_news() + await self.post_pep_news() + + async def sync_maillists(self) -> None: + """Sync currently in-use maillists with API.""" + # Wait until guild is available to avoid running before everything is ready + await self.bot.wait_until_guild_available() + + response = await self.bot.api_client.get("bot/bot-settings/news") + for mail in constants.PythonNews.mail_lists: + if mail not in response["data"]: + response["data"][mail] = [] + + # Because we are handling PEPs differently, we don't include it to mail lists + if "pep" not in response["data"]: + response["data"]["pep"] = [] + + await self.bot.api_client.put("bot/bot-settings/news", json=response) + + async def get_webhook_names(self) -> None: + """Get webhook author names from maillist API.""" + await self.bot.wait_until_guild_available() + + async with self.bot.http_session.get("https://mail.python.org/archives/api/lists") as resp: + lists = await resp.json() + + for mail in lists: + if mail["name"].split("@")[0] in constants.PythonNews.mail_lists: + self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] + + async def post_pep_news(self) -> None: + """Fetch new PEPs and when they don't have announcement in #python-news, create it.""" + # Wait until everything is ready and http_session available + await self.bot.wait_until_guild_available() + await self.sync_maillists() + + async with self.bot.http_session.get(PEPS_RSS_URL) as resp: + data = feedparser.parse(await resp.text("utf-8")) + + news_listing = await self.bot.api_client.get("bot/bot-settings/news") + payload = news_listing.copy() + pep_numbers = news_listing["data"]["pep"] + + # Reverse entries to send oldest first + data["entries"].reverse() + for new in data["entries"]: + try: + new_datetime = datetime.strptime(new["published"], "%a, %d %b %Y %X %Z") + except ValueError: + log.warning(f"Wrong datetime format passed in PEP new: {new['published']}") + continue + pep_nr = new["title"].split(":")[0].split()[1] + if ( + pep_nr in pep_numbers + or new_datetime.date() < date.today() + ): + continue + + # Build an embed and send a webhook + embed = discord.Embed( + title=new["title"], + description=new["summary"], + timestamp=new_datetime, + url=new["link"], + colour=constants.Colours.soft_green + ) + embed.set_footer(text=data["feed"]["title"], icon_url=AVATAR_URL) + msg = await send_webhook( + webhook=self.webhook, + username=data["feed"]["title"], + embed=embed, + avatar_url=AVATAR_URL, + wait=True, + ) + payload["data"]["pep"].append(pep_nr) + + # Increase overall PEP new stat + self.bot.stats.incr("python_news.posted.pep") + + if msg.channel.is_news(): + log.trace("Publishing PEP annnouncement because it was in a news channel") + await msg.publish() + + # Apply new sent news to DB to avoid duplicate sending + await self.bot.api_client.put("bot/bot-settings/news", json=payload) + + async def post_maillist_news(self) -> None: + """Send new maillist threads to #python-news that is listed in configuration.""" + await self.bot.wait_until_guild_available() + await self.sync_maillists() + existing_news = await self.bot.api_client.get("bot/bot-settings/news") + payload = existing_news.copy() + + for maillist in constants.PythonNews.mail_lists: + async with self.bot.http_session.get(RECENT_THREADS_TEMPLATE.format(name=maillist)) as resp: + recents = BeautifulSoup(await resp.text(), features="lxml") + + # When a

element is present in the response then the mailing list + # has not had any activity during the current month, so therefore it + # can be ignored. + if recents.p: + continue + + for thread in recents.html.body.div.find_all("a", href=True): + # We want only these threads that have identifiers + if "latest" in thread["href"]: + continue + + thread_information, email_information = await self.get_thread_and_first_mail( + maillist, thread["href"].split("/")[-2] + ) + + try: + new_date = datetime.strptime(email_information["date"], "%Y-%m-%dT%X%z") + except ValueError: + log.warning(f"Invalid datetime from Thread email: {email_information['date']}") + continue + + if ( + thread_information["thread_id"] in existing_news["data"][maillist] + or 'Re: ' in thread_information["subject"] + or new_date.date() < date.today() + ): + continue + + content = email_information["content"] + link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) + + # Build an embed and send a message to the webhook + embed = discord.Embed( + title=thread_information["subject"], + description=content[:500] + f"... [continue reading]({link})" if len(content) > 500 else content, + timestamp=new_date, + url=link, + colour=constants.Colours.soft_green + ) + embed.set_author( + name=f"{email_information['sender_name']} ({email_information['sender']['address']})", + url=MAILMAN_PROFILE_URL.format(id=email_information["sender"]["mailman_id"]), + ) + embed.set_footer( + text=f"Posted to {self.webhook_names[maillist]}", + icon_url=AVATAR_URL, + ) + msg = await send_webhook( + webhook=self.webhook, + username=self.webhook_names[maillist], + embed=embed, + avatar_url=AVATAR_URL, + wait=True, + ) + payload["data"][maillist].append(thread_information["thread_id"]) + + # Increase this specific maillist counter in stats + self.bot.stats.incr(f"python_news.posted.{maillist.replace('-', '_')}") + + if msg.channel.is_news(): + log.trace("Publishing mailing list message because it was in a news channel") + await msg.publish() + + await self.bot.api_client.put("bot/bot-settings/news", json=payload) + + async def get_thread_and_first_mail(self, maillist: str, thread_identifier: str) -> t.Tuple[t.Any, t.Any]: + """Get mail thread and first mail from mail.python.org based on `maillist` and `thread_identifier`.""" + async with self.bot.http_session.get( + THREAD_TEMPLATE_URL.format(name=maillist, id=thread_identifier) + ) as resp: + thread_information = await resp.json() + + async with self.bot.http_session.get(thread_information["starting_email"]) as resp: + email_information = await resp.json() + return thread_information, email_information + + async def get_webhook_and_channel(self) -> None: + """Storage #python-news channel Webhook and `TextChannel` to `News.webhook` and `channel`.""" + await self.bot.wait_until_guild_available() + self.webhook = await self.bot.fetch_webhook(constants.PythonNews.webhook) + + await self.start_tasks() + + def cog_unload(self) -> None: + """Stop news posting tasks on cog unload.""" + self.fetch_new_media.cancel() + + +def setup(bot: Bot) -> None: + """Add `News` cog.""" + bot.add_cog(PythonNews(bot)) diff --git a/bot/cogs/info/reddit.py b/bot/cogs/info/reddit.py new file mode 100644 index 000000000..d853ab2ea --- /dev/null +++ b/bot/cogs/info/reddit.py @@ -0,0 +1,304 @@ +import asyncio +import logging +import random +import textwrap +from collections import namedtuple +from datetime import datetime, timedelta +from typing import List + +from aiohttp import BasicAuth, ClientError +from discord import Colour, Embed, TextChannel +from discord.ext.commands import Cog, Context, group +from discord.ext.tasks import loop + +from bot.bot import Bot +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks +from bot.converters import Subreddit +from bot.decorators import with_role +from bot.pagination import LinePaginator +from bot.utils.messages import sub_clyde + +log = logging.getLogger(__name__) + +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + + +class Reddit(Cog): + """Track subreddit posts and show detailed statistics about them.""" + + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} + URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 + + def __init__(self, bot: Bot): + self.bot = bot + + self.webhook = None + self.access_token = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + bot.loop.create_task(self.init_reddit_ready()) + self.auto_poster_loop.start() + + def cog_unload(self) -> None: + """Stop the loop task and revoke the access token when the cog is unloaded.""" + self.auto_poster_loop.cancel() + if self.access_token and self.access_token.expires_at > datetime.utcnow(): + asyncio.create_task(self.revoke_access_token()) + + async def init_reddit_ready(self) -> None: + """Sets the reddit webhook when the cog is loaded.""" + await self.bot.wait_until_guild_available() + if not self.webhook: + self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) + + @property + def channel(self) -> TextChannel: + """Get the #reddit channel object from the bot's cache.""" + return self.bot.get_channel(Channels.reddit) + + async def get_access_token(self) -> None: + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded and a ClientError raised if retrieval was still unsuccessful. + """ + for i in range(1, self.MAX_RETRIES + 1): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(seconds=expiration) + ) + + log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") + return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) + + await asyncio.sleep(3) + + self.bot.remove_cog(self.qualified_name) + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + + async def revoke_access_token(self) -> None: + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } + ) + + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + else: + log.warning(f"Unable to revoke access token: status {response.status}.") + + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: + """A helper method to fetch a certain amount of Reddit posts at a given route.""" + # Reddit's JSON responses only provide 25 posts at most. + if not 25 >= amount > 0: + raise ValueError("Invalid amount of subreddit posts requested.") + + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): + response = await self.bot.http_session.get( + url=url, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, + params=params + ) + if response.status == 200 and response.content_type == 'application/json': + # Got appropriate response - process and return. + content = await response.json() + posts = content["data"]["children"] + return posts[:amount] + + await asyncio.sleep(3) + + log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") + return list() # Failed to get appropriate response within allowed number of retries. + + async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: + """ + Get the top amount of posts for a given subreddit within a specified timeframe. + + A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top + weekly posts. + + The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. + """ + embed = Embed(description="") + + posts = await self.fetch_posts( + route=f"{subreddit}/top", + amount=amount, + params={"t": time} + ) + + if not posts: + embed.title = random.choice(ERROR_REPLIES) + embed.colour = Colour.red() + embed.description = ( + "Sorry! We couldn't find any posts from that subreddit. " + "If this problem persists, please let us know." + ) + + return embed + + for post in posts: + data = post["data"] + + text = data["selftext"] + if text: + text = textwrap.shorten(text, width=128, placeholder="...") + text += "\n" # Add newline to separate embed info + + ups = data["ups"] + comments = data["num_comments"] + author = data["author"] + + title = textwrap.shorten(data["title"], width=64, placeholder="...") + link = self.URL + data["permalink"] + + embed.description += ( + f"**[{title}]({link})**\n" + f"{text}" + f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" + ) + + embed.colour = Colour.blurple() + return embed + + @loop() + async def auto_poster_loop(self) -> None: + """Post the top 5 posts daily, and the top 5 posts weekly.""" + # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter + now = datetime.utcnow() + tomorrow = now + timedelta(days=1) + midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) + seconds_until = (midnight_tomorrow - now).total_seconds() + + await asyncio.sleep(seconds_until) + + await self.bot.wait_until_guild_available() + if not self.webhook: + await self.bot.fetch_webhook(Webhooks.reddit) + + if datetime.utcnow().weekday() == 0: + await self.top_weekly_posts() + # if it's a monday send the top weekly posts + + for subreddit in RedditConfig.subreddits: + top_posts = await self.get_top_posts(subreddit=subreddit, time="day") + username = sub_clyde(f"{subreddit} Top Daily Posts") + message = await self.webhook.send(username=username, embed=top_posts, wait=True) + + if message.channel.is_news(): + await message.publish() + + async def top_weekly_posts(self) -> None: + """Post a summary of the top posts.""" + for subreddit in RedditConfig.subreddits: + # Send and pin the new weekly posts. + top_posts = await self.get_top_posts(subreddit=subreddit, time="week") + username = sub_clyde(f"{subreddit} Top Weekly Posts") + message = await self.webhook.send(wait=True, username=username, embed=top_posts) + + if subreddit.lower() == "r/python": + if not self.channel: + log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") + return + + # Remove the oldest pins so that only 12 remain at most. + pins = await self.channel.pins() + + while len(pins) >= 12: + await pins[-1].unpin() + del pins[-1] + + await message.pin() + + if message.channel.is_news(): + await message.publish() + + @group(name="reddit", invoke_without_command=True) + async def reddit_group(self, ctx: Context) -> None: + """View the top posts from various subreddits.""" + await ctx.send_help(ctx.command) + + @reddit_group.command(name="top") + async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of all time from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="all") + + await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) + + @reddit_group.command(name="daily") + async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of today from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="day") + + await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) + + @reddit_group.command(name="weekly") + async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of this week from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="week") + + await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) + + @with_role(*STAFF_ROLES) + @reddit_group.command(name="subreddits", aliases=("subs",)) + async def subreddits_command(self, ctx: Context) -> None: + """Send a paginated embed of all the subreddits we're relaying.""" + embed = Embed() + embed.title = "Relayed subreddits." + embed.colour = Colour.blurple() + + await LinePaginator.paginate( + RedditConfig.subreddits, + ctx, embed, + footer_text="Use the reddit commands along with these to view their posts.", + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Reddit cog.""" + if not RedditConfig.secret or not RedditConfig.client_id: + log.error("Credentials not provided, cog not loaded.") + return + bot.add_cog(Reddit(bot)) diff --git a/bot/cogs/info/site.py b/bot/cogs/info/site.py new file mode 100644 index 000000000..ac29daa1d --- /dev/null +++ b/bot/cogs/info/site.py @@ -0,0 +1,146 @@ +import logging + +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import URLs +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + +PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" + + +class Site(Cog): + """Commands for linking to different parts of the site.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="site", aliases=("s",), invoke_without_command=True) + async def site_group(self, ctx: Context) -> None: + """Commands for getting info about our website.""" + await ctx.send_help(ctx.command) + + @site_group.command(name="home", aliases=("about",)) + async def site_main(self, ctx: Context) -> None: + """Info about the website itself.""" + url = f"{URLs.site_schema}{URLs.site}/" + + embed = Embed(title="Python Discord website") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + f"[Our official website]({url}) is an open-source community project " + "created with Python and Django. It contains information about the server " + "itself, lets you sign up for upcoming events, has its own wiki, contains " + "a list of valuable learning resources, and much more." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="resources") + async def site_resources(self, ctx: Context) -> None: + """Info about the site's Resources page.""" + learning_url = f"{PAGES_URL}/resources" + + embed = Embed(title="Resources") + embed.set_footer(text=f"{learning_url}") + embed.colour = Colour.blurple() + embed.description = ( + f"The [Resources page]({learning_url}) on our website contains a " + "list of hand-selected learning resources that we regularly recommend " + f"to both beginners and experts." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="tools") + async def site_tools(self, ctx: Context) -> None: + """Info about the site's Tools page.""" + tools_url = f"{PAGES_URL}/resources/tools" + + embed = Embed(title="Tools") + embed.set_footer(text=f"{tools_url}") + embed.colour = Colour.blurple() + embed.description = ( + f"The [Tools page]({tools_url}) on our website contains a " + f"couple of the most popular tools for programming in Python." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="help") + async def site_help(self, ctx: Context) -> None: + """Info about the site's Getting Help page.""" + url = f"{PAGES_URL}/resources/guides/asking-good-questions" + + embed = Embed(title="Asking Good Questions") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + "Asking the right question about something that's new to you can sometimes be tricky. " + f"To help with this, we've created a [guide to asking good questions]({url}) on our website. " + "It contains everything you need to get the very best help from our community." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="faq") + async def site_faq(self, ctx: Context) -> None: + """Info about the site's FAQ page.""" + url = f"{PAGES_URL}/frequently-asked-questions" + + embed = Embed(title="FAQ") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + "As the largest Python community on Discord, we get hundreds of questions every day. " + "Many of these questions have been asked before. We've compiled a list of the most " + "frequently asked questions along with their answers, which can be found on " + f"our [FAQ page]({url})." + ) + + await ctx.send(embed=embed) + + @site_group.command(aliases=['r', 'rule'], name='rules') + async def site_rules(self, ctx: Context, *rules: int) -> None: + """Provides a link to all rules or, if specified, displays specific rule(s).""" + rules_embed = Embed(title='Rules', color=Colour.blurple()) + rules_embed.url = f"{PAGES_URL}/rules" + + if not rules: + # Rules were not submitted. Return the default description. + rules_embed.description = ( + "The rules and guidelines that apply to this community can be found on" + f" our [rules page]({PAGES_URL}/rules). We expect" + " all members of the community to have read and understood these." + ) + + await ctx.send(embed=rules_embed) + return + + full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) + invalid_indices = tuple( + pick + for pick in rules + if pick < 1 or pick > len(full_rules) + ) + + if invalid_indices: + indices = ', '.join(map(str, invalid_indices)) + await ctx.send(f":x: Invalid rule indices: {indices}") + return + + for rule in rules: + self.bot.stats.incr(f"rule_uses.{rule}") + + final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) + + await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) + + +def setup(bot: Bot) -> None: + """Load the Site cog.""" + bot.add_cog(Site(bot)) diff --git a/bot/cogs/info/source.py b/bot/cogs/info/source.py new file mode 100644 index 000000000..205e0ba81 --- /dev/null +++ b/bot/cogs/info/source.py @@ -0,0 +1,141 @@ +import inspect +from pathlib import Path +from typing import Optional, Tuple, Union + +from discord import Embed +from discord.ext import commands + +from bot.bot import Bot +from bot.constants import URLs + +SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] + + +class SourceConverter(commands.Converter): + """Convert an argument into a help command, tag, command, or cog.""" + + async def convert(self, ctx: commands.Context, argument: str) -> SourceType: + """Convert argument into source object.""" + if argument.lower().startswith("help"): + return ctx.bot.help_command + + cog = ctx.bot.get_cog(argument) + if cog: + return cog + + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd + + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + elif argument.lower() in tags_cog._cache: + return argument.lower() + + raise commands.BadArgument( + f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog." + ) + + +class BotSource(commands.Cog): + """Displays information about the bot's source code.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @commands.command(name="source", aliases=("src",)) + async def source_command(self, ctx: commands.Context, *, source_item: SourceConverter = None) -> None: + """Display information and a GitHub link to the source code of a command, tag, or cog.""" + if not source_item: + embed = Embed(title="Bot's GitHub Repository") + embed.add_field(name="Repository", value=f"[Go to GitHub]({URLs.github_bot_repo})") + embed.set_thumbnail(url="https://avatars1.githubusercontent.com/u/9919") + await ctx.send(embed=embed) + return + + embed = await self.build_embed(source_item) + await ctx.send(embed=embed) + + def get_source_link(self, source_item: SourceType) -> Tuple[str, str, Optional[int]]: + """ + Build GitHub link of source item, return this link, file location and first line number. + + Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). + """ + if isinstance(source_item, commands.Command): + if source_item.cog_name == "Alias": + cmd_name = source_item.callback.__name__.replace("_alias", "") + cmd = self.bot.get_command(cmd_name.replace("_", " ")) + src = cmd.callback.__code__ + filename = src.co_filename + else: + src = source_item.callback.__code__ + filename = src.co_filename + elif isinstance(source_item, str): + tags_cog = self.bot.get_cog("Tags") + filename = tags_cog._cache[source_item]["location"] + else: + src = type(source_item) + try: + filename = inspect.getsourcefile(src) + except TypeError: + raise commands.BadArgument("Cannot get source for a dynamically-created object.") + + if not isinstance(source_item, str): + try: + lines, first_line_no = inspect.getsourcelines(src) + except OSError: + raise commands.BadArgument("Cannot get source for a dynamically-created object.") + + lines_extension = f"#L{first_line_no}-L{first_line_no+len(lines)-1}" + else: + first_line_no = None + lines_extension = "" + + # Handle tag file location differently than others to avoid errors in some cases + if not first_line_no: + file_location = Path(filename).relative_to("/bot/") + else: + file_location = Path(filename).relative_to(Path.cwd()).as_posix() + + url = f"{URLs.github_bot_repo}/blob/master/{file_location}{lines_extension}" + + return url, file_location, first_line_no or None + + async def build_embed(self, source_object: SourceType) -> Optional[Embed]: + """Build embed based on source object.""" + url, location, first_line = self.get_source_link(source_object) + + if isinstance(source_object, commands.HelpCommand): + title = "Help Command" + description = source_object.__doc__.splitlines()[1] + elif isinstance(source_object, commands.Command): + if source_object.cog_name == "Alias": + cmd_name = source_object.callback.__name__.replace("_alias", "") + cmd = self.bot.get_command(cmd_name.replace("_", " ")) + description = cmd.short_doc + else: + description = source_object.short_doc + + title = f"Command: {source_object.qualified_name}" + elif isinstance(source_object, str): + title = f"Tag: {source_object}" + description = "" + else: + title = f"Cog: {source_object.qualified_name}" + description = source_object.description.splitlines()[0] + + embed = Embed(title=title, description=description) + embed.add_field(name="Source Code", value=f"[Go to GitHub]({url})") + line_text = f":{first_line}" if first_line else "" + embed.set_footer(text=f"{location}{line_text}") + + return embed + + +def setup(bot: Bot) -> None: + """Load the BotSource cog.""" + bot.add_cog(BotSource(bot)) diff --git a/bot/cogs/info/stats.py b/bot/cogs/info/stats.py new file mode 100644 index 000000000..d42f55466 --- /dev/null +++ b/bot/cogs/info/stats.py @@ -0,0 +1,129 @@ +import string +from datetime import datetime + +from discord import Member, Message, Status +from discord.ext.commands import Cog, Context +from discord.ext.tasks import loop + +from bot.bot import Bot +from bot.constants import Categories, Channels, Guild, Stats as StatConf + + +CHANNEL_NAME_OVERRIDES = { + Channels.off_topic_0: "off_topic_0", + Channels.off_topic_1: "off_topic_1", + Channels.off_topic_2: "off_topic_2", + Channels.staff_lounge: "staff_lounge" +} + +ALLOWED_CHARS = string.ascii_letters + string.digits + "_" + + +class Stats(Cog): + """A cog which provides a way to hook onto Discord events and forward to stats.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.last_presence_update = None + self.update_guild_boost.start() + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Report message events in the server to statsd.""" + if message.guild is None: + return + + if message.guild.id != Guild.id: + return + + cat = getattr(message.channel, "category", None) + if cat is not None and cat.id == Categories.modmail: + if message.channel.id != Channels.incidents: + # Do not report modmail channels to stats, there are too many + # of them for interesting statistics to be drawn out of this. + return + + reformatted_name = message.channel.name.replace('-', '_') + + if CHANNEL_NAME_OVERRIDES.get(message.channel.id): + reformatted_name = CHANNEL_NAME_OVERRIDES.get(message.channel.id) + + reformatted_name = "".join(char for char in reformatted_name if char in ALLOWED_CHARS) + + stat_name = f"channels.{reformatted_name}" + self.bot.stats.incr(stat_name) + + # Increment the total message count + self.bot.stats.incr("messages") + + @Cog.listener() + async def on_command_completion(self, ctx: Context) -> None: + """Report completed commands to statsd.""" + command_name = ctx.command.qualified_name.replace(" ", "_") + + self.bot.stats.incr(f"commands.{command_name}") + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Update member count stat on member join.""" + if member.guild.id != Guild.id: + return + + self.bot.stats.gauge("guild.total_members", len(member.guild.members)) + + @Cog.listener() + async def on_member_leave(self, member: Member) -> None: + """Update member count stat on member leave.""" + if member.guild.id != Guild.id: + return + + self.bot.stats.gauge("guild.total_members", len(member.guild.members)) + + @Cog.listener() + async def on_member_update(self, _before: Member, after: Member) -> None: + """Update presence estimates on member update.""" + if after.guild.id != Guild.id: + return + + if self.last_presence_update: + if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: + return + + self.last_presence_update = datetime.now() + + online = 0 + idle = 0 + dnd = 0 + offline = 0 + + for member in after.guild.members: + if member.status is Status.online: + online += 1 + elif member.status is Status.dnd: + dnd += 1 + elif member.status is Status.idle: + idle += 1 + elif member.status is Status.offline: + offline += 1 + + self.bot.stats.gauge("guild.status.online", online) + self.bot.stats.gauge("guild.status.idle", idle) + self.bot.stats.gauge("guild.status.do_not_disturb", dnd) + self.bot.stats.gauge("guild.status.offline", offline) + + @loop(hours=1) + async def update_guild_boost(self) -> None: + """Post the server boost level and tier every hour.""" + await self.bot.wait_until_guild_available() + g = self.bot.get_guild(Guild.id) + self.bot.stats.gauge("boost.amount", g.premium_subscription_count) + self.bot.stats.gauge("boost.tier", g.premium_tier) + + def cog_unload(self) -> None: + """Stop the boost statistic task on unload of the Cog.""" + self.update_guild_boost.stop() + + +def setup(bot: Bot) -> None: + """Load the stats cog.""" + bot.add_cog(Stats(bot)) diff --git a/bot/cogs/info/tags.py b/bot/cogs/info/tags.py new file mode 100644 index 000000000..3d76c5c08 --- /dev/null +++ b/bot/cogs/info/tags.py @@ -0,0 +1,277 @@ +import logging +import re +import time +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot import constants +from bot.bot import Bot +from bot.converters import TagNameConverter +from bot.pagination import LinePaginator +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +TEST_CHANNELS = ( + constants.Channels.bot_commands, + constants.Channels.helpers +) + +REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) +FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." + + +class Tags(Cog): + """Save new tags and fetch existing tags.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.tag_cooldowns = {} + self._cache = self.get_tags() + + @staticmethod + def get_tags() -> dict: + """Get all tags.""" + cache = {} + + base_path = Path("bot", "resources", "tags") + for file in base_path.glob("**/*"): + if file.is_file(): + tag_title = file.stem + tag = { + "title": tag_title, + "embed": { + "description": file.read_text(encoding="utf8"), + }, + "restricted_to": "developers", + "location": f"/bot/{file}" + } + + # Convert to a list to allow negative indexing. + parents = list(file.relative_to(base_path).parents) + if len(parents) > 1: + # -1 would be '.' hence -2 is used as the index. + tag["restricted_to"] = parents[-2].name + + cache[tag_title] = tag + + return cache + + @staticmethod + def check_accessibility(user: Member, tag: dict) -> bool: + """Check if user can access a tag.""" + return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] + + @staticmethod + def _fuzzy_search(search: str, target: str) -> float: + """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" + current, index = 0, 0 + _search = REGEX_NON_ALPHABET.sub('', search.lower()) + _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) + _target = next(_targets) + try: + while True: + while index < len(_target) and _search[current] == _target[index]: + current += 1 + index += 1 + index, _target = 0, next(_targets) + except (StopIteration, IndexError): + pass + return current / len(_search) * 100 + + def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: + """Return a list of suggested tags.""" + scores: Dict[str, int] = { + tag_title: Tags._fuzzy_search(tag_name, tag['title']) + for tag_title, tag in self._cache.items() + } + + thresholds = thresholds or [100, 90, 80, 70, 60] + + for threshold in thresholds: + suggestions = [ + self._cache[tag_title] + for tag_title, matching_score in scores.items() + if matching_score >= threshold + ] + if suggestions: + return suggestions + + return [] + + def _get_tag(self, tag_name: str) -> list: + """Get a specific tag.""" + found = [self._cache.get(tag_name.lower(), None)] + if not found[0]: + return self._get_suggestions(tag_name) + return found + + def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: + """ + Search for tags via contents. + + `predicate` will be the built-in any, all, or a custom callable. Must return a bool. + """ + keywords_processed: List[str] = [] + for keyword in keywords.split(','): + keyword_sanitized = keyword.strip().casefold() + if not keyword_sanitized: + # this happens when there are leading / trailing / consecutive comma. + continue + keywords_processed.append(keyword_sanitized) + + if not keywords_processed: + # after sanitizing, we can end up with an empty list, for example when keywords is ',' + # in that case, we simply want to search for such keywords directly instead. + keywords_processed = [keywords] + + matching_tags = [] + for tag in self._cache.values(): + matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) + if self.check_accessibility(user, tag) and check(matches): + matching_tags.append(tag) + + return matching_tags + + async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: + """Send the result of matching tags to user.""" + if not matching_tags: + pass + elif len(matching_tags) == 1: + await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) + else: + is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 + embed = Embed( + title=f"Here are the tags containing the given keyword{'s' * is_plural}:", + description='\n'.join(tag['title'] for tag in matching_tags[:10]) + ) + await LinePaginator.paginate( + sorted(f"**»** {tag['title']}" for tag in matching_tags), + ctx, + embed, + footer_text=FOOTER_TEXT, + empty=False, + max_lines=15 + ) + + @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) + async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + """Show all known tags, a single tag, or run a subcommand.""" + await ctx.invoke(self.get_command, tag_name=tag_name) + + @tags_group.group(name='search', invoke_without_command=True) + async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Only search for tags that has ALL the keywords. + """ + matching_tags = self._get_tags_via_content(all, keywords, ctx.author) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @search_tag_content.command(name='any') + async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Search for tags that has ANY of the keywords. + """ + matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @tags_group.command(name='get', aliases=('show', 'g')) + async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + """Get a specified tag, or a list of all tags if no tag is specified.""" + + def _command_on_cooldown(tag_name: str) -> bool: + """ + Check if the command is currently on cooldown, on a per-tag, per-channel basis. + + The cooldown duration is set in constants.py. + """ + now = time.time() + + cooldown_conditions = ( + tag_name + and tag_name in self.tag_cooldowns + and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags + and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id + ) + + if cooldown_conditions: + return True + return False + + if _command_on_cooldown(tag_name): + time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] + time_left = constants.Cooldowns.tags - time_elapsed + log.info( + f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " + f"Cooldown ends in {time_left:.1f} seconds." + ) + return + + if tag_name is not None: + temp_founds = self._get_tag(tag_name) + + founds = [] + + for found_tag in temp_founds: + if self.check_accessibility(ctx.author, found_tag): + founds.append(found_tag) + + if len(founds) == 1: + tag = founds[0] + if ctx.channel.id not in TEST_CHANNELS: + self.tag_cooldowns[tag_name] = { + "time": time.time(), + "channel": ctx.channel.id + } + + self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") + + await wait_for_deletion( + await ctx.send(embed=Embed.from_dict(tag['embed'])), + [ctx.author.id], + client=self.bot + ) + elif founds and len(tag_name) >= 3: + await wait_for_deletion( + await ctx.send( + embed=Embed( + title='Did you mean ...', + description='\n'.join(tag['title'] for tag in founds[:10]) + ) + ), + [ctx.author.id], + client=self.bot + ) + + else: + tags = self._cache.values() + if not tags: + await ctx.send(embed=Embed( + description="**There are no tags in the database!**", + colour=Colour.red() + )) + else: + embed: Embed = Embed(title="**Current tags**") + await LinePaginator.paginate( + sorted( + f"**»** {tag['title']}" for tag in tags + if self.check_accessibility(ctx.author, tag) + ), + ctx, + embed, + footer_text=FOOTER_TEXT, + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Tags cog.""" + bot.add_cog(Tags(bot)) diff --git a/bot/cogs/info/wolfram.py b/bot/cogs/info/wolfram.py new file mode 100644 index 000000000..e6cae3bb8 --- /dev/null +++ b/bot/cogs/info/wolfram.py @@ -0,0 +1,280 @@ +import logging +from io import BytesIO +from typing import Callable, List, Optional, Tuple +from urllib import parse + +import discord +from dateutil.relativedelta import relativedelta +from discord import Embed +from discord.ext import commands +from discord.ext.commands import BucketType, Cog, Context, check, group + +from bot.bot import Bot +from bot.constants import Colours, STAFF_ROLES, Wolfram +from bot.pagination import ImagePaginator +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +APPID = Wolfram.key +DEFAULT_OUTPUT_FORMAT = "JSON" +QUERY = "http://api.wolframalpha.com/v2/{request}?{data}" +WOLF_IMAGE = "https://www.symbols.com/gi.php?type=1&id=2886&i=1" + +MAX_PODS = 20 + +# Allows for 10 wolfram calls pr user pr day +usercd = commands.CooldownMapping.from_cooldown(Wolfram.user_limit_day, 60*60*24, BucketType.user) + +# Allows for max api requests / days in month per day for the entire guild (Temporary) +guildcd = commands.CooldownMapping.from_cooldown(Wolfram.guild_limit_day, 60*60*24, BucketType.guild) + + +async def send_embed( + ctx: Context, + message_txt: str, + colour: int = Colours.soft_red, + footer: str = None, + img_url: str = None, + f: discord.File = None +) -> None: + """Generate & send a response embed with Wolfram as the author.""" + embed = Embed(colour=colour) + embed.description = message_txt + embed.set_author(name="Wolfram Alpha", + icon_url=WOLF_IMAGE, + url="https://www.wolframalpha.com/") + if footer: + embed.set_footer(text=footer) + + if img_url: + embed.set_image(url=img_url) + + await ctx.send(embed=embed, file=f) + + +def custom_cooldown(*ignore: List[int]) -> Callable: + """ + Implement per-user and per-guild cooldowns for requests to the Wolfram API. + + A list of roles may be provided to ignore the per-user cooldown + """ + async def predicate(ctx: Context) -> bool: + if ctx.invoked_with == 'help': + # if the invoked command is help we don't want to increase the ratelimits since it's not actually + # invoking the command/making a request, so instead just check if the user/guild are on cooldown. + guild_cooldown = not guildcd.get_bucket(ctx.message).get_tokens() == 0 # if guild is on cooldown + if not any(r.id in ignore for r in ctx.author.roles): # check user bucket if user is not ignored + return guild_cooldown and not usercd.get_bucket(ctx.message).get_tokens() == 0 + return guild_cooldown + + user_bucket = usercd.get_bucket(ctx.message) + + if all(role.id not in ignore for role in ctx.author.roles): + user_rate = user_bucket.update_rate_limit() + + if user_rate: + # Can't use api; cause: member limit + delta = relativedelta(seconds=int(user_rate)) + cooldown = humanize_delta(delta) + message = ( + "You've used up your limit for Wolfram|Alpha requests.\n" + f"Cooldown: {cooldown}" + ) + await send_embed(ctx, message) + return False + + guild_bucket = guildcd.get_bucket(ctx.message) + guild_rate = guild_bucket.update_rate_limit() + + # Repr has a token attribute to read requests left + log.debug(guild_bucket) + + if guild_rate: + # Can't use api; cause: guild limit + message = ( + "The max limit of requests for the server has been reached for today.\n" + f"Cooldown: {int(guild_rate)}" + ) + await send_embed(ctx, message) + return False + + return True + return check(predicate) + + +async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tuple]]: + """Get the Wolfram API pod pages for the provided query.""" + async with ctx.channel.typing(): + url_str = parse.urlencode({ + "input": query, + "appid": APPID, + "output": DEFAULT_OUTPUT_FORMAT, + "format": "image,plaintext" + }) + request_url = QUERY.format(request="query", data=url_str) + + async with bot.http_session.get(request_url) as response: + json = await response.json(content_type='text/plain') + + result = json["queryresult"] + + if result["error"]: + # API key not set up correctly + if result["error"]["msg"] == "Invalid appid": + message = "Wolfram API key is invalid or missing." + log.warning( + "API key seems to be missing, or invalid when " + f"processing a wolfram request: {url_str}, Response: {json}" + ) + await send_embed(ctx, message) + return + + message = "Something went wrong internally with your request, please notify staff!" + log.warning(f"Something went wrong getting a response from wolfram: {url_str}, Response: {json}") + await send_embed(ctx, message) + return + + if not result["success"]: + message = f"I couldn't find anything for {query}." + await send_embed(ctx, message) + return + + if not result["numpods"]: + message = "Could not find any results." + await send_embed(ctx, message) + return + + pods = result["pods"] + pages = [] + for pod in pods[:MAX_PODS]: + subs = pod.get("subpods") + + for sub in subs: + title = sub.get("title") or sub.get("plaintext") or sub.get("id", "") + img = sub["img"]["src"] + pages.append((title, img)) + return pages + + +class Wolfram(Cog): + """Commands for interacting with the Wolfram|Alpha API.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_command(self, ctx: Context, *, query: str) -> None: + """Requests all answers on a single image, sends an image of all related pods.""" + url_str = parse.urlencode({ + "i": query, + "appid": APPID, + }) + query = QUERY.format(request="simple", data=url_str) + + # Give feedback that the bot is working. + async with ctx.channel.typing(): + async with self.bot.http_session.get(query) as response: + status = response.status + image_bytes = await response.read() + + f = discord.File(BytesIO(image_bytes), filename="image.png") + image_url = "attachment://image.png" + + if status == 501: + message = "Failed to get response" + footer = "" + color = Colours.soft_red + elif status == 400: + message = "No input found" + footer = "" + color = Colours.soft_red + elif status == 403: + message = "Wolfram API key is invalid or missing." + footer = "" + color = Colours.soft_red + else: + message = "" + footer = "View original for a bigger picture." + color = Colours.soft_orange + + # Sends a "blank" embed if no request is received, unsure how to fix + await send_embed(ctx, message, color, footer=footer, img_url=image_url, f=f) + + @wolfram_command.command(name="page", aliases=("pa", "p")) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_page_command(self, ctx: Context, *, query: str) -> None: + """ + Requests a drawn image of given query. + + Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. + """ + pages = await get_pod_pages(ctx, self.bot, query) + + if not pages: + return + + embed = Embed() + embed.set_author(name="Wolfram Alpha", + icon_url=WOLF_IMAGE, + url="https://www.wolframalpha.com/") + embed.colour = Colours.soft_orange + + await ImagePaginator.paginate(pages, ctx, embed) + + @wolfram_command.command(name="cut", aliases=("c",)) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_cut_command(self, ctx: Context, *, query: str) -> None: + """ + Requests a drawn image of given query. + + Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. + """ + pages = await get_pod_pages(ctx, self.bot, query) + + if not pages: + return + + if len(pages) >= 2: + page = pages[1] + else: + page = pages[0] + + await send_embed(ctx, page[0], colour=Colours.soft_orange, img_url=page[1]) + + @wolfram_command.command(name="short", aliases=("sh", "s")) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_short_command(self, ctx: Context, *, query: str) -> None: + """Requests an answer to a simple question.""" + url_str = parse.urlencode({ + "i": query, + "appid": APPID, + }) + query = QUERY.format(request="result", data=url_str) + + # Give feedback that the bot is working. + async with ctx.channel.typing(): + async with self.bot.http_session.get(query) as response: + status = response.status + response_text = await response.text() + + if status == 501: + message = "Failed to get response" + color = Colours.soft_red + elif status == 400: + message = "No input found" + color = Colours.soft_red + elif response_text == "Error 1: Invalid appid": + message = "Wolfram API key is invalid or missing." + color = Colours.soft_red + else: + message = response_text + color = Colours.soft_orange + + await send_embed(ctx, message, color) + + +def setup(bot: Bot) -> None: + """Load the Wolfram cog.""" + bot.add_cog(Wolfram(bot)) diff --git a/bot/cogs/information.py b/bot/cogs/information.py deleted file mode 100644 index 8982196d1..000000000 --- a/bot/cogs/information.py +++ /dev/null @@ -1,422 +0,0 @@ -import colorsys -import logging -import pprint -import textwrap -from collections import Counter, defaultdict -from string import Template -from typing import Any, Mapping, Optional, Union - -from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils -from discord.abc import GuildChannel -from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group -from discord.utils import escape_markdown - -from bot import constants -from bot.bot import Bot -from bot.decorators import in_whitelist, with_role -from bot.pagination import LinePaginator -from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - - -class Information(Cog): - """A cog with commands for generating embeds with server info, such as server stats and user info.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @staticmethod - def role_can_read(channel: GuildChannel, role: Role) -> bool: - """Return True if `role` can read messages in `channel`.""" - overwrites = channel.overwrites_for(role) - return overwrites.read_messages is True - - def get_staff_channel_count(self, guild: Guild) -> int: - """ - Get the number of channels that are staff-only. - - We need to know two things about a channel: - - Does the @everyone role have explicit read deny permissions? - - Do staff roles have explicit read allow permissions? - - If the answer to both of these questions is yes, it's a staff channel. - """ - channel_ids = set() - for channel in guild.channels: - if channel.type is ChannelType.category: - continue - - everyone_can_read = self.role_can_read(channel, guild.default_role) - - for role in constants.STAFF_ROLES: - role_can_read = self.role_can_read(channel, guild.get_role(role)) - if role_can_read and not everyone_can_read: - channel_ids.add(channel.id) - break - - return len(channel_ids) - - @staticmethod - def get_channel_type_counts(guild: Guild) -> str: - """Return the total amounts of the various types of channels in `guild`.""" - channel_counter = Counter(c.type for c in guild.channels) - channel_type_list = [] - for channel, count in channel_counter.items(): - channel_type = str(channel).title() - channel_type_list.append(f"{channel_type} channels: {count}") - - channel_type_list = sorted(channel_type_list) - return "\n".join(channel_type_list) - - @with_role(*constants.MODERATION_ROLES) - @command(name="roles") - async def roles_info(self, ctx: Context) -> None: - """Returns a list of all roles and their corresponding IDs.""" - # Sort the roles alphabetically and remove the @everyone role - roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) - - # Build a list - role_list = [] - for role in roles: - role_list.append(f"`{role.id}` - {role.mention}") - - # Build an embed - embed = Embed( - title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", - colour=Colour.blurple() - ) - - await LinePaginator.paginate(role_list, ctx, embed, empty=False) - - @with_role(*constants.MODERATION_ROLES) - @command(name="role") - async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: - """ - Return information on a role or list of roles. - - To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. - """ - parsed_roles = [] - failed_roles = [] - - for role_name in roles: - if isinstance(role_name, Role): - # Role conversion has already succeeded - parsed_roles.append(role_name) - continue - - role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) - - if not role: - failed_roles.append(role_name) - continue - - parsed_roles.append(role) - - if failed_roles: - await ctx.send(f":x: Could not retrieve the following roles: {', '.join(failed_roles)}") - - for role in parsed_roles: - h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) - - embed = Embed( - title=f"{role.name} info", - colour=role.colour, - ) - embed.add_field(name="ID", value=role.id, inline=True) - embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) - embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) - embed.add_field(name="Member count", value=len(role.members), inline=True) - embed.add_field(name="Position", value=role.position) - embed.add_field(name="Permission code", value=role.permissions.value, inline=True) - - await ctx.send(embed=embed) - - @command(name="server", aliases=["server_info", "guild", "guild_info"]) - async def server_info(self, ctx: Context) -> None: - """Returns an embed full of server information.""" - created = time_since(ctx.guild.created_at, precision="days") - features = ", ".join(ctx.guild.features) - region = ctx.guild.region - - roles = len(ctx.guild.roles) - member_count = ctx.guild.member_count - channel_counts = self.get_channel_type_counts(ctx.guild) - - # How many of each user status? - statuses = Counter(member.status for member in ctx.guild.members) - embed = Embed(colour=Colour.blurple()) - - # How many staff members and staff channels do we have? - staff_member_count = len(ctx.guild.get_role(constants.Roles.helpers).members) - staff_channel_count = self.get_staff_channel_count(ctx.guild) - - # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the - # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting - # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts - # after the dedent is made. - embed.description = Template( - textwrap.dedent(f""" - **Server information** - Created: {created} - Voice region: {region} - Features: {features} - - **Channel counts** - $channel_counts - Staff channels: {staff_channel_count} - - **Member counts** - Members: {member_count:,} - Staff members: {staff_member_count} - Roles: {roles} - - **Member statuses** - {constants.Emojis.status_online} {statuses[Status.online]:,} - {constants.Emojis.status_idle} {statuses[Status.idle]:,} - {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} - {constants.Emojis.status_offline} {statuses[Status.offline]:,} - """) - ).substitute({"channel_counts": channel_counts}) - embed.set_thumbnail(url=ctx.guild.icon_url) - - await ctx.send(embed=embed) - - @command(name="user", aliases=["user_info", "member", "member_info"]) - async def user_info(self, ctx: Context, user: Member = None) -> None: - """Returns info about a user.""" - if user is None: - user = ctx.author - - # Do a role check if this is being executed on someone other than the caller - elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): - await ctx.send("You may not use this command on users other than yourself.") - return - - # Non-staff may only do this in #bot-commands - if not with_role_check(ctx, *constants.STAFF_ROLES): - if not ctx.channel.id == constants.Channels.bot_commands: - raise InWhitelistCheckFailure(constants.Channels.bot_commands) - - embed = await self.create_user_embed(ctx, user) - - await ctx.send(embed=embed) - - async def create_user_embed(self, ctx: Context, user: Member) -> Embed: - """Creates an embed containing information on the `user`.""" - created = time_since(user.created_at, max_units=3) - - # Custom status - custom_status = '' - for activity in user.activities: - # Check activity.state for None value if user has a custom status set - # This guards against a custom status with an emoji but no text, which will cause - # escape_markdown to raise an exception - # This can be reworked after a move to d.py 1.3.0+, which adds a CustomActivity class - if activity.name == 'Custom Status' and activity.state: - state = escape_markdown(activity.state) - custom_status = f'Status: {state}\n' - - name = str(user) - if user.nick: - name = f"{user.nick} ({name})" - - joined = time_since(user.joined_at, max_units=3) - roles = ", ".join(role.mention for role in user.roles[1:]) - - description = [ - textwrap.dedent(f""" - **User Information** - Created: {created} - Profile: {user.mention} - ID: {user.id} - {custom_status} - **Member Information** - Joined: {joined} - Roles: {roles or None} - """).strip() - ] - - # Show more verbose output in moderation channels for infractions and nominations - if ctx.channel.id in constants.MODERATION_CHANNELS: - description.append(await self.expanded_user_infraction_counts(user)) - description.append(await self.user_nomination_counts(user)) - else: - description.append(await self.basic_user_infraction_counts(user)) - - # Let's build the embed now - embed = Embed( - title=name, - description="\n\n".join(description) - ) - - embed.set_thumbnail(url=user.avatar_url_as(static_format="png")) - embed.colour = user.top_role.colour if roles else Colour.blurple() - - return embed - - async def basic_user_infraction_counts(self, member: Member) -> str: - """Gets the total and active infraction counts for the given `member`.""" - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'hidden': 'False', - 'user__id': str(member.id) - } - ) - - total_infractions = len(infractions) - active_infractions = sum(infraction['active'] for infraction in infractions) - - infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}" - - return infraction_output - - async def expanded_user_infraction_counts(self, member: Member) -> str: - """ - Gets expanded infraction counts for the given `member`. - - The counts will be split by infraction type and the number of active infractions for each type will indicated - in the output as well. - """ - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'user__id': str(member.id) - } - ) - - infraction_output = ["**Infractions**"] - if not infractions: - infraction_output.append("This user has never received an infraction.") - else: - # Count infractions split by `type` and `active` status for this user - infraction_types = set() - infraction_counter = defaultdict(int) - for infraction in infractions: - infraction_type = infraction["type"] - infraction_active = 'active' if infraction["active"] else 'inactive' - - infraction_types.add(infraction_type) - infraction_counter[f"{infraction_active} {infraction_type}"] += 1 - - # Format the output of the infraction counts - for infraction_type in sorted(infraction_types): - active_count = infraction_counter[f"active {infraction_type}"] - total_count = active_count + infraction_counter[f"inactive {infraction_type}"] - - line = f"{infraction_type.capitalize()}s: {total_count}" - if active_count: - line += f" ({active_count} active)" - - infraction_output.append(line) - - return "\n".join(infraction_output) - - async def user_nomination_counts(self, member: Member) -> str: - """Gets the active and historical nomination counts for the given `member`.""" - nominations = await self.bot.api_client.get( - 'bot/nominations', - params={ - 'user__id': str(member.id) - } - ) - - output = ["**Nominations**"] - - if not nominations: - output.append("This user has never been nominated.") - else: - count = len(nominations) - is_currently_nominated = any(nomination["active"] for nomination in nominations) - nomination_noun = "nomination" if count == 1 else "nominations" - - if is_currently_nominated: - output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).") - else: - output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.") - - return "\n".join(output) - - def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str: - """Format a mapping to be readable to a human.""" - # sorting is technically superfluous but nice if you want to look for a specific field - fields = sorted(mapping.items(), key=lambda item: item[0]) - - if field_width is None: - field_width = len(max(mapping.keys(), key=len)) - - out = '' - - for key, val in fields: - if isinstance(val, dict): - # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries - inner_width = int(field_width * 1.6) - val = '\n' + self.format_fields(val, field_width=inner_width) - - elif isinstance(val, str): - # split up text since it might be long - text = textwrap.fill(val, width=100, replace_whitespace=False) - - # indent it, I guess you could do this with `wrap` and `join` but this is nicer - val = textwrap.indent(text, ' ' * (field_width + len(': '))) - - # the first line is already indented so we `str.lstrip` it - val = val.lstrip() - - if key == 'color': - # makes the base 10 representation of a hex number readable to humans - val = hex(val) - - out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width) - - # remove trailing whitespace - return out.rstrip() - - @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) - @group(invoke_without_command=True) - @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) - async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: - """Shows information about the raw API response.""" - # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling - # doing this extra request is also much easier than trying to convert everything back into a dictionary again - raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) - - paginator = Paginator() - - def add_content(title: str, content: str) -> None: - paginator.add_line(f'== {title} ==\n') - # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution. - # we hope it's not close to 2000 - paginator.add_line(content.replace('```', '`` `')) - paginator.close_page() - - if message.content: - add_content('Raw message', message.content) - - transformer = pprint.pformat if json else self.format_fields - for field_name in ('embeds', 'attachments'): - data = raw_data[field_name] - - if not data: - continue - - total = len(data) - for current, item in enumerate(data, start=1): - title = f'Raw {field_name} ({current}/{total})' - add_content(title, transformer(item)) - - for page in paginator.pages: - await ctx.send(page) - - @raw.command() - async def json(self, ctx: Context, message: Message) -> None: - """Shows information about the raw API response in a copy-pasteable Python format.""" - await ctx.invoke(self.raw, message=message, json=True) - - -def setup(bot: Bot) -> None: - """Load the Information cog.""" - bot.add_cog(Information(bot)) diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py deleted file mode 100644 index b3102db2f..000000000 --- a/bot/cogs/jams.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -import typing as t - -from discord import CategoryChannel, Guild, Member, PermissionOverwrite, Role -from discord.ext import commands -from more_itertools import unique_everseen - -from bot.bot import Bot -from bot.constants import Roles -from bot.decorators import with_role - -log = logging.getLogger(__name__) - -MAX_CHANNELS = 50 -CATEGORY_NAME = "Code Jam" - - -class CodeJams(commands.Cog): - """Manages the code-jam related parts of our server.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @commands.command() - @with_role(Roles.admins) - async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: - """ - Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. - - The first user passed will always be the team leader. - """ - # Ignore duplicate members - members = list(unique_everseen(members)) - - # We had a little issue during Code Jam 4 here, the greedy converter did it's job - # and ignored anything which wasn't a valid argument which left us with teams of - # two members or at some times even 1 member. This fixes that by checking that there - # are always 3 members in the members list. - if len(members) < 3: - await ctx.send( - ":no_entry_sign: One of your arguments was invalid\n" - f"There must be a minimum of 3 valid members in your team. Found: {len(members)}" - " members" - ) - return - - team_channel = await self.create_channels(ctx.guild, team_name, members) - await self.add_roles(ctx.guild, members) - - await ctx.send( - f":ok_hand: Team created: {team_channel}\n" - f"**Team Leader:** {members[0].mention}\n" - f"**Team Members:** {' '.join(member.mention for member in members[1:])}" - ) - - async def get_category(self, guild: Guild) -> CategoryChannel: - """ - Return a code jam category. - - If all categories are full or none exist, create a new category. - """ - for category in guild.categories: - # Need 2 available spaces: one for the text channel and one for voice. - if category.name == CATEGORY_NAME and MAX_CHANNELS - len(category.channels) >= 2: - return category - - return await self.create_category(guild) - - @staticmethod - async def create_category(guild: Guild) -> CategoryChannel: - """Create a new code jam category and return it.""" - log.info("Creating a new code jam category.") - - category_overwrites = { - guild.default_role: PermissionOverwrite(read_messages=False), - guild.me: PermissionOverwrite(read_messages=True) - } - - return await guild.create_category_channel( - CATEGORY_NAME, - overwrites=category_overwrites, - reason="It's code jam time!" - ) - - @staticmethod - def get_overwrites(members: t.List[Member], guild: Guild) -> t.Dict[t.Union[Member, Role], PermissionOverwrite]: - """Get code jam team channels permission overwrites.""" - # First member is always the team leader - team_channel_overwrites = { - members[0]: PermissionOverwrite( - manage_messages=True, - read_messages=True, - manage_webhooks=True, - connect=True - ), - guild.default_role: PermissionOverwrite(read_messages=False, connect=False), - guild.get_role(Roles.verified): PermissionOverwrite( - read_messages=False, - connect=False - ) - } - - # Rest of members should just have read_messages - for member in members[1:]: - team_channel_overwrites[member] = PermissionOverwrite( - read_messages=True, - connect=True - ) - - return team_channel_overwrites - - async def create_channels(self, guild: Guild, team_name: str, members: t.List[Member]) -> str: - """Create team text and voice channels. Return the mention for the text channel.""" - # Get permission overwrites and category - team_channel_overwrites = self.get_overwrites(members, guild) - code_jam_category = await self.get_category(guild) - - # Create a text channel for the team - team_channel = await guild.create_text_channel( - team_name, - overwrites=team_channel_overwrites, - category=code_jam_category - ) - - # Create a voice channel for the team - team_voice_name = " ".join(team_name.split("-")).title() - - await guild.create_voice_channel( - team_voice_name, - overwrites=team_channel_overwrites, - category=code_jam_category - ) - - return team_channel.mention - - @staticmethod - async def add_roles(guild: Guild, members: t.List[Member]) -> None: - """Assign team leader and jammer roles.""" - # Assign team leader role - await members[0].add_roles(guild.get_role(Roles.team_leaders)) - - # Assign rest of roles - jammer_role = guild.get_role(Roles.jammers) - for member in members: - await member.add_roles(jammer_role) - - -def setup(bot: Bot) -> None: - """Load the CodeJams cog.""" - bot.add_cog(CodeJams(bot)) diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py deleted file mode 100644 index 94fa2b139..000000000 --- a/bot/cogs/logging.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging - -from discord import Embed -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, DEBUG_MODE - - -log = logging.getLogger(__name__) - - -class Logging(Cog): - """Debug logging module.""" - - def __init__(self, bot: Bot): - self.bot = bot - - self.bot.loop.create_task(self.startup_greeting()) - - async def startup_greeting(self) -> None: - """Announce our presence to the configured devlog channel.""" - await self.bot.wait_until_guild_available() - log.info("Bot connected!") - - embed = Embed(description="Connected!") - embed.set_author( - name="Python Bot", - url="https://github.com/python-discord/bot", - icon_url=( - "https://raw.githubusercontent.com/" - "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" - ) - ) - - if not DEBUG_MODE: - await self.bot.get_channel(Channels.dev_log).send(embed=embed) - - -def setup(bot: Bot) -> None: - """Load the Logging cog.""" - bot.add_cog(Logging(bot)) diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index 995187ef0..aad1f3c26 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,11 +1,11 @@ from bot.bot import Bot from .incidents import Incidents -from .infractions import Infractions -from .management import ModManagement +from .infraction.infractions import Infractions +from .infraction.management import ModManagement +from .infraction.superstarify import Superstarify from .modlog import ModLog from .silence import Silence from .slowmode import Slowmode -from .superstarify import Superstarify def setup(bot: Bot) -> None: diff --git a/bot/cogs/moderation/defcon.py b/bot/cogs/moderation/defcon.py new file mode 100644 index 000000000..4c0ad5914 --- /dev/null +++ b/bot/cogs/moderation/defcon.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import logging +from collections import namedtuple +from datetime import datetime, timedelta +from enum import Enum + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles +from bot.decorators import with_role + +log = logging.getLogger(__name__) + +REJECTION_MESSAGE = """ +Hi, {user} - Thanks for your interest in our server! + +Due to a current (or detected) cyberattack on our community, we've limited access to the server for new accounts. Since +your account is relatively new, we're unable to provide access to the server at this time. + +Even so, thanks for joining! We're very excited at the possibility of having you here, and we hope that this situation +will be resolved soon. In the meantime, please feel free to peruse the resources on our site at +, and have a nice day! +""" + +BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism" + + +class Action(Enum): + """Defcon Action.""" + + ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template']) + + ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n") + DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "") + UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n") + + +class Defcon(Cog): + """Time-sensitive server defense mechanisms.""" + + days = None # type: timedelta + enabled = False # type: bool + + def __init__(self, bot: Bot): + self.bot = bot + self.channel = None + self.days = timedelta(days=0) + + self.bot.loop.create_task(self.sync_settings()) + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def sync_settings(self) -> None: + """On cog load, try to synchronize DEFCON settings to the API.""" + await self.bot.wait_until_guild_available() + self.channel = await self.bot.fetch_channel(Channels.defcon) + + try: + response = await self.bot.api_client.get('bot/bot-settings/defcon') + data = response['data'] + + except Exception: # Yikes! + log.exception("Unable to get DEFCON settings!") + await self.bot.get_channel(Channels.dev_log).send( + f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" + ) + + else: + if data["enabled"]: + self.enabled = True + self.days = timedelta(days=data["days"]) + log.info(f"DEFCON enabled: {self.days.days} days") + + else: + self.enabled = False + self.days = timedelta(days=0) + log.info("DEFCON disabled") + + await self.update_channel_topic() + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """If DEFCON is enabled, check newly joining users to see if they meet the account age threshold.""" + if self.enabled and self.days.days > 0: + now = datetime.utcnow() + + if now - member.created_at < self.days: + log.info(f"Rejecting user {member}: Account is too new and DEFCON is enabled") + + message_sent = False + + try: + await member.send(REJECTION_MESSAGE.format(user=member.mention)) + + message_sent = True + except Exception: + log.exception(f"Unable to send rejection message to user: {member}") + + await member.kick(reason="DEFCON active, user is too new") + self.bot.stats.incr("defcon.leaves") + + message = ( + f"{member} (`{member.id}`) was denied entry because their account is too new." + ) + + if not message_sent: + message = f"{message}\n\nUnable to send rejection message via DM; they probably have DMs disabled." + + await self.mod_log.send_log_message( + Icons.defcon_denied, Colours.soft_red, "Entry denied", + message, member.avatar_url_as(static_format="png") + ) + + @group(name='defcon', aliases=('dc',), invoke_without_command=True) + @with_role(Roles.admins, Roles.owners) + async def defcon_group(self, ctx: Context) -> None: + """Check the DEFCON status or run a subcommand.""" + await ctx.send_help(ctx.command) + + async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: + """Providing a structured way to do an defcon action.""" + try: + response = await self.bot.api_client.get('bot/bot-settings/defcon') + data = response['data'] + + if "enable_date" in data and action is Action.DISABLED: + enabled = datetime.fromisoformat(data["enable_date"]) + + delta = datetime.now() - enabled + + self.bot.stats.timing("defcon.enabled", delta) + except Exception: + pass + + error = None + try: + await self.bot.api_client.put( + 'bot/bot-settings/defcon', + json={ + 'name': 'defcon', + 'data': { + # TODO: retrieve old days count + 'days': days, + 'enabled': action is not Action.DISABLED, + 'enable_date': datetime.now().isoformat() + } + } + ) + except Exception as err: + log.exception("Unable to update DEFCON settings.") + error = err + finally: + await ctx.send(self.build_defcon_msg(action, error)) + await self.send_defcon_log(action, ctx.author, error) + + self.bot.stats.gauge("defcon.threshold", days) + + @defcon_group.command(name='enable', aliases=('on', 'e')) + @with_role(Roles.admins, Roles.owners) + async def enable_command(self, ctx: Context) -> None: + """ + Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! + + Currently, this just adds an account age requirement. Use !defcon days to set how old an account must be, + in days. + """ + self.enabled = True + await self._defcon_action(ctx, days=0, action=Action.ENABLED) + await self.update_channel_topic() + + @defcon_group.command(name='disable', aliases=('off', 'd')) + @with_role(Roles.admins, Roles.owners) + async def disable_command(self, ctx: Context) -> None: + """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" + self.enabled = False + await self._defcon_action(ctx, days=0, action=Action.DISABLED) + await self.update_channel_topic() + + @defcon_group.command(name='status', aliases=('s',)) + @with_role(Roles.admins, Roles.owners) + async def status_command(self, ctx: Context) -> None: + """Check the current status of DEFCON mode.""" + embed = Embed( + colour=Colour.blurple(), title="DEFCON Status", + description=f"**Enabled:** {self.enabled}\n" + f"**Days:** {self.days.days}" + ) + + await ctx.send(embed=embed) + + @defcon_group.command(name='days') + @with_role(Roles.admins, Roles.owners) + async def days_command(self, ctx: Context, days: int) -> None: + """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" + self.days = timedelta(days=days) + self.enabled = True + await self._defcon_action(ctx, days=days, action=Action.UPDATED) + await self.update_channel_topic() + + async def update_channel_topic(self) -> None: + """Update the #defcon channel topic with the current DEFCON status.""" + if self.enabled: + day_str = "days" if self.days.days > 1 else "day" + new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Enabled, Threshold: {self.days.days} {day_str})" + else: + new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Disabled)" + + self.mod_log.ignore(Event.guild_channel_update, Channels.defcon) + await self.channel.edit(topic=new_topic) + + def build_defcon_msg(self, action: Action, e: Exception = None) -> str: + """Build in-channel response string for DEFCON action.""" + if action is Action.ENABLED: + msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n" + elif action is Action.DISABLED: + msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n" + elif action is Action.UPDATED: + msg = ( + f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} " + f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n" + ) + + if e: + msg += ( + "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" + f"```py\n{e}\n```" + ) + + return msg + + async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None: + """Send log message for DEFCON action.""" + info = action.value + log_msg: str = ( + f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n" + f"{info.template.format(days=self.days.days)}" + ) + status_msg = f"DEFCON {action.name.lower()}" + + if e: + log_msg += ( + "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" + f"```py\n{e}\n```" + ) + + await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg) + + +def setup(bot: Bot) -> None: + """Load the Defcon cog.""" + bot.add_cog(Defcon(bot)) diff --git a/bot/cogs/moderation/infraction/__init__.py b/bot/cogs/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/moderation/infraction/infractions.py b/bot/cogs/moderation/infraction/infractions.py new file mode 100644 index 000000000..8df642428 --- /dev/null +++ b/bot/cogs/moderation/infraction/infractions.py @@ -0,0 +1,370 @@ +import logging +import textwrap +import typing as t + +import discord +from discord import Member +from discord.ext import commands +from discord.ext.commands import Context, command + +from bot import constants +from bot.bot import Bot +from bot.constants import Event +from bot.converters import Expiry, FetchedMember +from bot.decorators import respect_role_hierarchy +from bot.utils.checks import with_role_check +from . import utils +from .scheduler import InfractionScheduler +from .utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class Infractions(InfractionScheduler, commands.Cog): + """Apply and pardon infractions on users for moderation purposes.""" + + category = "Moderation" + category_description = "Server moderation tools." + + def __init__(self, bot: Bot): + super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) + + self.category = "Moderation" + self._muted_role = discord.Object(constants.Roles.muted) + + @commands.Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Reapply active mute infractions for returning members.""" + active_mutes = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "mute", + "user__id": member.id + } + ) + + if active_mutes: + reason = f"Re-applying active mute: {active_mutes[0]['id']}" + action = member.add_roles(self._muted_role, reason=reason) + + await self.reapply_infraction(active_mutes[0], action) + + # region: Permanent infractions + + @command() + async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Warn a user for the given reason.""" + infraction = await utils.post_infraction(ctx, user, "warning", reason, active=False) + if infraction is None: + return + + await self.apply_infraction(ctx, infraction, user) + + @command() + async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Kick a user for the given reason.""" + await self.apply_kick(ctx, user, reason) + + @command() + async def ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Permanently ban a user for the given reason and stop watching them with Big Brother.""" + await self.apply_ban(ctx, user, reason) + + # endregion + # region: Temporary infractions + + @command(aliases=["mute"]) + async def tempmute(self, ctx: Context, user: Member, duration: Expiry, *, reason: t.Optional[str] = None) -> None: + """ + Temporarily mute a user for the given reason and duration. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_mute(ctx, user, reason, expires_at=duration) + + @command() + async def tempban( + self, + ctx: Context, + user: FetchedMember, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily ban a user for the given reason and duration. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_ban(ctx, user, reason, expires_at=duration) + + # endregion + # region: Permanent shadow infractions + + @command(hidden=True) + async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Create a private note for a user with the given reason without notifying the user.""" + infraction = await utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) + if infraction is None: + return + + await self.apply_infraction(ctx, infraction, user) + + @command(hidden=True, aliases=['shadowkick', 'skick']) + async def shadow_kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Kick a user for the given reason without notifying the user.""" + await self.apply_kick(ctx, user, reason, hidden=True) + + @command(hidden=True, aliases=['shadowban', 'sban']) + async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Permanently ban a user for the given reason without notifying the user.""" + await self.apply_ban(ctx, user, reason, hidden=True) + + # endregion + # region: Temporary shadow infractions + + @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) + async def shadow_tempmute( + self, ctx: Context, + user: Member, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily mute a user for the given reason and duration without notifying the user. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True) + + @command(hidden=True, aliases=["shadowtempban, stempban"]) + async def shadow_tempban( + self, + ctx: Context, + user: FetchedMember, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily ban a user for the given reason and duration without notifying the user. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) + + # endregion + # region: Remove infractions (un- commands) + + @command() + async def unmute(self, ctx: Context, user: FetchedMember) -> None: + """Prematurely end the active mute infraction for the user.""" + await self.pardon_infraction(ctx, "mute", user) + + @command() + async def unban(self, ctx: Context, user: FetchedMember) -> None: + """Prematurely end the active ban infraction for the user.""" + await self.pardon_infraction(ctx, "ban", user) + + # endregion + # region: Base apply functions + + async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: + """Apply a mute infraction with kwargs passed to `post_infraction`.""" + if await utils.get_active_infraction(ctx, user, "mute"): + return + + infraction = await utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_update, user.id) + + async def action() -> None: + await user.add_roles(self._muted_role, reason=reason) + + log.trace(f"Attempting to kick {user} from voice because they've been muted.") + await user.move_to(None, reason=reason) + + await self.apply_infraction(ctx, infraction, user, action()) + + @respect_role_hierarchy() + async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: + """Apply a kick infraction with kwargs passed to `post_infraction`.""" + infraction = await utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_remove, user.id) + + if reason: + reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = user.kick(reason=reason) + await self.apply_infraction(ctx, infraction, user, action) + + @respect_role_hierarchy() + async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: + """ + Apply a ban infraction with kwargs passed to `post_infraction`. + + Will also remove the banned user from the Big Brother watch list if applicable. + """ + # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active + is_temporary = kwargs.get("expires_at") is not None + active_infraction = await utils.get_active_infraction(ctx, user, "ban", is_temporary) + + if active_infraction: + if is_temporary: + log.trace("Tempban ignored as it cannot overwrite an active ban.") + return + + if active_infraction.get('expires_at') is None: + log.trace("Permaban already exists, notify.") + await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).") + return + + log.trace("Old tempban is being replaced by new permaban.") + await self.pardon_infraction(ctx, "ban", user, is_temporary) + + infraction = await utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_remove, user.id) + + if reason: + reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = ctx.guild.ban(user, reason=reason, delete_message_days=0) + await self.apply_infraction(ctx, infraction, user, action) + + if infraction.get('expires_at') is not None: + log.trace(f"Ban isn't permanent; user {user} won't be unwatched by Big Brother.") + return + + bb_cog = self.bot.get_cog("Big Brother") + if not bb_cog: + log.error(f"Big Brother cog not loaded; perma-banned user {user} won't be unwatched.") + return + + log.trace(f"Big Brother cog loaded; attempting to unwatch perma-banned user {user}.") + + bb_reason = "User has been permanently banned from the server. Automatically removed." + await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) + + # endregion + # region: Base pardon functions + + async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + """Remove a user's muted role, DM them a notification, and return a log dict.""" + user = guild.get_member(user_id) + log_text = {} + + if user: + # Remove the muted role. + self.mod_log.ignore(Event.member_update, user.id) + await user.remove_roles(self._muted_role, reason=reason) + + # DM the user about the expiration. + notified = await utils.notify_pardon( + user=user, + title="You have been unmuted", + content="You may now send messages in the server.", + icon_url=utils.INFRACTION_ICONS["mute"][1] + ) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["DM"] = "Sent" if notified else "**Failed**" + else: + log.info(f"Failed to unmute user {user_id}: user not found") + log_text["Failure"] = "User was not found in the guild." + + return log_text + + async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + """Remove a user's ban on the Discord guild and return a log dict.""" + user = discord.Object(user_id) + log_text = {} + + self.mod_log.ignore(Event.member_unban, user_id) + + try: + await guild.unban(user, reason=reason) + except discord.NotFound: + log.info(f"Failed to unban user {user_id}: no active ban found on Discord") + log_text["Note"] = "No active ban found on Discord." + + return log_text + + async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + guild = self.bot.get_guild(constants.Guild.id) + user_id = infraction["user"] + reason = f"Infraction #{infraction['id']} expired or was pardoned." + + if infraction["type"] == "mute": + return await self.pardon_mute(user_id, guild, reason) + elif infraction["type"] == "ban": + return await self.pardon_ban(user_id, guild, reason) + + # endregion + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Send a notification to the invoking context on a Union failure.""" + if isinstance(error, commands.BadUnionArgument): + if discord.User in error.converters or discord.Member in error.converters: + await ctx.send(str(error.errors[0])) + error.handled = True diff --git a/bot/cogs/moderation/infraction/management.py b/bot/cogs/moderation/infraction/management.py new file mode 100644 index 000000000..791585b6e --- /dev/null +++ b/bot/cogs/moderation/infraction/management.py @@ -0,0 +1,305 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext import commands +from discord.ext.commands import Context + +from bot import constants +from bot.bot import Bot +from bot.cogs.moderation.modlog import ModLog +from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user +from bot.pagination import LinePaginator +from bot.utils import time +from bot.utils.checks import in_whitelist_check, with_role_check +from . import utils +from .infractions import Infractions + +log = logging.getLogger(__name__) + + +class ModManagement(commands.Cog): + """Management of infractions.""" + + category = "Moderation" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @property + def infractions_cog(self) -> Infractions: + """Get currently loaded Infractions cog instance.""" + return self.bot.get_cog("Infractions") + + # region: Edit infraction commands + + @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) + async def infraction_group(self, ctx: Context) -> None: + """Infraction manipulation commands.""" + await ctx.send_help(ctx.command) + + @infraction_group.command(name='edit') + async def infraction_edit( + self, + ctx: Context, + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821 + duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821 + *, + reason: str = None + ) -> None: + """ + Edit the duration and/or the reason of an infraction. + + Durations are relative to the time of updating and should be appended with a unit of time. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. + + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. + """ + if duration is None and reason is None: + # Unlike UserInputError, the error handler will show a specified message for BadArgument + raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") + + # Retrieve the previous infraction for its information. + if isinstance(infraction_id, str): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get("bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + ":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") + + request_data = {} + confirm_messages = [] + log_text = "" + + if duration is not None and not old_infraction['active']: + if reason is None: + await ctx.send(":x: Cannot edit the expiration of an expired infraction.") + return + confirm_messages.append("expiry unchanged (infraction already expired)") + elif isinstance(duration, str): + request_data['expires_at'] = None + confirm_messages.append("marked as permanent") + elif duration is not None: + request_data['expires_at'] = duration.isoformat() + expiry = time.format_infraction_with_duration(request_data['expires_at']) + confirm_messages.append(f"set to expire on {expiry}") + else: + confirm_messages.append("expiry unchanged") + + if reason: + request_data['reason'] = reason + confirm_messages.append("set a new reason") + log_text += f""" + Previous reason: {old_infraction['reason']} + New reason: {reason} + """.rstrip() + else: + confirm_messages.append("reason unchanged") + + # Update the infraction + new_infraction = await self.bot.api_client.patch( + f'bot/infractions/{infraction_id}', + json=request_data, + ) + + # Re-schedule infraction if the expiration has been updated + if 'expires_at' in request_data: + # A scheduled task should only exist if the old infraction wasn't permanent + if old_infraction['expires_at']: + self.infractions_cog.scheduler.cancel(new_infraction['id']) + + # If the infraction was not marked as permanent, schedule a new expiration task + if request_data['expires_at']: + self.infractions_cog.schedule_expiration(new_infraction) + + log_text += f""" + Previous expiry: {old_infraction['expires_at'] or "Permanent"} + New expiry: {new_infraction['expires_at'] or "Permanent"} + """.rstrip() + + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") + + # Get information about the infraction's user + user_id = new_infraction['user'] + user = ctx.guild.get_member(user_id) + + if user: + user_text = f"{user.mention} (`{user.id}`)" + thumbnail = user.avatar_url_as(static_format="png") + else: + user_text = f"`{user_id}`" + thumbnail = None + + # The infraction's actor + actor_id = new_infraction['actor'] + actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`" + + await self.mod_log.send_log_message( + icon_url=constants.Icons.pencil, + colour=discord.Colour.blurple(), + title="Infraction edited", + thumbnail=thumbnail, + text=textwrap.dedent(f""" + Member: {user_text} + Actor: {actor} + Edited by: {ctx.message.author}{log_text} + """) + ) + + # endregion + # region: Search infractions + + @infraction_group.group(name="search", invoke_without_command=True) + async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None: + """Searches for infractions in the database.""" + if isinstance(query, discord.User): + await ctx.invoke(self.search_user, query) + else: + await ctx.invoke(self.search_reason, query) + + @infraction_search_group.command(name="user", aliases=("member", "id")) + async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: + """Search for infractions by member.""" + infraction_list = await self.bot.api_client.get( + 'bot/infractions', + params={'user__id': str(user.id)} + ) + embed = discord.Embed( + title=f"Infractions for {user} ({len(infraction_list)} total)", + colour=discord.Colour.orange() + ) + await self.send_infraction_list(ctx, embed, infraction_list) + + @infraction_search_group.command(name="reason", aliases=("match", "regex", "re")) + async def search_reason(self, ctx: Context, reason: str) -> None: + """Search for infractions by their reason. Use Re2 for matching.""" + infraction_list = await self.bot.api_client.get( + 'bot/infractions', + params={'search': reason} + ) + embed = discord.Embed( + title=f"Infractions matching `{reason}` ({len(infraction_list)} total)", + colour=discord.Colour.orange() + ) + await self.send_infraction_list(ctx, embed, infraction_list) + + # endregion + # region: Utility functions + + async def send_infraction_list( + self, + ctx: Context, + embed: discord.Embed, + infractions: t.Iterable[utils.Infraction] + ) -> None: + """Send a paginated embed of infractions for the specified user.""" + if not infractions: + await ctx.send(":warning: No infractions could be found for that query.") + return + + lines = tuple( + self.infraction_to_string(infraction) + for infraction in infractions + ) + + await LinePaginator.paginate( + lines, + ctx=ctx, + embed=embed, + empty=True, + max_lines=3, + max_size=1000 + ) + + def infraction_to_string(self, infraction: utils.Infraction) -> str: + """Convert the infraction object to a string representation.""" + actor_id = infraction["actor"] + guild = self.bot.get_guild(constants.Guild.id) + actor = guild.get_member(actor_id) + active = infraction["active"] + user_id = infraction["user"] + hidden = infraction["hidden"] + created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" + else: + remaining = "Inactive" + + if infraction["expires_at"] is None: + expires = "*Permanent*" + else: + date_from = datetime.strptime(created, time.INFRACTION_FORMAT) + expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) + + lines = textwrap.dedent(f""" + {"**===============**" if active else "==============="} + Status: {"__**Active**__" if active else "Inactive"} + User: {self.bot.get_user(user_id)} (`{user_id}`) + Type: **{infraction["type"]}** + Shadow: {hidden} + Created: {created} + Expires: {expires} + Remaining: {remaining} + Actor: {actor.mention if actor else actor_id} + ID: `{infraction["id"]}` + Reason: {infraction["reason"] or "*None*"} + {"**===============**" if active else "==============="} + """) + + return lines.strip() + + # endregion + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators inside moderator channels to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_whitelist_check( + ctx, + channels=constants.MODERATION_CHANNELS, + categories=[constants.Categories.modmail], + redirect=None, + fail_silently=True, + ) + ] + return all(checks) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Send a notification to the invoking context on a Union failure.""" + if isinstance(error, commands.BadUnionArgument): + if discord.User in error.converters: + await ctx.send(str(error.errors[0])) + error.handled = True diff --git a/bot/cogs/moderation/infraction/scheduler.py b/bot/cogs/moderation/infraction/scheduler.py new file mode 100644 index 000000000..b3d27fe76 --- /dev/null +++ b/bot/cogs/moderation/infraction/scheduler.py @@ -0,0 +1,463 @@ +import logging +import textwrap +import typing as t +from abc import abstractmethod +from datetime import datetime +from gettext import ngettext + +import dateutil.parser +import discord +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.cogs.moderation.modlog import ModLog +from bot.constants import Colours, STAFF_CHANNELS +from bot.utils import time +from bot.utils.scheduling import Scheduler +from . import utils +from .utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class InfractionScheduler: + """Handles the application, pardoning, and expiration of infractions.""" + + def __init__(self, bot: Bot, supported_infractions: t.Container[str]): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + @property + def mod_log(self) -> ModLog: + """Get the currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: + """Schedule expiration for previous infractions.""" + await self.bot.wait_until_guild_available() + + log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") + + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={'active': 'true'} + ) + for infraction in infractions: + if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: + self.schedule_expiration(infraction) + + async def reapply_infraction( + self, + infraction: utils.Infraction, + apply_coro: t.Optional[t.Awaitable] + ) -> None: + """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" + # Calculate the time remaining, in seconds, for the mute. + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + delta = (expiry - datetime.utcnow()).total_seconds() + + # Mark as inactive if less than a minute remains. + if delta < 60: + log.info( + "Infraction will be deactivated instead of re-applied " + "because less than 1 minute remains." + ) + await self.deactivate_infraction(infraction) + return + + # Allowing mod log since this is a passive action that should be logged. + await apply_coro + log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") + + async def apply_infraction( + self, + ctx: Context, + infraction: utils.Infraction, + user: UserSnowflake, + action_coro: t.Optional[t.Awaitable] = None + ) -> None: + """Apply an infraction to the user, log the infraction, and optionally notify the user.""" + infr_type = infraction["type"] + icon = utils.INFRACTION_ICONS[infr_type][0] + reason = infraction["reason"] + expiry = time.format_infraction_with_duration(infraction["expires_at"]) + id_ = infraction['id'] + + log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") + + # Default values for the confirmation message and mod log. + confirm_msg = ":ok_hand: applied" + + # Specifying an expiry for a note or warning makes no sense. + if infr_type in ("note", "warning"): + expiry_msg = "" + else: + expiry_msg = f" until {expiry}" if expiry else " permanently" + + dm_result = "" + dm_log_text = "" + expiry_log_text = f"\nExpires: {expiry}" if expiry else "" + log_title = "applied" + log_content = None + failed = False + + # DM the user about the infraction if it's not a shadow/hidden infraction. + # This needs to happen before we apply the infraction, as the bot cannot + # send DMs to user that it doesn't share a guild with. If we were to + # apply kick/ban infractions first, this would mean that we'd make it + # impossible for us to deliver a DM. See python-discord/bot#982. + if not infraction["hidden"]: + dm_result = f"{constants.Emojis.failmail} " + dm_log_text = "\nDM: **Failed**" + + # Sometimes user is a discord.Object; make it a proper user. + try: + if not isinstance(user, (discord.Member, discord.User)): + user = await self.bot.fetch_user(user.id) + except discord.HTTPException as e: + log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + else: + # Accordingly display whether the user was successfully notified via DM. + if await utils.notify_infraction(user, infr_type, expiry, reason, icon): + dm_result = ":incoming_envelope: " + dm_log_text = "\nDM: Sent" + + end_msg = "" + if infraction["actor"] == self.bot.user.id: + log.trace( + f"Infraction #{id_} actor is bot; including the reason in the confirmation message." + ) + if reason: + end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" + elif ctx.channel.id not in STAFF_CHANNELS: + log.trace( + f"Infraction #{id_} context is not in a staff channel; omitting infraction count." + ) + else: + log.trace(f"Fetching total infraction count for {user}.") + + infractions = await self.bot.api_client.get( + "bot/infractions", + params={"user__id": str(user.id)} + ) + total = len(infractions) + end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" + + # Execute the necessary actions to apply the infraction on Discord. + if action_coro: + log.trace(f"Awaiting the infraction #{id_} application action coroutine.") + try: + await action_coro + if expiry: + # Schedule the expiration of the infraction. + self.schedule_expiration(infraction) + except discord.HTTPException as e: + # Accordingly display that applying the infraction failed. + confirm_msg = ":x: failed to apply" + expiry_msg = "" + log_content = ctx.author.mention + log_title = "failed to apply" + + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) + failed = True + + if failed: + log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") + try: + await self.bot.api_client.delete(f"bot/infractions/{id_}") + except ResponseCodeError as e: + confirm_msg += " and failed to delete" + log_title += " and failed to delete" + log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") + infr_message = "" + else: + infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" + + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{id_} confirmation message.") + await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") + + # Send a log message to the mod log. + log.trace(f"Sending apply mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=icon, + colour=Colours.soft_red, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {user.mention} (`{user.id}`) + Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} + Reason: {reason} + """), + content=log_content, + footer=f"ID {infraction['id']}" + ) + + log.info(f"Applied {infr_type} infraction #{id_} to {user}.") + + async def pardon_infraction( + self, + ctx: Context, + infr_type: str, + user: UserSnowflake, + send_msg: bool = True + ) -> None: + """ + Prematurely end an infraction for a user and log the action in the mod log. + + If `send_msg` is True, then a pardoning confirmation message will be sent to + the context channel. Otherwise, no such message will be sent. + """ + log.trace(f"Pardoning {infr_type} infraction for {user}.") + + # Check the current active infraction + log.trace(f"Fetching active {infr_type} infractions for {user}.") + response = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': user.id + } + ) + + if not response: + log.debug(f"No active {infr_type} infraction found for {user}.") + await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") + return + + # Deactivate the infraction and cancel its scheduled expiration task. + log_text = await self.deactivate_infraction(response[0], send_log=False) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["Actor"] = str(ctx.message.author) + log_content = None + id_ = response[0]['id'] + footer = f"ID: {id_}" + + # If multiple active infractions were found, mark them as inactive in the database + # and cancel their expiration tasks. + if len(response) > 1: + log.info( + f"Found more than one active {infr_type} infraction for user {user.id}; " + "deactivating the extra active infractions too." + ) + + footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" + + log_note = f"Found multiple **active** {infr_type} infractions in the database." + if "Note" in log_text: + log_text["Note"] = f" {log_note}" + else: + log_text["Note"] = log_note + + # deactivate_infraction() is not called again because: + # 1. Discord cannot store multiple active bans or assign multiples of the same role + # 2. It would send a pardon DM for each active infraction, which is redundant + for infraction in response[1:]: + id_ = infraction['id'] + try: + # Mark infraction as inactive in the database. + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError: + log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") + # This is simpler and cleaner than trying to concatenate all the errors. + log_text["Failure"] = "See bot's logs for details." + + # Cancel pending expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Accordingly display whether the user was successfully notified via DM. + dm_emoji = "" + if log_text.get("DM") == "Sent": + dm_emoji = ":incoming_envelope: " + elif "DM" in log_text: + dm_emoji = f"{constants.Emojis.failmail} " + + # Accordingly display whether the pardon failed. + if "Failure" in log_text: + confirm_msg = ":x: failed to pardon" + log_title = "pardon failed" + log_content = ctx.author.mention + + log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") + else: + confirm_msg = ":ok_hand: pardoned" + log_title = "pardoned" + + log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") + + # Send a confirmation message to the invoking context. + if send_msg: + log.trace(f"Sending infraction #{id_} pardon confirmation message.") + await ctx.send( + f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " + f"{log_text.get('Failure', '')}" + ) + + # Move reason to end of entry to avoid cutting out some keys + log_text["Reason"] = log_text.pop("Reason") + + # Send a log message to the mod log. + await self.mod_log.send_log_message( + icon_url=utils.INFRACTION_ICONS[infr_type][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=footer, + content=log_content, + ) + + async def deactivate_infraction( + self, + infraction: utils.Infraction, + send_log: bool = True + ) -> t.Dict[str, str]: + """ + Deactivate an active infraction and return a dictionary of lines to send in a mod log. + + The infraction is removed from Discord, marked as inactive in the database, and has its + expiration task cancelled. If `send_log` is True, a mod log is sent for the + deactivation of the infraction. + + Infractions of unsupported types will raise a ValueError. + """ + guild = self.bot.get_guild(constants.Guild.id) + mod_role = guild.get_role(constants.Roles.moderators) + user_id = infraction["user"] + actor = infraction["actor"] + type_ = infraction["type"] + id_ = infraction["id"] + inserted_at = infraction["inserted_at"] + expiry = infraction["expires_at"] + + log.info(f"Marking infraction #{id_} as inactive (expired).") + + expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None + created = time.format_infraction_with_duration(inserted_at, expiry) + + log_content = None + log_text = { + "Member": f"<@{user_id}>", + "Actor": str(self.bot.get_user(actor) or actor), + "Reason": infraction["reason"], + "Created": created, + } + + try: + log.trace("Awaiting the pardon action coroutine.") + returned_log = await self._pardon_action(infraction) + + if returned_log is not None: + log_text = {**log_text, **returned_log} # Merge the logs together + else: + raise ValueError( + f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" + ) + except discord.Forbidden: + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") + log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" + log_content = mod_role.mention + except discord.HTTPException as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." + log_content = mod_role.mention + + # Check if the user is currently being watched by Big Brother. + try: + log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") + + active_watch = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "watch", + "user__id": user_id + } + ) + + log_text["Watching"] = "Yes" if active_watch else "No" + except ResponseCodeError: + log.exception(f"Failed to fetch watch status for user {user_id}") + log_text["Watching"] = "Unknown - failed to fetch watch status." + + try: + # Mark infraction as inactive in the database. + log.trace(f"Marking infraction #{id_} as inactive in the database.") + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_line = f"API request failed with code {e.status}." + log_content = mod_role.mention + + # Append to an existing failure message if possible + if "Failure" in log_text: + log_text["Failure"] += f" {log_line}" + else: + log_text["Failure"] = log_line + + # Cancel the expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Send a log message to the mod log. + if send_log: + log_title = "expiration failed" if "Failure" in log_text else "expired" + + user = self.bot.get_user(user_id) + avatar = user.avatar_url_as(static_format="png") if user else None + + # Move reason to end so when reason is too long, this is not gonna cut out required items. + log_text["Reason"] = log_text.pop("Reason") + + log.trace(f"Sending deactivation mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=utils.INFRACTION_ICONS[type_][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {type_}", + thumbnail=avatar, + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=f"ID: {id_}", + content=log_content, + ) + + return log_text + + @abstractmethod + async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + raise NotImplementedError + + def schedule_expiration(self, infraction: utils.Infraction) -> None: + """ + Marks an infraction expired after the delay from time of scheduling to time of expiration. + + At the time of expiration, the infraction is marked as inactive on the website and the + expiration task is cancelled. + """ + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/infraction/superstarify.py b/bot/cogs/moderation/infraction/superstarify.py new file mode 100644 index 000000000..867de815a --- /dev/null +++ b/bot/cogs/moderation/infraction/superstarify.py @@ -0,0 +1,239 @@ +import json +import logging +import random +import textwrap +import typing as t +from pathlib import Path + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, command + +from bot import constants +from bot.bot import Bot +from bot.converters import Expiry +from bot.utils.checks import with_role_check +from bot.utils.time import format_infraction +from . import utils +from .scheduler import InfractionScheduler + +log = logging.getLogger(__name__) +NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" + +with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file: + STAR_NAMES = json.load(stars_file) + + +class Superstarify(InfractionScheduler, Cog): + """A set of commands to moderate terrible nicknames.""" + + def __init__(self, bot: Bot): + super().__init__(bot, supported_infractions={"superstar"}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Revert nickname edits if the user has an active superstarify infraction.""" + if before.display_name == after.display_name: + return # User didn't change their nickname. Abort! + + log.trace( + f"{before} ({before.display_name}) is trying to change their nickname to " + f"{after.display_name}. Checking if the user is in superstar-prison..." + ) + + active_superstarifies = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "superstar", + "user__id": str(before.id) + } + ) + + if not active_superstarifies: + log.trace(f"{before} has no active superstar infractions.") + return + + infraction = active_superstarifies[0] + forced_nick = self.get_nick(infraction["id"], before.id) + if after.display_name == forced_nick: + return # Nick change was triggered by this event. Ignore. + + log.info( + f"{after.display_name} ({after.id}) tried to escape superstar prison. " + f"Changing the nick back to {before.display_name}." + ) + await after.edit( + nick=forced_nick, + reason=f"Superstarified member tried to escape the prison: {infraction['id']}" + ) + + notified = await utils.notify_infraction( + user=after, + infr_type="Superstarify", + expires_at=format_infraction(infraction["expires_at"]), + reason=( + "You have tried to change your nickname on the **Python Discord** server " + f"from **{before.display_name}** to **{after.display_name}**, but as you " + "are currently in superstar-prison, you do not have permission to do so." + ), + icon_url=utils.INFRACTION_ICONS["superstar"][0] + ) + + if not notified: + log.info("Failed to DM user about why they cannot change their nickname.") + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Reapply active superstar infractions for returning members.""" + active_superstarifies = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "superstar", + "user__id": member.id + } + ) + + if active_superstarifies: + infraction = active_superstarifies[0] + action = member.edit( + nick=self.get_nick(infraction["id"], member.id), + reason=f"Superstarified member tried to escape the prison: {infraction['id']}" + ) + + await self.reapply_infraction(infraction, action) + + @command(name="superstarify", aliases=("force_nick", "star")) + async def superstarify( + self, + ctx: Context, + member: Member, + duration: Expiry, + *, + reason: str = None, + ) -> None: + """ + Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + + An optional reason can be provided. If no reason is given, the original name will be shown + in a generated reason. + """ + if await utils.get_active_infraction(ctx, member, "superstar"): + return + + # Post the infraction to the API + reason = reason or f"old nick: {member.display_name}" + infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) + id_ = infraction["id"] + + old_nick = member.display_name + forced_nick = self.get_nick(id_, member.id) + expiry_str = format_infraction(infraction["expires_at"]) + + # Apply the infraction and schedule the expiration task. + log.debug(f"Changing nickname of {member} to {forced_nick}.") + self.mod_log.ignore(constants.Event.member_update, member.id) + await member.edit(nick=forced_nick, reason=reason) + self.schedule_expiration(infraction) + + # Send a DM to the user to notify them of their new infraction. + await utils.notify_infraction( + user=member, + infr_type="Superstarify", + expires_at=expiry_str, + icon_url=utils.INFRACTION_ICONS["superstar"][0], + reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." + ) + + # Send an embed with the infraction information to the invoking context. + log.trace(f"Sending superstar #{id_} embed.") + embed = Embed( + title="Congratulations!", + colour=constants.Colours.soft_orange, + description=( + f"Your previous nickname, **{old_nick}**, " + f"was so bad that we have decided to change it. " + f"Your new nickname will be **{forced_nick}**.\n\n" + f"You will be unable to change your nickname until **{expiry_str}**.\n\n" + "If you're confused by this, please read our " + f"[official nickname policy]({NICKNAME_POLICY_URL})." + ) + ) + await ctx.send(embed=embed) + + # Log to the mod log channel. + log.trace(f"Sending apply mod log for superstar #{id_}.") + await self.mod_log.send_log_message( + icon_url=utils.INFRACTION_ICONS["superstar"][0], + colour=Colour.gold(), + title="Member achieved superstardom", + thumbnail=member.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {member.mention} (`{member.id}`) + Actor: {ctx.message.author} + Expires: {expiry_str} + Old nickname: `{old_nick}` + New nickname: `{forced_nick}` + Reason: {reason} + """), + footer=f"ID {id_}" + ) + + @command(name="unsuperstarify", aliases=("release_nick", "unstar")) + async def unsuperstarify(self, ctx: Context, member: Member) -> None: + """Remove the superstarify infraction and allow the user to change their nickname.""" + await self.pardon_infraction(ctx, "superstar", member) + + async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """Pardon a superstar infraction and return a log dict.""" + if infraction["type"] != "superstar": + return + + guild = self.bot.get_guild(constants.Guild.id) + user = guild.get_member(infraction["user"]) + + # Don't bother sending a notification if the user left the guild. + if not user: + log.debug( + "User left the guild and therefore won't be notified about superstar " + f"{infraction['id']} pardon." + ) + return {} + + # DM the user about the expiration. + notified = await utils.notify_pardon( + user=user, + title="You are no longer superstarified", + content="You may now change your nickname on the server.", + icon_url=utils.INFRACTION_ICONS["superstar"][1] + ) + + return { + "Member": f"{user.mention}(`{user.id}`)", + "DM": "Sent" if notified else "**Failed**" + } + + @staticmethod + def get_nick(infraction_id: int, member_id: int) -> str: + """Randomly select a nickname from the Superstarify nickname list.""" + log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") + + rng = random.Random(str(infraction_id) + str(member_id)) + return rng.choice(STAR_NAMES) + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) diff --git a/bot/cogs/moderation/infraction/utils.py b/bot/cogs/moderation/infraction/utils.py new file mode 100644 index 000000000..fb55287b6 --- /dev/null +++ b/bot/cogs/moderation/infraction/utils.py @@ -0,0 +1,201 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext.commands import Context + +from bot.api import ResponseCodeError +from bot.constants import Colours, Icons + +log = logging.getLogger(__name__) + +# apply icon, pardon icon +INFRACTION_ICONS = { + "ban": (Icons.user_ban, Icons.user_unban), + "kick": (Icons.sign_out, None), + "mute": (Icons.user_mute, Icons.user_unmute), + "note": (Icons.user_warn, None), + "superstar": (Icons.superstarify, Icons.unsuperstarify), + "warning": (Icons.user_warn, None), +} +RULES_URL = "https://pythondiscord.com/pages/rules" +APPEALABLE_INFRACTIONS = ("ban", "mute") + +# Type aliases +UserObject = t.Union[discord.Member, discord.User] +UserSnowflake = t.Union[UserObject, discord.Object] +Infraction = t.Dict[str, t.Union[str, int, bool]] + + +async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: + """ + Create a new user in the database. + + Used when an infraction needs to be applied on a user absent in the guild. + """ + log.trace(f"Attempting to add user {user.id} to the database.") + + if not isinstance(user, (discord.Member, discord.User)): + log.debug("The user being added to the DB is not a Member or User object.") + + payload = { + 'discriminator': int(getattr(user, 'discriminator', 0)), + 'id': user.id, + 'in_guild': False, + 'name': getattr(user, 'name', 'Name unknown'), + 'roles': [] + } + + try: + response = await ctx.bot.api_client.post('bot/users', json=payload) + log.info(f"User {user.id} added to the DB.") + return response + except ResponseCodeError as e: + log.error(f"Failed to add user {user.id} to the DB. {e}") + await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") + + +async def post_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + reason: str, + expires_at: datetime = None, + hidden: bool = False, + active: bool = True +) -> t.Optional[dict]: + """Posts an infraction to the API.""" + log.trace(f"Posting {infr_type} infraction for {user} to the API.") + + payload = { + "actor": ctx.message.author.id, + "hidden": hidden, + "reason": reason, + "type": infr_type, + "user": user.id, + "active": active + } + if expires_at: + payload['expires_at'] = expires_at.isoformat() + + # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. + for should_post_user in (True, False): + try: + response = await ctx.bot.api_client.post('bot/infractions', json=payload) + return response + except ResponseCodeError as e: + if e.status == 400 and 'user' in e.response_json: + # Only one attempt to add the user to the database, not two: + if not should_post_user or await post_user(ctx, user) is None: + return + else: + log.exception(f"Unexpected error while adding an infraction for {user}:") + await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") + return + + +async def get_active_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + send_msg: bool = True +) -> t.Optional[dict]: + """ + Retrieves an active infraction of the given type for the user. + + If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, + then a message for the moderator will be sent to the context channel letting them know. + Otherwise, no message will be sent. + """ + log.trace(f"Checking if {user} has active infractions of type {infr_type}.") + + active_infractions = await ctx.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': str(user.id) + } + ) + if active_infractions: + # Checks to see if the moderator should be told there is an active infraction + if send_msg: + log.trace(f"{user} has active infractions of type {infr_type}.") + await ctx.send( + f":x: According to my records, this user already has a {infr_type} infraction. " + f"See infraction **#{active_infractions[0]['id']}**." + ) + return active_infractions[0] + else: + log.trace(f"{user} does not have active infractions of type {infr_type}.") + + +async def notify_infraction( + user: UserObject, + infr_type: str, + expires_at: t.Optional[str] = None, + reason: t.Optional[str] = None, + icon_url: str = Icons.token_removed +) -> bool: + """DM a user about their new infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + + text = textwrap.dedent(f""" + **Type:** {infr_type.capitalize()} + **Expires:** {expires_at or "N/A"} + **Reason:** {reason or "No reason provided."} + """) + + embed = discord.Embed( + description=textwrap.shorten(text, width=2048, placeholder="..."), + colour=Colours.soft_red + ) + + embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) + embed.title = f"Please review our rules over at {RULES_URL}" + embed.url = RULES_URL + + if infr_type in APPEALABLE_INFRACTIONS: + embed.set_footer( + text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" + ) + + return await send_private_embed(user, embed) + + +async def notify_pardon( + user: UserObject, + title: str, + content: str, + icon_url: str = Icons.user_verified +) -> bool: + """DM a user about their pardoned infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their pardoned infraction.") + + embed = discord.Embed( + description=content, + colour=Colours.soft_green + ) + + embed.set_author(name=title, icon_url=icon_url) + + return await send_private_embed(user, embed) + + +async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: + """ + A helper method for sending an embed to a user's DMs. + + Returns a boolean indicator of DM success. + """ + try: + await user.send(embed=embed) + return True + except (discord.HTTPException, discord.Forbidden, discord.NotFound): + log.debug( + f"Infraction-related information could not be sent to user {user} ({user.id}). " + "The user either could not be retrieved or probably disabled their DMs." + ) + return False diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py deleted file mode 100644 index 8df642428..000000000 --- a/bot/cogs/moderation/infractions.py +++ /dev/null @@ -1,370 +0,0 @@ -import logging -import textwrap -import typing as t - -import discord -from discord import Member -from discord.ext import commands -from discord.ext.commands import Context, command - -from bot import constants -from bot.bot import Bot -from bot.constants import Event -from bot.converters import Expiry, FetchedMember -from bot.decorators import respect_role_hierarchy -from bot.utils.checks import with_role_check -from . import utils -from .scheduler import InfractionScheduler -from .utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class Infractions(InfractionScheduler, commands.Cog): - """Apply and pardon infractions on users for moderation purposes.""" - - category = "Moderation" - category_description = "Server moderation tools." - - def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) - - self.category = "Moderation" - self._muted_role = discord.Object(constants.Roles.muted) - - @commands.Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Reapply active mute infractions for returning members.""" - active_mutes = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "mute", - "user__id": member.id - } - ) - - if active_mutes: - reason = f"Re-applying active mute: {active_mutes[0]['id']}" - action = member.add_roles(self._muted_role, reason=reason) - - await self.reapply_infraction(active_mutes[0], action) - - # region: Permanent infractions - - @command() - async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Warn a user for the given reason.""" - infraction = await utils.post_infraction(ctx, user, "warning", reason, active=False) - if infraction is None: - return - - await self.apply_infraction(ctx, infraction, user) - - @command() - async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Kick a user for the given reason.""" - await self.apply_kick(ctx, user, reason) - - @command() - async def ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Permanently ban a user for the given reason and stop watching them with Big Brother.""" - await self.apply_ban(ctx, user, reason) - - # endregion - # region: Temporary infractions - - @command(aliases=["mute"]) - async def tempmute(self, ctx: Context, user: Member, duration: Expiry, *, reason: t.Optional[str] = None) -> None: - """ - Temporarily mute a user for the given reason and duration. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_mute(ctx, user, reason, expires_at=duration) - - @command() - async def tempban( - self, - ctx: Context, - user: FetchedMember, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily ban a user for the given reason and duration. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_ban(ctx, user, reason, expires_at=duration) - - # endregion - # region: Permanent shadow infractions - - @command(hidden=True) - async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Create a private note for a user with the given reason without notifying the user.""" - infraction = await utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) - if infraction is None: - return - - await self.apply_infraction(ctx, infraction, user) - - @command(hidden=True, aliases=['shadowkick', 'skick']) - async def shadow_kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Kick a user for the given reason without notifying the user.""" - await self.apply_kick(ctx, user, reason, hidden=True) - - @command(hidden=True, aliases=['shadowban', 'sban']) - async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Permanently ban a user for the given reason without notifying the user.""" - await self.apply_ban(ctx, user, reason, hidden=True) - - # endregion - # region: Temporary shadow infractions - - @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) - async def shadow_tempmute( - self, ctx: Context, - user: Member, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily mute a user for the given reason and duration without notifying the user. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True) - - @command(hidden=True, aliases=["shadowtempban, stempban"]) - async def shadow_tempban( - self, - ctx: Context, - user: FetchedMember, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily ban a user for the given reason and duration without notifying the user. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) - - # endregion - # region: Remove infractions (un- commands) - - @command() - async def unmute(self, ctx: Context, user: FetchedMember) -> None: - """Prematurely end the active mute infraction for the user.""" - await self.pardon_infraction(ctx, "mute", user) - - @command() - async def unban(self, ctx: Context, user: FetchedMember) -> None: - """Prematurely end the active ban infraction for the user.""" - await self.pardon_infraction(ctx, "ban", user) - - # endregion - # region: Base apply functions - - async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: - """Apply a mute infraction with kwargs passed to `post_infraction`.""" - if await utils.get_active_infraction(ctx, user, "mute"): - return - - infraction = await utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_update, user.id) - - async def action() -> None: - await user.add_roles(self._muted_role, reason=reason) - - log.trace(f"Attempting to kick {user} from voice because they've been muted.") - await user.move_to(None, reason=reason) - - await self.apply_infraction(ctx, infraction, user, action()) - - @respect_role_hierarchy() - async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: - """Apply a kick infraction with kwargs passed to `post_infraction`.""" - infraction = await utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_remove, user.id) - - if reason: - reason = textwrap.shorten(reason, width=512, placeholder="...") - - action = user.kick(reason=reason) - await self.apply_infraction(ctx, infraction, user, action) - - @respect_role_hierarchy() - async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: - """ - Apply a ban infraction with kwargs passed to `post_infraction`. - - Will also remove the banned user from the Big Brother watch list if applicable. - """ - # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active - is_temporary = kwargs.get("expires_at") is not None - active_infraction = await utils.get_active_infraction(ctx, user, "ban", is_temporary) - - if active_infraction: - if is_temporary: - log.trace("Tempban ignored as it cannot overwrite an active ban.") - return - - if active_infraction.get('expires_at') is None: - log.trace("Permaban already exists, notify.") - await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).") - return - - log.trace("Old tempban is being replaced by new permaban.") - await self.pardon_infraction(ctx, "ban", user, is_temporary) - - infraction = await utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_remove, user.id) - - if reason: - reason = textwrap.shorten(reason, width=512, placeholder="...") - - action = ctx.guild.ban(user, reason=reason, delete_message_days=0) - await self.apply_infraction(ctx, infraction, user, action) - - if infraction.get('expires_at') is not None: - log.trace(f"Ban isn't permanent; user {user} won't be unwatched by Big Brother.") - return - - bb_cog = self.bot.get_cog("Big Brother") - if not bb_cog: - log.error(f"Big Brother cog not loaded; perma-banned user {user} won't be unwatched.") - return - - log.trace(f"Big Brother cog loaded; attempting to unwatch perma-banned user {user}.") - - bb_reason = "User has been permanently banned from the server. Automatically removed." - await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) - - # endregion - # region: Base pardon functions - - async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's muted role, DM them a notification, and return a log dict.""" - user = guild.get_member(user_id) - log_text = {} - - if user: - # Remove the muted role. - self.mod_log.ignore(Event.member_update, user.id) - await user.remove_roles(self._muted_role, reason=reason) - - # DM the user about the expiration. - notified = await utils.notify_pardon( - user=user, - title="You have been unmuted", - content="You may now send messages in the server.", - icon_url=utils.INFRACTION_ICONS["mute"][1] - ) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["DM"] = "Sent" if notified else "**Failed**" - else: - log.info(f"Failed to unmute user {user_id}: user not found") - log_text["Failure"] = "User was not found in the guild." - - return log_text - - async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's ban on the Discord guild and return a log dict.""" - user = discord.Object(user_id) - log_text = {} - - self.mod_log.ignore(Event.member_unban, user_id) - - try: - await guild.unban(user, reason=reason) - except discord.NotFound: - log.info(f"Failed to unban user {user_id}: no active ban found on Discord") - log_text["Note"] = "No active ban found on Discord." - - return log_text - - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - guild = self.bot.get_guild(constants.Guild.id) - user_id = infraction["user"] - reason = f"Infraction #{infraction['id']} expired or was pardoned." - - if infraction["type"] == "mute": - return await self.pardon_mute(user_id, guild, reason) - elif infraction["type"] == "ban": - return await self.pardon_ban(user_id, guild, reason) - - # endregion - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Send a notification to the invoking context on a Union failure.""" - if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters or discord.Member in error.converters: - await ctx.send(str(error.errors[0])) - error.handled = True diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py deleted file mode 100644 index 672bb0e9c..000000000 --- a/bot/cogs/moderation/management.py +++ /dev/null @@ -1,305 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext import commands -from discord.ext.commands import Context - -from bot import constants -from bot.bot import Bot -from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user -from bot.pagination import LinePaginator -from bot.utils import time -from bot.utils.checks import in_whitelist_check, with_role_check -from . import utils -from .infractions import Infractions -from .modlog import ModLog - -log = logging.getLogger(__name__) - - -class ModManagement(commands.Cog): - """Management of infractions.""" - - category = "Moderation" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @property - def infractions_cog(self) -> Infractions: - """Get currently loaded Infractions cog instance.""" - return self.bot.get_cog("Infractions") - - # region: Edit infraction commands - - @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) - async def infraction_group(self, ctx: Context) -> None: - """Infraction manipulation commands.""" - await ctx.send_help(ctx.command) - - @infraction_group.command(name='edit') - async def infraction_edit( - self, - ctx: Context, - infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821 - duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821 - *, - reason: str = None - ) -> None: - """ - Edit the duration and/or the reason of an infraction. - - Durations are relative to the time of updating and should be appended with a unit of time. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction - authored by the command invoker should be edited. - - Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 - timestamp can be provided for the duration. - """ - if duration is None and reason is None: - # Unlike UserInputError, the error handler will show a specified message for BadArgument - raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") - - # Retrieve the previous infraction for its information. - if isinstance(infraction_id, str): - params = { - "actor__id": ctx.author.id, - "ordering": "-inserted_at" - } - infractions = await self.bot.api_client.get("bot/infractions", params=params) - - if infractions: - old_infraction = infractions[0] - infraction_id = old_infraction["id"] - else: - await ctx.send( - ":x: Couldn't find most recent infraction; you have never given an infraction." - ) - return - else: - old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") - - request_data = {} - confirm_messages = [] - log_text = "" - - if duration is not None and not old_infraction['active']: - if reason is None: - await ctx.send(":x: Cannot edit the expiration of an expired infraction.") - return - confirm_messages.append("expiry unchanged (infraction already expired)") - elif isinstance(duration, str): - request_data['expires_at'] = None - confirm_messages.append("marked as permanent") - elif duration is not None: - request_data['expires_at'] = duration.isoformat() - expiry = time.format_infraction_with_duration(request_data['expires_at']) - confirm_messages.append(f"set to expire on {expiry}") - else: - confirm_messages.append("expiry unchanged") - - if reason: - request_data['reason'] = reason - confirm_messages.append("set a new reason") - log_text += f""" - Previous reason: {old_infraction['reason']} - New reason: {reason} - """.rstrip() - else: - confirm_messages.append("reason unchanged") - - # Update the infraction - new_infraction = await self.bot.api_client.patch( - f'bot/infractions/{infraction_id}', - json=request_data, - ) - - # Re-schedule infraction if the expiration has been updated - if 'expires_at' in request_data: - # A scheduled task should only exist if the old infraction wasn't permanent - if old_infraction['expires_at']: - self.infractions_cog.scheduler.cancel(new_infraction['id']) - - # If the infraction was not marked as permanent, schedule a new expiration task - if request_data['expires_at']: - self.infractions_cog.schedule_expiration(new_infraction) - - log_text += f""" - Previous expiry: {old_infraction['expires_at'] or "Permanent"} - New expiry: {new_infraction['expires_at'] or "Permanent"} - """.rstrip() - - changes = ' & '.join(confirm_messages) - await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") - - # Get information about the infraction's user - user_id = new_infraction['user'] - user = ctx.guild.get_member(user_id) - - if user: - user_text = f"{user.mention} (`{user.id}`)" - thumbnail = user.avatar_url_as(static_format="png") - else: - user_text = f"`{user_id}`" - thumbnail = None - - # The infraction's actor - actor_id = new_infraction['actor'] - actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`" - - await self.mod_log.send_log_message( - icon_url=constants.Icons.pencil, - colour=discord.Colour.blurple(), - title="Infraction edited", - thumbnail=thumbnail, - text=textwrap.dedent(f""" - Member: {user_text} - Actor: {actor} - Edited by: {ctx.message.author}{log_text} - """) - ) - - # endregion - # region: Search infractions - - @infraction_group.group(name="search", invoke_without_command=True) - async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None: - """Searches for infractions in the database.""" - if isinstance(query, discord.User): - await ctx.invoke(self.search_user, query) - else: - await ctx.invoke(self.search_reason, query) - - @infraction_search_group.command(name="user", aliases=("member", "id")) - async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: - """Search for infractions by member.""" - infraction_list = await self.bot.api_client.get( - 'bot/infractions', - params={'user__id': str(user.id)} - ) - embed = discord.Embed( - title=f"Infractions for {user} ({len(infraction_list)} total)", - colour=discord.Colour.orange() - ) - await self.send_infraction_list(ctx, embed, infraction_list) - - @infraction_search_group.command(name="reason", aliases=("match", "regex", "re")) - async def search_reason(self, ctx: Context, reason: str) -> None: - """Search for infractions by their reason. Use Re2 for matching.""" - infraction_list = await self.bot.api_client.get( - 'bot/infractions', - params={'search': reason} - ) - embed = discord.Embed( - title=f"Infractions matching `{reason}` ({len(infraction_list)} total)", - colour=discord.Colour.orange() - ) - await self.send_infraction_list(ctx, embed, infraction_list) - - # endregion - # region: Utility functions - - async def send_infraction_list( - self, - ctx: Context, - embed: discord.Embed, - infractions: t.Iterable[utils.Infraction] - ) -> None: - """Send a paginated embed of infractions for the specified user.""" - if not infractions: - await ctx.send(":warning: No infractions could be found for that query.") - return - - lines = tuple( - self.infraction_to_string(infraction) - for infraction in infractions - ) - - await LinePaginator.paginate( - lines, - ctx=ctx, - embed=embed, - empty=True, - max_lines=3, - max_size=1000 - ) - - def infraction_to_string(self, infraction: utils.Infraction) -> str: - """Convert the infraction object to a string representation.""" - actor_id = infraction["actor"] - guild = self.bot.get_guild(constants.Guild.id) - actor = guild.get_member(actor_id) - active = infraction["active"] - user_id = infraction["user"] - hidden = infraction["hidden"] - created = time.format_infraction(infraction["inserted_at"]) - - if active: - remaining = time.until_expiration(infraction["expires_at"]) or "Expired" - else: - remaining = "Inactive" - - if infraction["expires_at"] is None: - expires = "*Permanent*" - else: - date_from = datetime.strptime(created, time.INFRACTION_FORMAT) - expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) - - lines = textwrap.dedent(f""" - {"**===============**" if active else "==============="} - Status: {"__**Active**__" if active else "Inactive"} - User: {self.bot.get_user(user_id)} (`{user_id}`) - Type: **{infraction["type"]}** - Shadow: {hidden} - Created: {created} - Expires: {expires} - Remaining: {remaining} - Actor: {actor.mention if actor else actor_id} - ID: `{infraction["id"]}` - Reason: {infraction["reason"] or "*None*"} - {"**===============**" if active else "==============="} - """) - - return lines.strip() - - # endregion - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators inside moderator channels to invoke the commands in this cog.""" - checks = [ - with_role_check(ctx, *constants.MODERATION_ROLES), - in_whitelist_check( - ctx, - channels=constants.MODERATION_CHANNELS, - categories=[constants.Categories.modmail], - redirect=None, - fail_silently=True, - ) - ] - return all(checks) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Send a notification to the invoking context on a Union failure.""" - if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters: - await ctx.send(str(error.errors[0])) - error.handled = True diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py deleted file mode 100644 index 75028d851..000000000 --- a/bot/cogs/moderation/scheduler.py +++ /dev/null @@ -1,463 +0,0 @@ -import logging -import textwrap -import typing as t -from abc import abstractmethod -from datetime import datetime -from gettext import ngettext - -import dateutil.parser -import discord -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Colours, STAFF_CHANNELS -from bot.utils import time -from bot.utils.scheduling import Scheduler -from . import utils -from .modlog import ModLog -from .utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class InfractionScheduler: - """Handles the application, pardoning, and expiration of infractions.""" - - def __init__(self, bot: Bot, supported_infractions: t.Container[str]): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - @property - def mod_log(self) -> ModLog: - """Get the currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: - """Schedule expiration for previous infractions.""" - await self.bot.wait_until_guild_available() - - log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") - - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={'active': 'true'} - ) - for infraction in infractions: - if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_expiration(infraction) - - async def reapply_infraction( - self, - infraction: utils.Infraction, - apply_coro: t.Optional[t.Awaitable] - ) -> None: - """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" - # Calculate the time remaining, in seconds, for the mute. - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - delta = (expiry - datetime.utcnow()).total_seconds() - - # Mark as inactive if less than a minute remains. - if delta < 60: - log.info( - "Infraction will be deactivated instead of re-applied " - "because less than 1 minute remains." - ) - await self.deactivate_infraction(infraction) - return - - # Allowing mod log since this is a passive action that should be logged. - await apply_coro - log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") - - async def apply_infraction( - self, - ctx: Context, - infraction: utils.Infraction, - user: UserSnowflake, - action_coro: t.Optional[t.Awaitable] = None - ) -> None: - """Apply an infraction to the user, log the infraction, and optionally notify the user.""" - infr_type = infraction["type"] - icon = utils.INFRACTION_ICONS[infr_type][0] - reason = infraction["reason"] - expiry = time.format_infraction_with_duration(infraction["expires_at"]) - id_ = infraction['id'] - - log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") - - # Default values for the confirmation message and mod log. - confirm_msg = ":ok_hand: applied" - - # Specifying an expiry for a note or warning makes no sense. - if infr_type in ("note", "warning"): - expiry_msg = "" - else: - expiry_msg = f" until {expiry}" if expiry else " permanently" - - dm_result = "" - dm_log_text = "" - expiry_log_text = f"\nExpires: {expiry}" if expiry else "" - log_title = "applied" - log_content = None - failed = False - - # DM the user about the infraction if it's not a shadow/hidden infraction. - # This needs to happen before we apply the infraction, as the bot cannot - # send DMs to user that it doesn't share a guild with. If we were to - # apply kick/ban infractions first, this would mean that we'd make it - # impossible for us to deliver a DM. See python-discord/bot#982. - if not infraction["hidden"]: - dm_result = f"{constants.Emojis.failmail} " - dm_log_text = "\nDM: **Failed**" - - # Sometimes user is a discord.Object; make it a proper user. - try: - if not isinstance(user, (discord.Member, discord.User)): - user = await self.bot.fetch_user(user.id) - except discord.HTTPException as e: - log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") - else: - # Accordingly display whether the user was successfully notified via DM. - if await utils.notify_infraction(user, infr_type, expiry, reason, icon): - dm_result = ":incoming_envelope: " - dm_log_text = "\nDM: Sent" - - end_msg = "" - if infraction["actor"] == self.bot.user.id: - log.trace( - f"Infraction #{id_} actor is bot; including the reason in the confirmation message." - ) - if reason: - end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" - elif ctx.channel.id not in STAFF_CHANNELS: - log.trace( - f"Infraction #{id_} context is not in a staff channel; omitting infraction count." - ) - else: - log.trace(f"Fetching total infraction count for {user}.") - - infractions = await self.bot.api_client.get( - "bot/infractions", - params={"user__id": str(user.id)} - ) - total = len(infractions) - end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" - - # Execute the necessary actions to apply the infraction on Discord. - if action_coro: - log.trace(f"Awaiting the infraction #{id_} application action coroutine.") - try: - await action_coro - if expiry: - # Schedule the expiration of the infraction. - self.schedule_expiration(infraction) - except discord.HTTPException as e: - # Accordingly display that applying the infraction failed. - confirm_msg = ":x: failed to apply" - expiry_msg = "" - log_content = ctx.author.mention - log_title = "failed to apply" - - log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" - if isinstance(e, discord.Forbidden): - log.warning(f"{log_msg}: bot lacks permissions.") - else: - log.exception(log_msg) - failed = True - - if failed: - log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") - try: - await self.bot.api_client.delete(f"bot/infractions/{id_}") - except ResponseCodeError as e: - confirm_msg += " and failed to delete" - log_title += " and failed to delete" - log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") - infr_message = "" - else: - infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" - - # Send a confirmation message to the invoking context. - log.trace(f"Sending infraction #{id_} confirmation message.") - await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") - - # Send a log message to the mod log. - log.trace(f"Sending apply mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=icon, - colour=Colours.soft_red, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {user.mention} (`{user.id}`) - Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} - Reason: {reason} - """), - content=log_content, - footer=f"ID {infraction['id']}" - ) - - log.info(f"Applied {infr_type} infraction #{id_} to {user}.") - - async def pardon_infraction( - self, - ctx: Context, - infr_type: str, - user: UserSnowflake, - send_msg: bool = True - ) -> None: - """ - Prematurely end an infraction for a user and log the action in the mod log. - - If `send_msg` is True, then a pardoning confirmation message will be sent to - the context channel. Otherwise, no such message will be sent. - """ - log.trace(f"Pardoning {infr_type} infraction for {user}.") - - # Check the current active infraction - log.trace(f"Fetching active {infr_type} infractions for {user}.") - response = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': user.id - } - ) - - if not response: - log.debug(f"No active {infr_type} infraction found for {user}.") - await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") - return - - # Deactivate the infraction and cancel its scheduled expiration task. - log_text = await self.deactivate_infraction(response[0], send_log=False) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["Actor"] = str(ctx.message.author) - log_content = None - id_ = response[0]['id'] - footer = f"ID: {id_}" - - # If multiple active infractions were found, mark them as inactive in the database - # and cancel their expiration tasks. - if len(response) > 1: - log.info( - f"Found more than one active {infr_type} infraction for user {user.id}; " - "deactivating the extra active infractions too." - ) - - footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" - - log_note = f"Found multiple **active** {infr_type} infractions in the database." - if "Note" in log_text: - log_text["Note"] = f" {log_note}" - else: - log_text["Note"] = log_note - - # deactivate_infraction() is not called again because: - # 1. Discord cannot store multiple active bans or assign multiples of the same role - # 2. It would send a pardon DM for each active infraction, which is redundant - for infraction in response[1:]: - id_ = infraction['id'] - try: - # Mark infraction as inactive in the database. - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError: - log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") - # This is simpler and cleaner than trying to concatenate all the errors. - log_text["Failure"] = "See bot's logs for details." - - # Cancel pending expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Accordingly display whether the user was successfully notified via DM. - dm_emoji = "" - if log_text.get("DM") == "Sent": - dm_emoji = ":incoming_envelope: " - elif "DM" in log_text: - dm_emoji = f"{constants.Emojis.failmail} " - - # Accordingly display whether the pardon failed. - if "Failure" in log_text: - confirm_msg = ":x: failed to pardon" - log_title = "pardon failed" - log_content = ctx.author.mention - - log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") - else: - confirm_msg = ":ok_hand: pardoned" - log_title = "pardoned" - - log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") - - # Send a confirmation message to the invoking context. - if send_msg: - log.trace(f"Sending infraction #{id_} pardon confirmation message.") - await ctx.send( - f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " - f"{log_text.get('Failure', '')}" - ) - - # Move reason to end of entry to avoid cutting out some keys - log_text["Reason"] = log_text.pop("Reason") - - # Send a log message to the mod log. - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[infr_type][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=footer, - content=log_content, - ) - - async def deactivate_infraction( - self, - infraction: utils.Infraction, - send_log: bool = True - ) -> t.Dict[str, str]: - """ - Deactivate an active infraction and return a dictionary of lines to send in a mod log. - - The infraction is removed from Discord, marked as inactive in the database, and has its - expiration task cancelled. If `send_log` is True, a mod log is sent for the - deactivation of the infraction. - - Infractions of unsupported types will raise a ValueError. - """ - guild = self.bot.get_guild(constants.Guild.id) - mod_role = guild.get_role(constants.Roles.moderators) - user_id = infraction["user"] - actor = infraction["actor"] - type_ = infraction["type"] - id_ = infraction["id"] - inserted_at = infraction["inserted_at"] - expiry = infraction["expires_at"] - - log.info(f"Marking infraction #{id_} as inactive (expired).") - - expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None - created = time.format_infraction_with_duration(inserted_at, expiry) - - log_content = None - log_text = { - "Member": f"<@{user_id}>", - "Actor": str(self.bot.get_user(actor) or actor), - "Reason": infraction["reason"], - "Created": created, - } - - try: - log.trace("Awaiting the pardon action coroutine.") - returned_log = await self._pardon_action(infraction) - - if returned_log is not None: - log_text = {**log_text, **returned_log} # Merge the logs together - else: - raise ValueError( - f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" - ) - except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") - log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" - log_content = mod_role.mention - except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." - log_content = mod_role.mention - - # Check if the user is currently being watched by Big Brother. - try: - log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") - - active_watch = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "watch", - "user__id": user_id - } - ) - - log_text["Watching"] = "Yes" if active_watch else "No" - except ResponseCodeError: - log.exception(f"Failed to fetch watch status for user {user_id}") - log_text["Watching"] = "Unknown - failed to fetch watch status." - - try: - # Mark infraction as inactive in the database. - log.trace(f"Marking infraction #{id_} as inactive in the database.") - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_line = f"API request failed with code {e.status}." - log_content = mod_role.mention - - # Append to an existing failure message if possible - if "Failure" in log_text: - log_text["Failure"] += f" {log_line}" - else: - log_text["Failure"] = log_line - - # Cancel the expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Send a log message to the mod log. - if send_log: - log_title = "expiration failed" if "Failure" in log_text else "expired" - - user = self.bot.get_user(user_id) - avatar = user.avatar_url_as(static_format="png") if user else None - - # Move reason to end so when reason is too long, this is not gonna cut out required items. - log_text["Reason"] = log_text.pop("Reason") - - log.trace(f"Sending deactivation mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[type_][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {type_}", - thumbnail=avatar, - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=f"ID: {id_}", - content=log_content, - ) - - return log_text - - @abstractmethod - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - raise NotImplementedError - - def schedule_expiration(self, infraction: utils.Infraction) -> None: - """ - Marks an infraction expired after the delay from time of scheduling to time of expiration. - - At the time of expiration, the infraction is marked as inactive on the website and the - expiration task is cancelled. - """ - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py deleted file mode 100644 index 867de815a..000000000 --- a/bot/cogs/moderation/superstarify.py +++ /dev/null @@ -1,239 +0,0 @@ -import json -import logging -import random -import textwrap -import typing as t -from pathlib import Path - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, command - -from bot import constants -from bot.bot import Bot -from bot.converters import Expiry -from bot.utils.checks import with_role_check -from bot.utils.time import format_infraction -from . import utils -from .scheduler import InfractionScheduler - -log = logging.getLogger(__name__) -NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" - -with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file: - STAR_NAMES = json.load(stars_file) - - -class Superstarify(InfractionScheduler, Cog): - """A set of commands to moderate terrible nicknames.""" - - def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"superstar"}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Revert nickname edits if the user has an active superstarify infraction.""" - if before.display_name == after.display_name: - return # User didn't change their nickname. Abort! - - log.trace( - f"{before} ({before.display_name}) is trying to change their nickname to " - f"{after.display_name}. Checking if the user is in superstar-prison..." - ) - - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "superstar", - "user__id": str(before.id) - } - ) - - if not active_superstarifies: - log.trace(f"{before} has no active superstar infractions.") - return - - infraction = active_superstarifies[0] - forced_nick = self.get_nick(infraction["id"], before.id) - if after.display_name == forced_nick: - return # Nick change was triggered by this event. Ignore. - - log.info( - f"{after.display_name} ({after.id}) tried to escape superstar prison. " - f"Changing the nick back to {before.display_name}." - ) - await after.edit( - nick=forced_nick, - reason=f"Superstarified member tried to escape the prison: {infraction['id']}" - ) - - notified = await utils.notify_infraction( - user=after, - infr_type="Superstarify", - expires_at=format_infraction(infraction["expires_at"]), - reason=( - "You have tried to change your nickname on the **Python Discord** server " - f"from **{before.display_name}** to **{after.display_name}**, but as you " - "are currently in superstar-prison, you do not have permission to do so." - ), - icon_url=utils.INFRACTION_ICONS["superstar"][0] - ) - - if not notified: - log.info("Failed to DM user about why they cannot change their nickname.") - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Reapply active superstar infractions for returning members.""" - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "superstar", - "user__id": member.id - } - ) - - if active_superstarifies: - infraction = active_superstarifies[0] - action = member.edit( - nick=self.get_nick(infraction["id"], member.id), - reason=f"Superstarified member tried to escape the prison: {infraction['id']}" - ) - - await self.reapply_infraction(infraction, action) - - @command(name="superstarify", aliases=("force_nick", "star")) - async def superstarify( - self, - ctx: Context, - member: Member, - duration: Expiry, - *, - reason: str = None, - ) -> None: - """ - Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - - An optional reason can be provided. If no reason is given, the original name will be shown - in a generated reason. - """ - if await utils.get_active_infraction(ctx, member, "superstar"): - return - - # Post the infraction to the API - reason = reason or f"old nick: {member.display_name}" - infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) - id_ = infraction["id"] - - old_nick = member.display_name - forced_nick = self.get_nick(id_, member.id) - expiry_str = format_infraction(infraction["expires_at"]) - - # Apply the infraction and schedule the expiration task. - log.debug(f"Changing nickname of {member} to {forced_nick}.") - self.mod_log.ignore(constants.Event.member_update, member.id) - await member.edit(nick=forced_nick, reason=reason) - self.schedule_expiration(infraction) - - # Send a DM to the user to notify them of their new infraction. - await utils.notify_infraction( - user=member, - infr_type="Superstarify", - expires_at=expiry_str, - icon_url=utils.INFRACTION_ICONS["superstar"][0], - reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." - ) - - # Send an embed with the infraction information to the invoking context. - log.trace(f"Sending superstar #{id_} embed.") - embed = Embed( - title="Congratulations!", - colour=constants.Colours.soft_orange, - description=( - f"Your previous nickname, **{old_nick}**, " - f"was so bad that we have decided to change it. " - f"Your new nickname will be **{forced_nick}**.\n\n" - f"You will be unable to change your nickname until **{expiry_str}**.\n\n" - "If you're confused by this, please read our " - f"[official nickname policy]({NICKNAME_POLICY_URL})." - ) - ) - await ctx.send(embed=embed) - - # Log to the mod log channel. - log.trace(f"Sending apply mod log for superstar #{id_}.") - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS["superstar"][0], - colour=Colour.gold(), - title="Member achieved superstardom", - thumbnail=member.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {member.mention} (`{member.id}`) - Actor: {ctx.message.author} - Expires: {expiry_str} - Old nickname: `{old_nick}` - New nickname: `{forced_nick}` - Reason: {reason} - """), - footer=f"ID {id_}" - ) - - @command(name="unsuperstarify", aliases=("release_nick", "unstar")) - async def unsuperstarify(self, ctx: Context, member: Member) -> None: - """Remove the superstarify infraction and allow the user to change their nickname.""" - await self.pardon_infraction(ctx, "superstar", member) - - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """Pardon a superstar infraction and return a log dict.""" - if infraction["type"] != "superstar": - return - - guild = self.bot.get_guild(constants.Guild.id) - user = guild.get_member(infraction["user"]) - - # Don't bother sending a notification if the user left the guild. - if not user: - log.debug( - "User left the guild and therefore won't be notified about superstar " - f"{infraction['id']} pardon." - ) - return {} - - # DM the user about the expiration. - notified = await utils.notify_pardon( - user=user, - title="You are no longer superstarified", - content="You may now change your nickname on the server.", - icon_url=utils.INFRACTION_ICONS["superstar"][1] - ) - - return { - "Member": f"{user.mention}(`{user.id}`)", - "DM": "Sent" if notified else "**Failed**" - } - - @staticmethod - def get_nick(infraction_id: int, member_id: int) -> str: - """Randomly select a nickname from the Superstarify nickname list.""" - log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") - - rng = random.Random(str(infraction_id) + str(member_id)) - return rng.choice(STAR_NAMES) - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py deleted file mode 100644 index fb55287b6..000000000 --- a/bot/cogs/moderation/utils.py +++ /dev/null @@ -1,201 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext.commands import Context - -from bot.api import ResponseCodeError -from bot.constants import Colours, Icons - -log = logging.getLogger(__name__) - -# apply icon, pardon icon -INFRACTION_ICONS = { - "ban": (Icons.user_ban, Icons.user_unban), - "kick": (Icons.sign_out, None), - "mute": (Icons.user_mute, Icons.user_unmute), - "note": (Icons.user_warn, None), - "superstar": (Icons.superstarify, Icons.unsuperstarify), - "warning": (Icons.user_warn, None), -} -RULES_URL = "https://pythondiscord.com/pages/rules" -APPEALABLE_INFRACTIONS = ("ban", "mute") - -# Type aliases -UserObject = t.Union[discord.Member, discord.User] -UserSnowflake = t.Union[UserObject, discord.Object] -Infraction = t.Dict[str, t.Union[str, int, bool]] - - -async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: - """ - Create a new user in the database. - - Used when an infraction needs to be applied on a user absent in the guild. - """ - log.trace(f"Attempting to add user {user.id} to the database.") - - if not isinstance(user, (discord.Member, discord.User)): - log.debug("The user being added to the DB is not a Member or User object.") - - payload = { - 'discriminator': int(getattr(user, 'discriminator', 0)), - 'id': user.id, - 'in_guild': False, - 'name': getattr(user, 'name', 'Name unknown'), - 'roles': [] - } - - try: - response = await ctx.bot.api_client.post('bot/users', json=payload) - log.info(f"User {user.id} added to the DB.") - return response - except ResponseCodeError as e: - log.error(f"Failed to add user {user.id} to the DB. {e}") - await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") - - -async def post_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - reason: str, - expires_at: datetime = None, - hidden: bool = False, - active: bool = True -) -> t.Optional[dict]: - """Posts an infraction to the API.""" - log.trace(f"Posting {infr_type} infraction for {user} to the API.") - - payload = { - "actor": ctx.message.author.id, - "hidden": hidden, - "reason": reason, - "type": infr_type, - "user": user.id, - "active": active - } - if expires_at: - payload['expires_at'] = expires_at.isoformat() - - # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. - for should_post_user in (True, False): - try: - response = await ctx.bot.api_client.post('bot/infractions', json=payload) - return response - except ResponseCodeError as e: - if e.status == 400 and 'user' in e.response_json: - # Only one attempt to add the user to the database, not two: - if not should_post_user or await post_user(ctx, user) is None: - return - else: - log.exception(f"Unexpected error while adding an infraction for {user}:") - await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") - return - - -async def get_active_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - send_msg: bool = True -) -> t.Optional[dict]: - """ - Retrieves an active infraction of the given type for the user. - - If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, - then a message for the moderator will be sent to the context channel letting them know. - Otherwise, no message will be sent. - """ - log.trace(f"Checking if {user} has active infractions of type {infr_type}.") - - active_infractions = await ctx.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': str(user.id) - } - ) - if active_infractions: - # Checks to see if the moderator should be told there is an active infraction - if send_msg: - log.trace(f"{user} has active infractions of type {infr_type}.") - await ctx.send( - f":x: According to my records, this user already has a {infr_type} infraction. " - f"See infraction **#{active_infractions[0]['id']}**." - ) - return active_infractions[0] - else: - log.trace(f"{user} does not have active infractions of type {infr_type}.") - - -async def notify_infraction( - user: UserObject, - infr_type: str, - expires_at: t.Optional[str] = None, - reason: t.Optional[str] = None, - icon_url: str = Icons.token_removed -) -> bool: - """DM a user about their new infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their {infr_type} infraction.") - - text = textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """) - - embed = discord.Embed( - description=textwrap.shorten(text, width=2048, placeholder="..."), - colour=Colours.soft_red - ) - - embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) - embed.title = f"Please review our rules over at {RULES_URL}" - embed.url = RULES_URL - - if infr_type in APPEALABLE_INFRACTIONS: - embed.set_footer( - text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" - ) - - return await send_private_embed(user, embed) - - -async def notify_pardon( - user: UserObject, - title: str, - content: str, - icon_url: str = Icons.user_verified -) -> bool: - """DM a user about their pardoned infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their pardoned infraction.") - - embed = discord.Embed( - description=content, - colour=Colours.soft_green - ) - - embed.set_author(name=title, icon_url=icon_url) - - return await send_private_embed(user, embed) - - -async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: - """ - A helper method for sending an embed to a user's DMs. - - Returns a boolean indicator of DM success. - """ - try: - await user.send(embed=embed) - return True - except (discord.HTTPException, discord.Forbidden, discord.NotFound): - log.debug( - f"Infraction-related information could not be sent to user {user} ({user.id}). " - "The user either could not be retrieved or probably disabled their DMs." - ) - return False diff --git a/bot/cogs/moderation/verification.py b/bot/cogs/moderation/verification.py new file mode 100644 index 000000000..ae156cf70 --- /dev/null +++ b/bot/cogs/moderation/verification.py @@ -0,0 +1,191 @@ +import logging +from contextlib import suppress + +from discord import Colour, Forbidden, Message, NotFound, Object +from discord.ext.commands import Cog, Context, command + +from bot import constants +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.decorators import in_whitelist, without_role +from bot.utils.checks import InWhitelistCheckFailure, without_role_check + +log = logging.getLogger(__name__) + +WELCOME_MESSAGE = f""" +Hello! Welcome to the server, and thanks for verifying yourself! + +For your records, these are the documents you accepted: + +`1)` Our rules, here: +`2)` Our privacy policy, here: - you can find information on how to have \ +your information removed here as well. + +Feel free to review them at any point! + +Additionally, if you'd like to receive notifications for the announcements \ +we post in <#{constants.Channels.announcements}> +from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ +to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. + +If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ +<#{constants.Channels.bot_commands}>. +""" + +BOT_MESSAGE_DELETE_DELAY = 10 + + +class Verification(Cog): + """User verification and role self-management.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Check new message event for messages to the checkpoint channel & process.""" + if message.channel.id != constants.Channels.verification: + return # Only listen for #checkpoint messages + + if message.author.bot: + # They're a bot, delete their message after the delay. + await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) + return + + # if a user mentions a role or guild member + # alert the mods in mod-alerts channel + if message.mentions or message.role_mentions: + log.debug( + f"{message.author} mentioned one or more users " + f"and/or roles in {message.channel.name}" + ) + + embed_text = ( + f"{message.author.mention} sent a message in " + f"{message.channel.mention} that contained user and/or role mentions." + f"\n\n**Original message:**\n>>> {message.content}" + ) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=constants.Icons.filtering, + colour=Colour(constants.Colours.soft_red), + title=f"User/Role mentioned in {message.channel.name}", + text=embed_text, + thumbnail=message.author.avatar_url_as(static_format="png"), + channel_id=constants.Channels.mod_alerts, + ) + + ctx: Context = await self.bot.get_context(message) + if ctx.command is not None and ctx.command.name == "accept": + return + + if any(r.id == constants.Roles.verified for r in ctx.author.roles): + log.info( + f"{ctx.author} posted '{ctx.message.content}' " + "in the verification channel, but is already verified." + ) + return + + log.debug( + f"{ctx.author} posted '{ctx.message.content}' in the verification " + "channel. We are providing instructions how to verify." + ) + await ctx.send( + f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " + f"and gain access to the rest of the server.", + delete_after=20 + ) + + log.trace(f"Deleting the message posted by {ctx.author}") + with suppress(NotFound): + await ctx.message.delete() + + @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) + @without_role(constants.Roles.verified) + @in_whitelist(channels=(constants.Channels.verification,)) + async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Accept our rules and gain access to the rest of the server.""" + log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") + await ctx.author.add_roles(Object(constants.Roles.verified), reason="Accepted the rules") + try: + await ctx.author.send(WELCOME_MESSAGE) + except Forbidden: + log.info(f"Sending welcome message failed for {ctx.author}.") + finally: + log.trace(f"Deleting accept message by {ctx.author}.") + with suppress(NotFound): + self.mod_log.ignore(constants.Event.message_delete, ctx.message.id) + await ctx.message.delete() + + @command(name='subscribe') + @in_whitelist(channels=(constants.Channels.bot_commands,)) + async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Subscribe to announcement notifications by assigning yourself the role.""" + has_role = False + + for role in ctx.author.roles: + if role.id == constants.Roles.announcements: + has_role = True + break + + if has_role: + await ctx.send(f"{ctx.author.mention} You're already subscribed!") + return + + log.debug(f"{ctx.author} called !subscribe. Assigning the 'Announcements' role.") + await ctx.author.add_roles(Object(constants.Roles.announcements), reason="Subscribed to announcements") + + log.trace(f"Deleting the message posted by {ctx.author}.") + + await ctx.send( + f"{ctx.author.mention} Subscribed to <#{constants.Channels.announcements}> notifications.", + ) + + @command(name='unsubscribe') + @in_whitelist(channels=(constants.Channels.bot_commands,)) + async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Unsubscribe from announcement notifications by removing the role from yourself.""" + has_role = False + + for role in ctx.author.roles: + if role.id == constants.Roles.announcements: + has_role = True + break + + if not has_role: + await ctx.send(f"{ctx.author.mention} You're already unsubscribed!") + return + + log.debug(f"{ctx.author} called !unsubscribe. Removing the 'Announcements' role.") + await ctx.author.remove_roles(Object(constants.Roles.announcements), reason="Unsubscribed from announcements") + + log.trace(f"Deleting the message posted by {ctx.author}.") + + await ctx.send( + f"{ctx.author.mention} Unsubscribed from <#{constants.Channels.announcements}> notifications." + ) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Check for & ignore any InWhitelistCheckFailure.""" + if isinstance(error, InWhitelistCheckFailure): + error.handled = True + + @staticmethod + def bot_check(ctx: Context) -> bool: + """Block any command within the verification channel that is not !accept.""" + if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): + return ctx.command.name == "accept" + else: + return True + + +def setup(bot: Bot) -> None: + """Load the Verification cog.""" + bot.add_cog(Verification(bot)) diff --git a/bot/cogs/moderation/watchchannels/__init__.py b/bot/cogs/moderation/watchchannels/__init__.py new file mode 100644 index 000000000..69d118df6 --- /dev/null +++ b/bot/cogs/moderation/watchchannels/__init__.py @@ -0,0 +1,9 @@ +from bot.bot import Bot +from .bigbrother import BigBrother +from .talentpool import TalentPool + + +def setup(bot: Bot) -> None: + """Load the BigBrother and TalentPool cogs.""" + bot.add_cog(BigBrother(bot)) + bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/moderation/watchchannels/bigbrother.py b/bot/cogs/moderation/watchchannels/bigbrother.py new file mode 100644 index 000000000..0c72e88f7 --- /dev/null +++ b/bot/cogs/moderation/watchchannels/bigbrother.py @@ -0,0 +1,165 @@ +import logging +import textwrap +from collections import ChainMap + +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.cogs.moderation.infraction.utils import post_infraction +from bot.constants import Channels, MODERATION_ROLES, Webhooks +from bot.converters import FetchedMember +from bot.decorators import with_role +from .watchchannel import WatchChannel + +log = logging.getLogger(__name__) + + +class BigBrother(WatchChannel, Cog, name="Big Brother"): + """Monitors users by relaying their messages to a watch channel to assist with moderation.""" + + def __init__(self, bot: Bot) -> None: + super().__init__( + bot, + destination=Channels.big_brother_logs, + webhook_id=Webhooks.big_brother, + api_endpoint='bot/infractions', + api_default_params={'active': 'true', 'type': 'watch', 'ordering': '-inserted_at'}, + logger=log + ) + + @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def bigbrother_group(self, ctx: Context) -> None: + """Monitors users by relaying their messages to the Big Brother watch channel.""" + await ctx.send_help(ctx.command) + + @bigbrother_group.command(name='watched', aliases=('all', 'list')) + @with_role(*MODERATION_ROLES) + async def watched_command( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Shows the users that are currently being monitored by Big Brother. + + The optional kwarg `oldest_first` can be used to order the list by oldest watched. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) + + @bigbrother_group.command(name='oldest') + @with_role(*MODERATION_ROLES) + async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: + """ + Shows Big Brother monitored users ordered by oldest watched. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) + + @bigbrother_group.command(name='watch', aliases=('w',)) + @with_role(*MODERATION_ROLES) + async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Relay messages sent by the given `user` to the `#big-brother` channel. + + A `reason` for adding the user to Big Brother is required and will be displayed + in the header when relaying messages of this user to the watchchannel. + """ + await self.apply_watch(ctx, user, reason) + + @bigbrother_group.command(name='unwatch', aliases=('uw',)) + @with_role(*MODERATION_ROLES) + async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Stop relaying messages by the given `user`.""" + await self.apply_unwatch(ctx, user, reason) + + async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: + """ + Add `user` to watched users and apply a watch infraction with `reason`. + + A message indicating the result of the operation is sent to `ctx`. + The message will include `user`'s previous watch infraction history, if it exists. + """ + if user.bot: + await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") + return + + if not await self.fetch_user_cache(): + await ctx.send(f":x: Updating the user cache failed, can't watch user {user}") + return + + if user.id in self.watched_users: + await ctx.send(f":x: {user} is already being watched.") + return + + response = await post_infraction(ctx, user, 'watch', reason, hidden=True, active=True) + + if response is not None: + self.watched_users[user.id] = response + msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother." + + history = await self.bot.api_client.get( + self.api_endpoint, + params={ + "user__id": str(user.id), + "active": "false", + 'type': 'watch', + 'ordering': '-inserted_at' + } + ) + + if len(history) > 1: + total = f"({len(history) // 2} previous infractions in total)" + end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") + start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" + msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" + else: + msg = ":x: Failed to post the infraction: response was empty." + + await ctx.send(msg) + + async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: + """ + Remove `user` from watched users and mark their infraction as inactive with `reason`. + + If `send_message` is True, a message indicating the result of the operation is sent to + `ctx`. + """ + active_watches = await self.bot.api_client.get( + self.api_endpoint, + params=ChainMap( + self.api_default_params, + {"user__id": str(user.id)} + ) + ) + if active_watches: + log.trace("Active watches for user found. Attempting to remove.") + [infraction] = active_watches + + await self.bot.api_client.patch( + f"{self.api_endpoint}/{infraction['id']}", + json={'active': False} + ) + + await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False) + + self._remove_user(user.id) + + if not send_message: # Prevents a message being sent to the channel if part of a permanent ban + log.debug(f"Perma-banned user {user} was unwatched.") + return + log.trace("User is not banned. Sending message to channel") + message = f":white_check_mark: Messages sent by {user} will no longer be relayed." + + else: + log.trace("No active watches found for user.") + if not send_message: # Prevents a message being sent to the channel if part of a permanent ban + log.debug(f"{user} was not on the watch list; no removal necessary.") + return + log.trace("User is not perma banned. Send the error message.") + message = ":x: The specified user is currently not being watched." + + await ctx.send(message) diff --git a/bot/cogs/moderation/watchchannels/talentpool.py b/bot/cogs/moderation/watchchannels/talentpool.py new file mode 100644 index 000000000..89256e92e --- /dev/null +++ b/bot/cogs/moderation/watchchannels/talentpool.py @@ -0,0 +1,264 @@ +import logging +import textwrap +from collections import ChainMap + +from discord import Color, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks +from bot.converters import FetchedMember +from bot.decorators import with_role +from bot.pagination import LinePaginator +from bot.utils import time +from .watchchannel import WatchChannel + +log = logging.getLogger(__name__) + + +class TalentPool(WatchChannel, Cog, name="Talentpool"): + """Relays messages of helper candidates to a watch channel to observe them.""" + + def __init__(self, bot: Bot) -> None: + super().__init__( + bot, + destination=Channels.talent_pool, + webhook_id=Webhooks.talent_pool, + api_endpoint='bot/nominations', + api_default_params={'active': 'true', 'ordering': '-inserted_at'}, + logger=log, + ) + + @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def nomination_group(self, ctx: Context) -> None: + """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" + await ctx.send_help(ctx.command) + + @nomination_group.command(name='watched', aliases=('all', 'list')) + @with_role(*MODERATION_ROLES) + async def watched_command( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Shows the users that are currently being monitored in the talent pool. + + The optional kwarg `oldest_first` can be used to order the list by oldest nomination. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) + + @nomination_group.command(name='oldest') + @with_role(*MODERATION_ROLES) + async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: + """ + Shows talent pool monitored users ordered by oldest nomination. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) + + @nomination_group.command(name='watch', aliases=('w', 'add', 'a')) + @with_role(*STAFF_ROLES) + async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Relay messages sent by the given `user` to the `#talent-pool` channel. + + A `reason` for adding the user to the talent pool is required and will be displayed + in the header when relaying messages of this user to the channel. + """ + if user.bot: + await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") + return + + if isinstance(user, Member) and any(role.id in STAFF_ROLES for role in user.roles): + await ctx.send(":x: Nominating staff members, eh? Here's a cookie :cookie:") + return + + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update the user cache; can't add {user}") + return + + if user.id in self.watched_users: + await ctx.send(f":x: {user} is already being watched in the talent pool") + return + + # Manual request with `raise_for_status` as False because we want the actual response + session = self.bot.api_client.session + url = self.bot.api_client._url_for(self.api_endpoint) + kwargs = { + 'json': { + 'actor': ctx.author.id, + 'reason': reason, + 'user': user.id + }, + 'raise_for_status': False, + } + async with session.post(url, **kwargs) as resp: + response_data = await resp.json() + + if resp.status == 400 and response_data.get('user', False): + await ctx.send(":x: The specified user can't be found in the database tables") + return + else: + resp.raise_for_status() + + self.watched_users[user.id] = response_data + msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel" + + history = await self.bot.api_client.get( + self.api_endpoint, + params={ + "user__id": str(user.id), + "active": "false", + "ordering": "-inserted_at" + } + ) + + if history: + total = f"({len(history)} previous nominations in total)" + start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" + end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" + msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" + + await ctx.send(msg) + + @nomination_group.command(name='history', aliases=('info', 'search')) + @with_role(*MODERATION_ROLES) + async def history_command(self, ctx: Context, user: FetchedMember) -> None: + """Shows the specified user's nomination history.""" + result = await self.bot.api_client.get( + self.api_endpoint, + params={ + 'user__id': str(user.id), + 'ordering': "-active,-inserted_at" + } + ) + if not result: + await ctx.send(":warning: This user has never been nominated") + return + + embed = Embed( + title=f"Nominations for {user.display_name} `({user.id})`", + color=Color.blue() + ) + lines = [self._nomination_to_string(nomination) for nomination in result] + await LinePaginator.paginate( + lines, + ctx=ctx, + embed=embed, + empty=True, + max_lines=3, + max_size=1000 + ) + + @nomination_group.command(name='unwatch', aliases=('end', )) + @with_role(*MODERATION_ROLES) + async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Ends the active nomination of the specified user with the given reason. + + Providing a `reason` is required. + """ + active_nomination = await self.bot.api_client.get( + self.api_endpoint, + params=ChainMap( + self.api_default_params, + {"user__id": str(user.id)} + ) + ) + + if not active_nomination: + await ctx.send(":x: The specified user does not have an active nomination") + return + + [nomination] = active_nomination + await self.bot.api_client.patch( + f"{self.api_endpoint}/{nomination['id']}", + json={'end_reason': reason, 'active': False} + ) + await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") + self._remove_user(user.id) + + @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def nomination_edit_group(self, ctx: Context) -> None: + """Commands to edit nominations.""" + await ctx.send_help(ctx.command) + + @nomination_edit_group.command(name='reason') + @with_role(*MODERATION_ROLES) + async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None: + """ + Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. + + If the nomination is active, the reason for nominating the user will be edited; + If the nomination is no longer active, the reason for ending the nomination will be edited instead. + """ + try: + nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") + except ResponseCodeError as e: + if e.response.status == 404: + self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") + await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") + return + else: + raise + + field = "reason" if nomination["active"] else "end_reason" + + self.log.trace(f"Changing {field} for nomination with id {nomination_id} to {reason}") + + await self.bot.api_client.patch( + f"{self.api_endpoint}/{nomination_id}", + json={field: reason} + ) + + await ctx.send(f":white_check_mark: Updated the {field} of the nomination!") + + def _nomination_to_string(self, nomination_object: dict) -> str: + """Creates a string representation of a nomination.""" + guild = self.bot.get_guild(Guild.id) + + actor_id = nomination_object["actor"] + actor = guild.get_member(actor_id) + + active = nomination_object["active"] + log.debug(active) + log.debug(type(nomination_object["inserted_at"])) + + start_date = time.format_infraction(nomination_object["inserted_at"]) + if active: + lines = textwrap.dedent( + f""" + =============== + Status: **Active** + Date: {start_date} + Actor: {actor.mention if actor else actor_id} + Reason: {nomination_object["reason"]} + Nomination ID: `{nomination_object["id"]}` + =============== + """ + ) + else: + end_date = time.format_infraction(nomination_object["ended_at"]) + lines = textwrap.dedent( + f""" + =============== + Status: Inactive + Date: {start_date} + Actor: {actor.mention if actor else actor_id} + Reason: {nomination_object["reason"]} + + End date: {end_date} + Unwatch reason: {nomination_object["end_reason"]} + Nomination ID: `{nomination_object["id"]}` + =============== + """ + ) + + return lines.strip() diff --git a/bot/cogs/moderation/watchchannels/watchchannel.py b/bot/cogs/moderation/watchchannels/watchchannel.py new file mode 100644 index 000000000..044077350 --- /dev/null +++ b/bot/cogs/moderation/watchchannels/watchchannel.py @@ -0,0 +1,348 @@ +import asyncio +import logging +import re +import textwrap +from abc import abstractmethod +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Optional + +import dateutil.parser +import discord +from discord import Color, DMChannel, Embed, HTTPException, Message, errors +from discord.ext.commands import Cog, Context + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons +from bot.pagination import LinePaginator +from bot.utils import CogABCMeta, messages +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + +URL_RE = re.compile(r"(https?://[^\s]+)") + + +@dataclass +class MessageHistory: + """Represents a watch channel's message history.""" + + last_author: Optional[int] = None + last_channel: Optional[int] = None + message_count: int = 0 + + +class WatchChannel(metaclass=CogABCMeta): + """ABC with functionality for relaying users' messages to a certain channel.""" + + @abstractmethod + def __init__( + self, + bot: Bot, + destination: int, + webhook_id: int, + api_endpoint: str, + api_default_params: dict, + logger: logging.Logger + ) -> None: + self.bot = bot + + self.destination = destination # E.g., Channels.big_brother_logs + self.webhook_id = webhook_id # E.g., Webhooks.big_brother + self.api_endpoint = api_endpoint # E.g., 'bot/infractions' + self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} + self.log = logger # Logger of the child cog for a correct name in the logs + + self._consume_task = None + self.watched_users = defaultdict(dict) + self.message_queue = defaultdict(lambda: defaultdict(deque)) + self.consumption_queue = {} + self.retries = 5 + self.retry_delay = 10 + self.channel = None + self.webhook = None + self.message_history = MessageHistory() + + self._start = self.bot.loop.create_task(self.start_watchchannel()) + + @property + def modlog(self) -> ModLog: + """Provides access to the ModLog cog for alert purposes.""" + return self.bot.get_cog("ModLog") + + @property + def consuming_messages(self) -> bool: + """Checks if a consumption task is currently running.""" + if self._consume_task is None: + return False + + if self._consume_task.done(): + exc = self._consume_task.exception() + if exc: + self.log.exception( + "The message queue consume task has failed with:", + exc_info=exc + ) + return False + + return True + + async def start_watchchannel(self) -> None: + """Starts the watch channel by getting the channel, webhook, and user cache ready.""" + await self.bot.wait_until_guild_available() + + try: + self.channel = await self.bot.fetch_channel(self.destination) + except HTTPException: + self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + if self.channel is None or self.webhook is None: + self.log.error("Failed to start the watch channel; unloading the cog.") + + message = textwrap.dedent( + f""" + An error occurred while loading the text channel or webhook. + + TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} + Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} + + The Cog has been unloaded. + """ + ) + + await self.modlog.send_log_message( + title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", + text=message, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + if not await self.fetch_user_cache(): + await self.modlog.send_log_message( + title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", + text="Could not retrieve the list of watched users from the API and messages will not be relayed.", + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + async def fetch_user_cache(self) -> bool: + """ + Fetches watched users from the API and updates the watched user cache accordingly. + + This function returns `True` if the update succeeded. + """ + try: + data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) + except ResponseCodeError as err: + self.log.exception("Failed to fetch the watched users from the API", exc_info=err) + return False + + self.watched_users = defaultdict(dict) + + for entry in data: + user_id = entry.pop('user') + self.watched_users[user_id] = entry + + return True + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Queues up messages sent by watched users.""" + if msg.author.id in self.watched_users: + if not self.consuming_messages: + self._consume_task = self.bot.loop.create_task(self.consume_messages()) + + self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") + self.message_queue[msg.author.id][msg.channel.id].append(msg) + + async def consume_messages(self, delay_consumption: bool = True) -> None: + """Consumes the message queues to log watched users' messages.""" + if delay_consumption: + self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") + await asyncio.sleep(BigBrotherConfig.log_delay) + + self.log.trace("Started consuming the message queue") + + # If the previous consumption Task failed, first consume the existing comsumption_queue + if not self.consumption_queue: + self.consumption_queue = self.message_queue.copy() + self.message_queue.clear() + + for user_channel_queues in self.consumption_queue.values(): + for channel_queue in user_channel_queues.values(): + while channel_queue: + msg = channel_queue.popleft() + + self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") + await self.relay_message(msg) + + self.consumption_queue.clear() + + if self.message_queue: + self.log.trace("Channel queue not empty: Continuing consuming queues") + self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) + else: + self.log.trace("Done consuming messages.") + + async def webhook_send( + self, + content: Optional[str] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + embed: Optional[Embed] = None, + ) -> None: + """Sends a message to the webhook with the specified kwargs.""" + username = messages.sub_clyde(username) + try: + await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send a message to the webhook", + exc_info=exc + ) + + async def relay_message(self, msg: Message) -> None: + """Relays the message to the relevant watch channel.""" + limit = BigBrotherConfig.header_message_limit + + if ( + msg.author.id != self.message_history.last_author + or msg.channel.id != self.message_history.last_channel + or self.message_history.message_count >= limit + ): + self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) + + await self.send_header(msg) + + cleaned_content = msg.clean_content + + if cleaned_content: + # Put all non-media URLs in a code block to prevent embeds + media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} + for url in URL_RE.findall(cleaned_content): + if url not in media_urls: + cleaned_content = cleaned_content.replace(url, f"`{url}`") + await self.webhook_send( + cleaned_content, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + + if msg.attachments: + try: + await messages.send_attachments(msg, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.webhook_send( + embed=e, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send an attachment to the webhook", + exc_info=exc + ) + + self.message_history.message_count += 1 + + async def send_header(self, msg: Message) -> None: + """Sends a header embed with information about the relayed messages to the watch channel.""" + user_id = msg.author.id + + guild = self.bot.get_guild(GuildConfig.id) + actor = guild.get_member(self.watched_users[user_id]['actor']) + actor = actor.display_name if actor else self.watched_users[user_id]['actor'] + + inserted_at = self.watched_users[user_id]['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + + reason = self.watched_users[user_id]['reason'] + + if isinstance(msg.channel, DMChannel): + # If a watched user DMs the bot there won't be a channel name or jump URL + # This could technically include a GroupChannel but bot's can't be in those + message_jump = "via DM" + else: + message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + + footer = f"Added {time_delta} by {actor} | Reason: {reason}" + embed = Embed(description=f"{msg.author.mention} {message_jump}") + embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) + + await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) + + async def list_watched_users( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Gives an overview of the watched user list for this channel. + + The optional kwarg `oldest_first` orders the list by oldest entry. + + The optional kwarg `update_cache` specifies whether the cache should + be refreshed by polling the API. + """ + if update_cache: + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") + update_cache = False + + lines = [] + for user_id, user_data in self.watched_users.items(): + inserted_at = user_data['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + lines.append(f"• <@{user_id}> (added {time_delta})") + + if oldest_first: + lines.reverse() + + lines = lines or ("There's nothing here yet.",) + + embed = Embed( + title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", + color=Color.blue() + ) + await LinePaginator.paginate(lines, ctx, embed, empty=False) + + @staticmethod + def _get_time_delta(time_string: str) -> str: + """Returns the time in human-readable time delta format.""" + date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) + time_delta = time_since(date_time, precision="minutes", max_units=1) + + return time_delta + + def _remove_user(self, user_id: int) -> None: + """Removes a user from a watch channel.""" + self.watched_users.pop(user_id, None) + self.message_queue.pop(user_id, None) + self.consumption_queue.pop(user_id, None) + + def cog_unload(self) -> None: + """Takes care of unloading the cog and canceling the consumption task.""" + self.log.trace("Unloading the cog") + if self._consume_task and not self._consume_task.done(): + self._consume_task.cancel() + try: + self._consume_task.result() + except asyncio.CancelledError as e: + self.log.exception( + "The consume task was canceled. Messages may be lost.", + exc_info=e + ) diff --git a/bot/cogs/python_news.py b/bot/cogs/python_news.py deleted file mode 100644 index 0ab5738a4..000000000 --- a/bot/cogs/python_news.py +++ /dev/null @@ -1,232 +0,0 @@ -import logging -import typing as t -from datetime import date, datetime - -import discord -import feedparser -from bs4 import BeautifulSoup -from discord.ext.commands import Cog -from discord.ext.tasks import loop - -from bot import constants -from bot.bot import Bot -from bot.utils.webhooks import send_webhook - -PEPS_RSS_URL = "https://www.python.org/dev/peps/peps.rss/" - -RECENT_THREADS_TEMPLATE = "https://mail.python.org/archives/list/{name}@python.org/recent-threads" -THREAD_TEMPLATE_URL = "https://mail.python.org/archives/api/list/{name}@python.org/thread/{id}/" -MAILMAN_PROFILE_URL = "https://mail.python.org/archives/users/{id}/" -THREAD_URL = "https://mail.python.org/archives/list/{list}@python.org/thread/{id}/" - -AVATAR_URL = "https://www.python.org/static/opengraph-icon-200x200.png" - -log = logging.getLogger(__name__) - - -class PythonNews(Cog): - """Post new PEPs and Python News to `#python-news`.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_names = {} - self.webhook: t.Optional[discord.Webhook] = None - - self.bot.loop.create_task(self.get_webhook_names()) - self.bot.loop.create_task(self.get_webhook_and_channel()) - - async def start_tasks(self) -> None: - """Start the tasks for fetching new PEPs and mailing list messages.""" - self.fetch_new_media.start() - - @loop(minutes=20) - async def fetch_new_media(self) -> None: - """Fetch new mailing list messages and then new PEPs.""" - await self.post_maillist_news() - await self.post_pep_news() - - async def sync_maillists(self) -> None: - """Sync currently in-use maillists with API.""" - # Wait until guild is available to avoid running before everything is ready - await self.bot.wait_until_guild_available() - - response = await self.bot.api_client.get("bot/bot-settings/news") - for mail in constants.PythonNews.mail_lists: - if mail not in response["data"]: - response["data"][mail] = [] - - # Because we are handling PEPs differently, we don't include it to mail lists - if "pep" not in response["data"]: - response["data"]["pep"] = [] - - await self.bot.api_client.put("bot/bot-settings/news", json=response) - - async def get_webhook_names(self) -> None: - """Get webhook author names from maillist API.""" - await self.bot.wait_until_guild_available() - - async with self.bot.http_session.get("https://mail.python.org/archives/api/lists") as resp: - lists = await resp.json() - - for mail in lists: - if mail["name"].split("@")[0] in constants.PythonNews.mail_lists: - self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] - - async def post_pep_news(self) -> None: - """Fetch new PEPs and when they don't have announcement in #python-news, create it.""" - # Wait until everything is ready and http_session available - await self.bot.wait_until_guild_available() - await self.sync_maillists() - - async with self.bot.http_session.get(PEPS_RSS_URL) as resp: - data = feedparser.parse(await resp.text("utf-8")) - - news_listing = await self.bot.api_client.get("bot/bot-settings/news") - payload = news_listing.copy() - pep_numbers = news_listing["data"]["pep"] - - # Reverse entries to send oldest first - data["entries"].reverse() - for new in data["entries"]: - try: - new_datetime = datetime.strptime(new["published"], "%a, %d %b %Y %X %Z") - except ValueError: - log.warning(f"Wrong datetime format passed in PEP new: {new['published']}") - continue - pep_nr = new["title"].split(":")[0].split()[1] - if ( - pep_nr in pep_numbers - or new_datetime.date() < date.today() - ): - continue - - # Build an embed and send a webhook - embed = discord.Embed( - title=new["title"], - description=new["summary"], - timestamp=new_datetime, - url=new["link"], - colour=constants.Colours.soft_green - ) - embed.set_footer(text=data["feed"]["title"], icon_url=AVATAR_URL) - msg = await send_webhook( - webhook=self.webhook, - username=data["feed"]["title"], - embed=embed, - avatar_url=AVATAR_URL, - wait=True, - ) - payload["data"]["pep"].append(pep_nr) - - # Increase overall PEP new stat - self.bot.stats.incr("python_news.posted.pep") - - if msg.channel.is_news(): - log.trace("Publishing PEP annnouncement because it was in a news channel") - await msg.publish() - - # Apply new sent news to DB to avoid duplicate sending - await self.bot.api_client.put("bot/bot-settings/news", json=payload) - - async def post_maillist_news(self) -> None: - """Send new maillist threads to #python-news that is listed in configuration.""" - await self.bot.wait_until_guild_available() - await self.sync_maillists() - existing_news = await self.bot.api_client.get("bot/bot-settings/news") - payload = existing_news.copy() - - for maillist in constants.PythonNews.mail_lists: - async with self.bot.http_session.get(RECENT_THREADS_TEMPLATE.format(name=maillist)) as resp: - recents = BeautifulSoup(await resp.text(), features="lxml") - - # When a

element is present in the response then the mailing list - # has not had any activity during the current month, so therefore it - # can be ignored. - if recents.p: - continue - - for thread in recents.html.body.div.find_all("a", href=True): - # We want only these threads that have identifiers - if "latest" in thread["href"]: - continue - - thread_information, email_information = await self.get_thread_and_first_mail( - maillist, thread["href"].split("/")[-2] - ) - - try: - new_date = datetime.strptime(email_information["date"], "%Y-%m-%dT%X%z") - except ValueError: - log.warning(f"Invalid datetime from Thread email: {email_information['date']}") - continue - - if ( - thread_information["thread_id"] in existing_news["data"][maillist] - or 'Re: ' in thread_information["subject"] - or new_date.date() < date.today() - ): - continue - - content = email_information["content"] - link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) - - # Build an embed and send a message to the webhook - embed = discord.Embed( - title=thread_information["subject"], - description=content[:500] + f"... [continue reading]({link})" if len(content) > 500 else content, - timestamp=new_date, - url=link, - colour=constants.Colours.soft_green - ) - embed.set_author( - name=f"{email_information['sender_name']} ({email_information['sender']['address']})", - url=MAILMAN_PROFILE_URL.format(id=email_information["sender"]["mailman_id"]), - ) - embed.set_footer( - text=f"Posted to {self.webhook_names[maillist]}", - icon_url=AVATAR_URL, - ) - msg = await send_webhook( - webhook=self.webhook, - username=self.webhook_names[maillist], - embed=embed, - avatar_url=AVATAR_URL, - wait=True, - ) - payload["data"][maillist].append(thread_information["thread_id"]) - - # Increase this specific maillist counter in stats - self.bot.stats.incr(f"python_news.posted.{maillist.replace('-', '_')}") - - if msg.channel.is_news(): - log.trace("Publishing mailing list message because it was in a news channel") - await msg.publish() - - await self.bot.api_client.put("bot/bot-settings/news", json=payload) - - async def get_thread_and_first_mail(self, maillist: str, thread_identifier: str) -> t.Tuple[t.Any, t.Any]: - """Get mail thread and first mail from mail.python.org based on `maillist` and `thread_identifier`.""" - async with self.bot.http_session.get( - THREAD_TEMPLATE_URL.format(name=maillist, id=thread_identifier) - ) as resp: - thread_information = await resp.json() - - async with self.bot.http_session.get(thread_information["starting_email"]) as resp: - email_information = await resp.json() - return thread_information, email_information - - async def get_webhook_and_channel(self) -> None: - """Storage #python-news channel Webhook and `TextChannel` to `News.webhook` and `channel`.""" - await self.bot.wait_until_guild_available() - self.webhook = await self.bot.fetch_webhook(constants.PythonNews.webhook) - - await self.start_tasks() - - def cog_unload(self) -> None: - """Stop news posting tasks on cog unload.""" - self.fetch_new_media.cancel() - - -def setup(bot: Bot) -> None: - """Add `News` cog.""" - bot.add_cog(PythonNews(bot)) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py deleted file mode 100644 index d853ab2ea..000000000 --- a/bot/cogs/reddit.py +++ /dev/null @@ -1,304 +0,0 @@ -import asyncio -import logging -import random -import textwrap -from collections import namedtuple -from datetime import datetime, timedelta -from typing import List - -from aiohttp import BasicAuth, ClientError -from discord import Colour, Embed, TextChannel -from discord.ext.commands import Cog, Context, group -from discord.ext.tasks import loop - -from bot.bot import Bot -from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks -from bot.converters import Subreddit -from bot.decorators import with_role -from bot.pagination import LinePaginator -from bot.utils.messages import sub_clyde - -log = logging.getLogger(__name__) - -AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) - - -class Reddit(Cog): - """Track subreddit posts and show detailed statistics about them.""" - - HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} - URL = "https://www.reddit.com" - OAUTH_URL = "https://oauth.reddit.com" - MAX_RETRIES = 3 - - def __init__(self, bot: Bot): - self.bot = bot - - self.webhook = None - self.access_token = None - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - - bot.loop.create_task(self.init_reddit_ready()) - self.auto_poster_loop.start() - - def cog_unload(self) -> None: - """Stop the loop task and revoke the access token when the cog is unloaded.""" - self.auto_poster_loop.cancel() - if self.access_token and self.access_token.expires_at > datetime.utcnow(): - asyncio.create_task(self.revoke_access_token()) - - async def init_reddit_ready(self) -> None: - """Sets the reddit webhook when the cog is loaded.""" - await self.bot.wait_until_guild_available() - if not self.webhook: - self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) - - @property - def channel(self) -> TextChannel: - """Get the #reddit channel object from the bot's cache.""" - return self.bot.get_channel(Channels.reddit) - - async def get_access_token(self) -> None: - """ - Get a Reddit API OAuth2 access token and assign it to self.access_token. - - A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded and a ClientError raised if retrieval was still unsuccessful. - """ - for i in range(1, self.MAX_RETRIES + 1): - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "grant_type": "client_credentials", - "duration": "temporary" - } - ) - - if response.status == 200 and response.content_type == "application/json": - content = await response.json() - expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - self.access_token = AccessToken( - token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(seconds=expiration) - ) - - log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") - return - else: - log.debug( - f"Failed to get an access token: " - f"status {response.status} & content type {response.content_type}; " - f"retrying ({i}/{self.MAX_RETRIES})" - ) - - await asyncio.sleep(3) - - self.bot.remove_cog(self.qualified_name) - raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") - - async def revoke_access_token(self) -> None: - """ - Revoke the OAuth2 access token for the Reddit API. - - For security reasons, it's good practice to revoke the token when it's no longer being used. - """ - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/revoke_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "token": self.access_token.token, - "token_type_hint": "access_token" - } - ) - - if response.status == 204 and response.content_type == "application/json": - self.access_token = None - else: - log.warning(f"Unable to revoke access token: status {response.status}.") - - async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: - """A helper method to fetch a certain amount of Reddit posts at a given route.""" - # Reddit's JSON responses only provide 25 posts at most. - if not 25 >= amount > 0: - raise ValueError("Invalid amount of subreddit posts requested.") - - # Renew the token if necessary. - if not self.access_token or self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() - - url = f"{self.OAUTH_URL}/{route}" - for _ in range(self.MAX_RETRIES): - response = await self.bot.http_session.get( - url=url, - headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, - params=params - ) - if response.status == 200 and response.content_type == 'application/json': - # Got appropriate response - process and return. - content = await response.json() - posts = content["data"]["children"] - return posts[:amount] - - await asyncio.sleep(3) - - log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") - return list() # Failed to get appropriate response within allowed number of retries. - - async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: - """ - Get the top amount of posts for a given subreddit within a specified timeframe. - - A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top - weekly posts. - - The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. - """ - embed = Embed(description="") - - posts = await self.fetch_posts( - route=f"{subreddit}/top", - amount=amount, - params={"t": time} - ) - - if not posts: - embed.title = random.choice(ERROR_REPLIES) - embed.colour = Colour.red() - embed.description = ( - "Sorry! We couldn't find any posts from that subreddit. " - "If this problem persists, please let us know." - ) - - return embed - - for post in posts: - data = post["data"] - - text = data["selftext"] - if text: - text = textwrap.shorten(text, width=128, placeholder="...") - text += "\n" # Add newline to separate embed info - - ups = data["ups"] - comments = data["num_comments"] - author = data["author"] - - title = textwrap.shorten(data["title"], width=64, placeholder="...") - link = self.URL + data["permalink"] - - embed.description += ( - f"**[{title}]({link})**\n" - f"{text}" - f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" - ) - - embed.colour = Colour.blurple() - return embed - - @loop() - async def auto_poster_loop(self) -> None: - """Post the top 5 posts daily, and the top 5 posts weekly.""" - # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter - now = datetime.utcnow() - tomorrow = now + timedelta(days=1) - midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) - seconds_until = (midnight_tomorrow - now).total_seconds() - - await asyncio.sleep(seconds_until) - - await self.bot.wait_until_guild_available() - if not self.webhook: - await self.bot.fetch_webhook(Webhooks.reddit) - - if datetime.utcnow().weekday() == 0: - await self.top_weekly_posts() - # if it's a monday send the top weekly posts - - for subreddit in RedditConfig.subreddits: - top_posts = await self.get_top_posts(subreddit=subreddit, time="day") - username = sub_clyde(f"{subreddit} Top Daily Posts") - message = await self.webhook.send(username=username, embed=top_posts, wait=True) - - if message.channel.is_news(): - await message.publish() - - async def top_weekly_posts(self) -> None: - """Post a summary of the top posts.""" - for subreddit in RedditConfig.subreddits: - # Send and pin the new weekly posts. - top_posts = await self.get_top_posts(subreddit=subreddit, time="week") - username = sub_clyde(f"{subreddit} Top Weekly Posts") - message = await self.webhook.send(wait=True, username=username, embed=top_posts) - - if subreddit.lower() == "r/python": - if not self.channel: - log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") - return - - # Remove the oldest pins so that only 12 remain at most. - pins = await self.channel.pins() - - while len(pins) >= 12: - await pins[-1].unpin() - del pins[-1] - - await message.pin() - - if message.channel.is_news(): - await message.publish() - - @group(name="reddit", invoke_without_command=True) - async def reddit_group(self, ctx: Context) -> None: - """View the top posts from various subreddits.""" - await ctx.send_help(ctx.command) - - @reddit_group.command(name="top") - async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of all time from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="all") - - await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - - @reddit_group.command(name="daily") - async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of today from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="day") - - await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - - @reddit_group.command(name="weekly") - async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of this week from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="week") - - await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - - @with_role(*STAFF_ROLES) - @reddit_group.command(name="subreddits", aliases=("subs",)) - async def subreddits_command(self, ctx: Context) -> None: - """Send a paginated embed of all the subreddits we're relaying.""" - embed = Embed() - embed.title = "Relayed subreddits." - embed.colour = Colour.blurple() - - await LinePaginator.paginate( - RedditConfig.subreddits, - ctx, embed, - footer_text="Use the reddit commands along with these to view their posts.", - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Reddit cog.""" - if not RedditConfig.secret or not RedditConfig.client_id: - log.error("Credentials not provided, cog not loaded.") - return - bot.add_cog(Reddit(bot)) diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py deleted file mode 100644 index 670493bcf..000000000 --- a/bot/cogs/reminders.py +++ /dev/null @@ -1,427 +0,0 @@ -import asyncio -import logging -import random -import textwrap -import typing as t -from datetime import datetime, timedelta -from operator import itemgetter - -import discord -from dateutil.parser import isoparse -from dateutil.relativedelta import relativedelta -from discord.ext.commands import Cog, Context, Greedy, group - -from bot.bot import Bot -from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES -from bot.converters import Duration -from bot.pagination import LinePaginator -from bot.utils.checks import without_role_check -from bot.utils.messages import send_denial -from bot.utils.scheduling import Scheduler -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -WHITELISTED_CHANNELS = Guild.reminder_whitelist -MAXIMUM_REMINDERS = 5 - -Mentionable = t.Union[discord.Member, discord.Role] - - -class Reminders(Cog): - """Provide in-channel reminder functionality.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_reminders()) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - async def reschedule_reminders(self) -> None: - """Get all current reminders from the API and reschedule them.""" - await self.bot.wait_until_guild_available() - response = await self.bot.api_client.get( - 'bot/reminders', - params={'active': 'true'} - ) - - now = datetime.utcnow() - - for reminder in response: - is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) - if not is_valid: - continue - - remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) - - # If the reminder is already overdue ... - if remind_at < now: - late = relativedelta(now, remind_at) - await self.send_reminder(reminder, late) - else: - self.schedule_reminder(reminder) - - def ensure_valid_reminder( - self, - reminder: dict, - cancel_task: bool = True - ) -> t.Tuple[bool, discord.User, discord.TextChannel]: - """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" - user = self.bot.get_user(reminder['author']) - channel = self.bot.get_channel(reminder['channel_id']) - is_valid = True - if not user or not channel: - is_valid = False - log.info( - f"Reminder {reminder['id']} invalid: " - f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." - ) - asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) - - return is_valid, user, channel - - @staticmethod - async def _send_confirmation( - ctx: Context, - on_success: str, - reminder_id: str, - delivery_dt: t.Optional[datetime], - ) -> None: - """Send an embed confirming the reminder change was made successfully.""" - embed = discord.Embed() - embed.colour = discord.Colour.green() - embed.title = random.choice(POSITIVE_REPLIES) - embed.description = on_success - - footer_str = f"ID: {reminder_id}" - if delivery_dt: - # Reminder deletion will have a `None` `delivery_dt` - footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" - - embed.set_footer(text=footer_str) - - await ctx.send(embed=embed) - - @staticmethod - async def _check_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> t.Tuple[bool, str]: - """ - Returns whether or not the list of mentions is allowed. - - Conditions: - - Role reminders are Mods+ - - Reminders for other users are Helpers+ - - If mentions aren't allowed, also return the type of mention(s) disallowed. - """ - if without_role_check(ctx, *STAFF_ROLES): - return False, "members/roles" - elif without_role_check(ctx, *MODERATION_ROLES): - return all(isinstance(mention, discord.Member) for mention in mentions), "roles" - else: - return True, "" - - @staticmethod - async def validate_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> bool: - """ - Filter mentions to see if the user can mention, and sends a denial if not allowed. - - Returns whether or not the validation is successful. - """ - mentions_allowed, disallowed_mentions = await Reminders._check_mentions(ctx, mentions) - - if not mentions or mentions_allowed: - return True - else: - await send_denial(ctx, f"You can't mention other {disallowed_mentions} in your reminder!") - return False - - def get_mentionables(self, mention_ids: t.List[int]) -> t.Iterator[Mentionable]: - """Converts Role and Member ids to their corresponding objects if possible.""" - guild = self.bot.get_guild(Guild.id) - for mention_id in mention_ids: - if (mentionable := (guild.get_member(mention_id) or guild.get_role(mention_id))): - yield mentionable - - def schedule_reminder(self, reminder: dict) -> None: - """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" - reminder_id = reminder["id"] - reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) - - async def _remind() -> None: - await self.send_reminder(reminder) - - log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") - await self._delete_reminder(reminder_id) - - self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) - - async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: - """Delete a reminder from the database, given its ID, and cancel the running task.""" - await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) - - if cancel_task: - # Now we can remove it from the schedule list - self.scheduler.cancel(reminder_id) - - async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: - """ - Edits a reminder in the database given the ID and payload. - - Returns the edited reminder. - """ - # Send the request to update the reminder in the database - reminder = await self.bot.api_client.patch( - 'bot/reminders/' + str(reminder_id), - json=payload - ) - return reminder - - async def _reschedule_reminder(self, reminder: dict) -> None: - """Reschedule a reminder object.""" - log.trace(f"Cancelling old task #{reminder['id']}") - self.scheduler.cancel(reminder["id"]) - - log.trace(f"Scheduling new task #{reminder['id']}") - self.schedule_reminder(reminder) - - async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: - """Send the reminder.""" - is_valid, user, channel = self.ensure_valid_reminder(reminder) - if not is_valid: - return - - embed = discord.Embed() - embed.colour = discord.Colour.blurple() - embed.set_author( - icon_url=Icons.remind_blurple, - name="It has arrived!" - ) - - embed.description = f"Here's your reminder: `{reminder['content']}`." - - if reminder.get("jump_url"): # keep backward compatibility - embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" - - if late: - embed.colour = discord.Colour.red() - embed.set_author( - icon_url=Icons.remind_red, - name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" - ) - - additional_mentions = ' '.join( - mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) - ) - - await channel.send( - content=f"{user.mention} {additional_mentions}", - embed=embed - ) - await self._delete_reminder(reminder["id"]) - - @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) - async def remind_group( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str - ) -> None: - """Commands for managing your reminders.""" - await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content) - - @remind_group.command(name="new", aliases=("add", "create")) - async def new_reminder( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str - ) -> None: - """ - Set yourself a simple reminder. - - Expiration is parsed per: http://strftime.org/ - """ - # If the user is not staff, we need to verify whether or not to make a reminder at all. - if without_role_check(ctx, *STAFF_ROLES): - - # If they don't have permission to set a reminder in this channel - if ctx.channel.id not in WHITELISTED_CHANNELS: - await send_denial(ctx, "Sorry, you can't do that here!") - return - - # Get their current active reminders - active_reminders = await self.bot.api_client.get( - 'bot/reminders', - params={ - 'author__id': str(ctx.author.id) - } - ) - - # Let's limit this, so we don't get 10 000 - # reminders from kip or something like that :P - if len(active_reminders) > MAXIMUM_REMINDERS: - await send_denial(ctx, "You have too many active reminders!") - return - - # Remove duplicate mentions - mentions = set(mentions) - mentions.discard(ctx.author) - - # Filter mentions to see if the user can mention members/roles - if not await self.validate_mentions(ctx, mentions): - return - - mention_ids = [mention.id for mention in mentions] - - # Now we can attempt to actually set the reminder. - reminder = await self.bot.api_client.post( - 'bot/reminders', - json={ - 'author': ctx.author.id, - 'channel_id': ctx.message.channel.id, - 'jump_url': ctx.message.jump_url, - 'content': content, - 'expiration': expiration.isoformat(), - 'mentions': mention_ids, - } - ) - - now = datetime.utcnow() - timedelta(seconds=1) - humanized_delta = humanize_delta(relativedelta(expiration, now)) - mention_string = ( - f"Your reminder will arrive in {humanized_delta} " - f"and will mention {len(mentions)} other(s)!" - ) - - # Confirm to the user that it worked. - await self._send_confirmation( - ctx, - on_success=mention_string, - reminder_id=reminder["id"], - delivery_dt=expiration, - ) - - self.schedule_reminder(reminder) - - @remind_group.command(name="list") - async def list_reminders(self, ctx: Context) -> None: - """View a paginated embed of all reminders for your user.""" - # Get all the user's reminders from the database. - data = await self.bot.api_client.get( - 'bot/reminders', - params={'author__id': str(ctx.author.id)} - ) - - now = datetime.utcnow() - - # Make a list of tuples so it can be sorted by time. - reminders = sorted( - ( - (rem['content'], rem['expiration'], rem['id'], rem['mentions']) - for rem in data - ), - key=itemgetter(1) - ) - - lines = [] - - for content, remind_at, id_, mentions in reminders: - # Parse and humanize the time, make it pretty :D - remind_datetime = isoparse(remind_at).replace(tzinfo=None) - time = humanize_delta(relativedelta(remind_datetime, now)) - - mentions = ", ".join( - # Both Role and User objects have the `name` attribute - mention.name for mention in self.get_mentionables(mentions) - ) - mention_string = f"\n**Mentions:** {mentions}" if mentions else "" - - text = textwrap.dedent(f""" - **Reminder #{id_}:** *expires in {time}* (ID: {id_}){mention_string} - {content} - """).strip() - - lines.append(text) - - embed = discord.Embed() - embed.colour = discord.Colour.blurple() - embed.title = f"Reminders for {ctx.author}" - - # Remind the user that they have no reminders :^) - if not lines: - embed.description = "No active reminders could be found." - await ctx.send(embed=embed) - return - - # Construct the embed and paginate it. - embed.colour = discord.Colour.blurple() - - await LinePaginator.paginate( - lines, - ctx, embed, - max_lines=3, - empty=True - ) - - @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) - async def edit_reminder_group(self, ctx: Context) -> None: - """Commands for modifying your current reminders.""" - await ctx.send_help(ctx.command) - - @edit_reminder_group.command(name="duration", aliases=("time",)) - async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: - """ - Edit one of your reminder's expiration. - - Expiration is parsed per: http://strftime.org/ - """ - await self.edit_reminder(ctx, id_, {'expiration': expiration.isoformat()}) - - @edit_reminder_group.command(name="content", aliases=("reason",)) - async def edit_reminder_content(self, ctx: Context, id_: int, *, content: str) -> None: - """Edit one of your reminder's content.""" - await self.edit_reminder(ctx, id_, {"content": content}) - - @edit_reminder_group.command(name="mentions", aliases=("pings",)) - async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: - """Edit one of your reminder's mentions.""" - # Remove duplicate mentions - mentions = set(mentions) - mentions.discard(ctx.author) - - # Filter mentions to see if the user can mention members/roles - if not await self.validate_mentions(ctx, mentions): - return - - mention_ids = [mention.id for mention in mentions] - await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) - - async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: - """Edits a reminder with the given payload, then sends a confirmation message.""" - reminder = await self._edit_reminder(id_, payload) - - # Parse the reminder expiration back into a datetime - expiration = isoparse(reminder["expiration"]).replace(tzinfo=None) - - # Send a confirmation message to the channel - await self._send_confirmation( - ctx, - on_success="That reminder has been edited successfully!", - reminder_id=id_, - delivery_dt=expiration, - ) - await self._reschedule_reminder(reminder) - - @remind_group.command("delete", aliases=("remove", "cancel")) - async def delete_reminder(self, ctx: Context, id_: int) -> None: - """Delete one of your active reminders.""" - await self._delete_reminder(id_) - await self._send_confirmation( - ctx, - on_success="That reminder has been deleted successfully!", - reminder_id=id_, - delivery_dt=None, - ) - - -def setup(bot: Bot) -> None: - """Load the Reminders cog.""" - bot.add_cog(Reminders(bot)) diff --git a/bot/cogs/security.py b/bot/cogs/security.py deleted file mode 100644 index c680c5e27..000000000 --- a/bot/cogs/security.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging - -from discord.ext.commands import Cog, Context, NoPrivateMessage - -from bot.bot import Bot - -log = logging.getLogger(__name__) - - -class Security(Cog): - """Security-related helpers.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.bot.check(self.check_not_bot) # Global commands check - no bots can run any commands at all - self.bot.check(self.check_on_guild) # Global commands check - commands can't be run in a DM - - def check_not_bot(self, ctx: Context) -> bool: - """Check if the context is a bot user.""" - return not ctx.author.bot - - def check_on_guild(self, ctx: Context) -> bool: - """Check if the context is in a guild.""" - if ctx.guild is None: - raise NoPrivateMessage("This command cannot be used in private messages.") - return True - - -def setup(bot: Bot) -> None: - """Load the Security cog.""" - bot.add_cog(Security(bot)) diff --git a/bot/cogs/site.py b/bot/cogs/site.py deleted file mode 100644 index ac29daa1d..000000000 --- a/bot/cogs/site.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging - -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import URLs -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" - - -class Site(Cog): - """Commands for linking to different parts of the site.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="site", aliases=("s",), invoke_without_command=True) - async def site_group(self, ctx: Context) -> None: - """Commands for getting info about our website.""" - await ctx.send_help(ctx.command) - - @site_group.command(name="home", aliases=("about",)) - async def site_main(self, ctx: Context) -> None: - """Info about the website itself.""" - url = f"{URLs.site_schema}{URLs.site}/" - - embed = Embed(title="Python Discord website") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - f"[Our official website]({url}) is an open-source community project " - "created with Python and Django. It contains information about the server " - "itself, lets you sign up for upcoming events, has its own wiki, contains " - "a list of valuable learning resources, and much more." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="resources") - async def site_resources(self, ctx: Context) -> None: - """Info about the site's Resources page.""" - learning_url = f"{PAGES_URL}/resources" - - embed = Embed(title="Resources") - embed.set_footer(text=f"{learning_url}") - embed.colour = Colour.blurple() - embed.description = ( - f"The [Resources page]({learning_url}) on our website contains a " - "list of hand-selected learning resources that we regularly recommend " - f"to both beginners and experts." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="tools") - async def site_tools(self, ctx: Context) -> None: - """Info about the site's Tools page.""" - tools_url = f"{PAGES_URL}/resources/tools" - - embed = Embed(title="Tools") - embed.set_footer(text=f"{tools_url}") - embed.colour = Colour.blurple() - embed.description = ( - f"The [Tools page]({tools_url}) on our website contains a " - f"couple of the most popular tools for programming in Python." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="help") - async def site_help(self, ctx: Context) -> None: - """Info about the site's Getting Help page.""" - url = f"{PAGES_URL}/resources/guides/asking-good-questions" - - embed = Embed(title="Asking Good Questions") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - "Asking the right question about something that's new to you can sometimes be tricky. " - f"To help with this, we've created a [guide to asking good questions]({url}) on our website. " - "It contains everything you need to get the very best help from our community." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="faq") - async def site_faq(self, ctx: Context) -> None: - """Info about the site's FAQ page.""" - url = f"{PAGES_URL}/frequently-asked-questions" - - embed = Embed(title="FAQ") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - "As the largest Python community on Discord, we get hundreds of questions every day. " - "Many of these questions have been asked before. We've compiled a list of the most " - "frequently asked questions along with their answers, which can be found on " - f"our [FAQ page]({url})." - ) - - await ctx.send(embed=embed) - - @site_group.command(aliases=['r', 'rule'], name='rules') - async def site_rules(self, ctx: Context, *rules: int) -> None: - """Provides a link to all rules or, if specified, displays specific rule(s).""" - rules_embed = Embed(title='Rules', color=Colour.blurple()) - rules_embed.url = f"{PAGES_URL}/rules" - - if not rules: - # Rules were not submitted. Return the default description. - rules_embed.description = ( - "The rules and guidelines that apply to this community can be found on" - f" our [rules page]({PAGES_URL}/rules). We expect" - " all members of the community to have read and understood these." - ) - - await ctx.send(embed=rules_embed) - return - - full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) - invalid_indices = tuple( - pick - for pick in rules - if pick < 1 or pick > len(full_rules) - ) - - if invalid_indices: - indices = ', '.join(map(str, invalid_indices)) - await ctx.send(f":x: Invalid rule indices: {indices}") - return - - for rule in rules: - self.bot.stats.incr(f"rule_uses.{rule}") - - final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) - - await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) - - -def setup(bot: Bot) -> None: - """Load the Site cog.""" - bot.add_cog(Site(bot)) diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py deleted file mode 100644 index 52c8b6f88..000000000 --- a/bot/cogs/snekbox.py +++ /dev/null @@ -1,349 +0,0 @@ -import asyncio -import contextlib -import datetime -import logging -import re -import textwrap -from functools import partial -from signal import Signals -from typing import Optional, Tuple - -from discord import HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only - -from bot.bot import Bot -from bot.constants import Categories, Channels, Roles, URLs -from bot.decorators import in_whitelist -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") -FORMATTED_CODE_REGEX = re.compile( - r"^\s*" # any leading whitespace from the beginning of the string - r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block - r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) - r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all code inside the markup - r"\s*" # any more whitespace before the end of the code markup - r"(?P=delim)" # match the exact same delimiter from the start again - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive -) -RAW_CODE_REGEX = re.compile( - r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all the rest as code - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL # "." also matches newlines -) - -MAX_PASTE_LEN = 1000 - -# `!eval` command whitelists -EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) -EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) -EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) - -SIGKILL = 9 - -REEVAL_EMOJI = '\U0001f501' # :repeat: -REEVAL_TIMEOUT = 30 - - -class Snekbox(Cog): - """Safe evaluation of Python code using Snekbox.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.jobs = {} - - async def post_eval(self, code: str) -> dict: - """Send a POST request to the Snekbox API to evaluate code and return the results.""" - url = URLs.snekbox_eval_api - data = {"input": code} - async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return await resp.json() - - async def upload_output(self, output: str) -> Optional[str]: - """Upload the eval output to a paste service and return a URL to it if successful.""" - log.trace("Uploading full output to paste service...") - - if len(output) > MAX_PASTE_LEN: - log.info("Full output is too long to upload") - return "too long to upload" - - url = URLs.paste_service.format(key="documents") - try: - async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: - data = await resp.json() - - if "key" in data: - return URLs.paste_service.format(key=data["key"]) - except Exception: - # 400 (Bad Request) means there are too many characters - log.exception("Failed to upload full output to paste service!") - - @staticmethod - def prepare_input(code: str) -> str: - """Extract code from the Markdown, format it, and insert it into the code template.""" - match = FORMATTED_CODE_REGEX.fullmatch(code) - if match: - code, block, lang, delim = match.group("code", "block", "lang", "delim") - code = textwrap.dedent(code) - if block: - info = (f"'{lang}' highlighted" if lang else "plain") + " code block" - else: - info = f"{delim}-enclosed inline code" - log.trace(f"Extracted {info} for evaluation:\n{code}") - else: - code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) - log.trace( - f"Eval message contains unformatted or badly formatted code, " - f"stripping whitespace only:\n{code}" - ) - - return code - - @staticmethod - def get_results_message(results: dict) -> Tuple[str, str]: - """Return a user-friendly message and error corresponding to the process's return code.""" - stdout, returncode = results["stdout"], results["returncode"] - msg = f"Your eval job has completed with return code {returncode}" - error = "" - - if returncode is None: - msg = "Your eval job has failed" - error = stdout.strip() - elif returncode == 128 + SIGKILL: - msg = "Your eval job timed out or ran out of memory" - elif returncode == 255: - msg = "Your eval job has failed" - error = "A fatal NsJail error occurred" - else: - # Try to append signal's name if one exists - try: - name = Signals(returncode - 128).name - msg = f"{msg} ({name})" - except ValueError: - pass - - return msg, error - - @staticmethod - def get_status_emoji(results: dict) -> str: - """Return an emoji corresponding to the status code or lack of output in result.""" - if not results["stdout"].strip(): # No output - return ":warning:" - elif results["returncode"] == 0: # No error - return ":white_check_mark:" - else: # Exception - return ":x:" - - async def format_output(self, output: str) -> Tuple[str, Optional[str]]: - """ - Format the output and return a tuple of the formatted output and a URL to the full output. - - Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters - and upload the full output to a paste service. - """ - log.trace("Formatting output...") - - output = output.rstrip("\n") - original_output = output # To be uploaded to a pasting service if needed - paste_link = None - - if "<@" in output: - output = output.replace("<@", "<@\u200B") # Zero-width space - - if " 0: - output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] - output = output[:11] # Limiting to only 11 lines - output = "\n".join(output) - - if lines > 10: - truncated = True - if len(output) >= 1000: - output = f"{output[:1000]}\n... (truncated - too long, too many lines)" - else: - output = f"{output}\n... (truncated - too many lines)" - elif len(output) >= 1000: - truncated = True - output = f"{output[:1000]}\n... (truncated - too long)" - - if truncated: - paste_link = await self.upload_output(original_output) - - output = output or "[No output]" - - return output, paste_link - - async def send_eval(self, ctx: Context, code: str) -> Message: - """ - Evaluate code, format it, and send the output to the corresponding channel. - - Return the bot response. - """ - async with ctx.typing(): - results = await self.post_eval(code) - msg, error = self.get_results_message(results) - - if error: - output, paste_link = error, None - else: - output, paste_link = await self.format_output(results["stdout"]) - - icon = self.get_status_emoji(results) - msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" - if paste_link: - msg = f"{msg}\nFull output: {paste_link}" - - # Collect stats of eval fails + successes - if icon == ":x:": - self.bot.stats.incr("snekbox.python.fail") - else: - self.bot.stats.incr("snekbox.python.success") - - filter_cog = self.bot.get_cog("Filtering") - filter_triggered = False - if filter_cog: - filter_triggered = await filter_cog.filter_eval(msg, ctx.message) - if filter_triggered: - response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") - else: - response = await ctx.send(msg) - self.bot.loop.create_task( - wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) - ) - - log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") - return response - - async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: - """ - Check if the eval session should continue. - - Return the new code to evaluate or None if the eval session should be terminated. - """ - _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) - _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) - - with contextlib.suppress(NotFound): - try: - _, new_message = await self.bot.wait_for( - 'message_edit', - check=_predicate_eval_message_edit, - timeout=REEVAL_TIMEOUT - ) - await ctx.message.add_reaction(REEVAL_EMOJI) - await self.bot.wait_for( - 'reaction_add', - check=_predicate_emoji_reaction, - timeout=10 - ) - - code = await self.get_code(new_message) - await ctx.message.clear_reactions() - with contextlib.suppress(HTTPException): - await response.delete() - - except asyncio.TimeoutError: - await ctx.message.clear_reactions() - return None - - return code - - async def get_code(self, message: Message) -> Optional[str]: - """ - Return the code from `message` to be evaluated. - - If the message is an invocation of the eval command, return the first argument or None if it - doesn't exist. Otherwise, return the full content of the message. - """ - log.trace(f"Getting context for message {message.id}.") - new_ctx = await self.bot.get_context(message) - - if new_ctx.command is self.eval_command: - log.trace(f"Message {message.id} invokes eval command.") - split = message.content.split(maxsplit=1) - code = split[1] if len(split) > 1 else None - else: - log.trace(f"Message {message.id} does not invoke eval command.") - code = message.content - - return code - - @command(name="eval", aliases=("e",)) - @guild_only() - @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) - async def eval_command(self, ctx: Context, *, code: str = None) -> None: - """ - Run Python code and get the results. - - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! - """ - if ctx.author.id in self.jobs: - await ctx.send( - f"{ctx.author.mention} You've already got a job running - " - "please wait for it to finish!" - ) - return - - if not code: # None or empty string - await ctx.send_help(ctx.command) - return - - if Roles.helpers in (role.id for role in ctx.author.roles): - self.bot.stats.incr("snekbox_usages.roles.helpers") - else: - self.bot.stats.incr("snekbox_usages.roles.developers") - - if ctx.channel.category_id == Categories.help_in_use: - self.bot.stats.incr("snekbox_usages.channels.help") - elif ctx.channel.id == Channels.bot_commands: - self.bot.stats.incr("snekbox_usages.channels.bot_commands") - else: - self.bot.stats.incr("snekbox_usages.channels.topical") - - log.info(f"Received code from {ctx.author} for evaluation:\n{code}") - - while True: - self.jobs[ctx.author.id] = datetime.datetime.now() - code = self.prepare_input(code) - try: - response = await self.send_eval(ctx, code) - finally: - del self.jobs[ctx.author.id] - - code = await self.continue_eval(ctx, response) - if not code: - break - log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") - - -def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: - """Return True if the edited message is the context message and the content was indeed modified.""" - return new_msg.id == ctx.message.id and old_msg.content != new_msg.content - - -def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: - """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI - - -def setup(bot: Bot) -> None: - """Load the Snekbox cog.""" - bot.add_cog(Snekbox(bot)) diff --git a/bot/cogs/source.py b/bot/cogs/source.py deleted file mode 100644 index 205e0ba81..000000000 --- a/bot/cogs/source.py +++ /dev/null @@ -1,141 +0,0 @@ -import inspect -from pathlib import Path -from typing import Optional, Tuple, Union - -from discord import Embed -from discord.ext import commands - -from bot.bot import Bot -from bot.constants import URLs - -SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] - - -class SourceConverter(commands.Converter): - """Convert an argument into a help command, tag, command, or cog.""" - - async def convert(self, ctx: commands.Context, argument: str) -> SourceType: - """Convert argument into source object.""" - if argument.lower().startswith("help"): - return ctx.bot.help_command - - cog = ctx.bot.get_cog(argument) - if cog: - return cog - - cmd = ctx.bot.get_command(argument) - if cmd: - return cmd - - tags_cog = ctx.bot.get_cog("Tags") - show_tag = True - - if not tags_cog: - show_tag = False - elif argument.lower() in tags_cog._cache: - return argument.lower() - - raise commands.BadArgument( - f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog." - ) - - -class BotSource(commands.Cog): - """Displays information about the bot's source code.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @commands.command(name="source", aliases=("src",)) - async def source_command(self, ctx: commands.Context, *, source_item: SourceConverter = None) -> None: - """Display information and a GitHub link to the source code of a command, tag, or cog.""" - if not source_item: - embed = Embed(title="Bot's GitHub Repository") - embed.add_field(name="Repository", value=f"[Go to GitHub]({URLs.github_bot_repo})") - embed.set_thumbnail(url="https://avatars1.githubusercontent.com/u/9919") - await ctx.send(embed=embed) - return - - embed = await self.build_embed(source_item) - await ctx.send(embed=embed) - - def get_source_link(self, source_item: SourceType) -> Tuple[str, str, Optional[int]]: - """ - Build GitHub link of source item, return this link, file location and first line number. - - Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). - """ - if isinstance(source_item, commands.Command): - if source_item.cog_name == "Alias": - cmd_name = source_item.callback.__name__.replace("_alias", "") - cmd = self.bot.get_command(cmd_name.replace("_", " ")) - src = cmd.callback.__code__ - filename = src.co_filename - else: - src = source_item.callback.__code__ - filename = src.co_filename - elif isinstance(source_item, str): - tags_cog = self.bot.get_cog("Tags") - filename = tags_cog._cache[source_item]["location"] - else: - src = type(source_item) - try: - filename = inspect.getsourcefile(src) - except TypeError: - raise commands.BadArgument("Cannot get source for a dynamically-created object.") - - if not isinstance(source_item, str): - try: - lines, first_line_no = inspect.getsourcelines(src) - except OSError: - raise commands.BadArgument("Cannot get source for a dynamically-created object.") - - lines_extension = f"#L{first_line_no}-L{first_line_no+len(lines)-1}" - else: - first_line_no = None - lines_extension = "" - - # Handle tag file location differently than others to avoid errors in some cases - if not first_line_no: - file_location = Path(filename).relative_to("/bot/") - else: - file_location = Path(filename).relative_to(Path.cwd()).as_posix() - - url = f"{URLs.github_bot_repo}/blob/master/{file_location}{lines_extension}" - - return url, file_location, first_line_no or None - - async def build_embed(self, source_object: SourceType) -> Optional[Embed]: - """Build embed based on source object.""" - url, location, first_line = self.get_source_link(source_object) - - if isinstance(source_object, commands.HelpCommand): - title = "Help Command" - description = source_object.__doc__.splitlines()[1] - elif isinstance(source_object, commands.Command): - if source_object.cog_name == "Alias": - cmd_name = source_object.callback.__name__.replace("_alias", "") - cmd = self.bot.get_command(cmd_name.replace("_", " ")) - description = cmd.short_doc - else: - description = source_object.short_doc - - title = f"Command: {source_object.qualified_name}" - elif isinstance(source_object, str): - title = f"Tag: {source_object}" - description = "" - else: - title = f"Cog: {source_object.qualified_name}" - description = source_object.description.splitlines()[0] - - embed = Embed(title=title, description=description) - embed.add_field(name="Source Code", value=f"[Go to GitHub]({url})") - line_text = f":{first_line}" if first_line else "" - embed.set_footer(text=f"{location}{line_text}") - - return embed - - -def setup(bot: Bot) -> None: - """Load the BotSource cog.""" - bot.add_cog(BotSource(bot)) diff --git a/bot/cogs/stats.py b/bot/cogs/stats.py deleted file mode 100644 index d42f55466..000000000 --- a/bot/cogs/stats.py +++ /dev/null @@ -1,129 +0,0 @@ -import string -from datetime import datetime - -from discord import Member, Message, Status -from discord.ext.commands import Cog, Context -from discord.ext.tasks import loop - -from bot.bot import Bot -from bot.constants import Categories, Channels, Guild, Stats as StatConf - - -CHANNEL_NAME_OVERRIDES = { - Channels.off_topic_0: "off_topic_0", - Channels.off_topic_1: "off_topic_1", - Channels.off_topic_2: "off_topic_2", - Channels.staff_lounge: "staff_lounge" -} - -ALLOWED_CHARS = string.ascii_letters + string.digits + "_" - - -class Stats(Cog): - """A cog which provides a way to hook onto Discord events and forward to stats.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.last_presence_update = None - self.update_guild_boost.start() - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Report message events in the server to statsd.""" - if message.guild is None: - return - - if message.guild.id != Guild.id: - return - - cat = getattr(message.channel, "category", None) - if cat is not None and cat.id == Categories.modmail: - if message.channel.id != Channels.incidents: - # Do not report modmail channels to stats, there are too many - # of them for interesting statistics to be drawn out of this. - return - - reformatted_name = message.channel.name.replace('-', '_') - - if CHANNEL_NAME_OVERRIDES.get(message.channel.id): - reformatted_name = CHANNEL_NAME_OVERRIDES.get(message.channel.id) - - reformatted_name = "".join(char for char in reformatted_name if char in ALLOWED_CHARS) - - stat_name = f"channels.{reformatted_name}" - self.bot.stats.incr(stat_name) - - # Increment the total message count - self.bot.stats.incr("messages") - - @Cog.listener() - async def on_command_completion(self, ctx: Context) -> None: - """Report completed commands to statsd.""" - command_name = ctx.command.qualified_name.replace(" ", "_") - - self.bot.stats.incr(f"commands.{command_name}") - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Update member count stat on member join.""" - if member.guild.id != Guild.id: - return - - self.bot.stats.gauge("guild.total_members", len(member.guild.members)) - - @Cog.listener() - async def on_member_leave(self, member: Member) -> None: - """Update member count stat on member leave.""" - if member.guild.id != Guild.id: - return - - self.bot.stats.gauge("guild.total_members", len(member.guild.members)) - - @Cog.listener() - async def on_member_update(self, _before: Member, after: Member) -> None: - """Update presence estimates on member update.""" - if after.guild.id != Guild.id: - return - - if self.last_presence_update: - if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: - return - - self.last_presence_update = datetime.now() - - online = 0 - idle = 0 - dnd = 0 - offline = 0 - - for member in after.guild.members: - if member.status is Status.online: - online += 1 - elif member.status is Status.dnd: - dnd += 1 - elif member.status is Status.idle: - idle += 1 - elif member.status is Status.offline: - offline += 1 - - self.bot.stats.gauge("guild.status.online", online) - self.bot.stats.gauge("guild.status.idle", idle) - self.bot.stats.gauge("guild.status.do_not_disturb", dnd) - self.bot.stats.gauge("guild.status.offline", offline) - - @loop(hours=1) - async def update_guild_boost(self) -> None: - """Post the server boost level and tier every hour.""" - await self.bot.wait_until_guild_available() - g = self.bot.get_guild(Guild.id) - self.bot.stats.gauge("boost.amount", g.premium_subscription_count) - self.bot.stats.gauge("boost.tier", g.premium_tier) - - def cog_unload(self) -> None: - """Stop the boost statistic task on unload of the Cog.""" - self.update_guild_boost.stop() - - -def setup(bot: Bot) -> None: - """Load the stats cog.""" - bot.add_cog(Stats(bot)) diff --git a/bot/cogs/sync/__init__.py b/bot/cogs/sync/__init__.py deleted file mode 100644 index fe7df4e9b..000000000 --- a/bot/cogs/sync/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from bot.bot import Bot -from .cog import Sync - - -def setup(bot: Bot) -> None: - """Load the Sync cog.""" - bot.add_cog(Sync(bot)) diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py deleted file mode 100644 index 5ace957e7..000000000 --- a/bot/cogs/sync/cog.py +++ /dev/null @@ -1,180 +0,0 @@ -import logging -from typing import Any, Dict - -from discord import Member, Role, User -from discord.ext import commands -from discord.ext.commands import Cog, Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.sync import syncers - -log = logging.getLogger(__name__) - - -class Sync(Cog): - """Captures relevant events and sends them to the site.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.role_syncer = syncers.RoleSyncer(self.bot) - self.user_syncer = syncers.UserSyncer(self.bot) - - self.bot.loop.create_task(self.sync_guild()) - - async def sync_guild(self) -> None: - """Syncs the roles/users of the guild with the database.""" - await self.bot.wait_until_guild_available() - - guild = self.bot.get_guild(constants.Guild.id) - if guild is None: - return - - for syncer in (self.role_syncer, self.user_syncer): - await syncer.sync(guild) - - async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: - """Send a PATCH request to partially update a user in the database.""" - try: - await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) - except ResponseCodeError as e: - if e.response.status != 404: - raise - if not ignore_404: - log.warning("Unable to update user, got 404. Assuming race condition from join event.") - - @Cog.listener() - async def on_guild_role_create(self, role: Role) -> None: - """Adds newly create role to the database table over the API.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.post( - 'bot/roles', - json={ - 'colour': role.colour.value, - 'id': role.id, - 'name': role.name, - 'permissions': role.permissions.value, - 'position': role.position, - } - ) - - @Cog.listener() - async def on_guild_role_delete(self, role: Role) -> None: - """Deletes role from the database when it's deleted from the guild.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - @Cog.listener() - async def on_guild_role_update(self, before: Role, after: Role) -> None: - """Syncs role with the database if any of the stored attributes were updated.""" - if after.guild.id != constants.Guild.id: - return - - was_updated = ( - before.name != after.name - or before.colour != after.colour - or before.permissions != after.permissions - or before.position != after.position - ) - - if was_updated: - await self.bot.api_client.put( - f'bot/roles/{after.id}', - json={ - 'colour': after.colour.value, - 'id': after.id, - 'name': after.name, - 'permissions': after.permissions.value, - 'position': after.position, - } - ) - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """ - Adds a new user or updates existing user to the database when a member joins the guild. - - If the joining member is a user that is already known to the database (i.e., a user that - previously left), it will update the user's information. If the user is not yet known by - the database, the user is added. - """ - if member.guild.id != constants.Guild.id: - return - - packed = { - 'discriminator': int(member.discriminator), - 'id': member.id, - 'in_guild': True, - 'name': member.name, - 'roles': sorted(role.id for role in member.roles) - } - - got_error = False - - try: - # First try an update of the user to set the `in_guild` field and other - # fields that may have changed since the last time we've seen them. - await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) - - except ResponseCodeError as e: - # If we didn't get 404, something else broke - propagate it up. - if e.response.status != 404: - raise - - got_error = True # yikes - - if got_error: - # If we got `404`, the user is new. Create them. - await self.bot.api_client.post('bot/users', json=packed) - - @Cog.listener() - async def on_member_remove(self, member: Member) -> None: - """Set the in_guild field to False when a member leaves the guild.""" - if member.guild.id != constants.Guild.id: - return - - await self.patch_user(member.id, json={"in_guild": False}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Update the roles of the member in the database if a change is detected.""" - if after.guild.id != constants.Guild.id: - return - - if before.roles != after.roles: - updated_information = {"roles": sorted(role.id for role in after.roles)} - await self.patch_user(after.id, json=updated_information) - - @Cog.listener() - async def on_user_update(self, before: User, after: User) -> None: - """Update the user information in the database if a relevant change is detected.""" - attrs = ("name", "discriminator") - if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): - updated_information = { - "name": after.name, - "discriminator": int(after.discriminator), - } - # A 404 likely means the user is in another guild. - await self.patch_user(after.id, json=updated_information, ignore_404=True) - - @commands.group(name='sync') - @commands.has_permissions(administrator=True) - async def sync_group(self, ctx: Context) -> None: - """Run synchronizations between the bot and site manually.""" - - @sync_group.command(name='roles') - @commands.has_permissions(administrator=True) - async def sync_roles_command(self, ctx: Context) -> None: - """Manually synchronise the guild's roles with the roles on the site.""" - await self.role_syncer.sync(ctx.guild, ctx) - - @sync_group.command(name='users') - @commands.has_permissions(administrator=True) - async def sync_users_command(self, ctx: Context) -> None: - """Manually synchronise the guild's users with the users on the site.""" - await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py deleted file mode 100644 index f7ba811bc..000000000 --- a/bot/cogs/sync/syncers.py +++ /dev/null @@ -1,347 +0,0 @@ -import abc -import asyncio -import logging -import typing as t -from collections import namedtuple -from functools import partial - -import discord -from discord import Guild, HTTPException, Member, Message, Reaction, User -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot - -log = logging.getLogger(__name__) - -# These objects are declared as namedtuples because tuples are hashable, -# something that we make use of when diffing site roles against guild roles. -_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) -_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) - - -class Syncer(abc.ABC): - """Base class for synchronising the database with objects in the Discord cache.""" - - _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " - _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - @property - @abc.abstractmethod - def name(self) -> str: - """The name of the syncer; used in output messages and logging.""" - raise NotImplementedError # pragma: no cover - - async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: - """ - Send a prompt to confirm or abort a sync using reactions and return the sent message. - - If a message is given, it is edited to display the prompt and reactions. Otherwise, a new - message is sent to the dev-core channel and mentions the core developers role. If the - channel cannot be retrieved, return None. - """ - log.trace(f"Sending {self.name} sync confirmation prompt.") - - msg_content = ( - f'Possible cache issue while syncing {self.name}s. ' - f'More than {constants.Sync.max_diff} {self.name}s were changed. ' - f'React to confirm or abort the sync.' - ) - - # Send to core developers if it's an automatic sync. - if not message: - log.trace("Message not provided for confirmation; creating a new one in dev-core.") - channel = self.bot.get_channel(constants.Channels.dev_core) - - if not channel: - log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") - try: - channel = await self.bot.fetch_channel(constants.Channels.dev_core) - except HTTPException: - log.exception( - f"Failed to fetch channel for sending sync confirmation prompt; " - f"aborting {self.name} sync." - ) - return None - - allowed_roles = [discord.Object(constants.Roles.core_developers)] - message = await channel.send( - f"{self._CORE_DEV_MENTION}{msg_content}", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) - ) - else: - await message.edit(content=msg_content) - - # Add the initial reactions. - log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") - for emoji in self._REACTION_EMOJIS: - await message.add_reaction(emoji) - - return message - - def _reaction_check( - self, - author: Member, - message: Message, - reaction: Reaction, - user: t.Union[Member, User] - ) -> bool: - """ - Return True if the `reaction` is a valid confirmation or abort reaction on `message`. - - If the `author` of the prompt is a bot, then a reaction by any core developer will be - considered valid. Otherwise, the author of the reaction (`user`) will have to be the - `author` of the prompt. - """ - # For automatic syncs, check for the core dev role instead of an exact author - has_role = any(constants.Roles.core_developers == role.id for role in user.roles) - return ( - reaction.message.id == message.id - and not user.bot - and (has_role if author.bot else user == author) - and str(reaction.emoji) in self._REACTION_EMOJIS - ) - - async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: - """ - Wait for a confirmation reaction by `author` on `message` and return True if confirmed. - - Uses the `_reaction_check` function to determine if a reaction is valid. - - If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. - To acknowledge the reaction (or lack thereof), `message` will be edited. - """ - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - reaction = None - try: - log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") - reaction, _ = await self.bot.wait_for( - 'reaction_add', - check=partial(self._reaction_check, author, message), - timeout=constants.Sync.confirm_timeout - ) - except asyncio.TimeoutError: - # reaction will remain none thus sync will be aborted in the finally block below. - log.debug(f"The {self.name} syncer confirmation prompt timed out.") - - if str(reaction) == constants.Emojis.check_mark: - log.trace(f"The {self.name} syncer was confirmed.") - await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') - return True - else: - log.info(f"The {self.name} syncer was aborted or timed out!") - await message.edit( - content=f':warning: {mention}{self.name} sync aborted or timed out!' - ) - return False - - @abc.abstractmethod - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference between the cache of `guild` and the database.""" - raise NotImplementedError # pragma: no cover - - @abc.abstractmethod - async def _sync(self, diff: _Diff) -> None: - """Perform the API calls for synchronisation.""" - raise NotImplementedError # pragma: no cover - - async def _get_confirmation_result( - self, - diff_size: int, - author: Member, - message: t.Optional[Message] = None - ) -> t.Tuple[bool, t.Optional[Message]]: - """ - Prompt for confirmation and return a tuple of the result and the prompt message. - - `diff_size` is the size of the diff of the sync. If it is greater than - `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the - sync and the `message` is an extant message to edit to display the prompt. - - If confirmed or no confirmation was needed, the result is True. The returned message will - either be the given `message` or a new one which was created when sending the prompt. - """ - log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") - if diff_size > constants.Sync.max_diff: - message = await self._send_prompt(message) - if not message: - return False, None # Couldn't get channel. - - confirmed = await self._wait_for_confirmation(author, message) - if not confirmed: - return False, message # Sync aborted. - - return True, message - - async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: - """ - Synchronise the database with the cache of `guild`. - - If the differences between the cache and the database are greater than - `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core - channel. The confirmation can be optionally redirect to `ctx` instead. - """ - log.info(f"Starting {self.name} syncer.") - - message = None - author = self.bot.user - if ctx: - message = await ctx.send(f"📊 Synchronising {self.name}s.") - author = ctx.author - - diff = await self._get_diff(guild) - diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict - totals = {k: len(v) for k, v in diff_dict.items() if v is not None} - diff_size = sum(totals.values()) - - confirmed, message = await self._get_confirmation_result(diff_size, author, message) - if not confirmed: - return - - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - try: - await self._sync(diff) - except ResponseCodeError as e: - log.exception(f"{self.name} syncer failed!") - - # Don't show response text because it's probably some really long HTML. - results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" - content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" - else: - results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) - log.info(f"{self.name} syncer finished: {results}.") - content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" - - if message: - await message.edit(content=content) - - -class RoleSyncer(Syncer): - """Synchronise the database with roles in the cache.""" - - name = "role" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of roles between the cache of `guild` and the database.""" - log.trace("Getting the diff for roles.") - roles = await self.bot.api_client.get('bot/roles') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_roles = {_Role(**role_dict) for role_dict in roles} - guild_roles = { - _Role( - id=role.id, - name=role.name, - colour=role.colour.value, - permissions=role.permissions.value, - position=role.position, - ) - for role in guild.roles - } - - guild_role_ids = {role.id for role in guild_roles} - api_role_ids = {role.id for role in db_roles} - new_role_ids = guild_role_ids - api_role_ids - deleted_role_ids = api_role_ids - guild_role_ids - - # New roles are those which are on the cached guild but not on the - # DB guild, going by the role ID. We need to send them in for creation. - roles_to_create = {role for role in guild_roles if role.id in new_role_ids} - roles_to_update = guild_roles - db_roles - roles_to_create - roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} - - return _Diff(roles_to_create, roles_to_update, roles_to_delete) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the role cache of `guild`.""" - log.trace("Syncing created roles...") - for role in diff.created: - await self.bot.api_client.post('bot/roles', json=role._asdict()) - - log.trace("Syncing updated roles...") - for role in diff.updated: - await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) - - log.trace("Syncing deleted roles...") - for role in diff.deleted: - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - -class UserSyncer(Syncer): - """Synchronise the database with users in the cache.""" - - name = "user" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of users between the cache of `guild` and the database.""" - log.trace("Getting the diff for users.") - users = await self.bot.api_client.get('bot/users') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_users = { - user_dict['id']: _User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in users - } - guild_users = { - member.id: _User( - id=member.id, - name=member.name, - discriminator=int(member.discriminator), - roles=tuple(sorted(role.id for role in member.roles)), - in_guild=True - ) - for member in guild.members - } - - users_to_create = set() - users_to_update = set() - - for db_user in db_users.values(): - guild_user = guild_users.get(db_user.id) - if guild_user is not None: - if db_user != guild_user: - users_to_update.add(guild_user) - - elif db_user.in_guild: - # The user is known in the DB but not the guild, and the - # DB currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - new_api_user = db_user._replace(in_guild=False) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(db_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) - - return _Diff(users_to_create, users_to_update, None) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the user cache of `guild`.""" - log.trace("Syncing created users...") - for user in diff.created: - await self.bot.api_client.post('bot/users', json=user._asdict()) - - log.trace("Syncing updated users...") - for user in diff.updated: - await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py deleted file mode 100644 index 3d76c5c08..000000000 --- a/bot/cogs/tags.py +++ /dev/null @@ -1,277 +0,0 @@ -import logging -import re -import time -from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot import constants -from bot.bot import Bot -from bot.converters import TagNameConverter -from bot.pagination import LinePaginator -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -TEST_CHANNELS = ( - constants.Channels.bot_commands, - constants.Channels.helpers -) - -REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) -FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." - - -class Tags(Cog): - """Save new tags and fetch existing tags.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.tag_cooldowns = {} - self._cache = self.get_tags() - - @staticmethod - def get_tags() -> dict: - """Get all tags.""" - cache = {} - - base_path = Path("bot", "resources", "tags") - for file in base_path.glob("**/*"): - if file.is_file(): - tag_title = file.stem - tag = { - "title": tag_title, - "embed": { - "description": file.read_text(encoding="utf8"), - }, - "restricted_to": "developers", - "location": f"/bot/{file}" - } - - # Convert to a list to allow negative indexing. - parents = list(file.relative_to(base_path).parents) - if len(parents) > 1: - # -1 would be '.' hence -2 is used as the index. - tag["restricted_to"] = parents[-2].name - - cache[tag_title] = tag - - return cache - - @staticmethod - def check_accessibility(user: Member, tag: dict) -> bool: - """Check if user can access a tag.""" - return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] - - @staticmethod - def _fuzzy_search(search: str, target: str) -> float: - """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" - current, index = 0, 0 - _search = REGEX_NON_ALPHABET.sub('', search.lower()) - _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) - _target = next(_targets) - try: - while True: - while index < len(_target) and _search[current] == _target[index]: - current += 1 - index += 1 - index, _target = 0, next(_targets) - except (StopIteration, IndexError): - pass - return current / len(_search) * 100 - - def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: - """Return a list of suggested tags.""" - scores: Dict[str, int] = { - tag_title: Tags._fuzzy_search(tag_name, tag['title']) - for tag_title, tag in self._cache.items() - } - - thresholds = thresholds or [100, 90, 80, 70, 60] - - for threshold in thresholds: - suggestions = [ - self._cache[tag_title] - for tag_title, matching_score in scores.items() - if matching_score >= threshold - ] - if suggestions: - return suggestions - - return [] - - def _get_tag(self, tag_name: str) -> list: - """Get a specific tag.""" - found = [self._cache.get(tag_name.lower(), None)] - if not found[0]: - return self._get_suggestions(tag_name) - return found - - def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: - """ - Search for tags via contents. - - `predicate` will be the built-in any, all, or a custom callable. Must return a bool. - """ - keywords_processed: List[str] = [] - for keyword in keywords.split(','): - keyword_sanitized = keyword.strip().casefold() - if not keyword_sanitized: - # this happens when there are leading / trailing / consecutive comma. - continue - keywords_processed.append(keyword_sanitized) - - if not keywords_processed: - # after sanitizing, we can end up with an empty list, for example when keywords is ',' - # in that case, we simply want to search for such keywords directly instead. - keywords_processed = [keywords] - - matching_tags = [] - for tag in self._cache.values(): - matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) - if self.check_accessibility(user, tag) and check(matches): - matching_tags.append(tag) - - return matching_tags - - async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: - """Send the result of matching tags to user.""" - if not matching_tags: - pass - elif len(matching_tags) == 1: - await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) - else: - is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 - embed = Embed( - title=f"Here are the tags containing the given keyword{'s' * is_plural}:", - description='\n'.join(tag['title'] for tag in matching_tags[:10]) - ) - await LinePaginator.paginate( - sorted(f"**»** {tag['title']}" for tag in matching_tags), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - - @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) - async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: - """Show all known tags, a single tag, or run a subcommand.""" - await ctx.invoke(self.get_command, tag_name=tag_name) - - @tags_group.group(name='search', invoke_without_command=True) - async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Only search for tags that has ALL the keywords. - """ - matching_tags = self._get_tags_via_content(all, keywords, ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @search_tag_content.command(name='any') - async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Search for tags that has ANY of the keywords. - """ - matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @tags_group.command(name='get', aliases=('show', 'g')) - async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: - """Get a specified tag, or a list of all tags if no tag is specified.""" - - def _command_on_cooldown(tag_name: str) -> bool: - """ - Check if the command is currently on cooldown, on a per-tag, per-channel basis. - - The cooldown duration is set in constants.py. - """ - now = time.time() - - cooldown_conditions = ( - tag_name - and tag_name in self.tag_cooldowns - and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags - and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id - ) - - if cooldown_conditions: - return True - return False - - if _command_on_cooldown(tag_name): - time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] - time_left = constants.Cooldowns.tags - time_elapsed - log.info( - f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " - f"Cooldown ends in {time_left:.1f} seconds." - ) - return - - if tag_name is not None: - temp_founds = self._get_tag(tag_name) - - founds = [] - - for found_tag in temp_founds: - if self.check_accessibility(ctx.author, found_tag): - founds.append(found_tag) - - if len(founds) == 1: - tag = founds[0] - if ctx.channel.id not in TEST_CHANNELS: - self.tag_cooldowns[tag_name] = { - "time": time.time(), - "channel": ctx.channel.id - } - - self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") - - await wait_for_deletion( - await ctx.send(embed=Embed.from_dict(tag['embed'])), - [ctx.author.id], - client=self.bot - ) - elif founds and len(tag_name) >= 3: - await wait_for_deletion( - await ctx.send( - embed=Embed( - title='Did you mean ...', - description='\n'.join(tag['title'] for tag in founds[:10]) - ) - ), - [ctx.author.id], - client=self.bot - ) - - else: - tags = self._cache.values() - if not tags: - await ctx.send(embed=Embed( - description="**There are no tags in the database!**", - colour=Colour.red() - )) - else: - embed: Embed = Embed(title="**Current tags**") - await LinePaginator.paginate( - sorted( - f"**»** {tag['title']}" for tag in tags - if self.check_accessibility(ctx.author, tag) - ), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Tags cog.""" - bot.add_cog(Tags(bot)) diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py deleted file mode 100644 index ef979f222..000000000 --- a/bot/cogs/token_remover.py +++ /dev/null @@ -1,182 +0,0 @@ -import base64 -import binascii -import logging -import re -import typing as t - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot import utils -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import Channels, Colours, Event, Icons - -log = logging.getLogger(__name__) - -LOG_MESSAGE = ( - "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, " - "token was `{user_id}.{timestamp}.{hmac}`" -) -DELETION_MESSAGE_TEMPLATE = ( - "Hey {mention}! I noticed you posted a seemingly valid Discord API " - "token in your message and have removed your message. " - "This means that your token has been **compromised**. " - "Please change your token **immediately** at: " - "\n\n" - "Feel free to re-post it with the token removed. " - "If you believe this was a mistake, please let us know!" -) -DISCORD_EPOCH = 1_420_070_400 -TOKEN_EPOCH = 1_293_840_000 - -# Three parts delimited by dots: user ID, creation timestamp, HMAC. -# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string. -# Each part only matches base64 URL-safe characters. -# Padding has never been observed, but the padding character '=' is matched just in case. -TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII) - - -class Token(t.NamedTuple): - """A Discord Bot token.""" - - user_id: str - timestamp: str - hmac: str - - -class TokenRemover(Cog): - """Scans messages for potential discord.py bot tokens and removes them.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """ - Check each message for a string that matches Discord's token pattern. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - # Ignore DMs; can't delete messages in there anyway. - if not msg.guild or msg.author.bot: - return - - found_token = self.find_token_in_message(msg) - if found_token: - await self.take_action(msg, found_token) - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """ - Check each edit for a string that matches Discord's token pattern. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - await self.on_message(after) - - async def take_action(self, msg: Message, found_token: Token) -> None: - """Remove the `msg` containing the `found_token` and send a mod log message.""" - self.mod_log.ignore(Event.message_delete, msg.id) - - try: - await msg.delete() - except NotFound: - log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") - return - - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - log_message = self.format_log_message(msg, found_token) - log.debug(log_message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=log_message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) - - self.bot.stats.incr("tokens.removed_tokens") - - @staticmethod - def format_log_message(msg: Message, token: Token) -> str: - """Return the log message to send for `token` being censored in `msg`.""" - return LOG_MESSAGE.format( - author=msg.author, - author_id=msg.author.id, - channel=msg.channel.mention, - user_id=token.user_id, - timestamp=token.timestamp, - hmac='x' * len(token.hmac), - ) - - @classmethod - def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: - """Return a seemingly valid token found in `msg` or `None` if no token is found.""" - # Use finditer rather than search to guard against method calls prematurely returning the - # token check (e.g. `message.channel.send` also matches our token pattern) - for match in TOKEN_RE.finditer(msg.content): - token = Token(*match.groups()) - if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp): - # Short-circuit on first match - return token - - # No matching substring - return - - @staticmethod - def is_valid_user_id(b64_content: str) -> bool: - """ - Check potential token to see if it contains a valid Discord user ID. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - b64_content = utils.pad_base64(b64_content) - - try: - decoded_bytes = base64.urlsafe_b64decode(b64_content) - string = decoded_bytes.decode('utf-8') - - # isdigit on its own would match a lot of other Unicode characters, hence the isascii. - return string.isascii() and string.isdigit() - except (binascii.Error, ValueError): - return False - - @staticmethod - def is_valid_timestamp(b64_content: str) -> bool: - """ - Return True if `b64_content` decodes to a valid timestamp. - - If the timestamp is greater than the Discord epoch, it's probably valid. - See: https://i.imgur.com/7WdehGn.png - """ - b64_content = utils.pad_base64(b64_content) - - try: - decoded_bytes = base64.urlsafe_b64decode(b64_content) - timestamp = int.from_bytes(decoded_bytes, byteorder="big") - except (binascii.Error, ValueError) as e: - log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") - return False - - # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound - # is not checked. - if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH: - return True - else: - log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch") - return False - - -def setup(bot: Bot) -> None: - """Load the TokenRemover cog.""" - bot.add_cog(TokenRemover(bot)) diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py deleted file mode 100644 index d96abbd5a..000000000 --- a/bot/cogs/utils.py +++ /dev/null @@ -1,265 +0,0 @@ -import difflib -import logging -import re -import unicodedata -from email.parser import HeaderParser -from io import StringIO -from typing import Tuple, Union - -from discord import Colour, Embed, utils -from discord.ext.commands import BadArgument, Cog, Context, clean_content, command - -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES -from bot.decorators import in_whitelist, with_role -from bot.pagination import LinePaginator -from bot.utils import messages - -log = logging.getLogger(__name__) - -ZEN_OF_PYTHON = """\ -Beautiful is better than ugly. -Explicit is better than implicit. -Simple is better than complex. -Complex is better than complicated. -Flat is better than nested. -Sparse is better than dense. -Readability counts. -Special cases aren't special enough to break the rules. -Although practicality beats purity. -Errors should never pass silently. -Unless explicitly silenced. -In the face of ambiguity, refuse the temptation to guess. -There should be one-- and preferably only one --obvious way to do it. -Although that way may not be obvious at first unless you're Dutch. -Now is better than never. -Although never is often better than *right* now. -If the implementation is hard to explain, it's a bad idea. -If the implementation is easy to explain, it may be a good idea. -Namespaces are one honking great idea -- let's do more of those! -""" - -ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" - - -class Utils(Cog): - """A selection of utilities which don't have a clear category.""" - - def __init__(self, bot: Bot): - self.bot = bot - - self.base_pep_url = "http://www.python.org/dev/peps/pep-" - self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" - - @command(name='pep', aliases=('get_pep', 'p')) - async def pep_command(self, ctx: Context, pep_number: str) -> None: - """Fetches information about a PEP and sends it to the channel.""" - if pep_number.isdigit(): - pep_number = int(pep_number) - else: - await ctx.send_help(ctx.command) - return - - # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. - if pep_number == 0: - return await self.send_pep_zero(ctx) - - possible_extensions = ['.txt', '.rst'] - found_pep = False - for extension in possible_extensions: - # Attempt to fetch the PEP - pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" - log.trace(f"Requesting PEP {pep_number} with {pep_url}") - response = await self.bot.http_session.get(pep_url) - - if response.status == 200: - log.trace("PEP found") - found_pep = True - - pep_content = await response.text() - - # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 - pep_header = HeaderParser().parse(StringIO(pep_content)) - - # Assemble the embed - pep_embed = Embed( - title=f"**PEP {pep_number} - {pep_header['Title']}**", - description=f"[Link]({self.base_pep_url}{pep_number:04})", - ) - - pep_embed.set_thumbnail(url=ICON_URL) - - # Add the interesting information - fields_to_check = ("Status", "Python-Version", "Created", "Type") - for field in fields_to_check: - # Check for a PEP metadata field that is present but has an empty value - # embed field values can't contain an empty string - if pep_header.get(field, ""): - pep_embed.add_field(name=field, value=pep_header[field]) - - elif response.status != 404: - # any response except 200 and 404 is expected - found_pep = True # actually not, but it's easier to display this way - log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " - f"{response.status}.\n{response.text}") - - error_message = "Unexpected HTTP error during PEP search. Please let us know." - pep_embed = Embed(title="Unexpected error", description=error_message) - pep_embed.colour = Colour.red() - break - - if not found_pep: - log.trace("PEP was not found") - not_found = f"PEP {pep_number} does not exist." - pep_embed = Embed(title="PEP not found", description=not_found) - pep_embed.colour = Colour.red() - - await ctx.message.channel.send(embed=pep_embed) - - @command() - @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) - async def charinfo(self, ctx: Context, *, characters: str) -> None: - """Shows you information on up to 50 unicode characters.""" - match = re.match(r"<(a?):(\w+):(\d+)>", characters) - if match: - return await messages.send_denial( - ctx, - "**Non-Character Detected**\n" - "Only unicode characters can be processed, but a custom Discord emoji " - "was found. Please remove it and try again." - ) - - if len(characters) > 50: - return await messages.send_denial(ctx, f"Too many characters ({len(characters)}/50)") - - def get_info(char: str) -> Tuple[str, str]: - digit = f"{ord(char):x}" - if len(digit) <= 4: - u_code = f"\\u{digit:>04}" - else: - u_code = f"\\U{digit:>08}" - url = f"https://www.compart.com/en/unicode/U+{digit:>04}" - name = f"[{unicodedata.name(char, '')}]({url})" - info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}" - return info, u_code - - char_list, raw_list = zip(*(get_info(c) for c in characters)) - embed = Embed().set_author(name="Character Info") - - if len(characters) > 1: - # Maximum length possible is 502 out of 1024, so there's no need to truncate. - embed.add_field(name='Full Raw Text', value=f"`{''.join(raw_list)}`", inline=False) - - await LinePaginator.paginate(char_list, ctx, embed, max_lines=10, max_size=2000, empty=False) - - @command() - async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: - """ - Show the Zen of Python. - - Without any arguments, the full Zen will be produced. - If an integer is provided, the line with that index will be produced. - If a string is provided, the line which matches best will be produced. - """ - embed = Embed( - colour=Colour.blurple(), - title="The Zen of Python", - description=ZEN_OF_PYTHON - ) - - if search_value is None: - embed.title += ", by Tim Peters" - await ctx.send(embed=embed) - return - - zen_lines = ZEN_OF_PYTHON.splitlines() - - # handle if it's an index int - if isinstance(search_value, int): - upper_bound = len(zen_lines) - 1 - lower_bound = -1 * upper_bound - if not (lower_bound <= search_value <= upper_bound): - raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") - - embed.title += f" (line {search_value % len(zen_lines)}):" - embed.description = zen_lines[search_value] - await ctx.send(embed=embed) - return - - # Try to handle first exact word due difflib.SequenceMatched may use some other similar word instead - # exact word. - for i, line in enumerate(zen_lines): - for word in line.split(): - if word.lower() == search_value.lower(): - embed.title += f" (line {i}):" - embed.description = line - await ctx.send(embed=embed) - return - - # handle if it's a search string and not exact word - matcher = difflib.SequenceMatcher(None, search_value.lower()) - - best_match = "" - match_index = 0 - best_ratio = 0 - - for index, line in enumerate(zen_lines): - matcher.set_seq2(line.lower()) - - # the match ratio needs to be adjusted because, naturally, - # longer lines will have worse ratios than shorter lines when - # fuzzy searching for keywords. this seems to work okay. - adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() - - if adjusted_ratio > best_ratio: - best_ratio = adjusted_ratio - best_match = line - match_index = index - - if not best_match: - raise BadArgument("I didn't get a match! Please try again with a different search term.") - - embed.title += f" (line {match_index}):" - embed.description = best_match - await ctx.send(embed=embed) - - @command(aliases=("poll",)) - @with_role(*MODERATION_ROLES) - async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None: - """ - Build a quick voting poll with matching reactions with the provided options. - - A maximum of 20 options can be provided, as Discord supports a max of 20 - reactions on a single message. - """ - if len(title) > 256: - raise BadArgument("The title cannot be longer than 256 characters.") - if len(options) < 2: - raise BadArgument("Please provide at least 2 options.") - if len(options) > 20: - raise BadArgument("I can only handle 20 options!") - - codepoint_start = 127462 # represents "regional_indicator_a" unicode value - options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} - embed = Embed(title=title, description="\n".join(options.values())) - message = await ctx.send(embed=embed) - for reaction in options: - await message.add_reaction(reaction) - - async def send_pep_zero(self, ctx: Context) -> None: - """Send information about PEP 0.""" - pep_embed = Embed( - title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**", - description="[Link](https://www.python.org/dev/peps/)" - ) - pep_embed.set_thumbnail(url=ICON_URL) - pep_embed.add_field(name="Status", value="Active") - pep_embed.add_field(name="Created", value="13-Jul-2000") - pep_embed.add_field(name="Type", value="Informational") - - await ctx.send(embed=pep_embed) - - -def setup(bot: Bot) -> None: - """Load the Utils cog.""" - bot.add_cog(Utils(bot)) diff --git a/bot/cogs/utils/__init__.py b/bot/cogs/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/cogs/utils/bot.py b/bot/cogs/utils/bot.py new file mode 100644 index 000000000..71ed54f60 --- /dev/null +++ b/bot/cogs/utils/bot.py @@ -0,0 +1,385 @@ +import ast +import logging +import re +import time +from typing import Optional, Tuple + +from discord import Embed, Message, RawMessageUpdateEvent, TextChannel +from discord.ext.commands import Cog, Context, command, group + +from bot.bot import Bot +from bot.cogs.filters.token_remover import TokenRemover +from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs +from bot.decorators import with_role +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +RE_MARKDOWN = re.compile(r'([*_~`|>])') + + +class BotCog(Cog, name="Bot"): + """Bot information commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + # Stores allowed channels plus epoch time since last call. + self.channel_cooldowns = { + Channels.python_discussion: 0, + } + + # These channels will also work, but will not be subject to cooldown + self.channel_whitelist = ( + Channels.bot_commands, + ) + + # Stores improperly formatted Python codeblock message ids and the corresponding bot message + self.codeblock_message_ids = {} + + @group(invoke_without_command=True, name="bot", hidden=True) + @with_role(Roles.verified) + async def botinfo_group(self, ctx: Context) -> None: + """Bot informational commands.""" + await ctx.send_help(ctx.command) + + @botinfo_group.command(name='about', aliases=('info',), hidden=True) + @with_role(Roles.verified) + async def about_command(self, ctx: Context) -> None: + """Get information about the bot.""" + embed = Embed( + description="A utility bot designed just for the Python server! Try `!help` for more info.", + url="https://github.com/python-discord/bot" + ) + + embed.add_field(name="Total Users", value=str(len(self.bot.get_guild(Guild.id).members))) + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=URLs.bot_avatar + ) + + await ctx.send(embed=embed) + + @command(name='echo', aliases=('print',)) + @with_role(*MODERATION_ROLES) + async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Repeat the given message in either a specified channel or the current channel.""" + if channel is None: + await ctx.send(text) + else: + await channel.send(text) + + @command(name='embed') + @with_role(*MODERATION_ROLES) + async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Send the input within an embed to either a specified channel or the current channel.""" + embed = Embed(description=text) + + if channel is None: + await ctx.send(embed=embed) + else: + await channel.send(embed=embed) + + def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: + """ + Strip msg in order to find Python code. + + Tries to strip out Python code out of msg and returns the stripped block or + None if the block is a valid Python codeblock. + """ + if msg.count("\n") >= 3: + # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. + if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: + log.trace( + "Someone wrote a message that was already a " + "valid Python syntax highlighted code block. No action taken." + ) + return None + + else: + # Stripping backticks from every line of the message. + log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") + content = "" + for line in msg.splitlines(keepends=True): + content += line.strip("`") + + content = content.strip() + + # Remove "Python" or "Py" from start of the message if it exists. + log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") + pycode = False + if content.lower().startswith("python"): + content = content[6:] + pycode = True + elif content.lower().startswith("py"): + content = content[2:] + pycode = True + + if pycode: + content = content.splitlines(keepends=True) + + # Check if there might be code in the first line, and preserve it. + first_line = content[0] + if " " in content[0]: + first_space = first_line.index(" ") + content[0] = first_line[first_space:] + content = "".join(content) + + # If there's no code we can just get rid of the first line. + else: + content = "".join(content[1:]) + + # Strip it again to remove any leading whitespace. This is neccessary + # if the first line of the message looked like ```python + old = content.strip() + + # Strips REPL code out of the message if there is any. + content, repl_code = self.repl_stripping(old) + if old != content: + return (content, old), repl_code + + # Try to apply indentation fixes to the code. + content = self.fix_indentation(content) + + # Check if the code contains backticks, if it does ignore the message. + if "`" in content: + log.trace("Detected ` inside the code, won't reply") + return None + else: + log.trace(f"Returning message.\n\n{content}\n\n") + return (content,), repl_code + + def fix_indentation(self, msg: str) -> str: + """Attempts to fix badly indented code.""" + def unindent(code: str, skip_spaces: int = 0) -> str: + """Unindents all code down to the number of spaces given in skip_spaces.""" + final = "" + current = code[0] + leading_spaces = 0 + + # Get numbers of spaces before code in the first line. + while current == " ": + current = code[leading_spaces + 1] + leading_spaces += 1 + leading_spaces -= skip_spaces + + # If there are any, remove that number of spaces from every line. + if leading_spaces > 0: + for line in code.splitlines(keepends=True): + line = line[leading_spaces:] + final += line + return final + else: + return code + + # Apply fix for "all lines are overindented" case. + msg = unindent(msg) + + # If the first line does not end with a colon, we can be + # certain the next line will be on the same indentation level. + # + # If it does end with a colon, we will need to indent all successive + # lines one additional level. + first_line = msg.splitlines()[0] + code = "".join(msg.splitlines(keepends=True)[1:]) + if not first_line.endswith(":"): + msg = f"{first_line}\n{unindent(code)}" + else: + msg = f"{first_line}\n{unindent(code, 4)}" + return msg + + def repl_stripping(self, msg: str) -> Tuple[str, bool]: + """ + Strip msg in order to extract Python code out of REPL output. + + Tries to strip out REPL Python code out of msg and returns the stripped msg. + + Returns True for the boolean if REPL code was found in the input msg. + """ + final = "" + for line in msg.splitlines(keepends=True): + if line.startswith(">>>") or line.startswith("..."): + final += line[4:] + log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") + if not final: + log.trace(f"Found no REPL code in \n\n{msg}\n\n") + return msg, False + else: + log.trace(f"Found REPL code in \n\n{msg}\n\n") + return final.rstrip(), True + + def has_bad_ticks(self, msg: Message) -> bool: + """Check to see if msg contains ticks that aren't '`'.""" + not_backticks = [ + "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", + "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", + "\u3003\u3003\u3003" + ] + + return msg.content[:3] in not_backticks + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """ + Detect poorly formatted Python code in new messages. + + If poorly formatted code is detected, send the user a helpful message explaining how to do + properly formatted Python syntax highlighting codeblocks. + """ + is_help_channel = ( + getattr(msg.channel, "category", None) + and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) + ) + parse_codeblock = ( + ( + is_help_channel + or msg.channel.id in self.channel_cooldowns + or msg.channel.id in self.channel_whitelist + ) + and not msg.author.bot + and len(msg.content.splitlines()) > 3 + and not TokenRemover.find_token_in_message(msg) + ) + + if parse_codeblock: # no token in the msg + on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 + if not on_cooldown or DEBUG_MODE: + try: + if self.has_bad_ticks(msg): + ticks = msg.content[:3] + content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) + if content is None: + return + + content, repl_code = content + + if len(content) == 2: + content = content[1] + else: + content = content[0] + + space_left = 204 + if len(content) >= space_left: + current_length = 0 + lines_walked = 0 + for line in content.splitlines(keepends=True): + if current_length + len(line) > space_left or lines_walked == 10: + break + current_length += len(line) + lines_walked += 1 + content = content[:current_length] + "#..." + content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) + howto = ( + "It looks like you are trying to paste code into this channel.\n\n" + "You seem to be using the wrong symbols to indicate where the codeblock should start. " + f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" + "**Here is an example of how it should look:**\n" + f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" + "**This will result in the following:**\n" + f"```python\n{content}\n```" + ) + + else: + howto = "" + content = self.codeblock_stripping(msg.content, False) + if content is None: + return + + content, repl_code = content + # Attempts to parse the message into an AST node. + # Invalid Python code will raise a SyntaxError. + tree = ast.parse(content[0]) + + # Multiple lines of single words could be interpreted as expressions. + # This check is to avoid all nodes being parsed as expressions. + # (e.g. words over multiple lines) + if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: + # Shorten the code to 10 lines and/or 204 characters. + space_left = 204 + if content and repl_code: + content = content[1] + else: + content = content[0] + + if len(content) >= space_left: + current_length = 0 + lines_walked = 0 + for line in content.splitlines(keepends=True): + if current_length + len(line) > space_left or lines_walked == 10: + break + current_length += len(line) + lines_walked += 1 + content = content[:current_length] + "#..." + + content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) + howto += ( + "It looks like you're trying to paste code into this channel.\n\n" + "Discord has support for Markdown, which allows you to post code with full " + "syntax highlighting. Please use these whenever you paste code, as this " + "helps improve the legibility and makes it easier for us to help you.\n\n" + f"**To do this, use the following method:**\n" + f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" + "**This will result in the following:**\n" + f"```python\n{content}\n```" + ) + + log.debug(f"{msg.author} posted something that needed to be put inside python code " + "blocks. Sending the user some instructions.") + else: + log.trace("The code consists only of expressions, not sending instructions") + + if howto != "": + # Increase amount of codeblock correction in stats + self.bot.stats.incr("codeblock_corrections") + howto_embed = Embed(description=howto) + bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) + self.codeblock_message_ids[msg.id] = bot_message.id + + self.bot.loop.create_task( + wait_for_deletion(bot_message, user_ids=(msg.author.id,), client=self.bot) + ) + else: + return + + if msg.channel.id not in self.channel_whitelist: + self.channel_cooldowns[msg.channel.id] = time.time() + + except SyntaxError: + log.trace( + f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " + "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " + f"The message that was posted was:\n\n{msg.content}\n\n" + ) + + @Cog.listener() + async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: + """Check to see if an edited message (previously called out) still contains poorly formatted code.""" + if ( + # Checks to see if the message was called out by the bot + payload.message_id not in self.codeblock_message_ids + # Makes sure that there is content in the message + or payload.data.get("content") is None + # Makes sure there's a channel id in the message payload + or payload.data.get("channel_id") is None + ): + return + + # Retrieve channel and message objects for use later + channel = self.bot.get_channel(int(payload.data.get("channel_id"))) + user_message = await channel.fetch_message(payload.message_id) + + # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None + has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) + + # If the message is fixed, delete the bot message and the entry from the id dictionary + if has_fixed_codeblock is None: + bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) + await bot_message.delete() + del self.codeblock_message_ids[payload.message_id] + log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") + + +def setup(bot: Bot) -> None: + """Load the Bot cog.""" + bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/utils/clean.py b/bot/cogs/utils/clean.py new file mode 100644 index 000000000..f436e531a --- /dev/null +++ b/bot/cogs/utils/clean.py @@ -0,0 +1,272 @@ +import logging +import random +import re +from typing import Iterable, Optional + +from discord import Colour, Embed, Message, TextChannel, User +from discord.ext import commands +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import ( + Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES +) +from bot.decorators import with_role + +log = logging.getLogger(__name__) + + +class Clean(Cog): + """ + A cog that allows messages to be deleted in bulk, while applying various filters. + + You can delete messages sent by a specific user, messages sent by bots, all messages, or messages that match a + specific regular expression. + + The deleted messages are saved and uploaded to the database via an API endpoint, and a URL is returned which can be + used to view the messages in the Discord dark theme style. + """ + + def __init__(self, bot: Bot): + self.bot = bot + self.cleaning = False + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def _clean_messages( + self, + amount: int, + ctx: Context, + channels: Iterable[TextChannel], + bots_only: bool = False, + user: User = None, + regex: Optional[str] = None, + until_message: Optional[Message] = None, + ) -> None: + """A helper function that does the actual message cleaning.""" + def predicate_bots_only(message: Message) -> bool: + """Return True if the message was sent by a bot.""" + return message.author.bot + + def predicate_specific_user(message: Message) -> bool: + """Return True if the message was sent by the user provided in the _clean_messages call.""" + return message.author == user + + def predicate_regex(message: Message) -> bool: + """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" + content = [message.content] + + # Add the content for all embed attributes + for embed in message.embeds: + content.append(embed.title) + content.append(embed.description) + content.append(embed.footer.text) + content.append(embed.author.name) + for field in embed.fields: + content.append(field.name) + content.append(field.value) + + # Get rid of empty attributes and turn it into a string + content = [attr for attr in content if attr] + content = "\n".join(content) + + # Now let's see if there's a regex match + if not content: + return False + else: + return bool(re.search(regex.lower(), content.lower())) + + # Is this an acceptable amount of messages to clean? + if amount > CleanMessages.message_limit: + embed = Embed( + color=Colour(Colours.soft_red), + title=random.choice(NEGATIVE_REPLIES), + description=f"You cannot clean more than {CleanMessages.message_limit} messages." + ) + await ctx.send(embed=embed) + return + + # Are we already performing a clean? + if self.cleaning: + embed = Embed( + color=Colour(Colours.soft_red), + title=random.choice(NEGATIVE_REPLIES), + description="Please wait for the currently ongoing clean operation to complete." + ) + await ctx.send(embed=embed) + return + + # Set up the correct predicate + if bots_only: + predicate = predicate_bots_only # Delete messages from bots + elif user: + predicate = predicate_specific_user # Delete messages from specific user + elif regex: + predicate = predicate_regex # Delete messages that match regex + else: + predicate = None # Delete all messages + + # Default to using the invoking context's channel + if not channels: + channels = [ctx.channel] + + # Delete the invocation first + self.mod_log.ignore(Event.message_delete, ctx.message.id) + await ctx.message.delete() + + messages = [] + message_ids = [] + self.cleaning = True + + # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. + for channel in channels: + async for message in channel.history(limit=amount): + + # If at any point the cancel command is invoked, we should stop. + if not self.cleaning: + return + + # If we are looking for specific message. + if until_message: + + # we could use ID's here however in case if the message we are looking for gets deleted, + # we won't have a way to figure that out thus checking for datetime should be more reliable + if message.created_at < until_message.created_at: + # means we have found the message until which we were supposed to be deleting. + break + + # Since we will be using `delete_messages` method of a TextChannel and we need message objects to + # use it as well as to send logs we will start appending messages here instead adding them from + # purge. + messages.append(message) + + # If the message passes predicate, let's save it. + if predicate is None or predicate(message): + message_ids.append(message.id) + + self.cleaning = False + + # Now let's delete the actual messages with purge. + self.mod_log.ignore(Event.message_delete, *message_ids) + for channel in channels: + if until_message: + for i in range(0, len(messages), 100): + # while purge automatically handles the amount of messages + # delete_messages only allows for up to 100 messages at once + # thus we need to paginate the amount to always be <= 100 + await channel.delete_messages(messages[i:i + 100]) + else: + messages += await channel.purge(limit=amount, check=predicate) + + # Reverse the list to restore chronological order + if messages: + messages = reversed(messages) + log_url = await self.mod_log.upload_log(messages, ctx.author.id) + else: + # Can't build an embed, nothing to clean! + embed = Embed( + color=Colour(Colours.soft_red), + description="No matching messages could be found." + ) + await ctx.send(embed=embed, delete_after=10) + return + + # Build the embed and send it + target_channels = ", ".join(channel.mention for channel in channels) + + message = ( + f"**{len(message_ids)}** messages deleted in {target_channels} by **{ctx.author.name}**\n\n" + f"A log of the deleted messages can be found [here]({log_url})." + ) + + await self.mod_log.send_log_message( + icon_url=Icons.message_bulk_delete, + colour=Colour(Colours.soft_red), + title="Bulk message delete", + text=message, + channel_id=Channels.mod_log, + ) + + @group(invoke_without_command=True, name="clean", aliases=["purge"]) + @with_role(*MODERATION_ROLES) + async def clean_group(self, ctx: Context) -> None: + """Commands for cleaning messages in channels.""" + await ctx.send_help(ctx.command) + + @clean_group.command(name="user", aliases=["users"]) + @with_role(*MODERATION_ROLES) + async def clean_user( + self, + ctx: Context, + user: User, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, user=user, channels=channels) + + @clean_group.command(name="all", aliases=["everything"]) + @with_role(*MODERATION_ROLES) + async def clean_all( + self, + ctx: Context, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, channels=channels) + + @clean_group.command(name="bots", aliases=["bot"]) + @with_role(*MODERATION_ROLES) + async def clean_bots( + self, + ctx: Context, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, bots_only=True, channels=channels) + + @clean_group.command(name="regex", aliases=["word", "expression"]) + @with_role(*MODERATION_ROLES) + async def clean_regex( + self, + ctx: Context, + regex: str, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, regex=regex, channels=channels) + + @clean_group.command(name="message", aliases=["messages"]) + @with_role(*MODERATION_ROLES) + async def clean_message(self, ctx: Context, message: Message) -> None: + """Delete all messages until certain message, stop cleaning after hitting the `message`.""" + await self._clean_messages( + CleanMessages.message_limit, + ctx, + channels=[message.channel], + until_message=message + ) + + @clean_group.command(name="stop", aliases=["cancel", "abort"]) + @with_role(*MODERATION_ROLES) + async def clean_cancel(self, ctx: Context) -> None: + """If there is an ongoing cleaning process, attempt to immediately cancel it.""" + self.cleaning = False + + embed = Embed( + color=Colour.blurple(), + description="Clean interrupted." + ) + await ctx.send(embed=embed, delete_after=10) + + +def setup(bot: Bot) -> None: + """Load the Clean cog.""" + bot.add_cog(Clean(bot)) diff --git a/bot/cogs/utils/eval.py b/bot/cogs/utils/eval.py new file mode 100644 index 000000000..eb8bfb1cf --- /dev/null +++ b/bot/cogs/utils/eval.py @@ -0,0 +1,202 @@ +import contextlib +import inspect +import logging +import pprint +import re +import textwrap +import traceback +from io import StringIO +from typing import Any, Optional, Tuple + +import discord +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Roles +from bot.decorators import with_role +from bot.interpreter import Interpreter + +log = logging.getLogger(__name__) + + +class CodeEval(Cog): + """Owner and admin feature that evaluates code and returns the result to the channel.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.env = {} + self.ln = 0 + self.stdout = StringIO() + + self.interpreter = Interpreter(bot) + + def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: + """Format the eval output into a string & attempt to format it into an Embed.""" + self._ = out + + res = "" + + # Erase temp input we made + if inp.startswith("_ = "): + inp = inp[4:] + + # Get all non-empty lines + lines = [line for line in inp.split("\n") if line.strip()] + if len(lines) != 1: + lines += [""] + + # Create the input dialog + for i, line in enumerate(lines): + if i == 0: + # Start dialog + start = f"In [{self.ln}]: " + + else: + # Indent the 3 dots correctly; + # Normally, it's something like + # In [X]: + # ...: + # + # But if it's + # In [XX]: + # ...: + # + # You can see it doesn't look right. + # This code simply indents the dots + # far enough to align them. + # we first `str()` the line number + # then we get the length + # and use `str.rjust()` + # to indent it. + start = "...: ".rjust(len(str(self.ln)) + 7) + + if i == len(lines) - 2: + if line.startswith("return"): + line = line[6:].strip() + + # Combine everything + res += (start + line + "\n") + + self.stdout.seek(0) + text = self.stdout.read() + self.stdout.close() + self.stdout = StringIO() + + if text: + res += (text + "\n") + + if out is None: + # No output, return the input statement + return (res, None) + + res += f"Out[{self.ln}]: " + + if isinstance(out, discord.Embed): + # We made an embed? Send that as embed + res += "" + res = (res, out) + + else: + if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): + # Leave out the traceback message + out = "\n" + "\n".join(out.split("\n")[1:]) + + if isinstance(out, str): + pretty = out + else: + pretty = pprint.pformat(out, compact=True, width=60) + + if pretty != str(out): + # We're using the pretty version, start on the next line + res += "\n" + + if pretty.count("\n") > 20: + # Text too long, shorten + li = pretty.split("\n") + + pretty = ("\n".join(li[:3]) # First 3 lines + + "\n ...\n" # Ellipsis to indicate removed lines + + "\n".join(li[-3:])) # last 3 lines + + # Add the output + res += pretty + res = (res, None) + + return res # Return (text, embed) + + async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: + """Eval the input code string & send an embed to the invoking context.""" + self.ln += 1 + + if code.startswith("exit"): + self.ln = 0 + self.env = {} + return await ctx.send("```Reset history!```") + + env = { + "message": ctx.message, + "author": ctx.message.author, + "channel": ctx.channel, + "guild": ctx.guild, + "ctx": ctx, + "self": self, + "bot": self.bot, + "inspect": inspect, + "discord": discord, + "contextlib": contextlib + } + + self.env.update(env) + + # Ignore this code, it works + code_ = """ +async def func(): # (None,) -> Any + try: + with contextlib.redirect_stdout(self.stdout): +{0} + if '_' in locals(): + if inspect.isawaitable(_): + _ = await _ + return _ + finally: + self.env.update(locals()) +""".format(textwrap.indent(code, ' ')) + + try: + exec(code_, self.env) # noqa: B102,S102 + func = self.env['func'] + res = await func() + + except Exception: + res = traceback.format_exc() + + out, embed = self._format(code, res) + await ctx.send(f"```py\n{out}```", embed=embed) + + @group(name='internal', aliases=('int',)) + @with_role(Roles.owners, Roles.admins) + async def internal_group(self, ctx: Context) -> None: + """Internal commands. Top secret!""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @internal_group.command(name='eval', aliases=('e',)) + @with_role(Roles.admins, Roles.owners) + async def eval(self, ctx: Context, *, code: str) -> None: + """Run eval in a REPL-like format.""" + code = code.strip("`") + if re.match('py(thon)?\n', code): + code = "\n".join(code.split("\n")[1:]) + + if not re.search( # Check if it's an expression + r"^(return|import|for|while|def|class|" + r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( + code.split("\n")) == 1: + code = "_ = " + code + + await self._eval(ctx, code) + + +def setup(bot: Bot) -> None: + """Load the CodeEval cog.""" + bot.add_cog(CodeEval(bot)) diff --git a/bot/cogs/utils/extensions.py b/bot/cogs/utils/extensions.py new file mode 100644 index 000000000..365f198ff --- /dev/null +++ b/bot/cogs/utils/extensions.py @@ -0,0 +1,236 @@ +import functools +import logging +import typing as t +from enum import Enum +from pkgutil import iter_modules + +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Context, group + +from bot.bot import Bot +from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + +UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"} +EXTENSIONS = frozenset( + ext.name + for ext in iter_modules(("bot/cogs",), "bot.cogs.") + if ext.name[-1] != "_" +) + + +class Action(Enum): + """Represents an action to perform on an extension.""" + + # Need to be partial otherwise they are considered to be function definitions. + LOAD = functools.partial(Bot.load_extension) + UNLOAD = functools.partial(Bot.unload_extension) + RELOAD = functools.partial(Bot.reload_extension) + + +class Extension(commands.Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if "." not in argument: + argument = f"bot.cogs.{argument}" + + if argument in EXTENSIONS: + return argument + else: + raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") + + +class Extensions(commands.Cog): + """Extension management commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) + async def extensions_group(self, ctx: Context) -> None: + """Load, unload, reload, and list loaded extensions.""" + await ctx.send_help(ctx.command) + + @extensions_group.command(name="load", aliases=("l",)) + async def load_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Load extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "*" in extensions or "**" in extensions: + extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + + msg = self.batch_manage(Action.LOAD, *extensions) + await ctx.send(msg) + + @extensions_group.command(name="unload", aliases=("ul",)) + async def unload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Unload currently loaded extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + + if blacklisted: + msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" + else: + if "*" in extensions or "**" in extensions: + extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + + msg = self.batch_manage(Action.UNLOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="reload", aliases=("r",)) + async def reload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Reload extensions given their fully qualified or unqualified names. + + If an extension fails to be reloaded, it will be rolled-back to the prior working state. + + If '\*' is given as the name, all currently loaded extensions will be reloaded. + If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "**" in extensions: + extensions = EXTENSIONS + elif "*" in extensions: + extensions = set(self.bot.extensions.keys()) | set(extensions) + extensions.remove("*") + + msg = self.batch_manage(Action.RELOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="list", aliases=("all",)) + async def list_command(self, ctx: Context) -> None: + """ + Get a list of all extensions, including their loaded status. + + Grey indicates that the extension is unloaded. + Green indicates that the extension is currently loaded. + """ + embed = Embed() + lines = [] + + embed.colour = Colour.blurple() + embed.set_author( + name="Extensions List", + url=URLs.github_bot_repo, + icon_url=URLs.bot_avatar + ) + + for ext in sorted(list(EXTENSIONS)): + if ext in self.bot.extensions: + status = Emojis.status_online + else: + status = Emojis.status_offline + + ext = ext.rsplit(".", 1)[1] + lines.append(f"{status} {ext}") + + log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") + await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) + + def batch_manage(self, action: Action, *extensions: str) -> str: + """ + Apply an action to multiple extensions and return a message with the results. + + If only one extension is given, it is deferred to `manage()`. + """ + if len(extensions) == 1: + msg, _ = self.manage(action, extensions[0]) + return msg + + verb = action.name.lower() + failures = {} + + for extension in extensions: + _, error = self.manage(action, extension) + if error: + failures[extension] = error + + emoji = ":x:" if failures else ":ok_hand:" + msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." + + if failures: + failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) + msg += f"\nFailures:```{failures}```" + + log.debug(f"Batch {verb}ed extensions.") + + return msg + + def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: + """Apply an action to an extension and return the status message and any error message.""" + verb = action.name.lower() + error_msg = None + + try: + action.value(self.bot, ext) + except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): + if action is Action.RELOAD: + # When reloading, just load the extension if it was not loaded. + return self.manage(Action.LOAD, ext) + + msg = f":x: Extension `{ext}` is already {verb}ed." + log.debug(msg[4:]) + except Exception as e: + if hasattr(e, "original"): + e = e.original + + log.exception(f"Extension '{ext}' failed to {verb}.") + + error_msg = f"{e.__class__.__name__}: {e}" + msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" + else: + msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." + log.debug(msg[10:]) + + return msg, error_msg + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators and core developers to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Handle BadArgument errors locally to prevent the help command from showing.""" + if isinstance(error, commands.BadArgument): + await ctx.send(str(error)) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Extensions cog.""" + bot.add_cog(Extensions(bot)) diff --git a/bot/cogs/utils/jams.py b/bot/cogs/utils/jams.py new file mode 100644 index 000000000..b3102db2f --- /dev/null +++ b/bot/cogs/utils/jams.py @@ -0,0 +1,150 @@ +import logging +import typing as t + +from discord import CategoryChannel, Guild, Member, PermissionOverwrite, Role +from discord.ext import commands +from more_itertools import unique_everseen + +from bot.bot import Bot +from bot.constants import Roles +from bot.decorators import with_role + +log = logging.getLogger(__name__) + +MAX_CHANNELS = 50 +CATEGORY_NAME = "Code Jam" + + +class CodeJams(commands.Cog): + """Manages the code-jam related parts of our server.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @commands.command() + @with_role(Roles.admins) + async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: + """ + Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. + + The first user passed will always be the team leader. + """ + # Ignore duplicate members + members = list(unique_everseen(members)) + + # We had a little issue during Code Jam 4 here, the greedy converter did it's job + # and ignored anything which wasn't a valid argument which left us with teams of + # two members or at some times even 1 member. This fixes that by checking that there + # are always 3 members in the members list. + if len(members) < 3: + await ctx.send( + ":no_entry_sign: One of your arguments was invalid\n" + f"There must be a minimum of 3 valid members in your team. Found: {len(members)}" + " members" + ) + return + + team_channel = await self.create_channels(ctx.guild, team_name, members) + await self.add_roles(ctx.guild, members) + + await ctx.send( + f":ok_hand: Team created: {team_channel}\n" + f"**Team Leader:** {members[0].mention}\n" + f"**Team Members:** {' '.join(member.mention for member in members[1:])}" + ) + + async def get_category(self, guild: Guild) -> CategoryChannel: + """ + Return a code jam category. + + If all categories are full or none exist, create a new category. + """ + for category in guild.categories: + # Need 2 available spaces: one for the text channel and one for voice. + if category.name == CATEGORY_NAME and MAX_CHANNELS - len(category.channels) >= 2: + return category + + return await self.create_category(guild) + + @staticmethod + async def create_category(guild: Guild) -> CategoryChannel: + """Create a new code jam category and return it.""" + log.info("Creating a new code jam category.") + + category_overwrites = { + guild.default_role: PermissionOverwrite(read_messages=False), + guild.me: PermissionOverwrite(read_messages=True) + } + + return await guild.create_category_channel( + CATEGORY_NAME, + overwrites=category_overwrites, + reason="It's code jam time!" + ) + + @staticmethod + def get_overwrites(members: t.List[Member], guild: Guild) -> t.Dict[t.Union[Member, Role], PermissionOverwrite]: + """Get code jam team channels permission overwrites.""" + # First member is always the team leader + team_channel_overwrites = { + members[0]: PermissionOverwrite( + manage_messages=True, + read_messages=True, + manage_webhooks=True, + connect=True + ), + guild.default_role: PermissionOverwrite(read_messages=False, connect=False), + guild.get_role(Roles.verified): PermissionOverwrite( + read_messages=False, + connect=False + ) + } + + # Rest of members should just have read_messages + for member in members[1:]: + team_channel_overwrites[member] = PermissionOverwrite( + read_messages=True, + connect=True + ) + + return team_channel_overwrites + + async def create_channels(self, guild: Guild, team_name: str, members: t.List[Member]) -> str: + """Create team text and voice channels. Return the mention for the text channel.""" + # Get permission overwrites and category + team_channel_overwrites = self.get_overwrites(members, guild) + code_jam_category = await self.get_category(guild) + + # Create a text channel for the team + team_channel = await guild.create_text_channel( + team_name, + overwrites=team_channel_overwrites, + category=code_jam_category + ) + + # Create a voice channel for the team + team_voice_name = " ".join(team_name.split("-")).title() + + await guild.create_voice_channel( + team_voice_name, + overwrites=team_channel_overwrites, + category=code_jam_category + ) + + return team_channel.mention + + @staticmethod + async def add_roles(guild: Guild, members: t.List[Member]) -> None: + """Assign team leader and jammer roles.""" + # Assign team leader role + await members[0].add_roles(guild.get_role(Roles.team_leaders)) + + # Assign rest of roles + jammer_role = guild.get_role(Roles.jammers) + for member in members: + await member.add_roles(jammer_role) + + +def setup(bot: Bot) -> None: + """Load the CodeJams cog.""" + bot.add_cog(CodeJams(bot)) diff --git a/bot/cogs/utils/reminders.py b/bot/cogs/utils/reminders.py new file mode 100644 index 000000000..670493bcf --- /dev/null +++ b/bot/cogs/utils/reminders.py @@ -0,0 +1,427 @@ +import asyncio +import logging +import random +import textwrap +import typing as t +from datetime import datetime, timedelta +from operator import itemgetter + +import discord +from dateutil.parser import isoparse +from dateutil.relativedelta import relativedelta +from discord.ext.commands import Cog, Context, Greedy, group + +from bot.bot import Bot +from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES +from bot.converters import Duration +from bot.pagination import LinePaginator +from bot.utils.checks import without_role_check +from bot.utils.messages import send_denial +from bot.utils.scheduling import Scheduler +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +WHITELISTED_CHANNELS = Guild.reminder_whitelist +MAXIMUM_REMINDERS = 5 + +Mentionable = t.Union[discord.Member, discord.Role] + + +class Reminders(Cog): + """Provide in-channel reminder functionality.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_reminders()) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + async def reschedule_reminders(self) -> None: + """Get all current reminders from the API and reschedule them.""" + await self.bot.wait_until_guild_available() + response = await self.bot.api_client.get( + 'bot/reminders', + params={'active': 'true'} + ) + + now = datetime.utcnow() + + for reminder in response: + is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) + if not is_valid: + continue + + remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) + + # If the reminder is already overdue ... + if remind_at < now: + late = relativedelta(now, remind_at) + await self.send_reminder(reminder, late) + else: + self.schedule_reminder(reminder) + + def ensure_valid_reminder( + self, + reminder: dict, + cancel_task: bool = True + ) -> t.Tuple[bool, discord.User, discord.TextChannel]: + """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" + user = self.bot.get_user(reminder['author']) + channel = self.bot.get_channel(reminder['channel_id']) + is_valid = True + if not user or not channel: + is_valid = False + log.info( + f"Reminder {reminder['id']} invalid: " + f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." + ) + asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) + + return is_valid, user, channel + + @staticmethod + async def _send_confirmation( + ctx: Context, + on_success: str, + reminder_id: str, + delivery_dt: t.Optional[datetime], + ) -> None: + """Send an embed confirming the reminder change was made successfully.""" + embed = discord.Embed() + embed.colour = discord.Colour.green() + embed.title = random.choice(POSITIVE_REPLIES) + embed.description = on_success + + footer_str = f"ID: {reminder_id}" + if delivery_dt: + # Reminder deletion will have a `None` `delivery_dt` + footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" + + embed.set_footer(text=footer_str) + + await ctx.send(embed=embed) + + @staticmethod + async def _check_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> t.Tuple[bool, str]: + """ + Returns whether or not the list of mentions is allowed. + + Conditions: + - Role reminders are Mods+ + - Reminders for other users are Helpers+ + + If mentions aren't allowed, also return the type of mention(s) disallowed. + """ + if without_role_check(ctx, *STAFF_ROLES): + return False, "members/roles" + elif without_role_check(ctx, *MODERATION_ROLES): + return all(isinstance(mention, discord.Member) for mention in mentions), "roles" + else: + return True, "" + + @staticmethod + async def validate_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> bool: + """ + Filter mentions to see if the user can mention, and sends a denial if not allowed. + + Returns whether or not the validation is successful. + """ + mentions_allowed, disallowed_mentions = await Reminders._check_mentions(ctx, mentions) + + if not mentions or mentions_allowed: + return True + else: + await send_denial(ctx, f"You can't mention other {disallowed_mentions} in your reminder!") + return False + + def get_mentionables(self, mention_ids: t.List[int]) -> t.Iterator[Mentionable]: + """Converts Role and Member ids to their corresponding objects if possible.""" + guild = self.bot.get_guild(Guild.id) + for mention_id in mention_ids: + if (mentionable := (guild.get_member(mention_id) or guild.get_role(mention_id))): + yield mentionable + + def schedule_reminder(self, reminder: dict) -> None: + """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" + reminder_id = reminder["id"] + reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) + + async def _remind() -> None: + await self.send_reminder(reminder) + + log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") + await self._delete_reminder(reminder_id) + + self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) + + async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: + """Delete a reminder from the database, given its ID, and cancel the running task.""" + await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) + + if cancel_task: + # Now we can remove it from the schedule list + self.scheduler.cancel(reminder_id) + + async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: + """ + Edits a reminder in the database given the ID and payload. + + Returns the edited reminder. + """ + # Send the request to update the reminder in the database + reminder = await self.bot.api_client.patch( + 'bot/reminders/' + str(reminder_id), + json=payload + ) + return reminder + + async def _reschedule_reminder(self, reminder: dict) -> None: + """Reschedule a reminder object.""" + log.trace(f"Cancelling old task #{reminder['id']}") + self.scheduler.cancel(reminder["id"]) + + log.trace(f"Scheduling new task #{reminder['id']}") + self.schedule_reminder(reminder) + + async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: + """Send the reminder.""" + is_valid, user, channel = self.ensure_valid_reminder(reminder) + if not is_valid: + return + + embed = discord.Embed() + embed.colour = discord.Colour.blurple() + embed.set_author( + icon_url=Icons.remind_blurple, + name="It has arrived!" + ) + + embed.description = f"Here's your reminder: `{reminder['content']}`." + + if reminder.get("jump_url"): # keep backward compatibility + embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" + + if late: + embed.colour = discord.Colour.red() + embed.set_author( + icon_url=Icons.remind_red, + name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" + ) + + additional_mentions = ' '.join( + mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) + ) + + await channel.send( + content=f"{user.mention} {additional_mentions}", + embed=embed + ) + await self._delete_reminder(reminder["id"]) + + @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) + async def remind_group( + self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + ) -> None: + """Commands for managing your reminders.""" + await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content) + + @remind_group.command(name="new", aliases=("add", "create")) + async def new_reminder( + self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + ) -> None: + """ + Set yourself a simple reminder. + + Expiration is parsed per: http://strftime.org/ + """ + # If the user is not staff, we need to verify whether or not to make a reminder at all. + if without_role_check(ctx, *STAFF_ROLES): + + # If they don't have permission to set a reminder in this channel + if ctx.channel.id not in WHITELISTED_CHANNELS: + await send_denial(ctx, "Sorry, you can't do that here!") + return + + # Get their current active reminders + active_reminders = await self.bot.api_client.get( + 'bot/reminders', + params={ + 'author__id': str(ctx.author.id) + } + ) + + # Let's limit this, so we don't get 10 000 + # reminders from kip or something like that :P + if len(active_reminders) > MAXIMUM_REMINDERS: + await send_denial(ctx, "You have too many active reminders!") + return + + # Remove duplicate mentions + mentions = set(mentions) + mentions.discard(ctx.author) + + # Filter mentions to see if the user can mention members/roles + if not await self.validate_mentions(ctx, mentions): + return + + mention_ids = [mention.id for mention in mentions] + + # Now we can attempt to actually set the reminder. + reminder = await self.bot.api_client.post( + 'bot/reminders', + json={ + 'author': ctx.author.id, + 'channel_id': ctx.message.channel.id, + 'jump_url': ctx.message.jump_url, + 'content': content, + 'expiration': expiration.isoformat(), + 'mentions': mention_ids, + } + ) + + now = datetime.utcnow() - timedelta(seconds=1) + humanized_delta = humanize_delta(relativedelta(expiration, now)) + mention_string = ( + f"Your reminder will arrive in {humanized_delta} " + f"and will mention {len(mentions)} other(s)!" + ) + + # Confirm to the user that it worked. + await self._send_confirmation( + ctx, + on_success=mention_string, + reminder_id=reminder["id"], + delivery_dt=expiration, + ) + + self.schedule_reminder(reminder) + + @remind_group.command(name="list") + async def list_reminders(self, ctx: Context) -> None: + """View a paginated embed of all reminders for your user.""" + # Get all the user's reminders from the database. + data = await self.bot.api_client.get( + 'bot/reminders', + params={'author__id': str(ctx.author.id)} + ) + + now = datetime.utcnow() + + # Make a list of tuples so it can be sorted by time. + reminders = sorted( + ( + (rem['content'], rem['expiration'], rem['id'], rem['mentions']) + for rem in data + ), + key=itemgetter(1) + ) + + lines = [] + + for content, remind_at, id_, mentions in reminders: + # Parse and humanize the time, make it pretty :D + remind_datetime = isoparse(remind_at).replace(tzinfo=None) + time = humanize_delta(relativedelta(remind_datetime, now)) + + mentions = ", ".join( + # Both Role and User objects have the `name` attribute + mention.name for mention in self.get_mentionables(mentions) + ) + mention_string = f"\n**Mentions:** {mentions}" if mentions else "" + + text = textwrap.dedent(f""" + **Reminder #{id_}:** *expires in {time}* (ID: {id_}){mention_string} + {content} + """).strip() + + lines.append(text) + + embed = discord.Embed() + embed.colour = discord.Colour.blurple() + embed.title = f"Reminders for {ctx.author}" + + # Remind the user that they have no reminders :^) + if not lines: + embed.description = "No active reminders could be found." + await ctx.send(embed=embed) + return + + # Construct the embed and paginate it. + embed.colour = discord.Colour.blurple() + + await LinePaginator.paginate( + lines, + ctx, embed, + max_lines=3, + empty=True + ) + + @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) + async def edit_reminder_group(self, ctx: Context) -> None: + """Commands for modifying your current reminders.""" + await ctx.send_help(ctx.command) + + @edit_reminder_group.command(name="duration", aliases=("time",)) + async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: + """ + Edit one of your reminder's expiration. + + Expiration is parsed per: http://strftime.org/ + """ + await self.edit_reminder(ctx, id_, {'expiration': expiration.isoformat()}) + + @edit_reminder_group.command(name="content", aliases=("reason",)) + async def edit_reminder_content(self, ctx: Context, id_: int, *, content: str) -> None: + """Edit one of your reminder's content.""" + await self.edit_reminder(ctx, id_, {"content": content}) + + @edit_reminder_group.command(name="mentions", aliases=("pings",)) + async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: + """Edit one of your reminder's mentions.""" + # Remove duplicate mentions + mentions = set(mentions) + mentions.discard(ctx.author) + + # Filter mentions to see if the user can mention members/roles + if not await self.validate_mentions(ctx, mentions): + return + + mention_ids = [mention.id for mention in mentions] + await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) + + async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: + """Edits a reminder with the given payload, then sends a confirmation message.""" + reminder = await self._edit_reminder(id_, payload) + + # Parse the reminder expiration back into a datetime + expiration = isoparse(reminder["expiration"]).replace(tzinfo=None) + + # Send a confirmation message to the channel + await self._send_confirmation( + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=id_, + delivery_dt=expiration, + ) + await self._reschedule_reminder(reminder) + + @remind_group.command("delete", aliases=("remove", "cancel")) + async def delete_reminder(self, ctx: Context, id_: int) -> None: + """Delete one of your active reminders.""" + await self._delete_reminder(id_) + await self._send_confirmation( + ctx, + on_success="That reminder has been deleted successfully!", + reminder_id=id_, + delivery_dt=None, + ) + + +def setup(bot: Bot) -> None: + """Load the Reminders cog.""" + bot.add_cog(Reminders(bot)) diff --git a/bot/cogs/utils/snekbox.py b/bot/cogs/utils/snekbox.py new file mode 100644 index 000000000..52c8b6f88 --- /dev/null +++ b/bot/cogs/utils/snekbox.py @@ -0,0 +1,349 @@ +import asyncio +import contextlib +import datetime +import logging +import re +import textwrap +from functools import partial +from signal import Signals +from typing import Optional, Tuple + +from discord import HTTPException, Message, NotFound, Reaction, User +from discord.ext.commands import Cog, Context, command, guild_only + +from bot.bot import Bot +from bot.constants import Categories, Channels, Roles, URLs +from bot.decorators import in_whitelist +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") +FORMATTED_CODE_REGEX = re.compile( + r"^\s*" # any leading whitespace from the beginning of the string + r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block + r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) + r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all code inside the markup + r"\s*" # any more whitespace before the end of the code markup + r"(?P=delim)" # match the exact same delimiter from the start again + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive +) +RAW_CODE_REGEX = re.compile( + r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all the rest as code + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL # "." also matches newlines +) + +MAX_PASTE_LEN = 1000 + +# `!eval` command whitelists +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) +EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) +EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) + +SIGKILL = 9 + +REEVAL_EMOJI = '\U0001f501' # :repeat: +REEVAL_TIMEOUT = 30 + + +class Snekbox(Cog): + """Safe evaluation of Python code using Snekbox.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.jobs = {} + + async def post_eval(self, code: str) -> dict: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + url = URLs.snekbox_eval_api + data = {"input": code} + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return await resp.json() + + async def upload_output(self, output: str) -> Optional[str]: + """Upload the eval output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") + + if len(output) > MAX_PASTE_LEN: + log.info("Full output is too long to upload") + return "too long to upload" + + url = URLs.paste_service.format(key="documents") + try: + async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: + data = await resp.json() + + if "key" in data: + return URLs.paste_service.format(key=data["key"]) + except Exception: + # 400 (Bad Request) means there are too many characters + log.exception("Failed to upload full output to paste service!") + + @staticmethod + def prepare_input(code: str) -> str: + """Extract code from the Markdown, format it, and insert it into the code template.""" + match = FORMATTED_CODE_REGEX.fullmatch(code) + if match: + code, block, lang, delim = match.group("code", "block", "lang", "delim") + code = textwrap.dedent(code) + if block: + info = (f"'{lang}' highlighted" if lang else "plain") + " code block" + else: + info = f"{delim}-enclosed inline code" + log.trace(f"Extracted {info} for evaluation:\n{code}") + else: + code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) + log.trace( + f"Eval message contains unformatted or badly formatted code, " + f"stripping whitespace only:\n{code}" + ) + + return code + + @staticmethod + def get_results_message(results: dict) -> Tuple[str, str]: + """Return a user-friendly message and error corresponding to the process's return code.""" + stdout, returncode = results["stdout"], results["returncode"] + msg = f"Your eval job has completed with return code {returncode}" + error = "" + + if returncode is None: + msg = "Your eval job has failed" + error = stdout.strip() + elif returncode == 128 + SIGKILL: + msg = "Your eval job timed out or ran out of memory" + elif returncode == 255: + msg = "Your eval job has failed" + error = "A fatal NsJail error occurred" + else: + # Try to append signal's name if one exists + try: + name = Signals(returncode - 128).name + msg = f"{msg} ({name})" + except ValueError: + pass + + return msg, error + + @staticmethod + def get_status_emoji(results: dict) -> str: + """Return an emoji corresponding to the status code or lack of output in result.""" + if not results["stdout"].strip(): # No output + return ":warning:" + elif results["returncode"] == 0: # No error + return ":white_check_mark:" + else: # Exception + return ":x:" + + async def format_output(self, output: str) -> Tuple[str, Optional[str]]: + """ + Format the output and return a tuple of the formatted output and a URL to the full output. + + Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters + and upload the full output to a paste service. + """ + log.trace("Formatting output...") + + output = output.rstrip("\n") + original_output = output # To be uploaded to a pasting service if needed + paste_link = None + + if "<@" in output: + output = output.replace("<@", "<@\u200B") # Zero-width space + + if " 0: + output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] + output = output[:11] # Limiting to only 11 lines + output = "\n".join(output) + + if lines > 10: + truncated = True + if len(output) >= 1000: + output = f"{output[:1000]}\n... (truncated - too long, too many lines)" + else: + output = f"{output}\n... (truncated - too many lines)" + elif len(output) >= 1000: + truncated = True + output = f"{output[:1000]}\n... (truncated - too long)" + + if truncated: + paste_link = await self.upload_output(original_output) + + output = output or "[No output]" + + return output, paste_link + + async def send_eval(self, ctx: Context, code: str) -> Message: + """ + Evaluate code, format it, and send the output to the corresponding channel. + + Return the bot response. + """ + async with ctx.typing(): + results = await self.post_eval(code) + msg, error = self.get_results_message(results) + + if error: + output, paste_link = error, None + else: + output, paste_link = await self.format_output(results["stdout"]) + + icon = self.get_status_emoji(results) + msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" + if paste_link: + msg = f"{msg}\nFull output: {paste_link}" + + # Collect stats of eval fails + successes + if icon == ":x:": + self.bot.stats.incr("snekbox.python.fail") + else: + self.bot.stats.incr("snekbox.python.success") + + filter_cog = self.bot.get_cog("Filtering") + filter_triggered = False + if filter_cog: + filter_triggered = await filter_cog.filter_eval(msg, ctx.message) + if filter_triggered: + response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + else: + response = await ctx.send(msg) + self.bot.loop.create_task( + wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) + ) + + log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") + return response + + async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + """ + Check if the eval session should continue. + + Return the new code to evaluate or None if the eval session should be terminated. + """ + _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + + with contextlib.suppress(NotFound): + try: + _, new_message = await self.bot.wait_for( + 'message_edit', + check=_predicate_eval_message_edit, + timeout=REEVAL_TIMEOUT + ) + await ctx.message.add_reaction(REEVAL_EMOJI) + await self.bot.wait_for( + 'reaction_add', + check=_predicate_emoji_reaction, + timeout=10 + ) + + code = await self.get_code(new_message) + await ctx.message.clear_reactions() + with contextlib.suppress(HTTPException): + await response.delete() + + except asyncio.TimeoutError: + await ctx.message.clear_reactions() + return None + + return code + + async def get_code(self, message: Message) -> Optional[str]: + """ + Return the code from `message` to be evaluated. + + If the message is an invocation of the eval command, return the first argument or None if it + doesn't exist. Otherwise, return the full content of the message. + """ + log.trace(f"Getting context for message {message.id}.") + new_ctx = await self.bot.get_context(message) + + if new_ctx.command is self.eval_command: + log.trace(f"Message {message.id} invokes eval command.") + split = message.content.split(maxsplit=1) + code = split[1] if len(split) > 1 else None + else: + log.trace(f"Message {message.id} does not invoke eval command.") + code = message.content + + return code + + @command(name="eval", aliases=("e",)) + @guild_only() + @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) + async def eval_command(self, ctx: Context, *, code: str = None) -> None: + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + if ctx.author.id in self.jobs: + await ctx.send( + f"{ctx.author.mention} You've already got a job running - " + "please wait for it to finish!" + ) + return + + if not code: # None or empty string + await ctx.send_help(ctx.command) + return + + if Roles.helpers in (role.id for role in ctx.author.roles): + self.bot.stats.incr("snekbox_usages.roles.helpers") + else: + self.bot.stats.incr("snekbox_usages.roles.developers") + + if ctx.channel.category_id == Categories.help_in_use: + self.bot.stats.incr("snekbox_usages.channels.help") + elif ctx.channel.id == Channels.bot_commands: + self.bot.stats.incr("snekbox_usages.channels.bot_commands") + else: + self.bot.stats.incr("snekbox_usages.channels.topical") + + log.info(f"Received code from {ctx.author} for evaluation:\n{code}") + + while True: + self.jobs[ctx.author.id] = datetime.datetime.now() + code = self.prepare_input(code) + try: + response = await self.send_eval(ctx, code) + finally: + del self.jobs[ctx.author.id] + + code = await self.continue_eval(ctx, response) + if not code: + break + log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") + + +def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: + """Return True if the edited message is the context message and the content was indeed modified.""" + return new_msg.id == ctx.message.id and old_msg.content != new_msg.content + + +def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI + + +def setup(bot: Bot) -> None: + """Load the Snekbox cog.""" + bot.add_cog(Snekbox(bot)) diff --git a/bot/cogs/utils/utils.py b/bot/cogs/utils/utils.py new file mode 100644 index 000000000..d96abbd5a --- /dev/null +++ b/bot/cogs/utils/utils.py @@ -0,0 +1,265 @@ +import difflib +import logging +import re +import unicodedata +from email.parser import HeaderParser +from io import StringIO +from typing import Tuple, Union + +from discord import Colour, Embed, utils +from discord.ext.commands import BadArgument, Cog, Context, clean_content, command + +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES +from bot.decorators import in_whitelist, with_role +from bot.pagination import LinePaginator +from bot.utils import messages + +log = logging.getLogger(__name__) + +ZEN_OF_PYTHON = """\ +Beautiful is better than ugly. +Explicit is better than implicit. +Simple is better than complex. +Complex is better than complicated. +Flat is better than nested. +Sparse is better than dense. +Readability counts. +Special cases aren't special enough to break the rules. +Although practicality beats purity. +Errors should never pass silently. +Unless explicitly silenced. +In the face of ambiguity, refuse the temptation to guess. +There should be one-- and preferably only one --obvious way to do it. +Although that way may not be obvious at first unless you're Dutch. +Now is better than never. +Although never is often better than *right* now. +If the implementation is hard to explain, it's a bad idea. +If the implementation is easy to explain, it may be a good idea. +Namespaces are one honking great idea -- let's do more of those! +""" + +ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" + + +class Utils(Cog): + """A selection of utilities which don't have a clear category.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.base_pep_url = "http://www.python.org/dev/peps/pep-" + self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" + + @command(name='pep', aliases=('get_pep', 'p')) + async def pep_command(self, ctx: Context, pep_number: str) -> None: + """Fetches information about a PEP and sends it to the channel.""" + if pep_number.isdigit(): + pep_number = int(pep_number) + else: + await ctx.send_help(ctx.command) + return + + # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. + if pep_number == 0: + return await self.send_pep_zero(ctx) + + possible_extensions = ['.txt', '.rst'] + found_pep = False + for extension in possible_extensions: + # Attempt to fetch the PEP + pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" + log.trace(f"Requesting PEP {pep_number} with {pep_url}") + response = await self.bot.http_session.get(pep_url) + + if response.status == 200: + log.trace("PEP found") + found_pep = True + + pep_content = await response.text() + + # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 + pep_header = HeaderParser().parse(StringIO(pep_content)) + + # Assemble the embed + pep_embed = Embed( + title=f"**PEP {pep_number} - {pep_header['Title']}**", + description=f"[Link]({self.base_pep_url}{pep_number:04})", + ) + + pep_embed.set_thumbnail(url=ICON_URL) + + # Add the interesting information + fields_to_check = ("Status", "Python-Version", "Created", "Type") + for field in fields_to_check: + # Check for a PEP metadata field that is present but has an empty value + # embed field values can't contain an empty string + if pep_header.get(field, ""): + pep_embed.add_field(name=field, value=pep_header[field]) + + elif response.status != 404: + # any response except 200 and 404 is expected + found_pep = True # actually not, but it's easier to display this way + log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " + f"{response.status}.\n{response.text}") + + error_message = "Unexpected HTTP error during PEP search. Please let us know." + pep_embed = Embed(title="Unexpected error", description=error_message) + pep_embed.colour = Colour.red() + break + + if not found_pep: + log.trace("PEP was not found") + not_found = f"PEP {pep_number} does not exist." + pep_embed = Embed(title="PEP not found", description=not_found) + pep_embed.colour = Colour.red() + + await ctx.message.channel.send(embed=pep_embed) + + @command() + @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) + async def charinfo(self, ctx: Context, *, characters: str) -> None: + """Shows you information on up to 50 unicode characters.""" + match = re.match(r"<(a?):(\w+):(\d+)>", characters) + if match: + return await messages.send_denial( + ctx, + "**Non-Character Detected**\n" + "Only unicode characters can be processed, but a custom Discord emoji " + "was found. Please remove it and try again." + ) + + if len(characters) > 50: + return await messages.send_denial(ctx, f"Too many characters ({len(characters)}/50)") + + def get_info(char: str) -> Tuple[str, str]: + digit = f"{ord(char):x}" + if len(digit) <= 4: + u_code = f"\\u{digit:>04}" + else: + u_code = f"\\U{digit:>08}" + url = f"https://www.compart.com/en/unicode/U+{digit:>04}" + name = f"[{unicodedata.name(char, '')}]({url})" + info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}" + return info, u_code + + char_list, raw_list = zip(*(get_info(c) for c in characters)) + embed = Embed().set_author(name="Character Info") + + if len(characters) > 1: + # Maximum length possible is 502 out of 1024, so there's no need to truncate. + embed.add_field(name='Full Raw Text', value=f"`{''.join(raw_list)}`", inline=False) + + await LinePaginator.paginate(char_list, ctx, embed, max_lines=10, max_size=2000, empty=False) + + @command() + async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: + """ + Show the Zen of Python. + + Without any arguments, the full Zen will be produced. + If an integer is provided, the line with that index will be produced. + If a string is provided, the line which matches best will be produced. + """ + embed = Embed( + colour=Colour.blurple(), + title="The Zen of Python", + description=ZEN_OF_PYTHON + ) + + if search_value is None: + embed.title += ", by Tim Peters" + await ctx.send(embed=embed) + return + + zen_lines = ZEN_OF_PYTHON.splitlines() + + # handle if it's an index int + if isinstance(search_value, int): + upper_bound = len(zen_lines) - 1 + lower_bound = -1 * upper_bound + if not (lower_bound <= search_value <= upper_bound): + raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") + + embed.title += f" (line {search_value % len(zen_lines)}):" + embed.description = zen_lines[search_value] + await ctx.send(embed=embed) + return + + # Try to handle first exact word due difflib.SequenceMatched may use some other similar word instead + # exact word. + for i, line in enumerate(zen_lines): + for word in line.split(): + if word.lower() == search_value.lower(): + embed.title += f" (line {i}):" + embed.description = line + await ctx.send(embed=embed) + return + + # handle if it's a search string and not exact word + matcher = difflib.SequenceMatcher(None, search_value.lower()) + + best_match = "" + match_index = 0 + best_ratio = 0 + + for index, line in enumerate(zen_lines): + matcher.set_seq2(line.lower()) + + # the match ratio needs to be adjusted because, naturally, + # longer lines will have worse ratios than shorter lines when + # fuzzy searching for keywords. this seems to work okay. + adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() + + if adjusted_ratio > best_ratio: + best_ratio = adjusted_ratio + best_match = line + match_index = index + + if not best_match: + raise BadArgument("I didn't get a match! Please try again with a different search term.") + + embed.title += f" (line {match_index}):" + embed.description = best_match + await ctx.send(embed=embed) + + @command(aliases=("poll",)) + @with_role(*MODERATION_ROLES) + async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None: + """ + Build a quick voting poll with matching reactions with the provided options. + + A maximum of 20 options can be provided, as Discord supports a max of 20 + reactions on a single message. + """ + if len(title) > 256: + raise BadArgument("The title cannot be longer than 256 characters.") + if len(options) < 2: + raise BadArgument("Please provide at least 2 options.") + if len(options) > 20: + raise BadArgument("I can only handle 20 options!") + + codepoint_start = 127462 # represents "regional_indicator_a" unicode value + options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} + embed = Embed(title=title, description="\n".join(options.values())) + message = await ctx.send(embed=embed) + for reaction in options: + await message.add_reaction(reaction) + + async def send_pep_zero(self, ctx: Context) -> None: + """Send information about PEP 0.""" + pep_embed = Embed( + title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**", + description="[Link](https://www.python.org/dev/peps/)" + ) + pep_embed.set_thumbnail(url=ICON_URL) + pep_embed.add_field(name="Status", value="Active") + pep_embed.add_field(name="Created", value="13-Jul-2000") + pep_embed.add_field(name="Type", value="Informational") + + await ctx.send(embed=pep_embed) + + +def setup(bot: Bot) -> None: + """Load the Utils cog.""" + bot.add_cog(Utils(bot)) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py deleted file mode 100644 index ae156cf70..000000000 --- a/bot/cogs/verification.py +++ /dev/null @@ -1,191 +0,0 @@ -import logging -from contextlib import suppress - -from discord import Colour, Forbidden, Message, NotFound, Object -from discord.ext.commands import Cog, Context, command - -from bot import constants -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.decorators import in_whitelist, without_role -from bot.utils.checks import InWhitelistCheckFailure, without_role_check - -log = logging.getLogger(__name__) - -WELCOME_MESSAGE = f""" -Hello! Welcome to the server, and thanks for verifying yourself! - -For your records, these are the documents you accepted: - -`1)` Our rules, here: -`2)` Our privacy policy, here: - you can find information on how to have \ -your information removed here as well. - -Feel free to review them at any point! - -Additionally, if you'd like to receive notifications for the announcements \ -we post in <#{constants.Channels.announcements}> -from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ -to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. - -If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ -<#{constants.Channels.bot_commands}>. -""" - -BOT_MESSAGE_DELETE_DELAY = 10 - - -class Verification(Cog): - """User verification and role self-management.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Check new message event for messages to the checkpoint channel & process.""" - if message.channel.id != constants.Channels.verification: - return # Only listen for #checkpoint messages - - if message.author.bot: - # They're a bot, delete their message after the delay. - await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) - return - - # if a user mentions a role or guild member - # alert the mods in mod-alerts channel - if message.mentions or message.role_mentions: - log.debug( - f"{message.author} mentioned one or more users " - f"and/or roles in {message.channel.name}" - ) - - embed_text = ( - f"{message.author.mention} sent a message in " - f"{message.channel.mention} that contained user and/or role mentions." - f"\n\n**Original message:**\n>>> {message.content}" - ) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=constants.Icons.filtering, - colour=Colour(constants.Colours.soft_red), - title=f"User/Role mentioned in {message.channel.name}", - text=embed_text, - thumbnail=message.author.avatar_url_as(static_format="png"), - channel_id=constants.Channels.mod_alerts, - ) - - ctx: Context = await self.bot.get_context(message) - if ctx.command is not None and ctx.command.name == "accept": - return - - if any(r.id == constants.Roles.verified for r in ctx.author.roles): - log.info( - f"{ctx.author} posted '{ctx.message.content}' " - "in the verification channel, but is already verified." - ) - return - - log.debug( - f"{ctx.author} posted '{ctx.message.content}' in the verification " - "channel. We are providing instructions how to verify." - ) - await ctx.send( - f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " - f"and gain access to the rest of the server.", - delete_after=20 - ) - - log.trace(f"Deleting the message posted by {ctx.author}") - with suppress(NotFound): - await ctx.message.delete() - - @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) - @without_role(constants.Roles.verified) - @in_whitelist(channels=(constants.Channels.verification,)) - async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Accept our rules and gain access to the rest of the server.""" - log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") - await ctx.author.add_roles(Object(constants.Roles.verified), reason="Accepted the rules") - try: - await ctx.author.send(WELCOME_MESSAGE) - except Forbidden: - log.info(f"Sending welcome message failed for {ctx.author}.") - finally: - log.trace(f"Deleting accept message by {ctx.author}.") - with suppress(NotFound): - self.mod_log.ignore(constants.Event.message_delete, ctx.message.id) - await ctx.message.delete() - - @command(name='subscribe') - @in_whitelist(channels=(constants.Channels.bot_commands,)) - async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Subscribe to announcement notifications by assigning yourself the role.""" - has_role = False - - for role in ctx.author.roles: - if role.id == constants.Roles.announcements: - has_role = True - break - - if has_role: - await ctx.send(f"{ctx.author.mention} You're already subscribed!") - return - - log.debug(f"{ctx.author} called !subscribe. Assigning the 'Announcements' role.") - await ctx.author.add_roles(Object(constants.Roles.announcements), reason="Subscribed to announcements") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - await ctx.send( - f"{ctx.author.mention} Subscribed to <#{constants.Channels.announcements}> notifications.", - ) - - @command(name='unsubscribe') - @in_whitelist(channels=(constants.Channels.bot_commands,)) - async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Unsubscribe from announcement notifications by removing the role from yourself.""" - has_role = False - - for role in ctx.author.roles: - if role.id == constants.Roles.announcements: - has_role = True - break - - if not has_role: - await ctx.send(f"{ctx.author.mention} You're already unsubscribed!") - return - - log.debug(f"{ctx.author} called !unsubscribe. Removing the 'Announcements' role.") - await ctx.author.remove_roles(Object(constants.Roles.announcements), reason="Unsubscribed from announcements") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - await ctx.send( - f"{ctx.author.mention} Unsubscribed from <#{constants.Channels.announcements}> notifications." - ) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Check for & ignore any InWhitelistCheckFailure.""" - if isinstance(error, InWhitelistCheckFailure): - error.handled = True - - @staticmethod - def bot_check(ctx: Context) -> bool: - """Block any command within the verification channel that is not !accept.""" - if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): - return ctx.command.name == "accept" - else: - return True - - -def setup(bot: Bot) -> None: - """Load the Verification cog.""" - bot.add_cog(Verification(bot)) diff --git a/bot/cogs/watchchannels/__init__.py b/bot/cogs/watchchannels/__init__.py deleted file mode 100644 index 69d118df6..000000000 --- a/bot/cogs/watchchannels/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from bot.bot import Bot -from .bigbrother import BigBrother -from .talentpool import TalentPool - - -def setup(bot: Bot) -> None: - """Load the BigBrother and TalentPool cogs.""" - bot.add_cog(BigBrother(bot)) - bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py deleted file mode 100644 index 4d27a6333..000000000 --- a/bot/cogs/watchchannels/bigbrother.py +++ /dev/null @@ -1,165 +0,0 @@ -import logging -import textwrap -from collections import ChainMap - -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation.utils import post_infraction -from bot.constants import Channels, MODERATION_ROLES, Webhooks -from bot.converters import FetchedMember -from bot.decorators import with_role -from .watchchannel import WatchChannel - -log = logging.getLogger(__name__) - - -class BigBrother(WatchChannel, Cog, name="Big Brother"): - """Monitors users by relaying their messages to a watch channel to assist with moderation.""" - - def __init__(self, bot: Bot) -> None: - super().__init__( - bot, - destination=Channels.big_brother_logs, - webhook_id=Webhooks.big_brother, - api_endpoint='bot/infractions', - api_default_params={'active': 'true', 'type': 'watch', 'ordering': '-inserted_at'}, - logger=log - ) - - @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def bigbrother_group(self, ctx: Context) -> None: - """Monitors users by relaying their messages to the Big Brother watch channel.""" - await ctx.send_help(ctx.command) - - @bigbrother_group.command(name='watched', aliases=('all', 'list')) - @with_role(*MODERATION_ROLES) - async def watched_command( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Shows the users that are currently being monitored by Big Brother. - - The optional kwarg `oldest_first` can be used to order the list by oldest watched. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) - - @bigbrother_group.command(name='oldest') - @with_role(*MODERATION_ROLES) - async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: - """ - Shows Big Brother monitored users ordered by oldest watched. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) - - @bigbrother_group.command(name='watch', aliases=('w',)) - @with_role(*MODERATION_ROLES) - async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Relay messages sent by the given `user` to the `#big-brother` channel. - - A `reason` for adding the user to Big Brother is required and will be displayed - in the header when relaying messages of this user to the watchchannel. - """ - await self.apply_watch(ctx, user, reason) - - @bigbrother_group.command(name='unwatch', aliases=('uw',)) - @with_role(*MODERATION_ROLES) - async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Stop relaying messages by the given `user`.""" - await self.apply_unwatch(ctx, user, reason) - - async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: - """ - Add `user` to watched users and apply a watch infraction with `reason`. - - A message indicating the result of the operation is sent to `ctx`. - The message will include `user`'s previous watch infraction history, if it exists. - """ - if user.bot: - await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") - return - - if not await self.fetch_user_cache(): - await ctx.send(f":x: Updating the user cache failed, can't watch user {user}") - return - - if user.id in self.watched_users: - await ctx.send(f":x: {user} is already being watched.") - return - - response = await post_infraction(ctx, user, 'watch', reason, hidden=True, active=True) - - if response is not None: - self.watched_users[user.id] = response - msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother." - - history = await self.bot.api_client.get( - self.api_endpoint, - params={ - "user__id": str(user.id), - "active": "false", - 'type': 'watch', - 'ordering': '-inserted_at' - } - ) - - if len(history) > 1: - total = f"({len(history) // 2} previous infractions in total)" - end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") - start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" - msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" - else: - msg = ":x: Failed to post the infraction: response was empty." - - await ctx.send(msg) - - async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: - """ - Remove `user` from watched users and mark their infraction as inactive with `reason`. - - If `send_message` is True, a message indicating the result of the operation is sent to - `ctx`. - """ - active_watches = await self.bot.api_client.get( - self.api_endpoint, - params=ChainMap( - self.api_default_params, - {"user__id": str(user.id)} - ) - ) - if active_watches: - log.trace("Active watches for user found. Attempting to remove.") - [infraction] = active_watches - - await self.bot.api_client.patch( - f"{self.api_endpoint}/{infraction['id']}", - json={'active': False} - ) - - await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False) - - self._remove_user(user.id) - - if not send_message: # Prevents a message being sent to the channel if part of a permanent ban - log.debug(f"Perma-banned user {user} was unwatched.") - return - log.trace("User is not banned. Sending message to channel") - message = f":white_check_mark: Messages sent by {user} will no longer be relayed." - - else: - log.trace("No active watches found for user.") - if not send_message: # Prevents a message being sent to the channel if part of a permanent ban - log.debug(f"{user} was not on the watch list; no removal necessary.") - return - log.trace("User is not perma banned. Send the error message.") - message = ":x: The specified user is currently not being watched." - - await ctx.send(message) diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py deleted file mode 100644 index 89256e92e..000000000 --- a/bot/cogs/watchchannels/talentpool.py +++ /dev/null @@ -1,264 +0,0 @@ -import logging -import textwrap -from collections import ChainMap - -from discord import Color, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks -from bot.converters import FetchedMember -from bot.decorators import with_role -from bot.pagination import LinePaginator -from bot.utils import time -from .watchchannel import WatchChannel - -log = logging.getLogger(__name__) - - -class TalentPool(WatchChannel, Cog, name="Talentpool"): - """Relays messages of helper candidates to a watch channel to observe them.""" - - def __init__(self, bot: Bot) -> None: - super().__init__( - bot, - destination=Channels.talent_pool, - webhook_id=Webhooks.talent_pool, - api_endpoint='bot/nominations', - api_default_params={'active': 'true', 'ordering': '-inserted_at'}, - logger=log, - ) - - @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def nomination_group(self, ctx: Context) -> None: - """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" - await ctx.send_help(ctx.command) - - @nomination_group.command(name='watched', aliases=('all', 'list')) - @with_role(*MODERATION_ROLES) - async def watched_command( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Shows the users that are currently being monitored in the talent pool. - - The optional kwarg `oldest_first` can be used to order the list by oldest nomination. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) - - @nomination_group.command(name='oldest') - @with_role(*MODERATION_ROLES) - async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: - """ - Shows talent pool monitored users ordered by oldest nomination. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) - - @nomination_group.command(name='watch', aliases=('w', 'add', 'a')) - @with_role(*STAFF_ROLES) - async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Relay messages sent by the given `user` to the `#talent-pool` channel. - - A `reason` for adding the user to the talent pool is required and will be displayed - in the header when relaying messages of this user to the channel. - """ - if user.bot: - await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") - return - - if isinstance(user, Member) and any(role.id in STAFF_ROLES for role in user.roles): - await ctx.send(":x: Nominating staff members, eh? Here's a cookie :cookie:") - return - - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update the user cache; can't add {user}") - return - - if user.id in self.watched_users: - await ctx.send(f":x: {user} is already being watched in the talent pool") - return - - # Manual request with `raise_for_status` as False because we want the actual response - session = self.bot.api_client.session - url = self.bot.api_client._url_for(self.api_endpoint) - kwargs = { - 'json': { - 'actor': ctx.author.id, - 'reason': reason, - 'user': user.id - }, - 'raise_for_status': False, - } - async with session.post(url, **kwargs) as resp: - response_data = await resp.json() - - if resp.status == 400 and response_data.get('user', False): - await ctx.send(":x: The specified user can't be found in the database tables") - return - else: - resp.raise_for_status() - - self.watched_users[user.id] = response_data - msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel" - - history = await self.bot.api_client.get( - self.api_endpoint, - params={ - "user__id": str(user.id), - "active": "false", - "ordering": "-inserted_at" - } - ) - - if history: - total = f"({len(history)} previous nominations in total)" - start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" - end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" - msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" - - await ctx.send(msg) - - @nomination_group.command(name='history', aliases=('info', 'search')) - @with_role(*MODERATION_ROLES) - async def history_command(self, ctx: Context, user: FetchedMember) -> None: - """Shows the specified user's nomination history.""" - result = await self.bot.api_client.get( - self.api_endpoint, - params={ - 'user__id': str(user.id), - 'ordering': "-active,-inserted_at" - } - ) - if not result: - await ctx.send(":warning: This user has never been nominated") - return - - embed = Embed( - title=f"Nominations for {user.display_name} `({user.id})`", - color=Color.blue() - ) - lines = [self._nomination_to_string(nomination) for nomination in result] - await LinePaginator.paginate( - lines, - ctx=ctx, - embed=embed, - empty=True, - max_lines=3, - max_size=1000 - ) - - @nomination_group.command(name='unwatch', aliases=('end', )) - @with_role(*MODERATION_ROLES) - async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Ends the active nomination of the specified user with the given reason. - - Providing a `reason` is required. - """ - active_nomination = await self.bot.api_client.get( - self.api_endpoint, - params=ChainMap( - self.api_default_params, - {"user__id": str(user.id)} - ) - ) - - if not active_nomination: - await ctx.send(":x: The specified user does not have an active nomination") - return - - [nomination] = active_nomination - await self.bot.api_client.patch( - f"{self.api_endpoint}/{nomination['id']}", - json={'end_reason': reason, 'active': False} - ) - await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") - self._remove_user(user.id) - - @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def nomination_edit_group(self, ctx: Context) -> None: - """Commands to edit nominations.""" - await ctx.send_help(ctx.command) - - @nomination_edit_group.command(name='reason') - @with_role(*MODERATION_ROLES) - async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None: - """ - Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. - - If the nomination is active, the reason for nominating the user will be edited; - If the nomination is no longer active, the reason for ending the nomination will be edited instead. - """ - try: - nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") - except ResponseCodeError as e: - if e.response.status == 404: - self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") - await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") - return - else: - raise - - field = "reason" if nomination["active"] else "end_reason" - - self.log.trace(f"Changing {field} for nomination with id {nomination_id} to {reason}") - - await self.bot.api_client.patch( - f"{self.api_endpoint}/{nomination_id}", - json={field: reason} - ) - - await ctx.send(f":white_check_mark: Updated the {field} of the nomination!") - - def _nomination_to_string(self, nomination_object: dict) -> str: - """Creates a string representation of a nomination.""" - guild = self.bot.get_guild(Guild.id) - - actor_id = nomination_object["actor"] - actor = guild.get_member(actor_id) - - active = nomination_object["active"] - log.debug(active) - log.debug(type(nomination_object["inserted_at"])) - - start_date = time.format_infraction(nomination_object["inserted_at"]) - if active: - lines = textwrap.dedent( - f""" - =============== - Status: **Active** - Date: {start_date} - Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} - Nomination ID: `{nomination_object["id"]}` - =============== - """ - ) - else: - end_date = time.format_infraction(nomination_object["ended_at"]) - lines = textwrap.dedent( - f""" - =============== - Status: Inactive - Date: {start_date} - Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} - - End date: {end_date} - Unwatch reason: {nomination_object["end_reason"]} - Nomination ID: `{nomination_object["id"]}` - =============== - """ - ) - - return lines.strip() diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py deleted file mode 100644 index 044077350..000000000 --- a/bot/cogs/watchchannels/watchchannel.py +++ /dev/null @@ -1,348 +0,0 @@ -import asyncio -import logging -import re -import textwrap -from abc import abstractmethod -from collections import defaultdict, deque -from dataclasses import dataclass -from typing import Optional - -import dateutil.parser -import discord -from discord import Color, DMChannel, Embed, HTTPException, Message, errors -from discord.ext.commands import Cog, Context - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons -from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - -URL_RE = re.compile(r"(https?://[^\s]+)") - - -@dataclass -class MessageHistory: - """Represents a watch channel's message history.""" - - last_author: Optional[int] = None - last_channel: Optional[int] = None - message_count: int = 0 - - -class WatchChannel(metaclass=CogABCMeta): - """ABC with functionality for relaying users' messages to a certain channel.""" - - @abstractmethod - def __init__( - self, - bot: Bot, - destination: int, - webhook_id: int, - api_endpoint: str, - api_default_params: dict, - logger: logging.Logger - ) -> None: - self.bot = bot - - self.destination = destination # E.g., Channels.big_brother_logs - self.webhook_id = webhook_id # E.g., Webhooks.big_brother - self.api_endpoint = api_endpoint # E.g., 'bot/infractions' - self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} - self.log = logger # Logger of the child cog for a correct name in the logs - - self._consume_task = None - self.watched_users = defaultdict(dict) - self.message_queue = defaultdict(lambda: defaultdict(deque)) - self.consumption_queue = {} - self.retries = 5 - self.retry_delay = 10 - self.channel = None - self.webhook = None - self.message_history = MessageHistory() - - self._start = self.bot.loop.create_task(self.start_watchchannel()) - - @property - def modlog(self) -> ModLog: - """Provides access to the ModLog cog for alert purposes.""" - return self.bot.get_cog("ModLog") - - @property - def consuming_messages(self) -> bool: - """Checks if a consumption task is currently running.""" - if self._consume_task is None: - return False - - if self._consume_task.done(): - exc = self._consume_task.exception() - if exc: - self.log.exception( - "The message queue consume task has failed with:", - exc_info=exc - ) - return False - - return True - - async def start_watchchannel(self) -> None: - """Starts the watch channel by getting the channel, webhook, and user cache ready.""" - await self.bot.wait_until_guild_available() - - try: - self.channel = await self.bot.fetch_channel(self.destination) - except HTTPException: - self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - if self.channel is None or self.webhook is None: - self.log.error("Failed to start the watch channel; unloading the cog.") - - message = textwrap.dedent( - f""" - An error occurred while loading the text channel or webhook. - - TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} - Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} - - The Cog has been unloaded. - """ - ) - - await self.modlog.send_log_message( - title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", - text=message, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - if not await self.fetch_user_cache(): - await self.modlog.send_log_message( - title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", - text="Could not retrieve the list of watched users from the API and messages will not be relayed.", - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - async def fetch_user_cache(self) -> bool: - """ - Fetches watched users from the API and updates the watched user cache accordingly. - - This function returns `True` if the update succeeded. - """ - try: - data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) - except ResponseCodeError as err: - self.log.exception("Failed to fetch the watched users from the API", exc_info=err) - return False - - self.watched_users = defaultdict(dict) - - for entry in data: - user_id = entry.pop('user') - self.watched_users[user_id] = entry - - return True - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Queues up messages sent by watched users.""" - if msg.author.id in self.watched_users: - if not self.consuming_messages: - self._consume_task = self.bot.loop.create_task(self.consume_messages()) - - self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") - self.message_queue[msg.author.id][msg.channel.id].append(msg) - - async def consume_messages(self, delay_consumption: bool = True) -> None: - """Consumes the message queues to log watched users' messages.""" - if delay_consumption: - self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") - await asyncio.sleep(BigBrotherConfig.log_delay) - - self.log.trace("Started consuming the message queue") - - # If the previous consumption Task failed, first consume the existing comsumption_queue - if not self.consumption_queue: - self.consumption_queue = self.message_queue.copy() - self.message_queue.clear() - - for user_channel_queues in self.consumption_queue.values(): - for channel_queue in user_channel_queues.values(): - while channel_queue: - msg = channel_queue.popleft() - - self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") - await self.relay_message(msg) - - self.consumption_queue.clear() - - if self.message_queue: - self.log.trace("Channel queue not empty: Continuing consuming queues") - self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) - else: - self.log.trace("Done consuming messages.") - - async def webhook_send( - self, - content: Optional[str] = None, - username: Optional[str] = None, - avatar_url: Optional[str] = None, - embed: Optional[Embed] = None, - ) -> None: - """Sends a message to the webhook with the specified kwargs.""" - username = messages.sub_clyde(username) - try: - await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send a message to the webhook", - exc_info=exc - ) - - async def relay_message(self, msg: Message) -> None: - """Relays the message to the relevant watch channel.""" - limit = BigBrotherConfig.header_message_limit - - if ( - msg.author.id != self.message_history.last_author - or msg.channel.id != self.message_history.last_channel - or self.message_history.message_count >= limit - ): - self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) - - await self.send_header(msg) - - cleaned_content = msg.clean_content - - if cleaned_content: - # Put all non-media URLs in a code block to prevent embeds - media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} - for url in URL_RE.findall(cleaned_content): - if url not in media_urls: - cleaned_content = cleaned_content.replace(url, f"`{url}`") - await self.webhook_send( - cleaned_content, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - - if msg.attachments: - try: - await messages.send_attachments(msg, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await self.webhook_send( - embed=e, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send an attachment to the webhook", - exc_info=exc - ) - - self.message_history.message_count += 1 - - async def send_header(self, msg: Message) -> None: - """Sends a header embed with information about the relayed messages to the watch channel.""" - user_id = msg.author.id - - guild = self.bot.get_guild(GuildConfig.id) - actor = guild.get_member(self.watched_users[user_id]['actor']) - actor = actor.display_name if actor else self.watched_users[user_id]['actor'] - - inserted_at = self.watched_users[user_id]['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - - reason = self.watched_users[user_id]['reason'] - - if isinstance(msg.channel, DMChannel): - # If a watched user DMs the bot there won't be a channel name or jump URL - # This could technically include a GroupChannel but bot's can't be in those - message_jump = "via DM" - else: - message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" - - footer = f"Added {time_delta} by {actor} | Reason: {reason}" - embed = Embed(description=f"{msg.author.mention} {message_jump}") - embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) - - await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) - - async def list_watched_users( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Gives an overview of the watched user list for this channel. - - The optional kwarg `oldest_first` orders the list by oldest entry. - - The optional kwarg `update_cache` specifies whether the cache should - be refreshed by polling the API. - """ - if update_cache: - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") - update_cache = False - - lines = [] - for user_id, user_data in self.watched_users.items(): - inserted_at = user_data['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - lines.append(f"• <@{user_id}> (added {time_delta})") - - if oldest_first: - lines.reverse() - - lines = lines or ("There's nothing here yet.",) - - embed = Embed( - title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", - color=Color.blue() - ) - await LinePaginator.paginate(lines, ctx, embed, empty=False) - - @staticmethod - def _get_time_delta(time_string: str) -> str: - """Returns the time in human-readable time delta format.""" - date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) - time_delta = time_since(date_time, precision="minutes", max_units=1) - - return time_delta - - def _remove_user(self, user_id: int) -> None: - """Removes a user from a watch channel.""" - self.watched_users.pop(user_id, None) - self.message_queue.pop(user_id, None) - self.consumption_queue.pop(user_id, None) - - def cog_unload(self) -> None: - """Takes care of unloading the cog and canceling the consumption task.""" - self.log.trace("Unloading the cog") - if self._consume_task and not self._consume_task.done(): - self._consume_task.cancel() - try: - self._consume_task.result() - except asyncio.CancelledError as e: - self.log.exception( - "The consume task was canceled. Messages may be lost.", - exc_info=e - ) diff --git a/bot/cogs/webhook_remover.py b/bot/cogs/webhook_remover.py deleted file mode 100644 index 5812da87c..000000000 --- a/bot/cogs/webhook_remover.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging -import re - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Channels, Colours, Event, Icons - -WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) - -ALERT_MESSAGE_TEMPLATE = ( - "{user}, looks like you posted a Discord webhook URL. Therefore, your " - "message has been removed. Your webhook may have been **compromised** so " - "please re-create the webhook **immediately**. If you believe this was " - "mistake, please let us know." -) - -log = logging.getLogger(__name__) - - -class WebhookRemover(Cog): - """Scan messages to detect Discord webhooks links.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get current instance of `ModLog`.""" - return self.bot.get_cog("ModLog") - - async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: - """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" - # Don't log this, due internal delete, not by user. Will make different entry. - self.mod_log.ignore(Event.message_delete, msg.id) - - try: - await msg.delete() - except NotFound: - log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") - return - - await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) - - message = ( - f"{msg.author} (`{msg.author.id}`) posted a Discord webhook URL " - f"to #{msg.channel}. Webhook URL was `{redacted_url}`" - ) - log.debug(message) - - # Send entry to moderation alerts. - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Discord webhook URL removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts - ) - - self.bot.stats.incr("tokens.removed_webhooks") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Check if a Discord webhook URL is in `message`.""" - # Ignore DMs; can't delete messages in there anyway. - if not msg.guild or msg.author.bot: - return - - matches = WEBHOOK_URL_RE.search(msg.content) - if matches: - await self.delete_and_respond(msg, matches[1] + "xxx") - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """Check if a Discord webhook URL is in the edited message `after`.""" - await self.on_message(after) - - -def setup(bot: Bot) -> None: - """Load `WebhookRemover` cog.""" - bot.add_cog(WebhookRemover(bot)) diff --git a/bot/cogs/wolfram.py b/bot/cogs/wolfram.py deleted file mode 100644 index e6cae3bb8..000000000 --- a/bot/cogs/wolfram.py +++ /dev/null @@ -1,280 +0,0 @@ -import logging -from io import BytesIO -from typing import Callable, List, Optional, Tuple -from urllib import parse - -import discord -from dateutil.relativedelta import relativedelta -from discord import Embed -from discord.ext import commands -from discord.ext.commands import BucketType, Cog, Context, check, group - -from bot.bot import Bot -from bot.constants import Colours, STAFF_ROLES, Wolfram -from bot.pagination import ImagePaginator -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -APPID = Wolfram.key -DEFAULT_OUTPUT_FORMAT = "JSON" -QUERY = "http://api.wolframalpha.com/v2/{request}?{data}" -WOLF_IMAGE = "https://www.symbols.com/gi.php?type=1&id=2886&i=1" - -MAX_PODS = 20 - -# Allows for 10 wolfram calls pr user pr day -usercd = commands.CooldownMapping.from_cooldown(Wolfram.user_limit_day, 60*60*24, BucketType.user) - -# Allows for max api requests / days in month per day for the entire guild (Temporary) -guildcd = commands.CooldownMapping.from_cooldown(Wolfram.guild_limit_day, 60*60*24, BucketType.guild) - - -async def send_embed( - ctx: Context, - message_txt: str, - colour: int = Colours.soft_red, - footer: str = None, - img_url: str = None, - f: discord.File = None -) -> None: - """Generate & send a response embed with Wolfram as the author.""" - embed = Embed(colour=colour) - embed.description = message_txt - embed.set_author(name="Wolfram Alpha", - icon_url=WOLF_IMAGE, - url="https://www.wolframalpha.com/") - if footer: - embed.set_footer(text=footer) - - if img_url: - embed.set_image(url=img_url) - - await ctx.send(embed=embed, file=f) - - -def custom_cooldown(*ignore: List[int]) -> Callable: - """ - Implement per-user and per-guild cooldowns for requests to the Wolfram API. - - A list of roles may be provided to ignore the per-user cooldown - """ - async def predicate(ctx: Context) -> bool: - if ctx.invoked_with == 'help': - # if the invoked command is help we don't want to increase the ratelimits since it's not actually - # invoking the command/making a request, so instead just check if the user/guild are on cooldown. - guild_cooldown = not guildcd.get_bucket(ctx.message).get_tokens() == 0 # if guild is on cooldown - if not any(r.id in ignore for r in ctx.author.roles): # check user bucket if user is not ignored - return guild_cooldown and not usercd.get_bucket(ctx.message).get_tokens() == 0 - return guild_cooldown - - user_bucket = usercd.get_bucket(ctx.message) - - if all(role.id not in ignore for role in ctx.author.roles): - user_rate = user_bucket.update_rate_limit() - - if user_rate: - # Can't use api; cause: member limit - delta = relativedelta(seconds=int(user_rate)) - cooldown = humanize_delta(delta) - message = ( - "You've used up your limit for Wolfram|Alpha requests.\n" - f"Cooldown: {cooldown}" - ) - await send_embed(ctx, message) - return False - - guild_bucket = guildcd.get_bucket(ctx.message) - guild_rate = guild_bucket.update_rate_limit() - - # Repr has a token attribute to read requests left - log.debug(guild_bucket) - - if guild_rate: - # Can't use api; cause: guild limit - message = ( - "The max limit of requests for the server has been reached for today.\n" - f"Cooldown: {int(guild_rate)}" - ) - await send_embed(ctx, message) - return False - - return True - return check(predicate) - - -async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tuple]]: - """Get the Wolfram API pod pages for the provided query.""" - async with ctx.channel.typing(): - url_str = parse.urlencode({ - "input": query, - "appid": APPID, - "output": DEFAULT_OUTPUT_FORMAT, - "format": "image,plaintext" - }) - request_url = QUERY.format(request="query", data=url_str) - - async with bot.http_session.get(request_url) as response: - json = await response.json(content_type='text/plain') - - result = json["queryresult"] - - if result["error"]: - # API key not set up correctly - if result["error"]["msg"] == "Invalid appid": - message = "Wolfram API key is invalid or missing." - log.warning( - "API key seems to be missing, or invalid when " - f"processing a wolfram request: {url_str}, Response: {json}" - ) - await send_embed(ctx, message) - return - - message = "Something went wrong internally with your request, please notify staff!" - log.warning(f"Something went wrong getting a response from wolfram: {url_str}, Response: {json}") - await send_embed(ctx, message) - return - - if not result["success"]: - message = f"I couldn't find anything for {query}." - await send_embed(ctx, message) - return - - if not result["numpods"]: - message = "Could not find any results." - await send_embed(ctx, message) - return - - pods = result["pods"] - pages = [] - for pod in pods[:MAX_PODS]: - subs = pod.get("subpods") - - for sub in subs: - title = sub.get("title") or sub.get("plaintext") or sub.get("id", "") - img = sub["img"]["src"] - pages.append((title, img)) - return pages - - -class Wolfram(Cog): - """Commands for interacting with the Wolfram|Alpha API.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_command(self, ctx: Context, *, query: str) -> None: - """Requests all answers on a single image, sends an image of all related pods.""" - url_str = parse.urlencode({ - "i": query, - "appid": APPID, - }) - query = QUERY.format(request="simple", data=url_str) - - # Give feedback that the bot is working. - async with ctx.channel.typing(): - async with self.bot.http_session.get(query) as response: - status = response.status - image_bytes = await response.read() - - f = discord.File(BytesIO(image_bytes), filename="image.png") - image_url = "attachment://image.png" - - if status == 501: - message = "Failed to get response" - footer = "" - color = Colours.soft_red - elif status == 400: - message = "No input found" - footer = "" - color = Colours.soft_red - elif status == 403: - message = "Wolfram API key is invalid or missing." - footer = "" - color = Colours.soft_red - else: - message = "" - footer = "View original for a bigger picture." - color = Colours.soft_orange - - # Sends a "blank" embed if no request is received, unsure how to fix - await send_embed(ctx, message, color, footer=footer, img_url=image_url, f=f) - - @wolfram_command.command(name="page", aliases=("pa", "p")) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_page_command(self, ctx: Context, *, query: str) -> None: - """ - Requests a drawn image of given query. - - Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. - """ - pages = await get_pod_pages(ctx, self.bot, query) - - if not pages: - return - - embed = Embed() - embed.set_author(name="Wolfram Alpha", - icon_url=WOLF_IMAGE, - url="https://www.wolframalpha.com/") - embed.colour = Colours.soft_orange - - await ImagePaginator.paginate(pages, ctx, embed) - - @wolfram_command.command(name="cut", aliases=("c",)) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_cut_command(self, ctx: Context, *, query: str) -> None: - """ - Requests a drawn image of given query. - - Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. - """ - pages = await get_pod_pages(ctx, self.bot, query) - - if not pages: - return - - if len(pages) >= 2: - page = pages[1] - else: - page = pages[0] - - await send_embed(ctx, page[0], colour=Colours.soft_orange, img_url=page[1]) - - @wolfram_command.command(name="short", aliases=("sh", "s")) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_short_command(self, ctx: Context, *, query: str) -> None: - """Requests an answer to a simple question.""" - url_str = parse.urlencode({ - "i": query, - "appid": APPID, - }) - query = QUERY.format(request="result", data=url_str) - - # Give feedback that the bot is working. - async with ctx.channel.typing(): - async with self.bot.http_session.get(query) as response: - status = response.status - response_text = await response.text() - - if status == 501: - message = "Failed to get response" - color = Colours.soft_red - elif status == 400: - message = "No input found" - color = Colours.soft_red - elif response_text == "Error 1: Invalid appid": - message = "Wolfram API key is invalid or missing." - color = Colours.soft_red - else: - message = response_text - color = Colours.soft_orange - - await send_embed(ctx, message, color) - - -def setup(bot: Bot) -> None: - """Load the Wolfram cog.""" - bot.add_cog(Wolfram(bot)) diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index da4e92ccc..df38090fb 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -2,7 +2,7 @@ import textwrap import unittest from unittest.mock import AsyncMock, Mock, patch -from bot.cogs.moderation.infractions import Infractions +from bot.cogs.moderation.infraction.infractions import Infractions from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index 70aea2bab..84d036405 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -6,7 +6,7 @@ import discord from bot import constants from bot.api import ResponseCodeError -from bot.cogs.sync.syncers import Syncer, _Diff +from bot.cogs.backend.sync import Syncer, _Diff from tests import helpers diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 120bc991d..ea7d090ba 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -5,8 +5,8 @@ import discord from bot import constants from bot.api import ResponseCodeError -from bot.cogs import sync -from bot.cogs.sync.syncers import Syncer +from bot.cogs.backend import sync +from bot.cogs.backend.sync import Syncer from tests import helpers from tests.base import CommandTestCase diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 79eee98f4..888c49ca8 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -3,7 +3,7 @@ from unittest import mock import discord -from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from bot.cogs.backend.sync import RoleSyncer, _Diff, _Role from tests import helpers diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 002a947ad..71f4b134c 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from bot.cogs.backend.sync import UserSyncer, _Diff, _User from tests import helpers diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index ecb7abf00..b00211f47 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, Mock from discord import NotFound -from bot.cogs import antimalware +from bot.cogs.filters import antimalware from bot.constants import Channels, STAFF_ROLES from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/cogs/test_antispam.py index ce5472c71..8a3d8d02e 100644 --- a/tests/bot/cogs/test_antispam.py +++ b/tests/bot/cogs/test_antispam.py @@ -1,6 +1,6 @@ import unittest -from bot.cogs import antispam +from bot.cogs.filters import antispam class AntispamConfigurationValidationTests(unittest.TestCase): diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 79c0e0ad3..305a2bad9 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -6,7 +6,7 @@ import unittest.mock import discord from bot import constants -from bot.cogs import information +from bot.cogs.info import information from bot.utils.checks import InWhitelistCheckFailure from tests import helpers diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py index 9d1a62f7e..82679f69c 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/cogs/test_security.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock from discord.ext.commands import NoPrivateMessage -from bot.cogs import security +from bot.cogs.filters import security from tests.helpers import MockBot, MockContext diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 343e37db9..c7bac3ab3 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -6,8 +6,8 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, pat from discord.ext import commands from bot import constants -from bot.cogs import snekbox -from bot.cogs.snekbox import Snekbox +from bot.cogs.utils import snekbox +from bot.cogs.utils.snekbox import Snekbox from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3349caa73..e33f3af38 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -6,9 +6,9 @@ from unittest.mock import MagicMock from discord import Colour, NotFound from bot import constants -from bot.cogs import token_remover +from bot.cogs.filters import token_remover +from bot.cogs.filters.token_remover import Token, TokenRemover from bot.cogs.moderation import ModLog -from bot.cogs.token_remover import Token, TokenRemover from tests.helpers import MockBot, MockMessage, autospec -- cgit v1.2.3 From b224d46d68699ece3382cd333df7ede9e9a62e02 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 12 Aug 2020 14:31:56 -0700 Subject: Restructure tests and fix broken tests The cog tests structure should mirror the structure of the cogs folder. Fix some import/patch paths which broke due to the restructure. --- tests/bot/cogs/backend/__init__.py | 0 tests/bot/cogs/backend/sync/__init__.py | 0 tests/bot/cogs/backend/sync/test_base.py | 404 ++++++++++++++ tests/bot/cogs/backend/sync/test_cog.py | 415 +++++++++++++++ tests/bot/cogs/backend/sync/test_roles.py | 157 ++++++ tests/bot/cogs/backend/sync/test_users.py | 158 ++++++ tests/bot/cogs/backend/test_logging.py | 32 ++ tests/bot/cogs/filters/__init__.py | 0 tests/bot/cogs/filters/test_antimalware.py | 165 ++++++ tests/bot/cogs/filters/test_antispam.py | 35 ++ tests/bot/cogs/filters/test_security.py | 54 ++ tests/bot/cogs/filters/test_token_remover.py | 310 +++++++++++ tests/bot/cogs/info/__init__.py | 0 tests/bot/cogs/info/test_information.py | 584 +++++++++++++++++++++ tests/bot/cogs/moderation/infraction/__init__.py | 0 .../cogs/moderation/infraction/test_infractions.py | 55 ++ tests/bot/cogs/moderation/test_incidents.py | 4 +- tests/bot/cogs/moderation/test_infractions.py | 55 -- tests/bot/cogs/moderation/test_slowmode.py | 111 ++++ tests/bot/cogs/sync/__init__.py | 0 tests/bot/cogs/sync/test_base.py | 404 -------------- tests/bot/cogs/sync/test_cog.py | 415 --------------- tests/bot/cogs/sync/test_roles.py | 157 ------ tests/bot/cogs/sync/test_users.py | 158 ------ tests/bot/cogs/test_antimalware.py | 165 ------ tests/bot/cogs/test_antispam.py | 35 -- tests/bot/cogs/test_information.py | 584 --------------------- tests/bot/cogs/test_jams.py | 173 ------ tests/bot/cogs/test_logging.py | 32 -- tests/bot/cogs/test_security.py | 54 -- tests/bot/cogs/test_slowmode.py | 111 ---- tests/bot/cogs/test_snekbox.py | 409 --------------- tests/bot/cogs/test_token_remover.py | 310 ----------- tests/bot/cogs/utils/__init__.py | 0 tests/bot/cogs/utils/test_jams.py | 173 ++++++ tests/bot/cogs/utils/test_snekbox.py | 409 +++++++++++++++ 36 files changed, 3064 insertions(+), 3064 deletions(-) create mode 100644 tests/bot/cogs/backend/__init__.py create mode 100644 tests/bot/cogs/backend/sync/__init__.py create mode 100644 tests/bot/cogs/backend/sync/test_base.py create mode 100644 tests/bot/cogs/backend/sync/test_cog.py create mode 100644 tests/bot/cogs/backend/sync/test_roles.py create mode 100644 tests/bot/cogs/backend/sync/test_users.py create mode 100644 tests/bot/cogs/backend/test_logging.py create mode 100644 tests/bot/cogs/filters/__init__.py create mode 100644 tests/bot/cogs/filters/test_antimalware.py create mode 100644 tests/bot/cogs/filters/test_antispam.py create mode 100644 tests/bot/cogs/filters/test_security.py create mode 100644 tests/bot/cogs/filters/test_token_remover.py create mode 100644 tests/bot/cogs/info/__init__.py create mode 100644 tests/bot/cogs/info/test_information.py create mode 100644 tests/bot/cogs/moderation/infraction/__init__.py create mode 100644 tests/bot/cogs/moderation/infraction/test_infractions.py delete mode 100644 tests/bot/cogs/moderation/test_infractions.py create mode 100644 tests/bot/cogs/moderation/test_slowmode.py delete mode 100644 tests/bot/cogs/sync/__init__.py delete mode 100644 tests/bot/cogs/sync/test_base.py delete mode 100644 tests/bot/cogs/sync/test_cog.py delete mode 100644 tests/bot/cogs/sync/test_roles.py delete mode 100644 tests/bot/cogs/sync/test_users.py delete mode 100644 tests/bot/cogs/test_antimalware.py delete mode 100644 tests/bot/cogs/test_antispam.py delete mode 100644 tests/bot/cogs/test_information.py delete mode 100644 tests/bot/cogs/test_jams.py delete mode 100644 tests/bot/cogs/test_logging.py delete mode 100644 tests/bot/cogs/test_security.py delete mode 100644 tests/bot/cogs/test_slowmode.py delete mode 100644 tests/bot/cogs/test_snekbox.py delete mode 100644 tests/bot/cogs/test_token_remover.py create mode 100644 tests/bot/cogs/utils/__init__.py create mode 100644 tests/bot/cogs/utils/test_jams.py create mode 100644 tests/bot/cogs/utils/test_snekbox.py (limited to 'tests') diff --git a/tests/bot/cogs/backend/__init__.py b/tests/bot/cogs/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/backend/sync/__init__.py b/tests/bot/cogs/backend/sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/backend/sync/test_base.py b/tests/bot/cogs/backend/sync/test_base.py new file mode 100644 index 000000000..0d0a8299d --- /dev/null +++ b/tests/bot/cogs/backend/sync/test_base.py @@ -0,0 +1,404 @@ +import asyncio +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.backend.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): + """Syncer subclass with mocks for abstract methods for testing purposes.""" + + name = "test" + _get_diff = mock.AsyncMock() + _sync = mock.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): + """Tests for the syncer base class.""" + + def setUp(self): + self.bot = helpers.MockBot() + + def test_instantiation_fails_without_abstract_methods(self): + """The class must have abstract methods implemented.""" + with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): + Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): + """Tests for sending the sync confirmation prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + + def mock_get_channel(self): + """Fixture to return a mock channel and message for when `get_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + mock_channel.send.return_value = mock_message + self.bot.get_channel.return_value = mock_channel + + return mock_channel, mock_message + + def mock_fetch_channel(self): + """Fixture to return a mock channel and message for when `fetch_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + self.bot.get_channel.return_value = None + mock_channel.send.return_value = mock_message + self.bot.fetch_channel.return_value = mock_channel + + return mock_channel, mock_message + + async def test_send_prompt_edits_and_returns_message(self): + """The given message should be edited to display the prompt and then should be returned.""" + msg = helpers.MockMessage() + ret_val = await self.syncer._send_prompt(msg) + + msg.edit.assert_called_once() + self.assertIn("content", msg.edit.call_args[1]) + self.assertEqual(ret_val, msg) + + async def test_send_prompt_gets_dev_core_channel(self): + """The dev-core channel should be retrieved if an extant message isn't given.""" + subtests = ( + (self.bot.get_channel, self.mock_get_channel), + (self.bot.fetch_channel, self.mock_fetch_channel), + ) + + for method, mock_ in subtests: + with self.subTest(method=method, msg=mock_.__name__): + mock_() + await self.syncer._send_prompt() + + method.assert_called_once_with(constants.Channels.dev_core) + + async def test_send_prompt_returns_none_if_channel_fetch_fails(self): + """None should be returned if there's an HTTPException when fetching the channel.""" + self.bot.get_channel.return_value = None + self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + + ret_val = await self.syncer._send_prompt() + + self.assertIsNone(ret_val) + + async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): + """A new message mentioning core devs should be sent and returned if message isn't given.""" + for mock_ in (self.mock_get_channel, self.mock_fetch_channel): + with self.subTest(msg=mock_.__name__): + mock_channel, mock_message = mock_() + ret_val = await self.syncer._send_prompt() + + mock_channel.send.assert_called_once() + self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) + self.assertEqual(ret_val, mock_message) + + async def test_send_prompt_adds_reactions(self): + """The message should have reactions for confirmation added.""" + extant_message = helpers.MockMessage() + subtests = ( + (extant_message, lambda: (None, extant_message)), + (None, self.mock_get_channel), + (None, self.mock_fetch_channel), + ) + + for message_arg, mock_ in subtests: + subtest_msg = "Extant message" if mock_.__name__ == "" else mock_.__name__ + + with self.subTest(msg=subtest_msg): + _, mock_message = mock_() + await self.syncer._send_prompt(message_arg) + + calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] + mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): + """Tests for waiting for a sync confirmation reaction on the prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) + + @staticmethod + def get_message_reaction(emoji): + """Fixture to return a mock message an reaction from the given `emoji`.""" + message = helpers.MockMessage() + reaction = helpers.MockReaction(emoji=emoji, message=message) + + return message, reaction + + def test_reaction_check_for_valid_emoji_and_authors(self): + """Should return True if authors are identical or are a bot and a core dev, respectively.""" + user_subtests = ( + ( + helpers.MockMember(id=77), + helpers.MockMember(id=77), + "identical users", + ), + ( + helpers.MockMember(id=77, bot=True), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "bot author and core-dev reactor", + ), + ) + + for emoji in self.syncer._REACTION_EMOJIS: + for author, user, msg in user_subtests: + with self.subTest(author=author, user=user, emoji=emoji, msg=msg): + message, reaction = self.get_message_reaction(emoji) + ret_val = self.syncer._reaction_check(author, message, reaction, user) + + self.assertTrue(ret_val) + + def test_reaction_check_for_invalid_reactions(self): + """Should return False for invalid reaction events.""" + valid_emoji = self.syncer._REACTION_EMOJIS[0] + subtests = ( + ( + helpers.MockMember(id=77), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "users are not identical", + ), + ( + helpers.MockMember(id=77, bot=True), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43), + "reactor lacks the core-dev role", + ), + ( + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + "reactor is a bot", + ), + ( + helpers.MockMember(id=77), + helpers.MockMessage(id=95), + helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), + helpers.MockMember(id=77), + "messages are not identical", + ), + ( + helpers.MockMember(id=77), + *self.get_message_reaction("InVaLiD"), + helpers.MockMember(id=77), + "emoji is invalid", + ), + ) + + for *args, msg in subtests: + kwargs = dict(zip(("author", "message", "reaction", "user"), args)) + with self.subTest(**kwargs, msg=msg): + ret_val = self.syncer._reaction_check(*args) + self.assertFalse(ret_val) + + async def test_wait_for_confirmation(self): + """The message should always be edited and only return True if the emoji is a check mark.""" + subtests = ( + (constants.Emojis.check_mark, True, None), + ("InVaLiD", False, None), + (None, False, asyncio.TimeoutError), + ) + + for emoji, ret_val, side_effect in subtests: + for bot in (True, False): + with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): + # Set up mocks + message = helpers.MockMessage() + member = helpers.MockMember(bot=bot) + + self.bot.wait_for.reset_mock() + self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) + self.bot.wait_for.side_effect = side_effect + + # Call the function + actual_return = await self.syncer._wait_for_confirmation(member, message) + + # Perform assertions + self.bot.wait_for.assert_called_once() + self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + + message.edit.assert_called_once() + kwargs = message.edit.call_args[1] + self.assertIn("content", kwargs) + + # Core devs should only be mentioned if the author is a bot. + if bot: + self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + else: + self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + + self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for main function orchestrating the sync.""" + + def setUp(self): + self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) + self.syncer = TestSyncer(self.bot) + + async def test_sync_respects_confirmation_result(self): + """The sync should abort if confirmation fails and continue if confirmed.""" + mock_message = helpers.MockMessage() + subtests = ( + (True, mock_message), + (False, None), + ) + + for confirmed, message in subtests: + with self.subTest(confirmed=confirmed): + self.syncer._sync.reset_mock() + self.syncer._get_diff.reset_mock() + + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(confirmed, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + + if confirmed: + self.syncer._sync.assert_called_once_with(diff) + else: + self.syncer._sync.assert_not_called() + + async def test_sync_diff_size(self): + """The diff size should be correctly calculated.""" + subtests = ( + (6, _Diff({1, 2}, {3, 4}, {5, 6})), + (5, _Diff({1, 2, 3}, None, {4, 5})), + (0, _Diff(None, None, None)), + (0, _Diff(set(), set(), set())), + ) + + for size, diff in subtests: + with self.subTest(size=size, diff=diff): + self.syncer._get_diff.reset_mock() + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + + async def test_sync_message_edited(self): + """The message should be edited if one was sent, even if the sync has an API error.""" + subtests = ( + (None, None, False), + (helpers.MockMessage(), None, True), + (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), + ) + + for message, side_effect, should_edit in subtests: + with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): + self.syncer._sync.side_effect = side_effect + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(True, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + if should_edit: + message.edit.assert_called_once() + self.assertIn("content", message.edit.call_args[1]) + + async def test_sync_confirmation_context_redirect(self): + """If ctx is given, a new message should be sent and author should be ctx's author.""" + mock_member = helpers.MockMember() + subtests = ( + (None, self.bot.user, None), + (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + ) + + for ctx, author, message in subtests: + with self.subTest(ctx=ctx, author=author, message=message): + if ctx is not None: + ctx.send.return_value = message + + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock + self.syncer._get_diff.return_value = mock.MagicMock() + + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild, ctx) + + if ctx is not None: + ctx.send.assert_called_once() + + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_small_diff(self): + """Should always return True and the given message if the diff size is too small.""" + author = helpers.MockMember() + expected_message = helpers.MockMessage() + + for size in (3, 2): # pragma: no cover + with self.subTest(size=size): + self.syncer._send_prompt = mock.AsyncMock() + self.syncer._wait_for_confirmation = mock.AsyncMock() + + coro = self.syncer._get_confirmation_result(size, author, expected_message) + result, actual_message = await coro + + self.assertTrue(result) + self.assertEqual(actual_message, expected_message) + self.syncer._send_prompt.assert_not_called() + self.syncer._wait_for_confirmation.assert_not_called() + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_large_diff(self): + """Should return True if confirmed and False if _send_prompt fails or aborted.""" + author = helpers.MockMember() + mock_message = helpers.MockMessage() + + subtests = ( + (True, mock_message, True, "confirmed"), + (False, None, False, "_send_prompt failed"), + (False, mock_message, False, "aborted"), + ) + + for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover + with self.subTest(msg=msg): + self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) + + coro = self.syncer._get_confirmation_result(4, author) + actual_result, actual_message = await coro + + self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None + self.assertIs(actual_result, expected_result) + self.assertEqual(actual_message, expected_message) + + if expected_message: + self.syncer._wait_for_confirmation.assert_called_once_with( + author, expected_message + ) diff --git a/tests/bot/cogs/backend/sync/test_cog.py b/tests/bot/cogs/backend/sync/test_cog.py new file mode 100644 index 000000000..199747051 --- /dev/null +++ b/tests/bot/cogs/backend/sync/test_cog.py @@ -0,0 +1,415 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.backend import sync +from bot.cogs.backend.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): + """Tests for the sync extension.""" + + @staticmethod + def test_extension_setup(): + """The Sync cog should be added.""" + bot = helpers.MockBot() + sync.setup(bot) + bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): + """Base class for Sync cog tests. Sets up patches for syncers.""" + + def setUp(self): + self.bot = helpers.MockBot() + + self.role_syncer_patcher = mock.patch( + "bot.cogs.backend.sync.syncers.RoleSyncer", + autospec=Syncer, + spec_set=True + ) + self.user_syncer_patcher = mock.patch( + "bot.cogs.backend.sync.syncers.UserSyncer", + autospec=Syncer, + spec_set=True + ) + self.RoleSyncer = self.role_syncer_patcher.start() + self.UserSyncer = self.user_syncer_patcher.start() + + self.cog = sync.Sync(self.bot) + + def tearDown(self): + self.role_syncer_patcher.stop() + self.user_syncer_patcher.stop() + + @staticmethod + def response_error(status: int) -> ResponseCodeError: + """Fixture to return a ResponseCodeError with the given status code.""" + response = mock.MagicMock() + response.status = status + + return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): + """Tests for the Sync cog.""" + + @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) + def test_sync_cog_init(self, sync_guild): + """Should instantiate syncers and run a sync for the guild.""" + # Reset because a Sync cog was already instantiated in setUp. + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + self.bot.loop.create_task = mock.MagicMock() + + mock_sync_guild_coro = mock.MagicMock() + sync_guild.return_value = mock_sync_guild_coro + + sync.Sync(self.bot) + + self.RoleSyncer.assert_called_once_with(self.bot) + self.UserSyncer.assert_called_once_with(self.bot) + sync_guild.assert_called_once_with() + self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + + async def test_sync_cog_sync_guild(self): + """Roles and users should be synced only if a guild is successfully retrieved.""" + for guild in (helpers.MockGuild(), None): + with self.subTest(guild=guild): + self.bot.reset_mock() + self.cog.role_syncer.reset_mock() + self.cog.user_syncer.reset_mock() + + self.bot.get_guild = mock.MagicMock(return_value=guild) + + await self.cog.sync_guild() + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(constants.Guild.id) + + if guild is None: + self.cog.role_syncer.sync.assert_not_called() + self.cog.user_syncer.sync.assert_not_called() + else: + self.cog.role_syncer.sync.assert_called_once_with(guild) + self.cog.user_syncer.sync.assert_called_once_with(guild) + + async def patch_user_helper(self, side_effect: BaseException) -> None: + """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" + self.bot.api_client.patch.reset_mock(side_effect=True) + self.bot.api_client.patch.side_effect = side_effect + + user_id, updated_information = 5, {"key": 123} + await self.cog.patch_user(user_id, updated_information) + + self.bot.api_client.patch.assert_called_once_with( + f"bot/users/{user_id}", + json=updated_information, + ) + + async def test_sync_cog_patch_user(self): + """A PATCH request should be sent and 404 errors ignored.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + await self.patch_user_helper(side_effect) + + async def test_sync_cog_patch_user_non_404(self): + """A PATCH request should be sent and the error raised if it's not a 404.""" + with self.assertRaises(ResponseCodeError): + await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): + """Tests for the listeners of the Sync cog.""" + + def setUp(self): + super().setUp() + self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) + + self.guild_id_patcher = mock.patch("bot.cogs.backend.sync.cog.constants.Guild.id", 5) + self.guild_id = self.guild_id_patcher.start() + + self.guild = helpers.MockGuild(id=self.guild_id) + self.other_guild = helpers.MockGuild(id=0) + + def tearDown(self): + self.guild_id_patcher.stop() + + async def test_sync_cog_on_guild_role_create(self): + """A POST request should be sent with the new role's data.""" + self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + role = helpers.MockRole(**role_data, guild=self.guild) + await self.cog.on_guild_role_create(role) + + self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + + async def test_sync_cog_on_guild_role_create_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_create(role) + self.bot.api_client.post.assert_not_awaited() + + async def test_sync_cog_on_guild_role_delete(self): + """A DELETE request should be sent.""" + self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + + role = helpers.MockRole(id=99, guild=self.guild) + await self.cog.on_guild_role_delete(role) + + self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + + async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_delete(role) + self.bot.api_client.delete.assert_not_awaited() + + async def test_sync_cog_on_guild_role_update(self): + """A PUT request should be sent if the colour, name, permissions, or position changes.""" + self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + subtests = ( + (True, ("colour", "name", "permissions", "position")), + (False, ("hoist", "mentionable")), + ) + + for should_put, attributes in subtests: + for attribute in attributes: + with self.subTest(should_put=should_put, changed_attribute=attribute): + self.bot.api_client.put.reset_mock() + + after_role_data = role_data.copy() + after_role_data[attribute] = 876 + + before_role = helpers.MockRole(**role_data, guild=self.guild) + after_role = helpers.MockRole(**after_role_data, guild=self.guild) + + await self.cog.on_guild_role_update(before_role, after_role) + + if should_put: + self.bot.api_client.put.assert_called_once_with( + f"bot/roles/{after_role.id}", + json=after_role_data + ) + else: + self.bot.api_client.put.assert_not_called() + + async def test_sync_cog_on_guild_role_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_update(role, role) + self.bot.api_client.put.assert_not_awaited() + + async def test_sync_cog_on_member_remove(self): + """Member should be patched to set in_guild as False.""" + self.assertTrue(self.cog.on_member_remove.__cog_listener__) + + member = helpers.MockMember(guild=self.guild) + await self.cog.on_member_remove(member) + + self.cog.patch_user.assert_called_once_with( + member.id, + json={"in_guild": False} + ) + + async def test_sync_cog_on_member_remove_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_remove(member) + self.cog.patch_user.assert_not_awaited() + + async def test_sync_cog_on_member_update_roles(self): + """Members should be patched if their roles have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + # Roles are intentionally unsorted. + before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] + before_member = helpers.MockMember(roles=before_roles, guild=self.guild) + after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) + + await self.cog.on_member_update(before_member, after_member) + + data = {"roles": sorted(role.id for role in after_member.roles)} + self.cog.patch_user.assert_called_once_with(after_member.id, json=data) + + async def test_sync_cog_on_member_update_other(self): + """Members should not be patched if other attributes have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + subtests = ( + ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("nick", "old nick", "new nick"), + ("status", discord.Status.online, discord.Status.offline), + ) + + for attribute, old_value, new_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) + after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) + + await self.cog.on_member_update(before_member, after_member) + + self.cog.patch_user.assert_not_called() + + async def test_sync_cog_on_member_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_update(member, member) + self.cog.patch_user.assert_not_awaited() + + async def test_sync_cog_on_user_update(self): + """A user should be patched only if the name, discriminator, or avatar changes.""" + self.assertTrue(self.cog.on_user_update.__cog_listener__) + + before_data = { + "name": "old name", + "discriminator": "1234", + "bot": False, + } + + subtests = ( + (True, "name", "name", "new name", "new name"), + (True, "discriminator", "discriminator", "8765", 8765), + (False, "bot", "bot", True, True), + ) + + for should_patch, attribute, api_field, value, api_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + after_data = before_data.copy() + after_data[attribute] = value + before_user = helpers.MockUser(**before_data) + after_user = helpers.MockUser(**after_data) + + await self.cog.on_user_update(before_user, after_user) + + if should_patch: + self.cog.patch_user.assert_called_once() + + # Don't care if *all* keys are present; only the changed one is required + call_args = self.cog.patch_user.call_args + self.assertEqual(call_args.args[0], after_user.id) + self.assertIn("json", call_args.kwargs) + + self.assertIn("ignore_404", call_args.kwargs) + self.assertTrue(call_args.kwargs["ignore_404"]) + + json = call_args.kwargs["json"] + self.assertIn(api_field, json) + self.assertEqual(json[api_field], api_value) + else: + self.cog.patch_user.assert_not_called() + + async def on_member_join_helper(self, side_effect: Exception) -> dict: + """ + Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + + The request data for the mock member is returned. All exceptions will be re-raised. + """ + member = helpers.MockMember( + discriminator="1234", + roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + guild=self.guild, + ) + + data = { + "discriminator": int(member.discriminator), + "id": member.id, + "in_guild": True, + "name": member.name, + "roles": sorted(role.id for role in member.roles) + } + + self.bot.api_client.put.reset_mock(side_effect=True) + self.bot.api_client.put.side_effect = side_effect + + try: + await self.cog.on_member_join(member) + except Exception: + raise + finally: + self.bot.api_client.put.assert_called_once_with( + f"bot/users/{member.id}", + json=data + ) + + return data + + async def test_sync_cog_on_member_join(self): + """Should PUT user's data or POST it if the user doesn't exist.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + self.bot.api_client.post.reset_mock() + data = await self.on_member_join_helper(side_effect) + + if side_effect: + self.bot.api_client.post.assert_called_once_with("bot/users", json=data) + else: + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_non_404(self): + """ResponseCodeError should be re-raised if status code isn't a 404.""" + with self.assertRaises(ResponseCodeError): + await self.on_member_join_helper(self.response_error(500)) + + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_join(member) + self.bot.api_client.post.assert_not_awaited() + self.bot.api_client.put.assert_not_awaited() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): + """Tests for the commands in the Sync cog.""" + + async def test_sync_roles_command(self): + """sync() should be called on the RoleSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_roles_command.callback(self.cog, ctx) + + self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_sync_users_command(self): + """sync() should be called on the UserSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_users_command.callback(self.cog, ctx) + + self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_commands_require_admin(self): + """The sync commands should only run if the author has the administrator permission.""" + cmds = ( + self.cog.sync_group, + self.cog.sync_roles_command, + self.cog.sync_users_command, + ) + + for cmd in cmds: + with self.subTest(cmd=cmd): + await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/backend/sync/test_roles.py b/tests/bot/cogs/backend/sync/test_roles.py new file mode 100644 index 000000000..cc2e51c7f --- /dev/null +++ b/tests/bot/cogs/backend/sync/test_roles.py @@ -0,0 +1,157 @@ +import unittest +from unittest import mock + +import discord + +from bot.cogs.backend.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): + """Fixture to return a dictionary representing a role with default values set.""" + kwargs.setdefault("id", 9) + kwargs.setdefault("name", "fake role") + kwargs.setdefault("colour", 7) + kwargs.setdefault("permissions", 0) + kwargs.setdefault("position", 55) + + return kwargs + + +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @staticmethod + def get_guild(*roles): + """Fixture to return a guild object with the given roles.""" + guild = helpers.MockGuild() + guild.roles = [] + + for role in roles: + mock_role = helpers.MockRole(**role) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) + guild.roles.append(mock_role) + + return guild + + async def test_empty_diff_for_identical_roles(self): + """No differences should be found if the roles in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_roles(self): + """Only updated roles should be added to the 'updated' set of the diff.""" + updated_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] + guild = self.get_guild(updated_role, fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_Role(**updated_role)}, set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_roles(self): + """Only new roles should be added to the 'created' set of the diff.""" + new_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role(), new_role) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new_role)}, set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_deleted_roles(self): + """Only deleted roles should be added to the 'deleted' set of the diff.""" + deleted_role = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [fake_role(), deleted_role] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), {_Role(**deleted_role)}) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_deleted_roles(self): + """When roles are added, updated, and removed, all of them are returned properly.""" + new = fake_role(id=41, name="new") + updated = fake_role(id=71, name="updated") + deleted = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [ + fake_role(), + fake_role(id=71, name="updated name"), + deleted, + ] + guild = self.get_guild(fake_role(), new, updated) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + + self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync roles.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + async def test_sync_created_roles(self): + """Only POST requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(role_tuples, set(), set()) + await self.syncer._sync(diff) + + calls = [mock.call("bot/roles", json=role) for role in roles] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_roles(self): + """Only PUT requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), role_tuples, set()) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_deleted_roles(self): + """Only DELETE requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), set(), role_tuples) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] + self.bot.api_client.delete.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/backend/sync/test_users.py b/tests/bot/cogs/backend/sync/test_users.py new file mode 100644 index 000000000..490ea9e06 --- /dev/null +++ b/tests/bot/cogs/backend/sync/test_users.py @@ -0,0 +1,158 @@ +import unittest +from unittest import mock + +from bot.cogs.backend.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers + + +def fake_user(**kwargs): + """Fixture to return a dictionary representing a user with default values set.""" + kwargs.setdefault("id", 43) + kwargs.setdefault("name", "bob the test man") + kwargs.setdefault("discriminator", 1337) + kwargs.setdefault("roles", (666,)) + kwargs.setdefault("in_guild", True) + + return kwargs + + +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between users in the DB and users in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @staticmethod + def get_guild(*members): + """Fixture to return a guild object with the given members.""" + guild = helpers.MockGuild() + guild.members = [] + + for member in members: + member = member.copy() + del member["in_guild"] + + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + + guild.members.append(mock_member) + + return guild + + async def test_empty_diff_for_no_users(self): + """When no users are given, an empty diff should be returned.""" + guild = self.get_guild() + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_identical_users(self): + """No differences should be found if the users in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_users(self): + """Only updated users should be added to the 'updated' set of the diff.""" + updated_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + guild = self.get_guild(updated_user, fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**updated_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_users(self): + """Only new users should be added to the 'created' set of the diff.""" + new_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user(), new_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_sets_in_guild_false_for_leaving_users(self): + """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_leaving_users(self): + """When users are added, updated, and removed, all of them are returned properly.""" + new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + guild = self.get_guild(fake_user(), new_user, updated_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_db_users_not_in_guild(self): + """When the DB knows a user the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync users.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + async def test_sync_created_users(self): + """Only POST requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(user_tuples, set(), None) + await self.syncer._sync(diff) + + calls = [mock.call("bot/users", json=user) for user in users] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(users)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_users(self): + """Only PUT requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(set(), user_tuples, None) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(users)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/backend/test_logging.py b/tests/bot/cogs/backend/test_logging.py new file mode 100644 index 000000000..c867773e2 --- /dev/null +++ b/tests/bot/cogs/backend/test_logging.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import patch + +from bot import constants +from bot.cogs.backend.logging import Logging +from tests.helpers import MockBot, MockTextChannel + + +class LoggingTests(unittest.IsolatedAsyncioTestCase): + """Test cases for connected login.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Logging(self.bot) + self.dev_log = MockTextChannel(id=1234, name="dev-log") + + @patch("bot.cogs.backend.logging.DEBUG_MODE", False) + async def test_debug_mode_false(self): + """Should send connected message to dev-log.""" + self.bot.get_channel.return_value = self.dev_log + + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) + self.dev_log.send.assert_awaited_once() + + @patch("bot.cogs.backend.logging.DEBUG_MODE", True) + async def test_debug_mode_true(self): + """Should not send anything to dev-log.""" + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_not_called() diff --git a/tests/bot/cogs/filters/__init__.py b/tests/bot/cogs/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/filters/test_antimalware.py b/tests/bot/cogs/filters/test_antimalware.py new file mode 100644 index 000000000..b00211f47 --- /dev/null +++ b/tests/bot/cogs/filters/test_antimalware.py @@ -0,0 +1,165 @@ +import unittest +from unittest.mock import AsyncMock, Mock + +from discord import NotFound + +from bot.cogs.filters import antimalware +from bot.constants import Channels, STAFF_ROLES +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + + +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): + """Test the AntiMalware cog.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = MockBot() + self.bot.filter_list_cache = { + "FILE_FORMAT.True": { + ".first": {}, + ".second": {}, + ".third": {}, + } + } + self.cog = antimalware.AntiMalware(self.bot) + self.message = MockMessage() + self.whitelist = [".first", ".second", ".third"] + + async def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should not be deleted""" + attachment = MockAttachment(filename="python.first") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_message_without_attachment(self): + """Messages without attachments should result in no action.""" + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_direct_message_with_attachment(self): + """Direct messages should have no action taken.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.guild = None + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_message_with_illegal_extension_gets_deleted(self): + """A message containing an illegal extension should send an embed.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_called_once() + + async def test_message_send_by_staff(self): + """A message send by a member of staff should be ignored.""" + staff_role = MockRole(id=STAFF_ROLES[0]) + self.message.author.roles.append(staff_role) + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_python_file_redirect_embed_description(self): + """A message containing a .py file should result in an embed redirecting the user to our paste site""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + + self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt file should result in the correct embed.""" + attachment = MockAttachment(filename="python.txt") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.TXT_EMBED_DESCRIPTION = Mock() + antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + cmd_channel = self.bot.get_channel(Channels.bot_commands) + + self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) + antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + + async def test_other_disallowed_extension_embed_description(self): + """Test the description for a non .py/.txt disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + meta_channel = self.bot.get_channel(Channels.meta) + + self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + joined_whitelist=", ".join(self.whitelist), + blocked_extensions_str=".disallowed", + meta_channel_mention=meta_channel.mention + ) + + async def test_removing_deleted_message_logs(self): + """Removing an already deleted message logs the correct message""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + self.message.delete.assert_called_once() + + async def test_message_with_illegal_attachment_logs(self): + """Deleting a message with an illegal attachment should result in a log.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + ([], []), + (self.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], [".disallowed"]), + ([".disallowed"], [".disallowed"]), + ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] + disallowed_extensions = self.cog._get_disallowed_extensions(self.message) + self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): + """Tests setup of the `AntiMalware` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + antimalware.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/filters/test_antispam.py b/tests/bot/cogs/filters/test_antispam.py new file mode 100644 index 000000000..8a3d8d02e --- /dev/null +++ b/tests/bot/cogs/filters/test_antispam.py @@ -0,0 +1,35 @@ +import unittest + +from bot.cogs.filters import antispam + + +class AntispamConfigurationValidationTests(unittest.TestCase): + """Tests validation of the antispam cog configuration.""" + + def test_default_antispam_config_is_valid(self): + """The default antispam configuration is valid.""" + validation_errors = antispam.validate_config() + self.assertEqual(validation_errors, {}) + + def test_unknown_rule_returns_error(self): + """Configuring an unknown rule returns an error.""" + self.assertEqual( + antispam.validate_config({'invalid-rule': {}}), + {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} + ) + + def test_missing_keys_returns_error(self): + """Not configuring required keys returns an error.""" + keys = (('interval', 'max'), ('max', 'interval')) + for configured_key, unconfigured_key in keys: + with self.subTest( + configured_key=configured_key, + unconfigured_key=unconfigured_key + ): + config = {'burst': {configured_key: 10}} + error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" + + self.assertEqual( + antispam.validate_config(config), + {'burst': error} + ) diff --git a/tests/bot/cogs/filters/test_security.py b/tests/bot/cogs/filters/test_security.py new file mode 100644 index 000000000..82679f69c --- /dev/null +++ b/tests/bot/cogs/filters/test_security.py @@ -0,0 +1,54 @@ +import unittest +from unittest.mock import MagicMock + +from discord.ext.commands import NoPrivateMessage + +from bot.cogs.filters import security +from tests.helpers import MockBot, MockContext + + +class SecurityCogTests(unittest.TestCase): + """Tests the `Security` cog.""" + + def setUp(self): + """Attach an instance of the cog to the class for tests.""" + self.bot = MockBot() + self.cog = security.Security(self.bot) + self.ctx = MockContext() + + def test_check_additions(self): + """The cog should add its checks after initialization.""" + self.bot.check.assert_any_call(self.cog.check_on_guild) + self.bot.check.assert_any_call(self.cog.check_not_bot) + + def test_check_not_bot_returns_false_for_humans(self): + """The bot check should return `True` when invoked with human authors.""" + self.ctx.author.bot = False + self.assertTrue(self.cog.check_not_bot(self.ctx)) + + def test_check_not_bot_returns_true_for_robots(self): + """The bot check should return `False` when invoked with robotic authors.""" + self.ctx.author.bot = True + self.assertFalse(self.cog.check_not_bot(self.ctx)) + + def test_check_on_guild_raises_when_outside_of_guild(self): + """When invoked outside of a guild, `check_on_guild` should cause an error.""" + self.ctx.guild = None + + with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): + self.cog.check_on_guild(self.ctx) + + def test_check_on_guild_returns_true_inside_of_guild(self): + """When invoked inside of a guild, `check_on_guild` should return `True`.""" + self.ctx.guild = "lemon's lemonade stand" + self.assertTrue(self.cog.check_on_guild(self.ctx)) + + +class SecurityCogLoadTests(unittest.TestCase): + """Tests loading the `Security` cog.""" + + def test_security_cog_load(self): + """Setup of the extension should call add_cog.""" + bot = MagicMock() + security.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/filters/test_token_remover.py b/tests/bot/cogs/filters/test_token_remover.py new file mode 100644 index 000000000..5c527ed94 --- /dev/null +++ b/tests/bot/cogs/filters/test_token_remover.py @@ -0,0 +1,310 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock + +from discord import Colour, NotFound + +from bot import constants +from bot.cogs.filters import token_remover +from bot.cogs.filters.token_remover import Token, TokenRemover +from bot.cogs.moderation import ModLog +from tests.helpers import MockBot, MockMessage, autospec + + +class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): + """Tests the `TokenRemover` cog.""" + + def setUp(self): + """Adds the cog, a bot, and a message to the instance for usage in tests.""" + self.bot = MockBot() + self.cog = TokenRemover(bot=self.bot) + + self.msg = MockMessage(id=555, content="hello world") + self.msg.channel.mention = "#lemonade-stand" + self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + self.msg.author.avatar_url_as.return_value = "picture-lemon.png" + + def test_is_valid_user_id_valid(self): + """Should consider user IDs valid if they decode entirely to ASCII digits.""" + ids = ( + "NDcyMjY1OTQzMDYyNDEzMzMy", + "NDc1MDczNjI5Mzk5NTQ3OTA0", + "NDY3MjIzMjMwNjUwNzc3NjQx", + ) + + for user_id in ids: + with self.subTest(user_id=user_id): + result = TokenRemover.is_valid_user_id(user_id) + self.assertTrue(result) + + def test_is_valid_user_id_invalid(self): + """Should consider non-digit and non-ASCII IDs invalid.""" + ids = ( + ("SGVsbG8gd29ybGQ", "non-digit ASCII"), + ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), + ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), + ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), + ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for user_id, msg in ids: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_user_id(user_id) + self.assertFalse(result) + + def test_is_valid_timestamp_valid(self): + """Should consider timestamps valid if they're greater than the Discord epoch.""" + timestamps = ( + "XsyRkw", + "Xrim9Q", + "XsyR-w", + "XsySD_", + "Dn9r_A", + ) + + for timestamp in timestamps: + with self.subTest(timestamp=timestamp): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertTrue(result) + + def test_is_valid_timestamp_invalid(self): + """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" + timestamps = ( + ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), + ("ew", "123"), + ("AoIKgA", "42076800"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for timestamp, msg in timestamps: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertFalse(result) + + def test_mod_log_property(self): + """The `mod_log` property should ask the bot to return the `ModLog` cog.""" + self.bot.get_cog.return_value = 'lemon' + self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) + self.bot.get_cog.assert_called_once_with('ModLog') + + async def test_on_message_edit_uses_on_message(self): + """The edit listener should delegate handling of the message to the normal listener.""" + self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) + + await self.cog.on_message_edit(MockMessage(), self.msg) + self.cog.on_message.assert_awaited_once_with(self.msg) + + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_takes_action(self, find_token_in_message, take_action): + """Should take action if a valid token is found when a message is sent.""" + cog = TokenRemover(self.bot) + found_token = "foobar" + find_token_in_message.return_value = found_token + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_awaited_once_with(cog, self.msg, found_token) + + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): + """Shouldn't take action if a valid token isn't found when a message is sent.""" + cog = TokenRemover(self.bot) + find_token_in_message.return_value = False + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_not_awaited() + + @autospec(TokenRemover, "find_token_in_message") + async def test_on_message_ignores_dms_bots(self, find_token_in_message): + """Shouldn't parse a message if it is a DM or authored by a bot.""" + cog = TokenRemover(self.bot) + dm_msg = MockMessage(guild=None) + bot_msg = MockMessage(author=MagicMock(bot=True)) + + for msg in (dm_msg, bot_msg): + await cog.on_message(msg) + find_token_in_message.assert_not_called() + + @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") + def test_find_token_no_matches(self, token_re): + """None should be returned if the regex matches no tokens in a message.""" + token_re.finditer.return_value = () + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.cogs.filters.token_remover", "Token") + @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") + def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """The first match with a valid user ID and timestamp should be returned as a `Token`.""" + matches = [ + mock.create_autospec(Match, spec_set=True, instance=True), + mock.create_autospec(Match, spec_set=True, instance=True), + ] + tokens = [ + mock.create_autospec(Token, spec_set=True, instance=True), + mock.create_autospec(Token, spec_set=True, instance=True), + ] + + token_re.finditer.return_value = matches + token_cls.side_effect = tokens + is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid. + is_valid_timestamp.return_value = True + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertEqual(tokens[1], return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.cogs.filters.token_remover", "Token") + @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") + def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """None should be returned if no matches have valid user IDs or timestamps.""" + token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] + token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) + is_valid_id.return_value = False + is_valid_timestamp.return_value = False + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + def test_regex_invalid_tokens(self): + """Messages without anything looking like a token are not matched.""" + tokens = ( + "", + "lemon wins", + "..", + "x.y", + "x.y.", + ".y.z", + ".y.", + "..z", + "x..z", + " . . ", + "\n.\n.\n", + "hellö.world.bye", + "base64.nötbåse64.morebase64", + "19jd3J.dfkm3d.€víł§tüff", + ) + + for token in tokens: + with self.subTest(token=token): + results = token_remover.TOKEN_RE.findall(token) + self.assertEqual(len(results), 0) + + def test_regex_valid_tokens(self): + """Messages that look like tokens should be matched.""" + # Don't worry, these tokens have been invalidated. + tokens = ( + "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", + "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", + "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", + ) + + for token in tokens: + with self.subTest(token=token): + results = token_remover.TOKEN_RE.fullmatch(token) + self.assertIsNotNone(results, f"{token} was not matched by the regex") + + def test_regex_matches_multiple_valid(self): + """Should support multiple matches in the middle of a string.""" + token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" + token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" + message = f"garbage {token_1} hello {token_2} world" + + results = token_remover.TOKEN_RE.finditer(message) + results = [match[0] for match in results] + self.assertCountEqual((token_1, token_2), results) + + @autospec("bot.cogs.filters.token_remover", "LOG_MESSAGE") + def test_format_log_message(self, log_message): + """Should correctly format the log message with info from the message and token.""" + token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + log_message.format.return_value = "Howdy" + + return_value = TokenRemover.format_log_message(self.msg, token) + + self.assertEqual(return_value, log_message.format.return_value) + log_message.format.assert_called_once_with( + author=self.msg.author, + author_id=self.msg.author.id, + channel=self.msg.channel.mention, + user_id=token.user_id, + timestamp=token.timestamp, + hmac="x" * len(token.hmac), + ) + + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + @autospec("bot.cogs.filters.token_remover", "log") + @autospec(TokenRemover, "format_log_message") + async def test_take_action(self, format_log_message, logger, mod_log_property): + """Should delete the message and send a mod log.""" + cog = TokenRemover(self.bot) + mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) + token = mock.create_autospec(Token, spec_set=True, instance=True) + log_msg = "testing123" + + mod_log_property.return_value = mod_log + format_log_message.return_value = log_msg + + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_called_once_with( + token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + ) + + format_log_message.assert_called_once_with(self.msg, token) + logger.debug.assert_called_with(log_msg) + self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") + + mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) + mod_log.send_log_message.assert_called_once_with( + icon_url=constants.Icons.token_removed, + colour=Colour(constants.Colours.soft_red), + title="Token removed!", + text=log_msg, + thumbnail=self.msg.author.avatar_url_as.return_value, + channel_id=constants.Channels.mod_alerts + ) + + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + async def test_take_action_delete_failure(self, mod_log_property): + """Shouldn't send any messages if the token message can't be deleted.""" + cog = TokenRemover(self.bot) + mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) + self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) + + token = mock.create_autospec(Token, spec_set=True, instance=True) + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_not_awaited() + + +class TokenRemoverExtensionTests(unittest.TestCase): + """Tests for the token_remover extension.""" + + @autospec("bot.cogs.filters.token_remover", "TokenRemover") + def test_extension_setup(self, cog): + """The TokenRemover cog should be added.""" + bot = MockBot() + token_remover.setup(bot) + + cog.assert_called_once_with(bot) + bot.add_cog.assert_called_once() + self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/cogs/info/__init__.py b/tests/bot/cogs/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/info/test_information.py b/tests/bot/cogs/info/test_information.py new file mode 100644 index 000000000..895a8328e --- /dev/null +++ b/tests/bot/cogs/info/test_information.py @@ -0,0 +1,584 @@ +import asyncio +import textwrap +import unittest +import unittest.mock + +import discord + +from bot import constants +from bot.cogs.info import information +from bot.utils.checks import InWhitelistCheckFailure +from tests import helpers + +COG_PATH = "bot.cogs.info.information.Information" + + +class InformationCogTests(unittest.TestCase): + """Tests the Information cog.""" + + @classmethod + def setUpClass(cls): + cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = helpers.MockBot() + + self.cog = information.Information(self.bot) + + self.ctx = helpers.MockContext() + self.ctx.author.roles.append(self.moderator_role) + + def test_roles_command_command(self): + """Test if the `role_info` command correctly returns the `moderator_role`.""" + self.ctx.guild.roles.append(self.moderator_role) + + self.cog.roles_info.can_run = unittest.mock.AsyncMock() + self.cog.roles_info.can_run.return_value = True + + coroutine = self.cog.roles_info.callback(self.cog, self.ctx) + + self.assertIsNone(asyncio.run(coroutine)) + self.ctx.send.assert_called_once() + + _, kwargs = self.ctx.send.call_args + embed = kwargs.pop('embed') + + self.assertEqual(embed.title, "Role information (Total 1 role)") + self.assertEqual(embed.colour, discord.Colour.blurple()) + self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") + + def test_role_info_command(self): + """Tests the `role info` command.""" + dummy_role = helpers.MockRole( + name="Dummy", + id=112233445566778899, + colour=discord.Colour.blurple(), + position=10, + members=[self.ctx.author], + permissions=discord.Permissions(0) + ) + + admin_role = helpers.MockRole( + name="Admins", + id=998877665544332211, + colour=discord.Colour.red(), + position=3, + members=[self.ctx.author], + permissions=discord.Permissions(0), + ) + + self.ctx.guild.roles.append([dummy_role, admin_role]) + + self.cog.role_info.can_run = unittest.mock.AsyncMock() + self.cog.role_info.can_run.return_value = True + + coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) + + self.assertIsNone(asyncio.run(coroutine)) + + self.assertEqual(self.ctx.send.call_count, 2) + + (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list + + dummy_embed = dummy_kwargs["embed"] + admin_embed = admin_kwargs["embed"] + + self.assertEqual(dummy_embed.title, "Dummy info") + self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) + + self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) + self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") + self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") + self.assertEqual(dummy_embed.fields[3].value, "1") + self.assertEqual(dummy_embed.fields[4].value, "10") + self.assertEqual(dummy_embed.fields[5].value, "0") + + self.assertEqual(admin_embed.title, "Admins info") + self.assertEqual(admin_embed.colour, discord.Colour.red()) + + @unittest.mock.patch('bot.cogs.info.information.time_since') + def test_server_info_command(self, time_since_patch): + time_since_patch.return_value = '2 days ago' + + self.ctx.guild = helpers.MockGuild( + features=('lemons', 'apples'), + region="The Moon", + roles=[self.moderator_role], + channels=[ + discord.TextChannel( + state={}, + guild=self.ctx.guild, + data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} + ), + discord.CategoryChannel( + state={}, + guild=self.ctx.guild, + data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} + ), + discord.VoiceChannel( + state={}, + guild=self.ctx.guild, + data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} + ) + ], + members=[ + *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), + *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), + *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), + *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), + ], + member_count=1_234, + icon_url='a-lemon.jpg', + ) + + coroutine = self.cog.server_info.callback(self.cog, self.ctx) + self.assertIsNone(asyncio.run(coroutine)) + + time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') + _, kwargs = self.ctx.send.call_args + embed = kwargs.pop('embed') + self.assertEqual(embed.colour, discord.Colour.blurple()) + self.assertEqual( + embed.description, + textwrap.dedent( + f""" + **Server information** + Created: {time_since_patch.return_value} + Voice region: {self.ctx.guild.region} + Features: {', '.join(self.ctx.guild.features)} + + **Channel counts** + Category channels: 1 + Text channels: 1 + Voice channels: 1 + Staff channels: 0 + + **Member counts** + Members: {self.ctx.guild.member_count:,} + Staff members: 0 + Roles: {len(self.ctx.guild.roles)} + + **Member statuses** + {constants.Emojis.status_online} 2 + {constants.Emojis.status_idle} 1 + {constants.Emojis.status_dnd} 4 + {constants.Emojis.status_offline} 3 + """ + ) + ) + self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') + + +class UserInfractionHelperMethodTests(unittest.TestCase): + """Tests for the helper methods of the `!user` command.""" + + def setUp(self): + """Common set-up steps done before for each test.""" + self.bot = helpers.MockBot() + self.bot.api_client.get = unittest.mock.AsyncMock() + self.cog = information.Information(self.bot) + self.member = helpers.MockMember(id=1234) + + def test_user_command_helper_method_get_requests(self): + """The helper methods should form the correct get requests.""" + test_values = ( + { + "helper_method": self.cog.basic_user_infraction_counts, + "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}), + }, + { + "helper_method": self.cog.expanded_user_infraction_counts, + "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}), + }, + { + "helper_method": self.cog.user_nomination_counts, + "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}), + }, + ) + + for test_value in test_values: + helper_method = test_value["helper_method"] + endpoint, params = test_value["expected_args"] + + with self.subTest(method=helper_method, endpoint=endpoint, params=params): + asyncio.run(helper_method(self.member)) + self.bot.api_client.get.assert_called_once_with(endpoint, params=params) + self.bot.api_client.get.reset_mock() + + def _method_subtests(self, method, test_values, default_header): + """Helper method that runs the subtests for the different helper methods.""" + for test_value in test_values: + api_response = test_value["api response"] + expected_lines = test_value["expected_lines"] + + with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): + self.bot.api_client.get.return_value = api_response + + expected_output = "\n".join(default_header + expected_lines) + actual_output = asyncio.run(method(self.member)) + + self.assertEqual(expected_output, actual_output) + + def test_basic_user_infraction_counts_returns_correct_strings(self): + """The method should correctly list both the total and active number of non-hidden infractions.""" + test_values = ( + # No infractions means zero counts + { + "api response": [], + "expected_lines": ["Total: 0", "Active: 0"], + }, + # Simple, single-infraction dictionaries + { + "api response": [{"type": "ban", "active": True}], + "expected_lines": ["Total: 1", "Active: 1"], + }, + { + "api response": [{"type": "ban", "active": False}], + "expected_lines": ["Total: 1", "Active: 0"], + }, + # Multiple infractions with various `active` status + { + "api response": [ + {"type": "ban", "active": True}, + {"type": "kick", "active": False}, + {"type": "ban", "active": True}, + {"type": "ban", "active": False}, + ], + "expected_lines": ["Total: 4", "Active: 2"], + }, + ) + + header = ["**Infractions**"] + + self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) + + def test_expanded_user_infraction_counts_returns_correct_strings(self): + """The method should correctly list the total and active number of all infractions split by infraction type.""" + test_values = ( + { + "api response": [], + "expected_lines": ["This user has never received an infraction."], + }, + # Shows non-hidden inactive infraction as expected + { + "api response": [{"type": "kick", "active": False, "hidden": False}], + "expected_lines": ["Kicks: 1"], + }, + # Shows non-hidden active infraction as expected + { + "api response": [{"type": "mute", "active": True, "hidden": False}], + "expected_lines": ["Mutes: 1 (1 active)"], + }, + # Shows hidden inactive infraction as expected + { + "api response": [{"type": "superstar", "active": False, "hidden": True}], + "expected_lines": ["Superstars: 1"], + }, + # Shows hidden active infraction as expected + { + "api response": [{"type": "ban", "active": True, "hidden": True}], + "expected_lines": ["Bans: 1 (1 active)"], + }, + # Correctly displays tally of multiple infractions of mixed properties in alphabetical order + { + "api response": [ + {"type": "kick", "active": False, "hidden": True}, + {"type": "ban", "active": True, "hidden": True}, + {"type": "superstar", "active": True, "hidden": True}, + {"type": "mute", "active": True, "hidden": True}, + {"type": "ban", "active": False, "hidden": False}, + {"type": "note", "active": False, "hidden": True}, + {"type": "note", "active": False, "hidden": True}, + {"type": "warn", "active": False, "hidden": False}, + {"type": "note", "active": False, "hidden": True}, + ], + "expected_lines": [ + "Bans: 2 (1 active)", + "Kicks: 1", + "Mutes: 1 (1 active)", + "Notes: 3", + "Superstars: 1 (1 active)", + "Warns: 1", + ], + }, + ) + + header = ["**Infractions**"] + + self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) + + def test_user_nomination_counts_returns_correct_strings(self): + """The method should list the number of active and historical nominations for the user.""" + test_values = ( + { + "api response": [], + "expected_lines": ["This user has never been nominated."], + }, + { + "api response": [{'active': True}], + "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], + }, + { + "api response": [{'active': True}, {'active': False}], + "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], + }, + { + "api response": [{'active': False}], + "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."], + }, + { + "api response": [{'active': False}, {'active': False}], + "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."], + }, + + ) + + header = ["**Nominations**"] + + self._method_subtests(self.cog.user_nomination_counts, test_values, header) + + +@unittest.mock.patch("bot.cogs.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) +@unittest.mock.patch("bot.cogs.info.information.constants.MODERATION_CHANNELS", new=[50]) +class UserEmbedTests(unittest.TestCase): + """Tests for the creation of the `!user` embed.""" + + def setUp(self): + """Common set-up steps done before for each test.""" + self.bot = helpers.MockBot() + self.bot.api_client.get = unittest.mock.AsyncMock() + self.cog = information.Information(self.bot) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): + """The embed should use the string representation of the user if they don't have a nick.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + user = helpers.MockMember() + user.nick = None + user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.title, "Mr. Hemlock") + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_nick_in_title_if_available(self): + """The embed should use the nick if it's available.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + user = helpers.MockMember() + user.nick = "Cat lover" + user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_ignores_everyone_role(self): + """Created `!user` embeds should not contain mention of the @everyone-role.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + admins_role = helpers.MockRole(name='Admins') + admins_role.colour = 100 + + # A `MockMember` has the @Everyone role by default; we add the Admins to that. + user = helpers.MockMember(roles=[admins_role], top_role=admins_role) + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertIn("&Admins", embed.description) + self.assertNotIn("&Everyone", embed.description) + + @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) + def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): + """The embed should contain expanded infractions and nomination info in mod channels.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + infraction_counts.return_value = "expanded infractions info" + nomination_counts.return_value = "nomination info" + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + infraction_counts.assert_called_once_with(user) + nomination_counts.assert_called_once_with(user) + + self.assertEqual( + textwrap.dedent(f""" + **User Information** + Created: {"1 year ago"} + Profile: {user.mention} + ID: {user.id} + + **Member Information** + Joined: {"1 year ago"} + Roles: &Moderators + + expanded infractions info + + nomination info + """).strip(), + embed.description + ) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): + """The embed should contain only basic infraction data outside of mod channels.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + infraction_counts.return_value = "basic infractions info" + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + infraction_counts.assert_called_once_with(user) + + self.assertEqual( + textwrap.dedent(f""" + **User Information** + Created: {"1 year ago"} + Profile: {user.mention} + ID: {user.id} + + **Member Information** + Joined: {"1 year ago"} + Roles: &Moderators + + basic infractions info + """).strip(), + embed.description + ) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): + """The embed should be created with the colour of the top role, if a top role is available.""" + ctx = helpers.MockContext() + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): + """The embed should be created with a blurple colour if the user has no assigned roles.""" + ctx = helpers.MockContext() + + user = helpers.MockMember(id=217) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.colour, discord.Colour.blurple()) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): + """The embed thumbnail should be set to the user's avatar in `png` format.""" + ctx = helpers.MockContext() + + user = helpers.MockMember(id=217) + user.avatar_url_as.return_value = "avatar url" + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + user.avatar_url_as.assert_called_once_with(static_format="png") + self.assertEqual(embed.thumbnail.url, "avatar url") + + +@unittest.mock.patch("bot.cogs.info.information.constants") +class UserCommandTests(unittest.TestCase): + """Tests for the `!user` command.""" + + def setUp(self): + """Set up steps executed before each test is run.""" + self.bot = helpers.MockBot() + self.cog = information.Information(self.bot) + + self.moderator_role = helpers.MockRole(name="Moderators", id=2, position=10) + self.flautist_role = helpers.MockRole(name="Flautists", id=3, position=2) + self.bassist_role = helpers.MockRole(name="Bassists", id=4, position=3) + + self.author = helpers.MockMember(id=1, name="syntaxaire") + self.moderator = helpers.MockMember(id=2, name="riffautae", roles=[self.moderator_role]) + self.target = helpers.MockMember(id=3, name="__fluzz__") + + def test_regular_member_cannot_target_another_member(self, constants): + """A regular user should not be able to use `!user` targeting another user.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + + ctx = helpers.MockContext(author=self.author) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + + ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") + + def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): + """A regular user should not be able to use this command outside of bot-commands.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) + + msg = "Sorry, but you may only use this command within <#50>." + with self.assertRaises(InWhitelistCheckFailure, msg=msg): + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") + def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): + """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + create_embed.assert_called_once_with(ctx, self.author) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") + def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): + """A user should target itself with `!user` when a `user` argument was not provided.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) + + create_embed.assert_called_once_with(ctx, self.author) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") + def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): + """Staff members should be able to bypass the bot-commands channel restriction.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + create_embed.assert_called_once_with(ctx, self.moderator) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") + def test_moderators_can_target_another_member(self, create_embed, constants): + """A moderator should be able to use `!user` targeting another user.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + constants.STAFF_ROLES = [self.moderator_role.id] + + ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + + create_embed.assert_called_once_with(ctx, self.target) + ctx.send.assert_called_once() diff --git a/tests/bot/cogs/moderation/infraction/__init__.py b/tests/bot/cogs/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/moderation/infraction/test_infractions.py b/tests/bot/cogs/moderation/infraction/test_infractions.py new file mode 100644 index 000000000..a79042557 --- /dev/null +++ b/tests/bot/cogs/moderation/infraction/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.cogs.moderation.infraction.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): + """Tests for ban and kick command reason truncation.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Infractions(self.bot) + self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) + self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) + self.guild = MockGuild(id=4567) + self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + + @patch("bot.cogs.moderation.infraction.utils.get_active_infraction") + @patch("bot.cogs.moderation.infraction.utils.post_infraction") + async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): + """Should truncate reason for `ctx.guild.ban`.""" + get_active_mock.return_value = None + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.bot.get_cog.return_value = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.ctx.guild.ban = Mock() + + await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) + self.ctx.guild.ban.assert_called_once_with( + self.target, + reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), + delete_message_days=0 + ) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value + ) + + @patch("bot.cogs.moderation.infraction.utils.post_infraction") + async def test_apply_kick_reason_truncation(self, post_infraction_mock): + """Should truncate reason for `Member.kick`.""" + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.target.kick = Mock() + + await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) + self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value + ) diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 435a1cd51..5e4d90251 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, call, patch import aiohttp import discord -from bot.cogs.moderation import Incidents, incidents +from bot.cogs.moderation import incidents from bot.constants import Colours from tests.helpers import ( MockAsyncWebhook, @@ -290,7 +290,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): Note that this will not schedule `crawl_incidents` in the background, as everything is being mocked. The `crawl_task` attribute will end up being None. """ - self.cog_instance = Incidents(MockBot()) + self.cog_instance = incidents.Incidents(MockBot()) @patch("asyncio.sleep", AsyncMock()) # Prevent the coro from sleeping to speed up the test diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py deleted file mode 100644 index df38090fb..000000000 --- a/tests/bot/cogs/moderation/test_infractions.py +++ /dev/null @@ -1,55 +0,0 @@ -import textwrap -import unittest -from unittest.mock import AsyncMock, Mock, patch - -from bot.cogs.moderation.infraction.infractions import Infractions -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole - - -class TruncationTests(unittest.IsolatedAsyncioTestCase): - """Tests for ban and kick command reason truncation.""" - - def setUp(self): - self.bot = MockBot() - self.cog = Infractions(self.bot) - self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) - self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) - self.guild = MockGuild(id=4567) - self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) - - @patch("bot.cogs.moderation.utils.get_active_infraction") - @patch("bot.cogs.moderation.utils.post_infraction") - async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): - """Should truncate reason for `ctx.guild.ban`.""" - get_active_mock.return_value = None - post_infraction_mock.return_value = {"foo": "bar"} - - self.cog.apply_infraction = AsyncMock() - self.bot.get_cog.return_value = AsyncMock() - self.cog.mod_log.ignore = Mock() - self.ctx.guild.ban = Mock() - - await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) - self.ctx.guild.ban.assert_called_once_with( - self.target, - reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), - delete_message_days=0 - ) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value - ) - - @patch("bot.cogs.moderation.utils.post_infraction") - async def test_apply_kick_reason_truncation(self, post_infraction_mock): - """Should truncate reason for `Member.kick`.""" - post_infraction_mock.return_value = {"foo": "bar"} - - self.cog.apply_infraction = AsyncMock() - self.cog.mod_log.ignore = Mock() - self.target.kick = Mock() - - await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) - self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value - ) diff --git a/tests/bot/cogs/moderation/test_slowmode.py b/tests/bot/cogs/moderation/test_slowmode.py new file mode 100644 index 000000000..f442814c8 --- /dev/null +++ b/tests/bot/cogs/moderation/test_slowmode.py @@ -0,0 +1,111 @@ +import unittest +from unittest import mock + +from dateutil.relativedelta import relativedelta + +from bot.cogs.moderation.slowmode import Slowmode +from bot.constants import Emojis +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SlowmodeTests(unittest.IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.bot = MockBot() + self.cog = Slowmode(self.bot) + self.ctx = MockContext() + + async def test_get_slowmode_no_channel(self) -> None: + """Get slowmode without a given channel.""" + self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) + + await self.cog.get_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") + + async def test_get_slowmode_with_channel(self) -> None: + """Get slowmode with a given channel.""" + text_channel = MockTextChannel(name='python-language', slowmode_delay=2) + + await self.cog.get_slowmode(self.cog, self.ctx, text_channel) + self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') + + async def test_set_slowmode_no_channel(self) -> None: + """Set slowmode without a given channel.""" + test_cases = ( + ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), + ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), + ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') + ) + + for channel_name, seconds, edited, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + edited=edited, + result_msg=result_msg + ): + self.ctx.channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) + + if edited: + self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + self.ctx.channel.edit.assert_not_called() + + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + + async def test_set_slowmode_with_channel(self) -> None: + """Set slowmode with a given channel.""" + test_cases = ( + ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), + ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), + ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') + ) + + for channel_name, seconds, edited, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + edited=edited, + result_msg=result_msg + ): + text_channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) + + if edited: + text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + text_channel.edit.assert_not_called() + + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + + async def test_reset_slowmode_no_channel(self) -> None: + """Reset slowmode without a given channel.""" + self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) + + await self.cog.reset_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' + ) + + async def test_reset_slowmode_with_channel(self) -> None: + """Reset slowmode with a given channel.""" + text_channel = MockTextChannel(name='meta', slowmode_delay=1) + + await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' + ) + + @mock.patch("bot.cogs.moderation.slowmode.with_role_check") + @mock.patch("bot.cogs.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) + def test_cog_check(self, role_check): + """Role check is called with `MODERATION_ROLES`""" + self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/sync/__init__.py b/tests/bot/cogs/sync/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py deleted file mode 100644 index 84d036405..000000000 --- a/tests/bot/cogs/sync/test_base.py +++ /dev/null @@ -1,404 +0,0 @@ -import asyncio -import unittest -from unittest import mock - -import discord - -from bot import constants -from bot.api import ResponseCodeError -from bot.cogs.backend.sync import Syncer, _Diff -from tests import helpers - - -class TestSyncer(Syncer): - """Syncer subclass with mocks for abstract methods for testing purposes.""" - - name = "test" - _get_diff = mock.AsyncMock() - _sync = mock.AsyncMock() - - -class SyncerBaseTests(unittest.TestCase): - """Tests for the syncer base class.""" - - def setUp(self): - self.bot = helpers.MockBot() - - def test_instantiation_fails_without_abstract_methods(self): - """The class must have abstract methods implemented.""" - with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): - Syncer(self.bot) - - -class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): - """Tests for sending the sync confirmation prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - - def mock_get_channel(self): - """Fixture to return a mock channel and message for when `get_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - mock_channel.send.return_value = mock_message - self.bot.get_channel.return_value = mock_channel - - return mock_channel, mock_message - - def mock_fetch_channel(self): - """Fixture to return a mock channel and message for when `fetch_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - self.bot.get_channel.return_value = None - mock_channel.send.return_value = mock_message - self.bot.fetch_channel.return_value = mock_channel - - return mock_channel, mock_message - - async def test_send_prompt_edits_and_returns_message(self): - """The given message should be edited to display the prompt and then should be returned.""" - msg = helpers.MockMessage() - ret_val = await self.syncer._send_prompt(msg) - - msg.edit.assert_called_once() - self.assertIn("content", msg.edit.call_args[1]) - self.assertEqual(ret_val, msg) - - async def test_send_prompt_gets_dev_core_channel(self): - """The dev-core channel should be retrieved if an extant message isn't given.""" - subtests = ( - (self.bot.get_channel, self.mock_get_channel), - (self.bot.fetch_channel, self.mock_fetch_channel), - ) - - for method, mock_ in subtests: - with self.subTest(method=method, msg=mock_.__name__): - mock_() - await self.syncer._send_prompt() - - method.assert_called_once_with(constants.Channels.dev_core) - - async def test_send_prompt_returns_none_if_channel_fetch_fails(self): - """None should be returned if there's an HTTPException when fetching the channel.""" - self.bot.get_channel.return_value = None - self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") - - ret_val = await self.syncer._send_prompt() - - self.assertIsNone(ret_val) - - async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): - """A new message mentioning core devs should be sent and returned if message isn't given.""" - for mock_ in (self.mock_get_channel, self.mock_fetch_channel): - with self.subTest(msg=mock_.__name__): - mock_channel, mock_message = mock_() - ret_val = await self.syncer._send_prompt() - - mock_channel.send.assert_called_once() - self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) - self.assertEqual(ret_val, mock_message) - - async def test_send_prompt_adds_reactions(self): - """The message should have reactions for confirmation added.""" - extant_message = helpers.MockMessage() - subtests = ( - (extant_message, lambda: (None, extant_message)), - (None, self.mock_get_channel), - (None, self.mock_fetch_channel), - ) - - for message_arg, mock_ in subtests: - subtest_msg = "Extant message" if mock_.__name__ == "" else mock_.__name__ - - with self.subTest(msg=subtest_msg): - _, mock_message = mock_() - await self.syncer._send_prompt(message_arg) - - calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] - mock_message.add_reaction.assert_has_calls(calls) - - -class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): - """Tests for waiting for a sync confirmation reaction on the prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) - - @staticmethod - def get_message_reaction(emoji): - """Fixture to return a mock message an reaction from the given `emoji`.""" - message = helpers.MockMessage() - reaction = helpers.MockReaction(emoji=emoji, message=message) - - return message, reaction - - def test_reaction_check_for_valid_emoji_and_authors(self): - """Should return True if authors are identical or are a bot and a core dev, respectively.""" - user_subtests = ( - ( - helpers.MockMember(id=77), - helpers.MockMember(id=77), - "identical users", - ), - ( - helpers.MockMember(id=77, bot=True), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "bot author and core-dev reactor", - ), - ) - - for emoji in self.syncer._REACTION_EMOJIS: - for author, user, msg in user_subtests: - with self.subTest(author=author, user=user, emoji=emoji, msg=msg): - message, reaction = self.get_message_reaction(emoji) - ret_val = self.syncer._reaction_check(author, message, reaction, user) - - self.assertTrue(ret_val) - - def test_reaction_check_for_invalid_reactions(self): - """Should return False for invalid reaction events.""" - valid_emoji = self.syncer._REACTION_EMOJIS[0] - subtests = ( - ( - helpers.MockMember(id=77), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "users are not identical", - ), - ( - helpers.MockMember(id=77, bot=True), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43), - "reactor lacks the core-dev role", - ), - ( - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - "reactor is a bot", - ), - ( - helpers.MockMember(id=77), - helpers.MockMessage(id=95), - helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), - helpers.MockMember(id=77), - "messages are not identical", - ), - ( - helpers.MockMember(id=77), - *self.get_message_reaction("InVaLiD"), - helpers.MockMember(id=77), - "emoji is invalid", - ), - ) - - for *args, msg in subtests: - kwargs = dict(zip(("author", "message", "reaction", "user"), args)) - with self.subTest(**kwargs, msg=msg): - ret_val = self.syncer._reaction_check(*args) - self.assertFalse(ret_val) - - async def test_wait_for_confirmation(self): - """The message should always be edited and only return True if the emoji is a check mark.""" - subtests = ( - (constants.Emojis.check_mark, True, None), - ("InVaLiD", False, None), - (None, False, asyncio.TimeoutError), - ) - - for emoji, ret_val, side_effect in subtests: - for bot in (True, False): - with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): - # Set up mocks - message = helpers.MockMessage() - member = helpers.MockMember(bot=bot) - - self.bot.wait_for.reset_mock() - self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) - self.bot.wait_for.side_effect = side_effect - - # Call the function - actual_return = await self.syncer._wait_for_confirmation(member, message) - - # Perform assertions - self.bot.wait_for.assert_called_once() - self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) - - message.edit.assert_called_once() - kwargs = message.edit.call_args[1] - self.assertIn("content", kwargs) - - # Core devs should only be mentioned if the author is a bot. - if bot: - self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - else: - self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - - self.assertIs(actual_return, ret_val) - - -class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for main function orchestrating the sync.""" - - def setUp(self): - self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) - self.syncer = TestSyncer(self.bot) - - async def test_sync_respects_confirmation_result(self): - """The sync should abort if confirmation fails and continue if confirmed.""" - mock_message = helpers.MockMessage() - subtests = ( - (True, mock_message), - (False, None), - ) - - for confirmed, message in subtests: - with self.subTest(confirmed=confirmed): - self.syncer._sync.reset_mock() - self.syncer._get_diff.reset_mock() - - diff = _Diff({1, 2, 3}, {4, 5}, None) - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(confirmed, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - - if confirmed: - self.syncer._sync.assert_called_once_with(diff) - else: - self.syncer._sync.assert_not_called() - - async def test_sync_diff_size(self): - """The diff size should be correctly calculated.""" - subtests = ( - (6, _Diff({1, 2}, {3, 4}, {5, 6})), - (5, _Diff({1, 2, 3}, None, {4, 5})), - (0, _Diff(None, None, None)), - (0, _Diff(set(), set(), set())), - ) - - for size, diff in subtests: - with self.subTest(size=size, diff=diff): - self.syncer._get_diff.reset_mock() - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) - - async def test_sync_message_edited(self): - """The message should be edited if one was sent, even if the sync has an API error.""" - subtests = ( - (None, None, False), - (helpers.MockMessage(), None, True), - (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), - ) - - for message, side_effect, should_edit in subtests: - with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): - self.syncer._sync.side_effect = side_effect - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(True, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - if should_edit: - message.edit.assert_called_once() - self.assertIn("content", message.edit.call_args[1]) - - async def test_sync_confirmation_context_redirect(self): - """If ctx is given, a new message should be sent and author should be ctx's author.""" - mock_member = helpers.MockMember() - subtests = ( - (None, self.bot.user, None), - (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), - ) - - for ctx, author, message in subtests: - with self.subTest(ctx=ctx, author=author, message=message): - if ctx is not None: - ctx.send.return_value = message - - # Make sure `_get_diff` returns a MagicMock, not an AsyncMock - self.syncer._get_diff.return_value = mock.MagicMock() - - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild, ctx) - - if ctx is not None: - ctx.send.assert_called_once() - - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_small_diff(self): - """Should always return True and the given message if the diff size is too small.""" - author = helpers.MockMember() - expected_message = helpers.MockMessage() - - for size in (3, 2): # pragma: no cover - with self.subTest(size=size): - self.syncer._send_prompt = mock.AsyncMock() - self.syncer._wait_for_confirmation = mock.AsyncMock() - - coro = self.syncer._get_confirmation_result(size, author, expected_message) - result, actual_message = await coro - - self.assertTrue(result) - self.assertEqual(actual_message, expected_message) - self.syncer._send_prompt.assert_not_called() - self.syncer._wait_for_confirmation.assert_not_called() - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_large_diff(self): - """Should return True if confirmed and False if _send_prompt fails or aborted.""" - author = helpers.MockMember() - mock_message = helpers.MockMessage() - - subtests = ( - (True, mock_message, True, "confirmed"), - (False, None, False, "_send_prompt failed"), - (False, mock_message, False, "aborted"), - ) - - for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover - with self.subTest(msg=msg): - self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) - self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) - - coro = self.syncer._get_confirmation_result(4, author) - actual_result, actual_message = await coro - - self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None - self.assertIs(actual_result, expected_result) - self.assertEqual(actual_message, expected_message) - - if expected_message: - self.syncer._wait_for_confirmation.assert_called_once_with( - author, expected_message - ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py deleted file mode 100644 index ea7d090ba..000000000 --- a/tests/bot/cogs/sync/test_cog.py +++ /dev/null @@ -1,415 +0,0 @@ -import unittest -from unittest import mock - -import discord - -from bot import constants -from bot.api import ResponseCodeError -from bot.cogs.backend import sync -from bot.cogs.backend.sync import Syncer -from tests import helpers -from tests.base import CommandTestCase - - -class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): - """Tests for the sync extension.""" - - @staticmethod - def test_extension_setup(): - """The Sync cog should be added.""" - bot = helpers.MockBot() - sync.setup(bot) - bot.add_cog.assert_called_once() - - -class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): - """Base class for Sync cog tests. Sets up patches for syncers.""" - - def setUp(self): - self.bot = helpers.MockBot() - - self.role_syncer_patcher = mock.patch( - "bot.cogs.sync.syncers.RoleSyncer", - autospec=Syncer, - spec_set=True - ) - self.user_syncer_patcher = mock.patch( - "bot.cogs.sync.syncers.UserSyncer", - autospec=Syncer, - spec_set=True - ) - self.RoleSyncer = self.role_syncer_patcher.start() - self.UserSyncer = self.user_syncer_patcher.start() - - self.cog = sync.Sync(self.bot) - - def tearDown(self): - self.role_syncer_patcher.stop() - self.user_syncer_patcher.stop() - - @staticmethod - def response_error(status: int) -> ResponseCodeError: - """Fixture to return a ResponseCodeError with the given status code.""" - response = mock.MagicMock() - response.status = status - - return ResponseCodeError(response) - - -class SyncCogTests(SyncCogTestCase): - """Tests for the Sync cog.""" - - @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) - def test_sync_cog_init(self, sync_guild): - """Should instantiate syncers and run a sync for the guild.""" - # Reset because a Sync cog was already instantiated in setUp. - self.RoleSyncer.reset_mock() - self.UserSyncer.reset_mock() - self.bot.loop.create_task = mock.MagicMock() - - mock_sync_guild_coro = mock.MagicMock() - sync_guild.return_value = mock_sync_guild_coro - - sync.Sync(self.bot) - - self.RoleSyncer.assert_called_once_with(self.bot) - self.UserSyncer.assert_called_once_with(self.bot) - sync_guild.assert_called_once_with() - self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) - - async def test_sync_cog_sync_guild(self): - """Roles and users should be synced only if a guild is successfully retrieved.""" - for guild in (helpers.MockGuild(), None): - with self.subTest(guild=guild): - self.bot.reset_mock() - self.cog.role_syncer.reset_mock() - self.cog.user_syncer.reset_mock() - - self.bot.get_guild = mock.MagicMock(return_value=guild) - - await self.cog.sync_guild() - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.get_guild.assert_called_once_with(constants.Guild.id) - - if guild is None: - self.cog.role_syncer.sync.assert_not_called() - self.cog.user_syncer.sync.assert_not_called() - else: - self.cog.role_syncer.sync.assert_called_once_with(guild) - self.cog.user_syncer.sync.assert_called_once_with(guild) - - async def patch_user_helper(self, side_effect: BaseException) -> None: - """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" - self.bot.api_client.patch.reset_mock(side_effect=True) - self.bot.api_client.patch.side_effect = side_effect - - user_id, updated_information = 5, {"key": 123} - await self.cog.patch_user(user_id, updated_information) - - self.bot.api_client.patch.assert_called_once_with( - f"bot/users/{user_id}", - json=updated_information, - ) - - async def test_sync_cog_patch_user(self): - """A PATCH request should be sent and 404 errors ignored.""" - for side_effect in (None, self.response_error(404)): - with self.subTest(side_effect=side_effect): - await self.patch_user_helper(side_effect) - - async def test_sync_cog_patch_user_non_404(self): - """A PATCH request should be sent and the error raised if it's not a 404.""" - with self.assertRaises(ResponseCodeError): - await self.patch_user_helper(self.response_error(500)) - - -class SyncCogListenerTests(SyncCogTestCase): - """Tests for the listeners of the Sync cog.""" - - def setUp(self): - super().setUp() - self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - - self.guild_id_patcher = mock.patch("bot.cogs.sync.cog.constants.Guild.id", 5) - self.guild_id = self.guild_id_patcher.start() - - self.guild = helpers.MockGuild(id=self.guild_id) - self.other_guild = helpers.MockGuild(id=0) - - def tearDown(self): - self.guild_id_patcher.stop() - - async def test_sync_cog_on_guild_role_create(self): - """A POST request should be sent with the new role's data.""" - self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) - - role_data = { - "colour": 49, - "id": 777, - "name": "rolename", - "permissions": 8, - "position": 23, - } - role = helpers.MockRole(**role_data, guild=self.guild) - await self.cog.on_guild_role_create(role) - - self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) - - async def test_sync_cog_on_guild_role_create_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_create(role) - self.bot.api_client.post.assert_not_awaited() - - async def test_sync_cog_on_guild_role_delete(self): - """A DELETE request should be sent.""" - self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) - - role = helpers.MockRole(id=99, guild=self.guild) - await self.cog.on_guild_role_delete(role) - - self.bot.api_client.delete.assert_called_once_with("bot/roles/99") - - async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_delete(role) - self.bot.api_client.delete.assert_not_awaited() - - async def test_sync_cog_on_guild_role_update(self): - """A PUT request should be sent if the colour, name, permissions, or position changes.""" - self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) - - role_data = { - "colour": 49, - "id": 777, - "name": "rolename", - "permissions": 8, - "position": 23, - } - subtests = ( - (True, ("colour", "name", "permissions", "position")), - (False, ("hoist", "mentionable")), - ) - - for should_put, attributes in subtests: - for attribute in attributes: - with self.subTest(should_put=should_put, changed_attribute=attribute): - self.bot.api_client.put.reset_mock() - - after_role_data = role_data.copy() - after_role_data[attribute] = 876 - - before_role = helpers.MockRole(**role_data, guild=self.guild) - after_role = helpers.MockRole(**after_role_data, guild=self.guild) - - await self.cog.on_guild_role_update(before_role, after_role) - - if should_put: - self.bot.api_client.put.assert_called_once_with( - f"bot/roles/{after_role.id}", - json=after_role_data - ) - else: - self.bot.api_client.put.assert_not_called() - - async def test_sync_cog_on_guild_role_update_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_update(role, role) - self.bot.api_client.put.assert_not_awaited() - - async def test_sync_cog_on_member_remove(self): - """Member should be patched to set in_guild as False.""" - self.assertTrue(self.cog.on_member_remove.__cog_listener__) - - member = helpers.MockMember(guild=self.guild) - await self.cog.on_member_remove(member) - - self.cog.patch_user.assert_called_once_with( - member.id, - json={"in_guild": False} - ) - - async def test_sync_cog_on_member_remove_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_remove(member) - self.cog.patch_user.assert_not_awaited() - - async def test_sync_cog_on_member_update_roles(self): - """Members should be patched if their roles have changed.""" - self.assertTrue(self.cog.on_member_update.__cog_listener__) - - # Roles are intentionally unsorted. - before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] - before_member = helpers.MockMember(roles=before_roles, guild=self.guild) - after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) - - await self.cog.on_member_update(before_member, after_member) - - data = {"roles": sorted(role.id for role in after_member.roles)} - self.cog.patch_user.assert_called_once_with(after_member.id, json=data) - - async def test_sync_cog_on_member_update_other(self): - """Members should not be patched if other attributes have changed.""" - self.assertTrue(self.cog.on_member_update.__cog_listener__) - - subtests = ( - ("activities", discord.Game("Pong"), discord.Game("Frogger")), - ("nick", "old nick", "new nick"), - ("status", discord.Status.online, discord.Status.offline), - ) - - for attribute, old_value, new_value in subtests: - with self.subTest(attribute=attribute): - self.cog.patch_user.reset_mock() - - before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) - after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) - - await self.cog.on_member_update(before_member, after_member) - - self.cog.patch_user.assert_not_called() - - async def test_sync_cog_on_member_update_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_update(member, member) - self.cog.patch_user.assert_not_awaited() - - async def test_sync_cog_on_user_update(self): - """A user should be patched only if the name, discriminator, or avatar changes.""" - self.assertTrue(self.cog.on_user_update.__cog_listener__) - - before_data = { - "name": "old name", - "discriminator": "1234", - "bot": False, - } - - subtests = ( - (True, "name", "name", "new name", "new name"), - (True, "discriminator", "discriminator", "8765", 8765), - (False, "bot", "bot", True, True), - ) - - for should_patch, attribute, api_field, value, api_value in subtests: - with self.subTest(attribute=attribute): - self.cog.patch_user.reset_mock() - - after_data = before_data.copy() - after_data[attribute] = value - before_user = helpers.MockUser(**before_data) - after_user = helpers.MockUser(**after_data) - - await self.cog.on_user_update(before_user, after_user) - - if should_patch: - self.cog.patch_user.assert_called_once() - - # Don't care if *all* keys are present; only the changed one is required - call_args = self.cog.patch_user.call_args - self.assertEqual(call_args.args[0], after_user.id) - self.assertIn("json", call_args.kwargs) - - self.assertIn("ignore_404", call_args.kwargs) - self.assertTrue(call_args.kwargs["ignore_404"]) - - json = call_args.kwargs["json"] - self.assertIn(api_field, json) - self.assertEqual(json[api_field], api_value) - else: - self.cog.patch_user.assert_not_called() - - async def on_member_join_helper(self, side_effect: Exception) -> dict: - """ - Helper to set `side_effect` for on_member_join and assert a PUT request was sent. - - The request data for the mock member is returned. All exceptions will be re-raised. - """ - member = helpers.MockMember( - discriminator="1234", - roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], - guild=self.guild, - ) - - data = { - "discriminator": int(member.discriminator), - "id": member.id, - "in_guild": True, - "name": member.name, - "roles": sorted(role.id for role in member.roles) - } - - self.bot.api_client.put.reset_mock(side_effect=True) - self.bot.api_client.put.side_effect = side_effect - - try: - await self.cog.on_member_join(member) - except Exception: - raise - finally: - self.bot.api_client.put.assert_called_once_with( - f"bot/users/{member.id}", - json=data - ) - - return data - - async def test_sync_cog_on_member_join(self): - """Should PUT user's data or POST it if the user doesn't exist.""" - for side_effect in (None, self.response_error(404)): - with self.subTest(side_effect=side_effect): - self.bot.api_client.post.reset_mock() - data = await self.on_member_join_helper(side_effect) - - if side_effect: - self.bot.api_client.post.assert_called_once_with("bot/users", json=data) - else: - self.bot.api_client.post.assert_not_called() - - async def test_sync_cog_on_member_join_non_404(self): - """ResponseCodeError should be re-raised if status code isn't a 404.""" - with self.assertRaises(ResponseCodeError): - await self.on_member_join_helper(self.response_error(500)) - - self.bot.api_client.post.assert_not_called() - - async def test_sync_cog_on_member_join_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_join(member) - self.bot.api_client.post.assert_not_awaited() - self.bot.api_client.put.assert_not_awaited() - - -class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): - """Tests for the commands in the Sync cog.""" - - async def test_sync_roles_command(self): - """sync() should be called on the RoleSyncer.""" - ctx = helpers.MockContext() - await self.cog.sync_roles_command.callback(self.cog, ctx) - - self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) - - async def test_sync_users_command(self): - """sync() should be called on the UserSyncer.""" - ctx = helpers.MockContext() - await self.cog.sync_users_command.callback(self.cog, ctx) - - self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) - - async def test_commands_require_admin(self): - """The sync commands should only run if the author has the administrator permission.""" - cmds = ( - self.cog.sync_group, - self.cog.sync_roles_command, - self.cog.sync_users_command, - ) - - for cmd in cmds: - with self.subTest(cmd=cmd): - await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py deleted file mode 100644 index 888c49ca8..000000000 --- a/tests/bot/cogs/sync/test_roles.py +++ /dev/null @@ -1,157 +0,0 @@ -import unittest -from unittest import mock - -import discord - -from bot.cogs.backend.sync import RoleSyncer, _Diff, _Role -from tests import helpers - - -def fake_role(**kwargs): - """Fixture to return a dictionary representing a role with default values set.""" - kwargs.setdefault("id", 9) - kwargs.setdefault("name", "fake role") - kwargs.setdefault("colour", 7) - kwargs.setdefault("permissions", 0) - kwargs.setdefault("position", 55) - - return kwargs - - -class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): - """Tests for determining differences between roles in the DB and roles in the Guild cache.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) - - @staticmethod - def get_guild(*roles): - """Fixture to return a guild object with the given roles.""" - guild = helpers.MockGuild() - guild.roles = [] - - for role in roles: - mock_role = helpers.MockRole(**role) - mock_role.colour = discord.Colour(role["colour"]) - mock_role.permissions = discord.Permissions(role["permissions"]) - guild.roles.append(mock_role) - - return guild - - async def test_empty_diff_for_identical_roles(self): - """No differences should be found if the roles in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_role()] - guild = self.get_guild(fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_updated_roles(self): - """Only updated roles should be added to the 'updated' set of the diff.""" - updated_role = fake_role(id=41, name="new") - - self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] - guild = self.get_guild(updated_role, fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_Role(**updated_role)}, set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_roles(self): - """Only new roles should be added to the 'created' set of the diff.""" - new_role = fake_role(id=41, name="new") - - self.bot.api_client.get.return_value = [fake_role()] - guild = self.get_guild(fake_role(), new_role) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_Role(**new_role)}, set(), set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_deleted_roles(self): - """Only deleted roles should be added to the 'deleted' set of the diff.""" - deleted_role = fake_role(id=61, name="deleted") - - self.bot.api_client.get.return_value = [fake_role(), deleted_role] - guild = self.get_guild(fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), {_Role(**deleted_role)}) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_updated_and_deleted_roles(self): - """When roles are added, updated, and removed, all of them are returned properly.""" - new = fake_role(id=41, name="new") - updated = fake_role(id=71, name="updated") - deleted = fake_role(id=61, name="deleted") - - self.bot.api_client.get.return_value = [ - fake_role(), - fake_role(id=71, name="updated name"), - deleted, - ] - guild = self.get_guild(fake_role(), new, updated) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) - - self.assertEqual(actual_diff, expected_diff) - - -class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for the API requests that sync roles.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) - - async def test_sync_created_roles(self): - """Only POST requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(role_tuples, set(), set()) - await self.syncer._sync(diff) - - calls = [mock.call("bot/roles", json=role) for role in roles] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(roles)) - - self.bot.api_client.put.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_updated_roles(self): - """Only PUT requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(set(), role_tuples, set()) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(roles)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_deleted_roles(self): - """Only DELETE requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(set(), set(), role_tuples) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] - self.bot.api_client.delete.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py deleted file mode 100644 index 71f4b134c..000000000 --- a/tests/bot/cogs/sync/test_users.py +++ /dev/null @@ -1,158 +0,0 @@ -import unittest -from unittest import mock - -from bot.cogs.backend.sync import UserSyncer, _Diff, _User -from tests import helpers - - -def fake_user(**kwargs): - """Fixture to return a dictionary representing a user with default values set.""" - kwargs.setdefault("id", 43) - kwargs.setdefault("name", "bob the test man") - kwargs.setdefault("discriminator", 1337) - kwargs.setdefault("roles", (666,)) - kwargs.setdefault("in_guild", True) - - return kwargs - - -class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): - """Tests for determining differences between users in the DB and users in the Guild cache.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) - - @staticmethod - def get_guild(*members): - """Fixture to return a guild object with the given members.""" - guild = helpers.MockGuild() - guild.members = [] - - for member in members: - member = member.copy() - del member["in_guild"] - - mock_member = helpers.MockMember(**member) - mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] - - guild.members.append(mock_member) - - return guild - - async def test_empty_diff_for_no_users(self): - """When no users are given, an empty diff should be returned.""" - guild = self.get_guild() - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_empty_diff_for_identical_users(self): - """No differences should be found if the users in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_user()] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_updated_users(self): - """Only updated users should be added to the 'updated' set of the diff.""" - updated_user = fake_user(id=99, name="new") - - self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] - guild = self.get_guild(updated_user, fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**updated_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_users(self): - """Only new users should be added to the 'created' set of the diff.""" - new_user = fake_user(id=99, name="new") - - self.bot.api_client.get.return_value = [fake_user()] - guild = self.get_guild(fake_user(), new_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_sets_in_guild_false_for_leaving_users(self): - """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**leaving_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_updated_and_leaving_users(self): - """When users are added, updated, and removed, all of them are returned properly.""" - new_user = fake_user(id=99, name="new") - updated_user = fake_user(id=55, name="updated") - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] - guild = self.get_guild(fake_user(), new_user, updated_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_empty_diff_for_db_users_not_in_guild(self): - """When the DB knows a user the guild doesn't, no difference is found.""" - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - -class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for the API requests that sync users.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) - - async def test_sync_created_users(self): - """Only POST requests should be made with the correct payload.""" - users = [fake_user(id=111), fake_user(id=222)] - - user_tuples = {_User(**user) for user in users} - diff = _Diff(user_tuples, set(), None) - await self.syncer._sync(diff) - - calls = [mock.call("bot/users", json=user) for user in users] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(users)) - - self.bot.api_client.put.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_updated_users(self): - """Only PUT requests should be made with the correct payload.""" - users = [fake_user(id=111), fake_user(id=222)] - - user_tuples = {_User(**user) for user in users} - diff = _Diff(set(), user_tuples, None) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(users)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py deleted file mode 100644 index b00211f47..000000000 --- a/tests/bot/cogs/test_antimalware.py +++ /dev/null @@ -1,165 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, Mock - -from discord import NotFound - -from bot.cogs.filters import antimalware -from bot.constants import Channels, STAFF_ROLES -from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole - - -class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): - """Test the AntiMalware cog.""" - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = MockBot() - self.bot.filter_list_cache = { - "FILE_FORMAT.True": { - ".first": {}, - ".second": {}, - ".third": {}, - } - } - self.cog = antimalware.AntiMalware(self.bot) - self.message = MockMessage() - self.whitelist = [".first", ".second", ".third"] - - async def test_message_with_allowed_attachment(self): - """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_message_without_attachment(self): - """Messages without attachments should result in no action.""" - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_direct_message_with_attachment(self): - """Direct messages should have no action taken.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.guild = None - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_message_with_illegal_extension_gets_deleted(self): - """A message containing an illegal extension should send an embed.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_called_once() - - async def test_message_send_by_staff(self): - """A message send by a member of staff should be ignored.""" - staff_role = MockRole(id=STAFF_ROLES[0]) - self.message.author.roles.append(staff_role) - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_python_file_redirect_embed_description(self): - """A message containing a .py file should result in an embed redirecting the user to our paste site""" - attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - - self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - - async def test_txt_file_redirect_embed_description(self): - """A message containing a .txt file should result in the correct embed.""" - attachment = MockAttachment(filename="python.txt") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.TXT_EMBED_DESCRIPTION = Mock() - antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - cmd_channel = self.bot.get_channel(Channels.bot_commands) - - self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) - antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) - - async def test_other_disallowed_extension_embed_description(self): - """Test the description for a non .py/.txt disallowed extension.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - meta_channel = self.bot.get_channel(Channels.meta) - - self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( - joined_whitelist=", ".join(self.whitelist), - blocked_extensions_str=".disallowed", - meta_channel_mention=meta_channel.mention - ) - - async def test_removing_deleted_message_logs(self): - """Removing an already deleted message logs the correct message""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - self.message.delete.assert_called_once() - - async def test_message_with_illegal_attachment_logs(self): - """Deleting a message with an illegal attachment should result in a log.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - - async def test_get_disallowed_extensions(self): - """The return value should include all non-whitelisted extensions.""" - test_values = ( - ([], []), - (self.whitelist, []), - ([".first"], []), - ([".first", ".disallowed"], [".disallowed"]), - ([".disallowed"], [".disallowed"]), - ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), - ) - - for extensions, expected_disallowed_extensions in test_values: - with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog._get_disallowed_extensions(self.message) - self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) - - -class AntiMalwareSetupTests(unittest.TestCase): - """Tests setup of the `AntiMalware` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = MockBot() - antimalware.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/cogs/test_antispam.py deleted file mode 100644 index 8a3d8d02e..000000000 --- a/tests/bot/cogs/test_antispam.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from bot.cogs.filters import antispam - - -class AntispamConfigurationValidationTests(unittest.TestCase): - """Tests validation of the antispam cog configuration.""" - - def test_default_antispam_config_is_valid(self): - """The default antispam configuration is valid.""" - validation_errors = antispam.validate_config() - self.assertEqual(validation_errors, {}) - - def test_unknown_rule_returns_error(self): - """Configuring an unknown rule returns an error.""" - self.assertEqual( - antispam.validate_config({'invalid-rule': {}}), - {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} - ) - - def test_missing_keys_returns_error(self): - """Not configuring required keys returns an error.""" - keys = (('interval', 'max'), ('max', 'interval')) - for configured_key, unconfigured_key in keys: - with self.subTest( - configured_key=configured_key, - unconfigured_key=unconfigured_key - ): - config = {'burst': {configured_key: 10}} - error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" - - self.assertEqual( - antispam.validate_config(config), - {'burst': error} - ) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py deleted file mode 100644 index 305a2bad9..000000000 --- a/tests/bot/cogs/test_information.py +++ /dev/null @@ -1,584 +0,0 @@ -import asyncio -import textwrap -import unittest -import unittest.mock - -import discord - -from bot import constants -from bot.cogs.info import information -from bot.utils.checks import InWhitelistCheckFailure -from tests import helpers - -COG_PATH = "bot.cogs.information.Information" - - -class InformationCogTests(unittest.TestCase): - """Tests the Information cog.""" - - @classmethod - def setUpClass(cls): - cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = helpers.MockBot() - - self.cog = information.Information(self.bot) - - self.ctx = helpers.MockContext() - self.ctx.author.roles.append(self.moderator_role) - - def test_roles_command_command(self): - """Test if the `role_info` command correctly returns the `moderator_role`.""" - self.ctx.guild.roles.append(self.moderator_role) - - self.cog.roles_info.can_run = unittest.mock.AsyncMock() - self.cog.roles_info.can_run.return_value = True - - coroutine = self.cog.roles_info.callback(self.cog, self.ctx) - - self.assertIsNone(asyncio.run(coroutine)) - self.ctx.send.assert_called_once() - - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - - self.assertEqual(embed.title, "Role information (Total 1 role)") - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") - - def test_role_info_command(self): - """Tests the `role info` command.""" - dummy_role = helpers.MockRole( - name="Dummy", - id=112233445566778899, - colour=discord.Colour.blurple(), - position=10, - members=[self.ctx.author], - permissions=discord.Permissions(0) - ) - - admin_role = helpers.MockRole( - name="Admins", - id=998877665544332211, - colour=discord.Colour.red(), - position=3, - members=[self.ctx.author], - permissions=discord.Permissions(0), - ) - - self.ctx.guild.roles.append([dummy_role, admin_role]) - - self.cog.role_info.can_run = unittest.mock.AsyncMock() - self.cog.role_info.can_run.return_value = True - - coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) - - self.assertIsNone(asyncio.run(coroutine)) - - self.assertEqual(self.ctx.send.call_count, 2) - - (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list - - dummy_embed = dummy_kwargs["embed"] - admin_embed = admin_kwargs["embed"] - - self.assertEqual(dummy_embed.title, "Dummy info") - self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) - - self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) - self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") - self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") - self.assertEqual(dummy_embed.fields[3].value, "1") - self.assertEqual(dummy_embed.fields[4].value, "10") - self.assertEqual(dummy_embed.fields[5].value, "0") - - self.assertEqual(admin_embed.title, "Admins info") - self.assertEqual(admin_embed.colour, discord.Colour.red()) - - @unittest.mock.patch('bot.cogs.information.time_since') - def test_server_info_command(self, time_since_patch): - time_since_patch.return_value = '2 days ago' - - self.ctx.guild = helpers.MockGuild( - features=('lemons', 'apples'), - region="The Moon", - roles=[self.moderator_role], - channels=[ - discord.TextChannel( - state={}, - guild=self.ctx.guild, - data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} - ), - discord.CategoryChannel( - state={}, - guild=self.ctx.guild, - data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} - ), - discord.VoiceChannel( - state={}, - guild=self.ctx.guild, - data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} - ) - ], - members=[ - *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), - *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), - *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), - *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), - ], - member_count=1_234, - icon_url='a-lemon.jpg', - ) - - coroutine = self.cog.server_info.callback(self.cog, self.ctx) - self.assertIsNone(asyncio.run(coroutine)) - - time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual( - embed.description, - textwrap.dedent( - f""" - **Server information** - Created: {time_since_patch.return_value} - Voice region: {self.ctx.guild.region} - Features: {', '.join(self.ctx.guild.features)} - - **Channel counts** - Category channels: 1 - Text channels: 1 - Voice channels: 1 - Staff channels: 0 - - **Member counts** - Members: {self.ctx.guild.member_count:,} - Staff members: 0 - Roles: {len(self.ctx.guild.roles)} - - **Member statuses** - {constants.Emojis.status_online} 2 - {constants.Emojis.status_idle} 1 - {constants.Emojis.status_dnd} 4 - {constants.Emojis.status_offline} 3 - """ - ) - ) - self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') - - -class UserInfractionHelperMethodTests(unittest.TestCase): - """Tests for the helper methods of the `!user` command.""" - - def setUp(self): - """Common set-up steps done before for each test.""" - self.bot = helpers.MockBot() - self.bot.api_client.get = unittest.mock.AsyncMock() - self.cog = information.Information(self.bot) - self.member = helpers.MockMember(id=1234) - - def test_user_command_helper_method_get_requests(self): - """The helper methods should form the correct get requests.""" - test_values = ( - { - "helper_method": self.cog.basic_user_infraction_counts, - "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}), - }, - { - "helper_method": self.cog.expanded_user_infraction_counts, - "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}), - }, - { - "helper_method": self.cog.user_nomination_counts, - "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}), - }, - ) - - for test_value in test_values: - helper_method = test_value["helper_method"] - endpoint, params = test_value["expected_args"] - - with self.subTest(method=helper_method, endpoint=endpoint, params=params): - asyncio.run(helper_method(self.member)) - self.bot.api_client.get.assert_called_once_with(endpoint, params=params) - self.bot.api_client.get.reset_mock() - - def _method_subtests(self, method, test_values, default_header): - """Helper method that runs the subtests for the different helper methods.""" - for test_value in test_values: - api_response = test_value["api response"] - expected_lines = test_value["expected_lines"] - - with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): - self.bot.api_client.get.return_value = api_response - - expected_output = "\n".join(default_header + expected_lines) - actual_output = asyncio.run(method(self.member)) - - self.assertEqual(expected_output, actual_output) - - def test_basic_user_infraction_counts_returns_correct_strings(self): - """The method should correctly list both the total and active number of non-hidden infractions.""" - test_values = ( - # No infractions means zero counts - { - "api response": [], - "expected_lines": ["Total: 0", "Active: 0"], - }, - # Simple, single-infraction dictionaries - { - "api response": [{"type": "ban", "active": True}], - "expected_lines": ["Total: 1", "Active: 1"], - }, - { - "api response": [{"type": "ban", "active": False}], - "expected_lines": ["Total: 1", "Active: 0"], - }, - # Multiple infractions with various `active` status - { - "api response": [ - {"type": "ban", "active": True}, - {"type": "kick", "active": False}, - {"type": "ban", "active": True}, - {"type": "ban", "active": False}, - ], - "expected_lines": ["Total: 4", "Active: 2"], - }, - ) - - header = ["**Infractions**"] - - self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) - - def test_expanded_user_infraction_counts_returns_correct_strings(self): - """The method should correctly list the total and active number of all infractions split by infraction type.""" - test_values = ( - { - "api response": [], - "expected_lines": ["This user has never received an infraction."], - }, - # Shows non-hidden inactive infraction as expected - { - "api response": [{"type": "kick", "active": False, "hidden": False}], - "expected_lines": ["Kicks: 1"], - }, - # Shows non-hidden active infraction as expected - { - "api response": [{"type": "mute", "active": True, "hidden": False}], - "expected_lines": ["Mutes: 1 (1 active)"], - }, - # Shows hidden inactive infraction as expected - { - "api response": [{"type": "superstar", "active": False, "hidden": True}], - "expected_lines": ["Superstars: 1"], - }, - # Shows hidden active infraction as expected - { - "api response": [{"type": "ban", "active": True, "hidden": True}], - "expected_lines": ["Bans: 1 (1 active)"], - }, - # Correctly displays tally of multiple infractions of mixed properties in alphabetical order - { - "api response": [ - {"type": "kick", "active": False, "hidden": True}, - {"type": "ban", "active": True, "hidden": True}, - {"type": "superstar", "active": True, "hidden": True}, - {"type": "mute", "active": True, "hidden": True}, - {"type": "ban", "active": False, "hidden": False}, - {"type": "note", "active": False, "hidden": True}, - {"type": "note", "active": False, "hidden": True}, - {"type": "warn", "active": False, "hidden": False}, - {"type": "note", "active": False, "hidden": True}, - ], - "expected_lines": [ - "Bans: 2 (1 active)", - "Kicks: 1", - "Mutes: 1 (1 active)", - "Notes: 3", - "Superstars: 1 (1 active)", - "Warns: 1", - ], - }, - ) - - header = ["**Infractions**"] - - self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) - - def test_user_nomination_counts_returns_correct_strings(self): - """The method should list the number of active and historical nominations for the user.""" - test_values = ( - { - "api response": [], - "expected_lines": ["This user has never been nominated."], - }, - { - "api response": [{'active': True}], - "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], - }, - { - "api response": [{'active': True}, {'active': False}], - "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], - }, - { - "api response": [{'active': False}], - "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."], - }, - { - "api response": [{'active': False}, {'active': False}], - "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."], - }, - - ) - - header = ["**Nominations**"] - - self._method_subtests(self.cog.user_nomination_counts, test_values, header) - - -@unittest.mock.patch("bot.cogs.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) -@unittest.mock.patch("bot.cogs.information.constants.MODERATION_CHANNELS", new=[50]) -class UserEmbedTests(unittest.TestCase): - """Tests for the creation of the `!user` embed.""" - - def setUp(self): - """Common set-up steps done before for each test.""" - self.bot = helpers.MockBot() - self.bot.api_client.get = unittest.mock.AsyncMock() - self.cog = information.Information(self.bot) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): - """The embed should use the string representation of the user if they don't have a nick.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - user = helpers.MockMember() - user.nick = None - user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.title, "Mr. Hemlock") - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_nick_in_title_if_available(self): - """The embed should use the nick if it's available.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - user = helpers.MockMember() - user.nick = "Cat lover" - user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_ignores_everyone_role(self): - """Created `!user` embeds should not contain mention of the @everyone-role.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - admins_role = helpers.MockRole(name='Admins') - admins_role.colour = 100 - - # A `MockMember` has the @Everyone role by default; we add the Admins to that. - user = helpers.MockMember(roles=[admins_role], top_role=admins_role) - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertIn("&Admins", embed.description) - self.assertNotIn("&Everyone", embed.description) - - @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): - """The embed should contain expanded infractions and nomination info in mod channels.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - infraction_counts.return_value = "expanded infractions info" - nomination_counts.return_value = "nomination info" - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - infraction_counts.assert_called_once_with(user) - nomination_counts.assert_called_once_with(user) - - self.assertEqual( - textwrap.dedent(f""" - **User Information** - Created: {"1 year ago"} - Profile: {user.mention} - ID: {user.id} - - **Member Information** - Joined: {"1 year ago"} - Roles: &Moderators - - expanded infractions info - - nomination info - """).strip(), - embed.description - ) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): - """The embed should contain only basic infraction data outside of mod channels.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - infraction_counts.return_value = "basic infractions info" - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - infraction_counts.assert_called_once_with(user) - - self.assertEqual( - textwrap.dedent(f""" - **User Information** - Created: {"1 year ago"} - Profile: {user.mention} - ID: {user.id} - - **Member Information** - Joined: {"1 year ago"} - Roles: &Moderators - - basic infractions info - """).strip(), - embed.description - ) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): - """The embed should be created with the colour of the top role, if a top role is available.""" - ctx = helpers.MockContext() - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): - """The embed should be created with a blurple colour if the user has no assigned roles.""" - ctx = helpers.MockContext() - - user = helpers.MockMember(id=217) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.colour, discord.Colour.blurple()) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): - """The embed thumbnail should be set to the user's avatar in `png` format.""" - ctx = helpers.MockContext() - - user = helpers.MockMember(id=217) - user.avatar_url_as.return_value = "avatar url" - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - user.avatar_url_as.assert_called_once_with(static_format="png") - self.assertEqual(embed.thumbnail.url, "avatar url") - - -@unittest.mock.patch("bot.cogs.information.constants") -class UserCommandTests(unittest.TestCase): - """Tests for the `!user` command.""" - - def setUp(self): - """Set up steps executed before each test is run.""" - self.bot = helpers.MockBot() - self.cog = information.Information(self.bot) - - self.moderator_role = helpers.MockRole(name="Moderators", id=2, position=10) - self.flautist_role = helpers.MockRole(name="Flautists", id=3, position=2) - self.bassist_role = helpers.MockRole(name="Bassists", id=4, position=3) - - self.author = helpers.MockMember(id=1, name="syntaxaire") - self.moderator = helpers.MockMember(id=2, name="riffautae", roles=[self.moderator_role]) - self.target = helpers.MockMember(id=3, name="__fluzz__") - - def test_regular_member_cannot_target_another_member(self, constants): - """A regular user should not be able to use `!user` targeting another user.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - - ctx = helpers.MockContext(author=self.author) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) - - ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") - - def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): - """A regular user should not be able to use this command outside of bot-commands.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) - - msg = "Sorry, but you may only use this command within <#50>." - with self.assertRaises(InWhitelistCheckFailure, msg=msg): - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) - def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): - """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - create_embed.assert_called_once_with(ctx, self.author) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) - def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): - """A user should target itself with `!user` when a `user` argument was not provided.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) - - create_embed.assert_called_once_with(ctx, self.author) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) - def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): - """Staff members should be able to bypass the bot-commands channel restriction.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - create_embed.assert_called_once_with(ctx, self.moderator) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) - def test_moderators_can_target_another_member(self, create_embed, constants): - """A moderator should be able to use `!user` targeting another user.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - constants.STAFF_ROLES = [self.moderator_role.id] - - ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) - - create_embed.assert_called_once_with(ctx, self.target) - ctx.send.assert_called_once() diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/cogs/test_jams.py deleted file mode 100644 index b4ad8535f..000000000 --- a/tests/bot/cogs/test_jams.py +++ /dev/null @@ -1,173 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, MagicMock, create_autospec - -from discord import CategoryChannel - -from bot.cogs import jams -from bot.constants import Roles -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel - - -def get_mock_category(channel_count: int, name: str) -> CategoryChannel: - """Return a mocked code jam category.""" - category = create_autospec(CategoryChannel, spec_set=True, instance=True) - category.name = name - category.channels = [MockTextChannel() for _ in range(channel_count)] - - return category - - -class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): - """Tests for `createteam` command.""" - - def setUp(self): - self.bot = MockBot() - self.admin_role = MockRole(name="Admins", id=Roles.admins) - self.command_user = MockMember([self.admin_role]) - self.guild = MockGuild([self.admin_role]) - self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) - self.cog = jams.CodeJams(self.bot) - - async def test_too_small_amount_of_team_members_passed(self): - """Should `ctx.send` and exit early when too small amount of members.""" - for case in (1, 2): - with self.subTest(amount_of_members=case): - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - self.ctx.reset_mock() - members = (MockMember() for _ in range(case)) - await self.cog.createteam(self.cog, self.ctx, "foo", members) - - self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_not_awaited() - self.cog.add_roles.assert_not_awaited() - - async def test_duplicate_members_provided(self): - """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - member = MockMember() - await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) - - self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_not_awaited() - self.cog.add_roles.assert_not_awaited() - - async def test_result_sending(self): - """Should call `ctx.send` when everything goes right.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - members = [MockMember() for _ in range(5)] - await self.cog.createteam(self.cog, self.ctx, "foo", members) - - self.cog.create_channels.assert_awaited_once() - self.cog.add_roles.assert_awaited_once() - self.ctx.send.assert_awaited_once() - - async def test_category_doesnt_exist(self): - """Should create a new code jam category.""" - subtests = ( - [], - [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], - [get_mock_category(jams.MAX_CHANNELS - 2, "other")], - ) - - for categories in subtests: - self.guild.reset_mock() - self.guild.categories = categories - - with self.subTest(categories=categories): - actual_category = await self.cog.get_category(self.guild) - - self.guild.create_category_channel.assert_awaited_once() - category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] - - self.assertFalse(category_overwrites[self.guild.default_role].read_messages) - self.assertTrue(category_overwrites[self.guild.me].read_messages) - self.assertEqual(self.guild.create_category_channel.return_value, actual_category) - - async def test_category_channel_exist(self): - """Should not try to create category channel.""" - expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) - self.guild.categories = [ - get_mock_category(jams.MAX_CHANNELS - 2, "other"), - expected_category, - get_mock_category(0, jams.CATEGORY_NAME), - ] - - actual_category = await self.cog.get_category(self.guild) - self.assertEqual(expected_category, actual_category) - - async def test_channel_overwrites(self): - """Should have correct permission overwrites for users and roles.""" - leader = MockMember() - members = [leader] + [MockMember() for _ in range(4)] - overwrites = self.cog.get_overwrites(members, self.guild) - - # Leader permission overwrites - self.assertTrue(overwrites[leader].manage_messages) - self.assertTrue(overwrites[leader].read_messages) - self.assertTrue(overwrites[leader].manage_webhooks) - self.assertTrue(overwrites[leader].connect) - - # Other members permission overwrites - for member in members[1:]: - self.assertTrue(overwrites[member].read_messages) - self.assertTrue(overwrites[member].connect) - - # Everyone and verified role overwrite - self.assertFalse(overwrites[self.guild.default_role].read_messages) - self.assertFalse(overwrites[self.guild.default_role].connect) - self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) - self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) - - async def test_team_channels_creation(self): - """Should create new voice and text channel for team.""" - members = [MockMember() for _ in range(5)] - - self.cog.get_overwrites = MagicMock() - self.cog.get_category = AsyncMock() - self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") - actual = await self.cog.create_channels(self.guild, "my-team", members) - - self.assertEqual("foobar-channel", actual) - self.cog.get_overwrites.assert_called_once_with(members, self.guild) - self.cog.get_category.assert_awaited_once_with(self.guild) - - self.guild.create_text_channel.assert_awaited_once_with( - "my-team", - overwrites=self.cog.get_overwrites.return_value, - category=self.cog.get_category.return_value - ) - self.guild.create_voice_channel.assert_awaited_once_with( - "My Team", - overwrites=self.cog.get_overwrites.return_value, - category=self.cog.get_category.return_value - ) - - async def test_jam_roles_adding(self): - """Should add team leader role to leader and jam role to every team member.""" - leader_role = MockRole(name="Team Leader") - jam_role = MockRole(name="Jammer") - self.guild.get_role.side_effect = [leader_role, jam_role] - - leader = MockMember() - members = [leader] + [MockMember() for _ in range(4)] - await self.cog.add_roles(self.guild, members) - - leader.add_roles.assert_any_await(leader_role) - for member in members: - member.add_roles.assert_any_await(jam_role) - - -class CodeJamSetup(unittest.TestCase): - """Test for `setup` function of `CodeJam` cog.""" - - def test_setup(self): - """Should call `bot.add_cog`.""" - bot = MockBot() - jams.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_logging.py b/tests/bot/cogs/test_logging.py deleted file mode 100644 index 8a18fdcd6..000000000 --- a/tests/bot/cogs/test_logging.py +++ /dev/null @@ -1,32 +0,0 @@ -import unittest -from unittest.mock import patch - -from bot import constants -from bot.cogs.logging import Logging -from tests.helpers import MockBot, MockTextChannel - - -class LoggingTests(unittest.IsolatedAsyncioTestCase): - """Test cases for connected login.""" - - def setUp(self): - self.bot = MockBot() - self.cog = Logging(self.bot) - self.dev_log = MockTextChannel(id=1234, name="dev-log") - - @patch("bot.cogs.logging.DEBUG_MODE", False) - async def test_debug_mode_false(self): - """Should send connected message to dev-log.""" - self.bot.get_channel.return_value = self.dev_log - - await self.cog.startup_greeting() - self.bot.wait_until_guild_available.assert_awaited_once_with() - self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) - self.dev_log.send.assert_awaited_once() - - @patch("bot.cogs.logging.DEBUG_MODE", True) - async def test_debug_mode_true(self): - """Should not send anything to dev-log.""" - await self.cog.startup_greeting() - self.bot.wait_until_guild_available.assert_awaited_once_with() - self.bot.get_channel.assert_not_called() diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py deleted file mode 100644 index 82679f69c..000000000 --- a/tests/bot/cogs/test_security.py +++ /dev/null @@ -1,54 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from discord.ext.commands import NoPrivateMessage - -from bot.cogs.filters import security -from tests.helpers import MockBot, MockContext - - -class SecurityCogTests(unittest.TestCase): - """Tests the `Security` cog.""" - - def setUp(self): - """Attach an instance of the cog to the class for tests.""" - self.bot = MockBot() - self.cog = security.Security(self.bot) - self.ctx = MockContext() - - def test_check_additions(self): - """The cog should add its checks after initialization.""" - self.bot.check.assert_any_call(self.cog.check_on_guild) - self.bot.check.assert_any_call(self.cog.check_not_bot) - - def test_check_not_bot_returns_false_for_humans(self): - """The bot check should return `True` when invoked with human authors.""" - self.ctx.author.bot = False - self.assertTrue(self.cog.check_not_bot(self.ctx)) - - def test_check_not_bot_returns_true_for_robots(self): - """The bot check should return `False` when invoked with robotic authors.""" - self.ctx.author.bot = True - self.assertFalse(self.cog.check_not_bot(self.ctx)) - - def test_check_on_guild_raises_when_outside_of_guild(self): - """When invoked outside of a guild, `check_on_guild` should cause an error.""" - self.ctx.guild = None - - with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): - self.cog.check_on_guild(self.ctx) - - def test_check_on_guild_returns_true_inside_of_guild(self): - """When invoked inside of a guild, `check_on_guild` should return `True`.""" - self.ctx.guild = "lemon's lemonade stand" - self.assertTrue(self.cog.check_on_guild(self.ctx)) - - -class SecurityCogLoadTests(unittest.TestCase): - """Tests loading the `Security` cog.""" - - def test_security_cog_load(self): - """Setup of the extension should call add_cog.""" - bot = MagicMock() - security.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/cogs/test_slowmode.py deleted file mode 100644 index f442814c8..000000000 --- a/tests/bot/cogs/test_slowmode.py +++ /dev/null @@ -1,111 +0,0 @@ -import unittest -from unittest import mock - -from dateutil.relativedelta import relativedelta - -from bot.cogs.moderation.slowmode import Slowmode -from bot.constants import Emojis -from tests.helpers import MockBot, MockContext, MockTextChannel - - -class SlowmodeTests(unittest.IsolatedAsyncioTestCase): - - def setUp(self) -> None: - self.bot = MockBot() - self.cog = Slowmode(self.bot) - self.ctx = MockContext() - - async def test_get_slowmode_no_channel(self) -> None: - """Get slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) - - await self.cog.get_slowmode(self.cog, self.ctx, None) - self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") - - async def test_get_slowmode_with_channel(self) -> None: - """Get slowmode with a given channel.""" - text_channel = MockTextChannel(name='python-language', slowmode_delay=2) - - await self.cog.get_slowmode(self.cog, self.ctx, text_channel) - self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') - - async def test_set_slowmode_no_channel(self) -> None: - """Set slowmode without a given channel.""" - test_cases = ( - ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), - ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), - ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') - ) - - for channel_name, seconds, edited, result_msg in test_cases: - with self.subTest( - channel_mention=channel_name, - seconds=seconds, - edited=edited, - result_msg=result_msg - ): - self.ctx.channel = MockTextChannel(name=channel_name) - - await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) - - if edited: - self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) - else: - self.ctx.channel.edit.assert_not_called() - - self.ctx.send.assert_called_once_with(result_msg) - - self.ctx.reset_mock() - - async def test_set_slowmode_with_channel(self) -> None: - """Set slowmode with a given channel.""" - test_cases = ( - ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), - ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), - ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') - ) - - for channel_name, seconds, edited, result_msg in test_cases: - with self.subTest( - channel_mention=channel_name, - seconds=seconds, - edited=edited, - result_msg=result_msg - ): - text_channel = MockTextChannel(name=channel_name) - - await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) - - if edited: - text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) - else: - text_channel.edit.assert_not_called() - - self.ctx.send.assert_called_once_with(result_msg) - - self.ctx.reset_mock() - - async def test_reset_slowmode_no_channel(self) -> None: - """Reset slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) - - await self.cog.reset_slowmode(self.cog, self.ctx, None) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' - ) - - async def test_reset_slowmode_with_channel(self) -> None: - """Reset slowmode with a given channel.""" - text_channel = MockTextChannel(name='meta', slowmode_delay=1) - - await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' - ) - - @mock.patch("bot.cogs.moderation.slowmode.with_role_check") - @mock.patch("bot.cogs.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) - def test_cog_check(self, role_check): - """Role check is called with `MODERATION_ROLES`""" - self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py deleted file mode 100644 index c7bac3ab3..000000000 --- a/tests/bot/cogs/test_snekbox.py +++ /dev/null @@ -1,409 +0,0 @@ -import asyncio -import logging -import unittest -from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch - -from discord.ext import commands - -from bot import constants -from bot.cogs.utils import snekbox -from bot.cogs.utils.snekbox import Snekbox -from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser - - -class SnekboxTests(unittest.IsolatedAsyncioTestCase): - def setUp(self): - """Add mocked bot and cog to the instance.""" - self.bot = MockBot() - self.cog = Snekbox(bot=self.bot) - - async def test_post_eval(self): - """Post the eval code to the URLs.snekbox_eval_api endpoint.""" - resp = MagicMock() - resp.json = AsyncMock(return_value="return") - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual(await self.cog.post_eval("import random"), "return") - self.bot.http_session.post.assert_called_with( - constants.URLs.snekbox_eval_api, - json={"input": "import random"}, - raise_for_status=True - ) - resp.json.assert_awaited_once() - - async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LEN.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) - self.assertEqual(result, "too long to upload") - - async def test_upload_output(self): - """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - key = "MarkDiamond" - resp = MagicMock() - resp.json = AsyncMock(return_value={"key": key}) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual( - await self.cog.upload_output("My awesome output"), - constants.URLs.paste_service.format(key=key) - ) - self.bot.http_session.post.assert_called_with( - constants.URLs.paste_service.format(key="documents"), - data="My awesome output", - raise_for_status=True - ) - - async def test_upload_output_gracefully_fallback_if_exception_during_request(self): - """Output upload gracefully fallback if the upload fail.""" - resp = MagicMock() - resp.json = AsyncMock(side_effect=Exception) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - log = logging.getLogger("bot.cogs.snekbox") - with self.assertLogs(logger=log, level='ERROR'): - await self.cog.upload_output('My awesome output!') - - async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): - """Output upload gracefully fallback if there is no key entry in the response body.""" - self.assertEqual((await self.cog.upload_output('My awesome output!')), None) - - def test_prepare_input(self): - cases = ( - ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), - ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), - ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), - ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), - ) - for case, expected, testname in cases: - with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) - - def test_get_results_message(self): - """Return error and message according to the eval result.""" - cases = ( - ('ERROR', None, ('Your eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), - ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) - - @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) - def test_get_results_message_invalid_signal(self, mock_signals: Mock): - self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127', '') - ) - - @patch('bot.cogs.snekbox.Signals') - def test_get_results_message_valid_signal(self, mock_signals: Mock): - mock_signals.return_value.name = 'SIGTEST' - self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127 (SIGTEST)', '') - ) - - def test_get_status_emoji(self): - """Return emoji according to the eval result.""" - cases = ( - (' ', -1, ':warning:'), - ('Hello world!', 0, ':white_check_mark:'), - ('Invalid beard size', -1, ':x:') - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) - - async def test_format_output(self): - """Test output formatting.""" - self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') - - too_many_lines = ( - '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' - '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' - ) - too_long_too_many_lines = ( - "\n".join( - f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) - )[:1000] + "\n... (truncated - too long, too many lines)" - ) - - cases = ( - ('', ('[No output]', None), 'No output'), - ('My awesome output', ('My awesome output', None), 'One line output'), - ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), - (' CategoryChannel: + """Return a mocked code jam category.""" + category = create_autospec(CategoryChannel, spec_set=True, instance=True) + category.name = name + category.channels = [MockTextChannel() for _ in range(channel_count)] + + return category + + +class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): + """Tests for `createteam` command.""" + + def setUp(self): + self.bot = MockBot() + self.admin_role = MockRole(name="Admins", id=Roles.admins) + self.command_user = MockMember([self.admin_role]) + self.guild = MockGuild([self.admin_role]) + self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) + self.cog = jams.CodeJams(self.bot) + + async def test_too_small_amount_of_team_members_passed(self): + """Should `ctx.send` and exit early when too small amount of members.""" + for case in (1, 2): + with self.subTest(amount_of_members=case): + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + self.ctx.reset_mock() + members = (MockMember() for _ in range(case)) + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + self.ctx.send.assert_awaited_once() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() + + async def test_duplicate_members_provided(self): + """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + member = MockMember() + await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) + + self.ctx.send.assert_awaited_once() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() + + async def test_result_sending(self): + """Should call `ctx.send` when everything goes right.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + members = [MockMember() for _ in range(5)] + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + self.cog.create_channels.assert_awaited_once() + self.cog.add_roles.assert_awaited_once() + self.ctx.send.assert_awaited_once() + + async def test_category_doesnt_exist(self): + """Should create a new code jam category.""" + subtests = ( + [], + [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], + [get_mock_category(jams.MAX_CHANNELS - 2, "other")], + ) + + for categories in subtests: + self.guild.reset_mock() + self.guild.categories = categories + + with self.subTest(categories=categories): + actual_category = await self.cog.get_category(self.guild) + + self.guild.create_category_channel.assert_awaited_once() + category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] + + self.assertFalse(category_overwrites[self.guild.default_role].read_messages) + self.assertTrue(category_overwrites[self.guild.me].read_messages) + self.assertEqual(self.guild.create_category_channel.return_value, actual_category) + + async def test_category_channel_exist(self): + """Should not try to create category channel.""" + expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) + self.guild.categories = [ + get_mock_category(jams.MAX_CHANNELS - 2, "other"), + expected_category, + get_mock_category(0, jams.CATEGORY_NAME), + ] + + actual_category = await self.cog.get_category(self.guild) + self.assertEqual(expected_category, actual_category) + + async def test_channel_overwrites(self): + """Should have correct permission overwrites for users and roles.""" + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + overwrites = self.cog.get_overwrites(members, self.guild) + + # Leader permission overwrites + self.assertTrue(overwrites[leader].manage_messages) + self.assertTrue(overwrites[leader].read_messages) + self.assertTrue(overwrites[leader].manage_webhooks) + self.assertTrue(overwrites[leader].connect) + + # Other members permission overwrites + for member in members[1:]: + self.assertTrue(overwrites[member].read_messages) + self.assertTrue(overwrites[member].connect) + + # Everyone and verified role overwrite + self.assertFalse(overwrites[self.guild.default_role].read_messages) + self.assertFalse(overwrites[self.guild.default_role].connect) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) + + async def test_team_channels_creation(self): + """Should create new voice and text channel for team.""" + members = [MockMember() for _ in range(5)] + + self.cog.get_overwrites = MagicMock() + self.cog.get_category = AsyncMock() + self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") + actual = await self.cog.create_channels(self.guild, "my-team", members) + + self.assertEqual("foobar-channel", actual) + self.cog.get_overwrites.assert_called_once_with(members, self.guild) + self.cog.get_category.assert_awaited_once_with(self.guild) + + self.guild.create_text_channel.assert_awaited_once_with( + "my-team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + self.guild.create_voice_channel.assert_awaited_once_with( + "My Team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + + async def test_jam_roles_adding(self): + """Should add team leader role to leader and jam role to every team member.""" + leader_role = MockRole(name="Team Leader") + jam_role = MockRole(name="Jammer") + self.guild.get_role.side_effect = [leader_role, jam_role] + + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + await self.cog.add_roles(self.guild, members) + + leader.add_roles.assert_any_await(leader_role) + for member in members: + member.add_roles.assert_any_await(jam_role) + + +class CodeJamSetup(unittest.TestCase): + """Test for `setup` function of `CodeJam` cog.""" + + def test_setup(self): + """Should call `bot.add_cog`.""" + bot = MockBot() + jams.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/utils/test_snekbox.py b/tests/bot/cogs/utils/test_snekbox.py new file mode 100644 index 000000000..3e447f319 --- /dev/null +++ b/tests/bot/cogs/utils/test_snekbox.py @@ -0,0 +1,409 @@ +import asyncio +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch + +from discord.ext import commands + +from bot import constants +from bot.cogs.utils import snekbox +from bot.cogs.utils.snekbox import Snekbox +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser + + +class SnekboxTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Add mocked bot and cog to the instance.""" + self.bot = MockBot() + self.cog = Snekbox(bot=self.bot) + + async def test_post_eval(self): + """Post the eval code to the URLs.snekbox_eval_api endpoint.""" + resp = MagicMock() + resp.json = AsyncMock(return_value="return") + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + self.assertEqual(await self.cog.post_eval("import random"), "return") + self.bot.http_session.post.assert_called_with( + constants.URLs.snekbox_eval_api, + json={"input": "import random"}, + raise_for_status=True + ) + resp.json.assert_awaited_once() + + async def test_upload_output_reject_too_long(self): + """Reject output longer than MAX_PASTE_LEN.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + self.assertEqual(result, "too long to upload") + + async def test_upload_output(self): + """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" + key = "MarkDiamond" + resp = MagicMock() + resp.json = AsyncMock(return_value={"key": key}) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + self.assertEqual( + await self.cog.upload_output("My awesome output"), + constants.URLs.paste_service.format(key=key) + ) + self.bot.http_session.post.assert_called_with( + constants.URLs.paste_service.format(key="documents"), + data="My awesome output", + raise_for_status=True + ) + + async def test_upload_output_gracefully_fallback_if_exception_during_request(self): + """Output upload gracefully fallback if the upload fail.""" + resp = MagicMock() + resp.json = AsyncMock(side_effect=Exception) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + log = logging.getLogger("bot.cogs.utils.snekbox") + with self.assertLogs(logger=log, level='ERROR'): + await self.cog.upload_output('My awesome output!') + + async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): + """Output upload gracefully fallback if there is no key entry in the response body.""" + self.assertEqual((await self.cog.upload_output('My awesome output!')), None) + + def test_prepare_input(self): + cases = ( + ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), + ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), + ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), + ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), + ) + for case, expected, testname in cases: + with self.subTest(msg=f'Extract code from {testname}.'): + self.assertEqual(self.cog.prepare_input(case), expected) + + def test_get_results_message(self): + """Return error and message according to the eval result.""" + cases = ( + ('ERROR', None, ('Your eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), + ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + @patch('bot.cogs.utils.snekbox.Signals', side_effect=ValueError) + def test_get_results_message_invalid_signal(self, mock_signals: Mock): + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127', '') + ) + + @patch('bot.cogs.utils.snekbox.Signals') + def test_get_results_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = 'SIGTEST' + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127 (SIGTEST)', '') + ) + + def test_get_status_emoji(self): + """Return emoji according to the eval result.""" + cases = ( + (' ', -1, ':warning:'), + ('Hello world!', 0, ':white_check_mark:'), + ('Invalid beard size', -1, ':x:') + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + async def test_format_output(self): + """Test output formatting.""" + self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + + too_many_lines = ( + '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' + '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' + ) + too_long_too_many_lines = ( + "\n".join( + f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) + )[:1000] + "\n... (truncated - too long, too many lines)" + ) + + cases = ( + ('', ('[No output]', None), 'No output'), + ('My awesome output', ('My awesome output', None), 'One line output'), + ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), + (' Date: Wed, 12 Aug 2020 22:31:08 -0700 Subject: Prefix names of non-extension modules with _ This naming scheme will make them easy to distinguish from extensions. --- bot/cogs/backend/sync/__init__.py | 2 +- bot/cogs/backend/sync/_cog.py | 180 ++++++++ bot/cogs/backend/sync/_syncers.py | 347 +++++++++++++++ bot/cogs/backend/sync/cog.py | 180 -------- bot/cogs/backend/sync/syncers.py | 347 --------------- bot/cogs/moderation/__init__.py | 19 - bot/cogs/moderation/incidents.py | 5 + bot/cogs/moderation/infraction/_scheduler.py | 463 +++++++++++++++++++++ bot/cogs/moderation/infraction/_utils.py | 201 +++++++++ bot/cogs/moderation/infraction/infractions.py | 31 +- bot/cogs/moderation/infraction/management.py | 11 +- bot/cogs/moderation/infraction/scheduler.py | 463 --------------------- bot/cogs/moderation/infraction/superstarify.py | 29 +- bot/cogs/moderation/infraction/utils.py | 201 --------- bot/cogs/moderation/modlog.py | 5 + bot/cogs/moderation/silence.py | 5 + bot/cogs/moderation/watchchannels/__init__.py | 9 - bot/cogs/moderation/watchchannels/_watchchannel.py | 348 ++++++++++++++++ bot/cogs/moderation/watchchannels/bigbrother.py | 9 +- bot/cogs/moderation/watchchannels/talentpool.py | 7 +- bot/cogs/moderation/watchchannels/watchchannel.py | 348 ---------------- tests/bot/cogs/backend/sync/test_base.py | 2 +- tests/bot/cogs/backend/sync/test_cog.py | 15 +- tests/bot/cogs/backend/sync/test_roles.py | 2 +- tests/bot/cogs/backend/sync/test_users.py | 2 +- .../cogs/moderation/infraction/test_infractions.py | 6 +- 26 files changed, 1625 insertions(+), 1612 deletions(-) create mode 100644 bot/cogs/backend/sync/_cog.py create mode 100644 bot/cogs/backend/sync/_syncers.py delete mode 100644 bot/cogs/backend/sync/cog.py delete mode 100644 bot/cogs/backend/sync/syncers.py create mode 100644 bot/cogs/moderation/infraction/_scheduler.py create mode 100644 bot/cogs/moderation/infraction/_utils.py delete mode 100644 bot/cogs/moderation/infraction/scheduler.py delete mode 100644 bot/cogs/moderation/infraction/utils.py create mode 100644 bot/cogs/moderation/watchchannels/_watchchannel.py delete mode 100644 bot/cogs/moderation/watchchannels/watchchannel.py (limited to 'tests') diff --git a/bot/cogs/backend/sync/__init__.py b/bot/cogs/backend/sync/__init__.py index fe7df4e9b..fb640a1cf 100644 --- a/bot/cogs/backend/sync/__init__.py +++ b/bot/cogs/backend/sync/__init__.py @@ -1,5 +1,5 @@ from bot.bot import Bot -from .cog import Sync +from ._cog import Sync def setup(bot: Bot) -> None: diff --git a/bot/cogs/backend/sync/_cog.py b/bot/cogs/backend/sync/_cog.py new file mode 100644 index 000000000..b6068f328 --- /dev/null +++ b/bot/cogs/backend/sync/_cog.py @@ -0,0 +1,180 @@ +import logging +from typing import Any, Dict + +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from . import _syncers + +log = logging.getLogger(__name__) + + +class Sync(Cog): + """Captures relevant events and sends them to the site.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.role_syncer = _syncers.RoleSyncer(self.bot) + self.user_syncer = _syncers.UserSyncer(self.bot) + + self.bot.loop.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Syncs the roles/users of the guild with the database.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: + """Send a PATCH request to partially update a user in the database.""" + try: + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) + except ResponseCodeError as e: + if e.response.status != 404: + raise + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") + + @Cog.listener() + async def on_guild_role_create(self, role: Role) -> None: + """Adds newly create role to the database table over the API.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.post( + 'bot/roles', + json={ + 'colour': role.colour.value, + 'id': role.id, + 'name': role.name, + 'permissions': role.permissions.value, + 'position': role.position, + } + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: Role) -> None: + """Deletes role from the database when it's deleted from the guild.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + @Cog.listener() + async def on_guild_role_update(self, before: Role, after: Role) -> None: + """Syncs role with the database if any of the stored attributes were updated.""" + if after.guild.id != constants.Guild.id: + return + + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: + await self.bot.api_client.put( + f'bot/roles/{after.id}', + json={ + 'colour': after.colour.value, + 'id': after.id, + 'name': after.name, + 'permissions': after.permissions.value, + 'position': after.position, + } + ) + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """ + Adds a new user or updates existing user to the database when a member joins the guild. + + If the joining member is a user that is already known to the database (i.e., a user that + previously left), it will update the user's information. If the user is not yet known by + the database, the user is added. + """ + if member.guild.id != constants.Guild.id: + return + + packed = { + 'discriminator': int(member.discriminator), + 'id': member.id, + 'in_guild': True, + 'name': member.name, + 'roles': sorted(role.id for role in member.roles) + } + + got_error = False + + try: + # First try an update of the user to set the `in_guild` field and other + # fields that may have changed since the last time we've seen them. + await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) + + except ResponseCodeError as e: + # If we didn't get 404, something else broke - propagate it up. + if e.response.status != 404: + raise + + got_error = True # yikes + + if got_error: + # If we got `404`, the user is new. Create them. + await self.bot.api_client.post('bot/users', json=packed) + + @Cog.listener() + async def on_member_remove(self, member: Member) -> None: + """Set the in_guild field to False when a member leaves the guild.""" + if member.guild.id != constants.Guild.id: + return + + await self.patch_user(member.id, json={"in_guild": False}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Update the roles of the member in the database if a change is detected.""" + if after.guild.id != constants.Guild.id: + return + + if before.roles != after.roles: + updated_information = {"roles": sorted(role.id for role in after.roles)} + await self.patch_user(after.id, json=updated_information) + + @Cog.listener() + async def on_user_update(self, before: User, after: User) -> None: + """Update the user information in the database if a relevant change is detected.""" + attrs = ("name", "discriminator") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): + updated_information = { + "name": after.name, + "discriminator": int(after.discriminator), + } + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) + + @commands.group(name='sync') + @commands.has_permissions(administrator=True) + async def sync_group(self, ctx: Context) -> None: + """Run synchronizations between the bot and site manually.""" + + @sync_group.command(name='roles') + @commands.has_permissions(administrator=True) + async def sync_roles_command(self, ctx: Context) -> None: + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) + + @sync_group.command(name='users') + @commands.has_permissions(administrator=True) + async def sync_users_command(self, ctx: Context) -> None: + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/backend/sync/_syncers.py b/bot/cogs/backend/sync/_syncers.py new file mode 100644 index 000000000..f7ba811bc --- /dev/null +++ b/bot/cogs/backend/sync/_syncers.py @@ -0,0 +1,347 @@ +import abc +import asyncio +import logging +import typing as t +from collections import namedtuple +from functools import partial + +import discord +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot + +log = logging.getLogger(__name__) + +# These objects are declared as namedtuples because tuples are hashable, +# something that we make use of when diffing site roles against guild roles. +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) + + +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' + ) + + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + allowed_roles = [discord.Object(constants.Roles.core_developers)] + message = await channel.send( + f"{self._CORE_DEV_MENTION}{msg_content}", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS + ) + + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except asyncio.TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.info(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/backend/sync/cog.py b/bot/cogs/backend/sync/cog.py deleted file mode 100644 index 274845a50..000000000 --- a/bot/cogs/backend/sync/cog.py +++ /dev/null @@ -1,180 +0,0 @@ -import logging -from typing import Any, Dict - -from discord import Member, Role, User -from discord.ext import commands -from discord.ext.commands import Cog, Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from . import syncers - -log = logging.getLogger(__name__) - - -class Sync(Cog): - """Captures relevant events and sends them to the site.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.role_syncer = syncers.RoleSyncer(self.bot) - self.user_syncer = syncers.UserSyncer(self.bot) - - self.bot.loop.create_task(self.sync_guild()) - - async def sync_guild(self) -> None: - """Syncs the roles/users of the guild with the database.""" - await self.bot.wait_until_guild_available() - - guild = self.bot.get_guild(constants.Guild.id) - if guild is None: - return - - for syncer in (self.role_syncer, self.user_syncer): - await syncer.sync(guild) - - async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: - """Send a PATCH request to partially update a user in the database.""" - try: - await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) - except ResponseCodeError as e: - if e.response.status != 404: - raise - if not ignore_404: - log.warning("Unable to update user, got 404. Assuming race condition from join event.") - - @Cog.listener() - async def on_guild_role_create(self, role: Role) -> None: - """Adds newly create role to the database table over the API.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.post( - 'bot/roles', - json={ - 'colour': role.colour.value, - 'id': role.id, - 'name': role.name, - 'permissions': role.permissions.value, - 'position': role.position, - } - ) - - @Cog.listener() - async def on_guild_role_delete(self, role: Role) -> None: - """Deletes role from the database when it's deleted from the guild.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - @Cog.listener() - async def on_guild_role_update(self, before: Role, after: Role) -> None: - """Syncs role with the database if any of the stored attributes were updated.""" - if after.guild.id != constants.Guild.id: - return - - was_updated = ( - before.name != after.name - or before.colour != after.colour - or before.permissions != after.permissions - or before.position != after.position - ) - - if was_updated: - await self.bot.api_client.put( - f'bot/roles/{after.id}', - json={ - 'colour': after.colour.value, - 'id': after.id, - 'name': after.name, - 'permissions': after.permissions.value, - 'position': after.position, - } - ) - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """ - Adds a new user or updates existing user to the database when a member joins the guild. - - If the joining member is a user that is already known to the database (i.e., a user that - previously left), it will update the user's information. If the user is not yet known by - the database, the user is added. - """ - if member.guild.id != constants.Guild.id: - return - - packed = { - 'discriminator': int(member.discriminator), - 'id': member.id, - 'in_guild': True, - 'name': member.name, - 'roles': sorted(role.id for role in member.roles) - } - - got_error = False - - try: - # First try an update of the user to set the `in_guild` field and other - # fields that may have changed since the last time we've seen them. - await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) - - except ResponseCodeError as e: - # If we didn't get 404, something else broke - propagate it up. - if e.response.status != 404: - raise - - got_error = True # yikes - - if got_error: - # If we got `404`, the user is new. Create them. - await self.bot.api_client.post('bot/users', json=packed) - - @Cog.listener() - async def on_member_remove(self, member: Member) -> None: - """Set the in_guild field to False when a member leaves the guild.""" - if member.guild.id != constants.Guild.id: - return - - await self.patch_user(member.id, json={"in_guild": False}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Update the roles of the member in the database if a change is detected.""" - if after.guild.id != constants.Guild.id: - return - - if before.roles != after.roles: - updated_information = {"roles": sorted(role.id for role in after.roles)} - await self.patch_user(after.id, json=updated_information) - - @Cog.listener() - async def on_user_update(self, before: User, after: User) -> None: - """Update the user information in the database if a relevant change is detected.""" - attrs = ("name", "discriminator") - if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): - updated_information = { - "name": after.name, - "discriminator": int(after.discriminator), - } - # A 404 likely means the user is in another guild. - await self.patch_user(after.id, json=updated_information, ignore_404=True) - - @commands.group(name='sync') - @commands.has_permissions(administrator=True) - async def sync_group(self, ctx: Context) -> None: - """Run synchronizations between the bot and site manually.""" - - @sync_group.command(name='roles') - @commands.has_permissions(administrator=True) - async def sync_roles_command(self, ctx: Context) -> None: - """Manually synchronise the guild's roles with the roles on the site.""" - await self.role_syncer.sync(ctx.guild, ctx) - - @sync_group.command(name='users') - @commands.has_permissions(administrator=True) - async def sync_users_command(self, ctx: Context) -> None: - """Manually synchronise the guild's users with the users on the site.""" - await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/backend/sync/syncers.py b/bot/cogs/backend/sync/syncers.py deleted file mode 100644 index f7ba811bc..000000000 --- a/bot/cogs/backend/sync/syncers.py +++ /dev/null @@ -1,347 +0,0 @@ -import abc -import asyncio -import logging -import typing as t -from collections import namedtuple -from functools import partial - -import discord -from discord import Guild, HTTPException, Member, Message, Reaction, User -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot - -log = logging.getLogger(__name__) - -# These objects are declared as namedtuples because tuples are hashable, -# something that we make use of when diffing site roles against guild roles. -_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) -_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) - - -class Syncer(abc.ABC): - """Base class for synchronising the database with objects in the Discord cache.""" - - _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " - _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - @property - @abc.abstractmethod - def name(self) -> str: - """The name of the syncer; used in output messages and logging.""" - raise NotImplementedError # pragma: no cover - - async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: - """ - Send a prompt to confirm or abort a sync using reactions and return the sent message. - - If a message is given, it is edited to display the prompt and reactions. Otherwise, a new - message is sent to the dev-core channel and mentions the core developers role. If the - channel cannot be retrieved, return None. - """ - log.trace(f"Sending {self.name} sync confirmation prompt.") - - msg_content = ( - f'Possible cache issue while syncing {self.name}s. ' - f'More than {constants.Sync.max_diff} {self.name}s were changed. ' - f'React to confirm or abort the sync.' - ) - - # Send to core developers if it's an automatic sync. - if not message: - log.trace("Message not provided for confirmation; creating a new one in dev-core.") - channel = self.bot.get_channel(constants.Channels.dev_core) - - if not channel: - log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") - try: - channel = await self.bot.fetch_channel(constants.Channels.dev_core) - except HTTPException: - log.exception( - f"Failed to fetch channel for sending sync confirmation prompt; " - f"aborting {self.name} sync." - ) - return None - - allowed_roles = [discord.Object(constants.Roles.core_developers)] - message = await channel.send( - f"{self._CORE_DEV_MENTION}{msg_content}", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) - ) - else: - await message.edit(content=msg_content) - - # Add the initial reactions. - log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") - for emoji in self._REACTION_EMOJIS: - await message.add_reaction(emoji) - - return message - - def _reaction_check( - self, - author: Member, - message: Message, - reaction: Reaction, - user: t.Union[Member, User] - ) -> bool: - """ - Return True if the `reaction` is a valid confirmation or abort reaction on `message`. - - If the `author` of the prompt is a bot, then a reaction by any core developer will be - considered valid. Otherwise, the author of the reaction (`user`) will have to be the - `author` of the prompt. - """ - # For automatic syncs, check for the core dev role instead of an exact author - has_role = any(constants.Roles.core_developers == role.id for role in user.roles) - return ( - reaction.message.id == message.id - and not user.bot - and (has_role if author.bot else user == author) - and str(reaction.emoji) in self._REACTION_EMOJIS - ) - - async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: - """ - Wait for a confirmation reaction by `author` on `message` and return True if confirmed. - - Uses the `_reaction_check` function to determine if a reaction is valid. - - If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. - To acknowledge the reaction (or lack thereof), `message` will be edited. - """ - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - reaction = None - try: - log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") - reaction, _ = await self.bot.wait_for( - 'reaction_add', - check=partial(self._reaction_check, author, message), - timeout=constants.Sync.confirm_timeout - ) - except asyncio.TimeoutError: - # reaction will remain none thus sync will be aborted in the finally block below. - log.debug(f"The {self.name} syncer confirmation prompt timed out.") - - if str(reaction) == constants.Emojis.check_mark: - log.trace(f"The {self.name} syncer was confirmed.") - await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') - return True - else: - log.info(f"The {self.name} syncer was aborted or timed out!") - await message.edit( - content=f':warning: {mention}{self.name} sync aborted or timed out!' - ) - return False - - @abc.abstractmethod - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference between the cache of `guild` and the database.""" - raise NotImplementedError # pragma: no cover - - @abc.abstractmethod - async def _sync(self, diff: _Diff) -> None: - """Perform the API calls for synchronisation.""" - raise NotImplementedError # pragma: no cover - - async def _get_confirmation_result( - self, - diff_size: int, - author: Member, - message: t.Optional[Message] = None - ) -> t.Tuple[bool, t.Optional[Message]]: - """ - Prompt for confirmation and return a tuple of the result and the prompt message. - - `diff_size` is the size of the diff of the sync. If it is greater than - `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the - sync and the `message` is an extant message to edit to display the prompt. - - If confirmed or no confirmation was needed, the result is True. The returned message will - either be the given `message` or a new one which was created when sending the prompt. - """ - log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") - if diff_size > constants.Sync.max_diff: - message = await self._send_prompt(message) - if not message: - return False, None # Couldn't get channel. - - confirmed = await self._wait_for_confirmation(author, message) - if not confirmed: - return False, message # Sync aborted. - - return True, message - - async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: - """ - Synchronise the database with the cache of `guild`. - - If the differences between the cache and the database are greater than - `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core - channel. The confirmation can be optionally redirect to `ctx` instead. - """ - log.info(f"Starting {self.name} syncer.") - - message = None - author = self.bot.user - if ctx: - message = await ctx.send(f"📊 Synchronising {self.name}s.") - author = ctx.author - - diff = await self._get_diff(guild) - diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict - totals = {k: len(v) for k, v in diff_dict.items() if v is not None} - diff_size = sum(totals.values()) - - confirmed, message = await self._get_confirmation_result(diff_size, author, message) - if not confirmed: - return - - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - try: - await self._sync(diff) - except ResponseCodeError as e: - log.exception(f"{self.name} syncer failed!") - - # Don't show response text because it's probably some really long HTML. - results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" - content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" - else: - results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) - log.info(f"{self.name} syncer finished: {results}.") - content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" - - if message: - await message.edit(content=content) - - -class RoleSyncer(Syncer): - """Synchronise the database with roles in the cache.""" - - name = "role" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of roles between the cache of `guild` and the database.""" - log.trace("Getting the diff for roles.") - roles = await self.bot.api_client.get('bot/roles') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_roles = {_Role(**role_dict) for role_dict in roles} - guild_roles = { - _Role( - id=role.id, - name=role.name, - colour=role.colour.value, - permissions=role.permissions.value, - position=role.position, - ) - for role in guild.roles - } - - guild_role_ids = {role.id for role in guild_roles} - api_role_ids = {role.id for role in db_roles} - new_role_ids = guild_role_ids - api_role_ids - deleted_role_ids = api_role_ids - guild_role_ids - - # New roles are those which are on the cached guild but not on the - # DB guild, going by the role ID. We need to send them in for creation. - roles_to_create = {role for role in guild_roles if role.id in new_role_ids} - roles_to_update = guild_roles - db_roles - roles_to_create - roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} - - return _Diff(roles_to_create, roles_to_update, roles_to_delete) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the role cache of `guild`.""" - log.trace("Syncing created roles...") - for role in diff.created: - await self.bot.api_client.post('bot/roles', json=role._asdict()) - - log.trace("Syncing updated roles...") - for role in diff.updated: - await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) - - log.trace("Syncing deleted roles...") - for role in diff.deleted: - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - -class UserSyncer(Syncer): - """Synchronise the database with users in the cache.""" - - name = "user" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of users between the cache of `guild` and the database.""" - log.trace("Getting the diff for users.") - users = await self.bot.api_client.get('bot/users') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_users = { - user_dict['id']: _User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in users - } - guild_users = { - member.id: _User( - id=member.id, - name=member.name, - discriminator=int(member.discriminator), - roles=tuple(sorted(role.id for role in member.roles)), - in_guild=True - ) - for member in guild.members - } - - users_to_create = set() - users_to_update = set() - - for db_user in db_users.values(): - guild_user = guild_users.get(db_user.id) - if guild_user is not None: - if db_user != guild_user: - users_to_update.add(guild_user) - - elif db_user.in_guild: - # The user is known in the DB but not the guild, and the - # DB currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - new_api_user = db_user._replace(in_guild=False) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(db_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) - - return _Diff(users_to_create, users_to_update, None) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the user cache of `guild`.""" - log.trace("Syncing created users...") - for user in diff.created: - await self.bot.api_client.post('bot/users', json=user._asdict()) - - log.trace("Syncing updated users...") - for user in diff.updated: - await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index aad1f3c26..e69de29bb 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,19 +0,0 @@ -from bot.bot import Bot -from .incidents import Incidents -from .infraction.infractions import Infractions -from .infraction.management import ModManagement -from .infraction.superstarify import Superstarify -from .modlog import ModLog -from .silence import Silence -from .slowmode import Slowmode - - -def setup(bot: Bot) -> None: - """Load the Incidents, Infractions, ModManagement, ModLog, Silence, Slowmode and Superstarify cogs.""" - bot.add_cog(Incidents(bot)) - bot.add_cog(Infractions(bot)) - bot.add_cog(ModLog(bot)) - bot.add_cog(ModManagement(bot)) - bot.add_cog(Silence(bot)) - bot.add_cog(Slowmode(bot)) - bot.add_cog(Superstarify(bot)) diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 3605ab1d2..e49913552 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -405,3 +405,8 @@ class Incidents(Cog): """Pass `message` to `add_signals` if and only if it satisfies `is_incident`.""" if is_incident(message): await add_signals(message) + + +def setup(bot: Bot) -> None: + """Load the Incidents cog.""" + bot.add_cog(Incidents(bot)) diff --git a/bot/cogs/moderation/infraction/_scheduler.py b/bot/cogs/moderation/infraction/_scheduler.py new file mode 100644 index 000000000..33944a8db --- /dev/null +++ b/bot/cogs/moderation/infraction/_scheduler.py @@ -0,0 +1,463 @@ +import logging +import textwrap +import typing as t +from abc import abstractmethod +from datetime import datetime +from gettext import ngettext + +import dateutil.parser +import discord +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.cogs.moderation.modlog import ModLog +from bot.constants import Colours, STAFF_CHANNELS +from bot.utils import time +from bot.utils.scheduling import Scheduler +from . import _utils +from ._utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class InfractionScheduler: + """Handles the application, pardoning, and expiration of infractions.""" + + def __init__(self, bot: Bot, supported_infractions: t.Container[str]): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + @property + def mod_log(self) -> ModLog: + """Get the currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: + """Schedule expiration for previous infractions.""" + await self.bot.wait_until_guild_available() + + log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") + + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={'active': 'true'} + ) + for infraction in infractions: + if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: + self.schedule_expiration(infraction) + + async def reapply_infraction( + self, + infraction: _utils.Infraction, + apply_coro: t.Optional[t.Awaitable] + ) -> None: + """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" + # Calculate the time remaining, in seconds, for the mute. + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + delta = (expiry - datetime.utcnow()).total_seconds() + + # Mark as inactive if less than a minute remains. + if delta < 60: + log.info( + "Infraction will be deactivated instead of re-applied " + "because less than 1 minute remains." + ) + await self.deactivate_infraction(infraction) + return + + # Allowing mod log since this is a passive action that should be logged. + await apply_coro + log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") + + async def apply_infraction( + self, + ctx: Context, + infraction: _utils.Infraction, + user: UserSnowflake, + action_coro: t.Optional[t.Awaitable] = None + ) -> None: + """Apply an infraction to the user, log the infraction, and optionally notify the user.""" + infr_type = infraction["type"] + icon = _utils.INFRACTION_ICONS[infr_type][0] + reason = infraction["reason"] + expiry = time.format_infraction_with_duration(infraction["expires_at"]) + id_ = infraction['id'] + + log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") + + # Default values for the confirmation message and mod log. + confirm_msg = ":ok_hand: applied" + + # Specifying an expiry for a note or warning makes no sense. + if infr_type in ("note", "warning"): + expiry_msg = "" + else: + expiry_msg = f" until {expiry}" if expiry else " permanently" + + dm_result = "" + dm_log_text = "" + expiry_log_text = f"\nExpires: {expiry}" if expiry else "" + log_title = "applied" + log_content = None + failed = False + + # DM the user about the infraction if it's not a shadow/hidden infraction. + # This needs to happen before we apply the infraction, as the bot cannot + # send DMs to user that it doesn't share a guild with. If we were to + # apply kick/ban infractions first, this would mean that we'd make it + # impossible for us to deliver a DM. See python-discord/bot#982. + if not infraction["hidden"]: + dm_result = f"{constants.Emojis.failmail} " + dm_log_text = "\nDM: **Failed**" + + # Sometimes user is a discord.Object; make it a proper user. + try: + if not isinstance(user, (discord.Member, discord.User)): + user = await self.bot.fetch_user(user.id) + except discord.HTTPException as e: + log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + else: + # Accordingly display whether the user was successfully notified via DM. + if await _utils.notify_infraction(user, infr_type, expiry, reason, icon): + dm_result = ":incoming_envelope: " + dm_log_text = "\nDM: Sent" + + end_msg = "" + if infraction["actor"] == self.bot.user.id: + log.trace( + f"Infraction #{id_} actor is bot; including the reason in the confirmation message." + ) + if reason: + end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" + elif ctx.channel.id not in STAFF_CHANNELS: + log.trace( + f"Infraction #{id_} context is not in a staff channel; omitting infraction count." + ) + else: + log.trace(f"Fetching total infraction count for {user}.") + + infractions = await self.bot.api_client.get( + "bot/infractions", + params={"user__id": str(user.id)} + ) + total = len(infractions) + end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" + + # Execute the necessary actions to apply the infraction on Discord. + if action_coro: + log.trace(f"Awaiting the infraction #{id_} application action coroutine.") + try: + await action_coro + if expiry: + # Schedule the expiration of the infraction. + self.schedule_expiration(infraction) + except discord.HTTPException as e: + # Accordingly display that applying the infraction failed. + confirm_msg = ":x: failed to apply" + expiry_msg = "" + log_content = ctx.author.mention + log_title = "failed to apply" + + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) + failed = True + + if failed: + log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") + try: + await self.bot.api_client.delete(f"bot/infractions/{id_}") + except ResponseCodeError as e: + confirm_msg += " and failed to delete" + log_title += " and failed to delete" + log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") + infr_message = "" + else: + infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" + + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{id_} confirmation message.") + await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") + + # Send a log message to the mod log. + log.trace(f"Sending apply mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=icon, + colour=Colours.soft_red, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {user.mention} (`{user.id}`) + Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} + Reason: {reason} + """), + content=log_content, + footer=f"ID {infraction['id']}" + ) + + log.info(f"Applied {infr_type} infraction #{id_} to {user}.") + + async def pardon_infraction( + self, + ctx: Context, + infr_type: str, + user: UserSnowflake, + send_msg: bool = True + ) -> None: + """ + Prematurely end an infraction for a user and log the action in the mod log. + + If `send_msg` is True, then a pardoning confirmation message will be sent to + the context channel. Otherwise, no such message will be sent. + """ + log.trace(f"Pardoning {infr_type} infraction for {user}.") + + # Check the current active infraction + log.trace(f"Fetching active {infr_type} infractions for {user}.") + response = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': user.id + } + ) + + if not response: + log.debug(f"No active {infr_type} infraction found for {user}.") + await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") + return + + # Deactivate the infraction and cancel its scheduled expiration task. + log_text = await self.deactivate_infraction(response[0], send_log=False) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["Actor"] = str(ctx.message.author) + log_content = None + id_ = response[0]['id'] + footer = f"ID: {id_}" + + # If multiple active infractions were found, mark them as inactive in the database + # and cancel their expiration tasks. + if len(response) > 1: + log.info( + f"Found more than one active {infr_type} infraction for user {user.id}; " + "deactivating the extra active infractions too." + ) + + footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" + + log_note = f"Found multiple **active** {infr_type} infractions in the database." + if "Note" in log_text: + log_text["Note"] = f" {log_note}" + else: + log_text["Note"] = log_note + + # deactivate_infraction() is not called again because: + # 1. Discord cannot store multiple active bans or assign multiples of the same role + # 2. It would send a pardon DM for each active infraction, which is redundant + for infraction in response[1:]: + id_ = infraction['id'] + try: + # Mark infraction as inactive in the database. + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError: + log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") + # This is simpler and cleaner than trying to concatenate all the errors. + log_text["Failure"] = "See bot's logs for details." + + # Cancel pending expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Accordingly display whether the user was successfully notified via DM. + dm_emoji = "" + if log_text.get("DM") == "Sent": + dm_emoji = ":incoming_envelope: " + elif "DM" in log_text: + dm_emoji = f"{constants.Emojis.failmail} " + + # Accordingly display whether the pardon failed. + if "Failure" in log_text: + confirm_msg = ":x: failed to pardon" + log_title = "pardon failed" + log_content = ctx.author.mention + + log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") + else: + confirm_msg = ":ok_hand: pardoned" + log_title = "pardoned" + + log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") + + # Send a confirmation message to the invoking context. + if send_msg: + log.trace(f"Sending infraction #{id_} pardon confirmation message.") + await ctx.send( + f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " + f"{log_text.get('Failure', '')}" + ) + + # Move reason to end of entry to avoid cutting out some keys + log_text["Reason"] = log_text.pop("Reason") + + # Send a log message to the mod log. + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS[infr_type][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=footer, + content=log_content, + ) + + async def deactivate_infraction( + self, + infraction: _utils.Infraction, + send_log: bool = True + ) -> t.Dict[str, str]: + """ + Deactivate an active infraction and return a dictionary of lines to send in a mod log. + + The infraction is removed from Discord, marked as inactive in the database, and has its + expiration task cancelled. If `send_log` is True, a mod log is sent for the + deactivation of the infraction. + + Infractions of unsupported types will raise a ValueError. + """ + guild = self.bot.get_guild(constants.Guild.id) + mod_role = guild.get_role(constants.Roles.moderators) + user_id = infraction["user"] + actor = infraction["actor"] + type_ = infraction["type"] + id_ = infraction["id"] + inserted_at = infraction["inserted_at"] + expiry = infraction["expires_at"] + + log.info(f"Marking infraction #{id_} as inactive (expired).") + + expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None + created = time.format_infraction_with_duration(inserted_at, expiry) + + log_content = None + log_text = { + "Member": f"<@{user_id}>", + "Actor": str(self.bot.get_user(actor) or actor), + "Reason": infraction["reason"], + "Created": created, + } + + try: + log.trace("Awaiting the pardon action coroutine.") + returned_log = await self._pardon_action(infraction) + + if returned_log is not None: + log_text = {**log_text, **returned_log} # Merge the logs together + else: + raise ValueError( + f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" + ) + except discord.Forbidden: + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") + log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" + log_content = mod_role.mention + except discord.HTTPException as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." + log_content = mod_role.mention + + # Check if the user is currently being watched by Big Brother. + try: + log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") + + active_watch = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "watch", + "user__id": user_id + } + ) + + log_text["Watching"] = "Yes" if active_watch else "No" + except ResponseCodeError: + log.exception(f"Failed to fetch watch status for user {user_id}") + log_text["Watching"] = "Unknown - failed to fetch watch status." + + try: + # Mark infraction as inactive in the database. + log.trace(f"Marking infraction #{id_} as inactive in the database.") + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_line = f"API request failed with code {e.status}." + log_content = mod_role.mention + + # Append to an existing failure message if possible + if "Failure" in log_text: + log_text["Failure"] += f" {log_line}" + else: + log_text["Failure"] = log_line + + # Cancel the expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Send a log message to the mod log. + if send_log: + log_title = "expiration failed" if "Failure" in log_text else "expired" + + user = self.bot.get_user(user_id) + avatar = user.avatar_url_as(static_format="png") if user else None + + # Move reason to end so when reason is too long, this is not gonna cut out required items. + log_text["Reason"] = log_text.pop("Reason") + + log.trace(f"Sending deactivation mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS[type_][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {type_}", + thumbnail=avatar, + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=f"ID: {id_}", + content=log_content, + ) + + return log_text + + @abstractmethod + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + raise NotImplementedError + + def schedule_expiration(self, infraction: _utils.Infraction) -> None: + """ + Marks an infraction expired after the delay from time of scheduling to time of expiration. + + At the time of expiration, the infraction is marked as inactive on the website and the + expiration task is cancelled. + """ + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/infraction/_utils.py b/bot/cogs/moderation/infraction/_utils.py new file mode 100644 index 000000000..fb55287b6 --- /dev/null +++ b/bot/cogs/moderation/infraction/_utils.py @@ -0,0 +1,201 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext.commands import Context + +from bot.api import ResponseCodeError +from bot.constants import Colours, Icons + +log = logging.getLogger(__name__) + +# apply icon, pardon icon +INFRACTION_ICONS = { + "ban": (Icons.user_ban, Icons.user_unban), + "kick": (Icons.sign_out, None), + "mute": (Icons.user_mute, Icons.user_unmute), + "note": (Icons.user_warn, None), + "superstar": (Icons.superstarify, Icons.unsuperstarify), + "warning": (Icons.user_warn, None), +} +RULES_URL = "https://pythondiscord.com/pages/rules" +APPEALABLE_INFRACTIONS = ("ban", "mute") + +# Type aliases +UserObject = t.Union[discord.Member, discord.User] +UserSnowflake = t.Union[UserObject, discord.Object] +Infraction = t.Dict[str, t.Union[str, int, bool]] + + +async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: + """ + Create a new user in the database. + + Used when an infraction needs to be applied on a user absent in the guild. + """ + log.trace(f"Attempting to add user {user.id} to the database.") + + if not isinstance(user, (discord.Member, discord.User)): + log.debug("The user being added to the DB is not a Member or User object.") + + payload = { + 'discriminator': int(getattr(user, 'discriminator', 0)), + 'id': user.id, + 'in_guild': False, + 'name': getattr(user, 'name', 'Name unknown'), + 'roles': [] + } + + try: + response = await ctx.bot.api_client.post('bot/users', json=payload) + log.info(f"User {user.id} added to the DB.") + return response + except ResponseCodeError as e: + log.error(f"Failed to add user {user.id} to the DB. {e}") + await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") + + +async def post_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + reason: str, + expires_at: datetime = None, + hidden: bool = False, + active: bool = True +) -> t.Optional[dict]: + """Posts an infraction to the API.""" + log.trace(f"Posting {infr_type} infraction for {user} to the API.") + + payload = { + "actor": ctx.message.author.id, + "hidden": hidden, + "reason": reason, + "type": infr_type, + "user": user.id, + "active": active + } + if expires_at: + payload['expires_at'] = expires_at.isoformat() + + # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. + for should_post_user in (True, False): + try: + response = await ctx.bot.api_client.post('bot/infractions', json=payload) + return response + except ResponseCodeError as e: + if e.status == 400 and 'user' in e.response_json: + # Only one attempt to add the user to the database, not two: + if not should_post_user or await post_user(ctx, user) is None: + return + else: + log.exception(f"Unexpected error while adding an infraction for {user}:") + await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") + return + + +async def get_active_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + send_msg: bool = True +) -> t.Optional[dict]: + """ + Retrieves an active infraction of the given type for the user. + + If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, + then a message for the moderator will be sent to the context channel letting them know. + Otherwise, no message will be sent. + """ + log.trace(f"Checking if {user} has active infractions of type {infr_type}.") + + active_infractions = await ctx.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': str(user.id) + } + ) + if active_infractions: + # Checks to see if the moderator should be told there is an active infraction + if send_msg: + log.trace(f"{user} has active infractions of type {infr_type}.") + await ctx.send( + f":x: According to my records, this user already has a {infr_type} infraction. " + f"See infraction **#{active_infractions[0]['id']}**." + ) + return active_infractions[0] + else: + log.trace(f"{user} does not have active infractions of type {infr_type}.") + + +async def notify_infraction( + user: UserObject, + infr_type: str, + expires_at: t.Optional[str] = None, + reason: t.Optional[str] = None, + icon_url: str = Icons.token_removed +) -> bool: + """DM a user about their new infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + + text = textwrap.dedent(f""" + **Type:** {infr_type.capitalize()} + **Expires:** {expires_at or "N/A"} + **Reason:** {reason or "No reason provided."} + """) + + embed = discord.Embed( + description=textwrap.shorten(text, width=2048, placeholder="..."), + colour=Colours.soft_red + ) + + embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) + embed.title = f"Please review our rules over at {RULES_URL}" + embed.url = RULES_URL + + if infr_type in APPEALABLE_INFRACTIONS: + embed.set_footer( + text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" + ) + + return await send_private_embed(user, embed) + + +async def notify_pardon( + user: UserObject, + title: str, + content: str, + icon_url: str = Icons.user_verified +) -> bool: + """DM a user about their pardoned infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their pardoned infraction.") + + embed = discord.Embed( + description=content, + colour=Colours.soft_green + ) + + embed.set_author(name=title, icon_url=icon_url) + + return await send_private_embed(user, embed) + + +async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: + """ + A helper method for sending an embed to a user's DMs. + + Returns a boolean indicator of DM success. + """ + try: + await user.send(embed=embed) + return True + except (discord.HTTPException, discord.Forbidden, discord.NotFound): + log.debug( + f"Infraction-related information could not be sent to user {user} ({user.id}). " + "The user either could not be retrieved or probably disabled their DMs." + ) + return False diff --git a/bot/cogs/moderation/infraction/infractions.py b/bot/cogs/moderation/infraction/infractions.py index 8df642428..cb459b447 100644 --- a/bot/cogs/moderation/infraction/infractions.py +++ b/bot/cogs/moderation/infraction/infractions.py @@ -13,9 +13,9 @@ from bot.constants import Event from bot.converters import Expiry, FetchedMember from bot.decorators import respect_role_hierarchy from bot.utils.checks import with_role_check -from . import utils -from .scheduler import InfractionScheduler -from .utils import UserSnowflake +from . import _utils +from ._scheduler import InfractionScheduler +from ._utils import UserSnowflake log = logging.getLogger(__name__) @@ -55,7 +55,7 @@ class Infractions(InfractionScheduler, commands.Cog): @command() async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: """Warn a user for the given reason.""" - infraction = await utils.post_infraction(ctx, user, "warning", reason, active=False) + infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False) if infraction is None: return @@ -125,7 +125,7 @@ class Infractions(InfractionScheduler, commands.Cog): @command(hidden=True) async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: """Create a private note for a user with the given reason without notifying the user.""" - infraction = await utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) + infraction = await _utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) if infraction is None: return @@ -213,10 +213,10 @@ class Infractions(InfractionScheduler, commands.Cog): async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: """Apply a mute infraction with kwargs passed to `post_infraction`.""" - if await utils.get_active_infraction(ctx, user, "mute"): + if await _utils.get_active_infraction(ctx, user, "mute"): return - infraction = await utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) + infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) if infraction is None: return @@ -233,7 +233,7 @@ class Infractions(InfractionScheduler, commands.Cog): @respect_role_hierarchy() async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: """Apply a kick infraction with kwargs passed to `post_infraction`.""" - infraction = await utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) + infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) if infraction is None: return @@ -254,7 +254,7 @@ class Infractions(InfractionScheduler, commands.Cog): """ # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active is_temporary = kwargs.get("expires_at") is not None - active_infraction = await utils.get_active_infraction(ctx, user, "ban", is_temporary) + active_infraction = await _utils.get_active_infraction(ctx, user, "ban", is_temporary) if active_infraction: if is_temporary: @@ -269,7 +269,7 @@ class Infractions(InfractionScheduler, commands.Cog): log.trace("Old tempban is being replaced by new permaban.") await self.pardon_infraction(ctx, "ban", user, is_temporary) - infraction = await utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) + infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) if infraction is None: return @@ -309,11 +309,11 @@ class Infractions(InfractionScheduler, commands.Cog): await user.remove_roles(self._muted_role, reason=reason) # DM the user about the expiration. - notified = await utils.notify_pardon( + notified = await _utils.notify_pardon( user=user, title="You have been unmuted", content="You may now send messages in the server.", - icon_url=utils.INFRACTION_ICONS["mute"][1] + icon_url=_utils.INFRACTION_ICONS["mute"][1] ) log_text["Member"] = f"{user.mention}(`{user.id}`)" @@ -339,7 +339,7 @@ class Infractions(InfractionScheduler, commands.Cog): return log_text - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: """ Execute deactivation steps specific to the infraction's type and return a log dict. @@ -368,3 +368,8 @@ class Infractions(InfractionScheduler, commands.Cog): if discord.User in error.converters or discord.Member in error.converters: await ctx.send(str(error.errors[0])) error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Infractions cog.""" + bot.add_cog(Infractions(bot)) diff --git a/bot/cogs/moderation/infraction/management.py b/bot/cogs/moderation/infraction/management.py index 791585b6e..9e7ae8113 100644 --- a/bot/cogs/moderation/infraction/management.py +++ b/bot/cogs/moderation/infraction/management.py @@ -14,7 +14,7 @@ from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_whitelist_check, with_role_check -from . import utils +from . import _utils from .infractions import Infractions log = logging.getLogger(__name__) @@ -220,7 +220,7 @@ class ModManagement(commands.Cog): self, ctx: Context, embed: discord.Embed, - infractions: t.Iterable[utils.Infraction] + infractions: t.Iterable[_utils.Infraction] ) -> None: """Send a paginated embed of infractions for the specified user.""" if not infractions: @@ -241,7 +241,7 @@ class ModManagement(commands.Cog): max_size=1000 ) - def infraction_to_string(self, infraction: utils.Infraction) -> str: + def infraction_to_string(self, infraction: _utils.Infraction) -> str: """Convert the infraction object to a string representation.""" actor_id = infraction["actor"] guild = self.bot.get_guild(constants.Guild.id) @@ -303,3 +303,8 @@ class ModManagement(commands.Cog): if discord.User in error.converters: await ctx.send(str(error.errors[0])) error.handled = True + + +def setup(bot: Bot) -> None: + """Load the ModManagement cog.""" + bot.add_cog(ModManagement(bot)) diff --git a/bot/cogs/moderation/infraction/scheduler.py b/bot/cogs/moderation/infraction/scheduler.py deleted file mode 100644 index b3d27fe76..000000000 --- a/bot/cogs/moderation/infraction/scheduler.py +++ /dev/null @@ -1,463 +0,0 @@ -import logging -import textwrap -import typing as t -from abc import abstractmethod -from datetime import datetime -from gettext import ngettext - -import dateutil.parser -import discord -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Colours, STAFF_CHANNELS -from bot.utils import time -from bot.utils.scheduling import Scheduler -from . import utils -from .utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class InfractionScheduler: - """Handles the application, pardoning, and expiration of infractions.""" - - def __init__(self, bot: Bot, supported_infractions: t.Container[str]): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - @property - def mod_log(self) -> ModLog: - """Get the currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: - """Schedule expiration for previous infractions.""" - await self.bot.wait_until_guild_available() - - log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") - - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={'active': 'true'} - ) - for infraction in infractions: - if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_expiration(infraction) - - async def reapply_infraction( - self, - infraction: utils.Infraction, - apply_coro: t.Optional[t.Awaitable] - ) -> None: - """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" - # Calculate the time remaining, in seconds, for the mute. - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - delta = (expiry - datetime.utcnow()).total_seconds() - - # Mark as inactive if less than a minute remains. - if delta < 60: - log.info( - "Infraction will be deactivated instead of re-applied " - "because less than 1 minute remains." - ) - await self.deactivate_infraction(infraction) - return - - # Allowing mod log since this is a passive action that should be logged. - await apply_coro - log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") - - async def apply_infraction( - self, - ctx: Context, - infraction: utils.Infraction, - user: UserSnowflake, - action_coro: t.Optional[t.Awaitable] = None - ) -> None: - """Apply an infraction to the user, log the infraction, and optionally notify the user.""" - infr_type = infraction["type"] - icon = utils.INFRACTION_ICONS[infr_type][0] - reason = infraction["reason"] - expiry = time.format_infraction_with_duration(infraction["expires_at"]) - id_ = infraction['id'] - - log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") - - # Default values for the confirmation message and mod log. - confirm_msg = ":ok_hand: applied" - - # Specifying an expiry for a note or warning makes no sense. - if infr_type in ("note", "warning"): - expiry_msg = "" - else: - expiry_msg = f" until {expiry}" if expiry else " permanently" - - dm_result = "" - dm_log_text = "" - expiry_log_text = f"\nExpires: {expiry}" if expiry else "" - log_title = "applied" - log_content = None - failed = False - - # DM the user about the infraction if it's not a shadow/hidden infraction. - # This needs to happen before we apply the infraction, as the bot cannot - # send DMs to user that it doesn't share a guild with. If we were to - # apply kick/ban infractions first, this would mean that we'd make it - # impossible for us to deliver a DM. See python-discord/bot#982. - if not infraction["hidden"]: - dm_result = f"{constants.Emojis.failmail} " - dm_log_text = "\nDM: **Failed**" - - # Sometimes user is a discord.Object; make it a proper user. - try: - if not isinstance(user, (discord.Member, discord.User)): - user = await self.bot.fetch_user(user.id) - except discord.HTTPException as e: - log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") - else: - # Accordingly display whether the user was successfully notified via DM. - if await utils.notify_infraction(user, infr_type, expiry, reason, icon): - dm_result = ":incoming_envelope: " - dm_log_text = "\nDM: Sent" - - end_msg = "" - if infraction["actor"] == self.bot.user.id: - log.trace( - f"Infraction #{id_} actor is bot; including the reason in the confirmation message." - ) - if reason: - end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" - elif ctx.channel.id not in STAFF_CHANNELS: - log.trace( - f"Infraction #{id_} context is not in a staff channel; omitting infraction count." - ) - else: - log.trace(f"Fetching total infraction count for {user}.") - - infractions = await self.bot.api_client.get( - "bot/infractions", - params={"user__id": str(user.id)} - ) - total = len(infractions) - end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" - - # Execute the necessary actions to apply the infraction on Discord. - if action_coro: - log.trace(f"Awaiting the infraction #{id_} application action coroutine.") - try: - await action_coro - if expiry: - # Schedule the expiration of the infraction. - self.schedule_expiration(infraction) - except discord.HTTPException as e: - # Accordingly display that applying the infraction failed. - confirm_msg = ":x: failed to apply" - expiry_msg = "" - log_content = ctx.author.mention - log_title = "failed to apply" - - log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" - if isinstance(e, discord.Forbidden): - log.warning(f"{log_msg}: bot lacks permissions.") - else: - log.exception(log_msg) - failed = True - - if failed: - log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") - try: - await self.bot.api_client.delete(f"bot/infractions/{id_}") - except ResponseCodeError as e: - confirm_msg += " and failed to delete" - log_title += " and failed to delete" - log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") - infr_message = "" - else: - infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" - - # Send a confirmation message to the invoking context. - log.trace(f"Sending infraction #{id_} confirmation message.") - await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") - - # Send a log message to the mod log. - log.trace(f"Sending apply mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=icon, - colour=Colours.soft_red, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {user.mention} (`{user.id}`) - Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} - Reason: {reason} - """), - content=log_content, - footer=f"ID {infraction['id']}" - ) - - log.info(f"Applied {infr_type} infraction #{id_} to {user}.") - - async def pardon_infraction( - self, - ctx: Context, - infr_type: str, - user: UserSnowflake, - send_msg: bool = True - ) -> None: - """ - Prematurely end an infraction for a user and log the action in the mod log. - - If `send_msg` is True, then a pardoning confirmation message will be sent to - the context channel. Otherwise, no such message will be sent. - """ - log.trace(f"Pardoning {infr_type} infraction for {user}.") - - # Check the current active infraction - log.trace(f"Fetching active {infr_type} infractions for {user}.") - response = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': user.id - } - ) - - if not response: - log.debug(f"No active {infr_type} infraction found for {user}.") - await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") - return - - # Deactivate the infraction and cancel its scheduled expiration task. - log_text = await self.deactivate_infraction(response[0], send_log=False) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["Actor"] = str(ctx.message.author) - log_content = None - id_ = response[0]['id'] - footer = f"ID: {id_}" - - # If multiple active infractions were found, mark them as inactive in the database - # and cancel their expiration tasks. - if len(response) > 1: - log.info( - f"Found more than one active {infr_type} infraction for user {user.id}; " - "deactivating the extra active infractions too." - ) - - footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" - - log_note = f"Found multiple **active** {infr_type} infractions in the database." - if "Note" in log_text: - log_text["Note"] = f" {log_note}" - else: - log_text["Note"] = log_note - - # deactivate_infraction() is not called again because: - # 1. Discord cannot store multiple active bans or assign multiples of the same role - # 2. It would send a pardon DM for each active infraction, which is redundant - for infraction in response[1:]: - id_ = infraction['id'] - try: - # Mark infraction as inactive in the database. - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError: - log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") - # This is simpler and cleaner than trying to concatenate all the errors. - log_text["Failure"] = "See bot's logs for details." - - # Cancel pending expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Accordingly display whether the user was successfully notified via DM. - dm_emoji = "" - if log_text.get("DM") == "Sent": - dm_emoji = ":incoming_envelope: " - elif "DM" in log_text: - dm_emoji = f"{constants.Emojis.failmail} " - - # Accordingly display whether the pardon failed. - if "Failure" in log_text: - confirm_msg = ":x: failed to pardon" - log_title = "pardon failed" - log_content = ctx.author.mention - - log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") - else: - confirm_msg = ":ok_hand: pardoned" - log_title = "pardoned" - - log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") - - # Send a confirmation message to the invoking context. - if send_msg: - log.trace(f"Sending infraction #{id_} pardon confirmation message.") - await ctx.send( - f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " - f"{log_text.get('Failure', '')}" - ) - - # Move reason to end of entry to avoid cutting out some keys - log_text["Reason"] = log_text.pop("Reason") - - # Send a log message to the mod log. - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[infr_type][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=footer, - content=log_content, - ) - - async def deactivate_infraction( - self, - infraction: utils.Infraction, - send_log: bool = True - ) -> t.Dict[str, str]: - """ - Deactivate an active infraction and return a dictionary of lines to send in a mod log. - - The infraction is removed from Discord, marked as inactive in the database, and has its - expiration task cancelled. If `send_log` is True, a mod log is sent for the - deactivation of the infraction. - - Infractions of unsupported types will raise a ValueError. - """ - guild = self.bot.get_guild(constants.Guild.id) - mod_role = guild.get_role(constants.Roles.moderators) - user_id = infraction["user"] - actor = infraction["actor"] - type_ = infraction["type"] - id_ = infraction["id"] - inserted_at = infraction["inserted_at"] - expiry = infraction["expires_at"] - - log.info(f"Marking infraction #{id_} as inactive (expired).") - - expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None - created = time.format_infraction_with_duration(inserted_at, expiry) - - log_content = None - log_text = { - "Member": f"<@{user_id}>", - "Actor": str(self.bot.get_user(actor) or actor), - "Reason": infraction["reason"], - "Created": created, - } - - try: - log.trace("Awaiting the pardon action coroutine.") - returned_log = await self._pardon_action(infraction) - - if returned_log is not None: - log_text = {**log_text, **returned_log} # Merge the logs together - else: - raise ValueError( - f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" - ) - except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") - log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" - log_content = mod_role.mention - except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." - log_content = mod_role.mention - - # Check if the user is currently being watched by Big Brother. - try: - log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") - - active_watch = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "watch", - "user__id": user_id - } - ) - - log_text["Watching"] = "Yes" if active_watch else "No" - except ResponseCodeError: - log.exception(f"Failed to fetch watch status for user {user_id}") - log_text["Watching"] = "Unknown - failed to fetch watch status." - - try: - # Mark infraction as inactive in the database. - log.trace(f"Marking infraction #{id_} as inactive in the database.") - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_line = f"API request failed with code {e.status}." - log_content = mod_role.mention - - # Append to an existing failure message if possible - if "Failure" in log_text: - log_text["Failure"] += f" {log_line}" - else: - log_text["Failure"] = log_line - - # Cancel the expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Send a log message to the mod log. - if send_log: - log_title = "expiration failed" if "Failure" in log_text else "expired" - - user = self.bot.get_user(user_id) - avatar = user.avatar_url_as(static_format="png") if user else None - - # Move reason to end so when reason is too long, this is not gonna cut out required items. - log_text["Reason"] = log_text.pop("Reason") - - log.trace(f"Sending deactivation mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[type_][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {type_}", - thumbnail=avatar, - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=f"ID: {id_}", - content=log_content, - ) - - return log_text - - @abstractmethod - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - raise NotImplementedError - - def schedule_expiration(self, infraction: utils.Infraction) -> None: - """ - Marks an infraction expired after the delay from time of scheduling to time of expiration. - - At the time of expiration, the infraction is marked as inactive on the website and the - expiration task is cancelled. - """ - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/infraction/superstarify.py b/bot/cogs/moderation/infraction/superstarify.py index 867de815a..7dc5b4691 100644 --- a/bot/cogs/moderation/infraction/superstarify.py +++ b/bot/cogs/moderation/infraction/superstarify.py @@ -13,8 +13,8 @@ from bot.bot import Bot from bot.converters import Expiry from bot.utils.checks import with_role_check from bot.utils.time import format_infraction -from . import utils -from .scheduler import InfractionScheduler +from . import _utils +from ._scheduler import InfractionScheduler log = logging.getLogger(__name__) NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" @@ -67,7 +67,7 @@ class Superstarify(InfractionScheduler, Cog): reason=f"Superstarified member tried to escape the prison: {infraction['id']}" ) - notified = await utils.notify_infraction( + notified = await _utils.notify_infraction( user=after, infr_type="Superstarify", expires_at=format_infraction(infraction["expires_at"]), @@ -76,7 +76,7 @@ class Superstarify(InfractionScheduler, Cog): f"from **{before.display_name}** to **{after.display_name}**, but as you " "are currently in superstar-prison, you do not have permission to do so." ), - icon_url=utils.INFRACTION_ICONS["superstar"][0] + icon_url=_utils.INFRACTION_ICONS["superstar"][0] ) if not notified: @@ -130,12 +130,12 @@ class Superstarify(InfractionScheduler, Cog): An optional reason can be provided. If no reason is given, the original name will be shown in a generated reason. """ - if await utils.get_active_infraction(ctx, member, "superstar"): + if await _utils.get_active_infraction(ctx, member, "superstar"): return # Post the infraction to the API reason = reason or f"old nick: {member.display_name}" - infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) + infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) id_ = infraction["id"] old_nick = member.display_name @@ -149,11 +149,11 @@ class Superstarify(InfractionScheduler, Cog): self.schedule_expiration(infraction) # Send a DM to the user to notify them of their new infraction. - await utils.notify_infraction( + await _utils.notify_infraction( user=member, infr_type="Superstarify", expires_at=expiry_str, - icon_url=utils.INFRACTION_ICONS["superstar"][0], + icon_url=_utils.INFRACTION_ICONS["superstar"][0], reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." ) @@ -176,7 +176,7 @@ class Superstarify(InfractionScheduler, Cog): # Log to the mod log channel. log.trace(f"Sending apply mod log for superstar #{id_}.") await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS["superstar"][0], + icon_url=_utils.INFRACTION_ICONS["superstar"][0], colour=Colour.gold(), title="Member achieved superstardom", thumbnail=member.avatar_url_as(static_format="png"), @@ -196,7 +196,7 @@ class Superstarify(InfractionScheduler, Cog): """Remove the superstarify infraction and allow the user to change their nickname.""" await self.pardon_infraction(ctx, "superstar", member) - async def _pardon_action(self, infraction: utils.Infraction) -> t.Optional[t.Dict[str, str]]: + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: """Pardon a superstar infraction and return a log dict.""" if infraction["type"] != "superstar": return @@ -213,11 +213,11 @@ class Superstarify(InfractionScheduler, Cog): return {} # DM the user about the expiration. - notified = await utils.notify_pardon( + notified = await _utils.notify_pardon( user=user, title="You are no longer superstarified", content="You may now change your nickname on the server.", - icon_url=utils.INFRACTION_ICONS["superstar"][1] + icon_url=_utils.INFRACTION_ICONS["superstar"][1] ) return { @@ -237,3 +237,8 @@ class Superstarify(InfractionScheduler, Cog): def cog_check(self, ctx: Context) -> bool: """Only allow moderators to invoke the commands in this cog.""" return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Superstarify cog.""" + bot.add_cog(Superstarify(bot)) diff --git a/bot/cogs/moderation/infraction/utils.py b/bot/cogs/moderation/infraction/utils.py deleted file mode 100644 index fb55287b6..000000000 --- a/bot/cogs/moderation/infraction/utils.py +++ /dev/null @@ -1,201 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext.commands import Context - -from bot.api import ResponseCodeError -from bot.constants import Colours, Icons - -log = logging.getLogger(__name__) - -# apply icon, pardon icon -INFRACTION_ICONS = { - "ban": (Icons.user_ban, Icons.user_unban), - "kick": (Icons.sign_out, None), - "mute": (Icons.user_mute, Icons.user_unmute), - "note": (Icons.user_warn, None), - "superstar": (Icons.superstarify, Icons.unsuperstarify), - "warning": (Icons.user_warn, None), -} -RULES_URL = "https://pythondiscord.com/pages/rules" -APPEALABLE_INFRACTIONS = ("ban", "mute") - -# Type aliases -UserObject = t.Union[discord.Member, discord.User] -UserSnowflake = t.Union[UserObject, discord.Object] -Infraction = t.Dict[str, t.Union[str, int, bool]] - - -async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: - """ - Create a new user in the database. - - Used when an infraction needs to be applied on a user absent in the guild. - """ - log.trace(f"Attempting to add user {user.id} to the database.") - - if not isinstance(user, (discord.Member, discord.User)): - log.debug("The user being added to the DB is not a Member or User object.") - - payload = { - 'discriminator': int(getattr(user, 'discriminator', 0)), - 'id': user.id, - 'in_guild': False, - 'name': getattr(user, 'name', 'Name unknown'), - 'roles': [] - } - - try: - response = await ctx.bot.api_client.post('bot/users', json=payload) - log.info(f"User {user.id} added to the DB.") - return response - except ResponseCodeError as e: - log.error(f"Failed to add user {user.id} to the DB. {e}") - await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") - - -async def post_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - reason: str, - expires_at: datetime = None, - hidden: bool = False, - active: bool = True -) -> t.Optional[dict]: - """Posts an infraction to the API.""" - log.trace(f"Posting {infr_type} infraction for {user} to the API.") - - payload = { - "actor": ctx.message.author.id, - "hidden": hidden, - "reason": reason, - "type": infr_type, - "user": user.id, - "active": active - } - if expires_at: - payload['expires_at'] = expires_at.isoformat() - - # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. - for should_post_user in (True, False): - try: - response = await ctx.bot.api_client.post('bot/infractions', json=payload) - return response - except ResponseCodeError as e: - if e.status == 400 and 'user' in e.response_json: - # Only one attempt to add the user to the database, not two: - if not should_post_user or await post_user(ctx, user) is None: - return - else: - log.exception(f"Unexpected error while adding an infraction for {user}:") - await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") - return - - -async def get_active_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - send_msg: bool = True -) -> t.Optional[dict]: - """ - Retrieves an active infraction of the given type for the user. - - If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, - then a message for the moderator will be sent to the context channel letting them know. - Otherwise, no message will be sent. - """ - log.trace(f"Checking if {user} has active infractions of type {infr_type}.") - - active_infractions = await ctx.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': str(user.id) - } - ) - if active_infractions: - # Checks to see if the moderator should be told there is an active infraction - if send_msg: - log.trace(f"{user} has active infractions of type {infr_type}.") - await ctx.send( - f":x: According to my records, this user already has a {infr_type} infraction. " - f"See infraction **#{active_infractions[0]['id']}**." - ) - return active_infractions[0] - else: - log.trace(f"{user} does not have active infractions of type {infr_type}.") - - -async def notify_infraction( - user: UserObject, - infr_type: str, - expires_at: t.Optional[str] = None, - reason: t.Optional[str] = None, - icon_url: str = Icons.token_removed -) -> bool: - """DM a user about their new infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their {infr_type} infraction.") - - text = textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """) - - embed = discord.Embed( - description=textwrap.shorten(text, width=2048, placeholder="..."), - colour=Colours.soft_red - ) - - embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) - embed.title = f"Please review our rules over at {RULES_URL}" - embed.url = RULES_URL - - if infr_type in APPEALABLE_INFRACTIONS: - embed.set_footer( - text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" - ) - - return await send_private_embed(user, embed) - - -async def notify_pardon( - user: UserObject, - title: str, - content: str, - icon_url: str = Icons.user_verified -) -> bool: - """DM a user about their pardoned infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their pardoned infraction.") - - embed = discord.Embed( - description=content, - colour=Colours.soft_green - ) - - embed.set_author(name=title, icon_url=icon_url) - - return await send_private_embed(user, embed) - - -async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: - """ - A helper method for sending an embed to a user's DMs. - - Returns a boolean indicator of DM success. - """ - try: - await user.send(embed=embed) - return True - except (discord.HTTPException, discord.Forbidden, discord.NotFound): - log.debug( - f"Infraction-related information could not be sent to user {user} ({user.id}). " - "The user either could not be retrieved or probably disabled their DMs." - ) - return False diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 0a63f57b8..c86f04b9d 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -830,3 +830,8 @@ class ModLog(Cog, name="ModLog"): thumbnail=member.avatar_url_as(static_format="png"), channel_id=Channels.voice_log ) + + +def setup(bot: Bot) -> None: + """Load the ModLog cog.""" + bot.add_cog(ModLog(bot)) diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index f8a6592bc..4af87c724 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -163,3 +163,8 @@ class Silence(commands.Cog): def cog_check(self, ctx: Context) -> bool: """Only allow moderators to invoke the commands in this cog.""" return with_role_check(ctx, *MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Silence cog.""" + bot.add_cog(Silence(bot)) diff --git a/bot/cogs/moderation/watchchannels/__init__.py b/bot/cogs/moderation/watchchannels/__init__.py index 69d118df6..e69de29bb 100644 --- a/bot/cogs/moderation/watchchannels/__init__.py +++ b/bot/cogs/moderation/watchchannels/__init__.py @@ -1,9 +0,0 @@ -from bot.bot import Bot -from .bigbrother import BigBrother -from .talentpool import TalentPool - - -def setup(bot: Bot) -> None: - """Load the BigBrother and TalentPool cogs.""" - bot.add_cog(BigBrother(bot)) - bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/moderation/watchchannels/_watchchannel.py b/bot/cogs/moderation/watchchannels/_watchchannel.py new file mode 100644 index 000000000..044077350 --- /dev/null +++ b/bot/cogs/moderation/watchchannels/_watchchannel.py @@ -0,0 +1,348 @@ +import asyncio +import logging +import re +import textwrap +from abc import abstractmethod +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Optional + +import dateutil.parser +import discord +from discord import Color, DMChannel, Embed, HTTPException, Message, errors +from discord.ext.commands import Cog, Context + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.cogs.moderation import ModLog +from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons +from bot.pagination import LinePaginator +from bot.utils import CogABCMeta, messages +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + +URL_RE = re.compile(r"(https?://[^\s]+)") + + +@dataclass +class MessageHistory: + """Represents a watch channel's message history.""" + + last_author: Optional[int] = None + last_channel: Optional[int] = None + message_count: int = 0 + + +class WatchChannel(metaclass=CogABCMeta): + """ABC with functionality for relaying users' messages to a certain channel.""" + + @abstractmethod + def __init__( + self, + bot: Bot, + destination: int, + webhook_id: int, + api_endpoint: str, + api_default_params: dict, + logger: logging.Logger + ) -> None: + self.bot = bot + + self.destination = destination # E.g., Channels.big_brother_logs + self.webhook_id = webhook_id # E.g., Webhooks.big_brother + self.api_endpoint = api_endpoint # E.g., 'bot/infractions' + self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} + self.log = logger # Logger of the child cog for a correct name in the logs + + self._consume_task = None + self.watched_users = defaultdict(dict) + self.message_queue = defaultdict(lambda: defaultdict(deque)) + self.consumption_queue = {} + self.retries = 5 + self.retry_delay = 10 + self.channel = None + self.webhook = None + self.message_history = MessageHistory() + + self._start = self.bot.loop.create_task(self.start_watchchannel()) + + @property + def modlog(self) -> ModLog: + """Provides access to the ModLog cog for alert purposes.""" + return self.bot.get_cog("ModLog") + + @property + def consuming_messages(self) -> bool: + """Checks if a consumption task is currently running.""" + if self._consume_task is None: + return False + + if self._consume_task.done(): + exc = self._consume_task.exception() + if exc: + self.log.exception( + "The message queue consume task has failed with:", + exc_info=exc + ) + return False + + return True + + async def start_watchchannel(self) -> None: + """Starts the watch channel by getting the channel, webhook, and user cache ready.""" + await self.bot.wait_until_guild_available() + + try: + self.channel = await self.bot.fetch_channel(self.destination) + except HTTPException: + self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + if self.channel is None or self.webhook is None: + self.log.error("Failed to start the watch channel; unloading the cog.") + + message = textwrap.dedent( + f""" + An error occurred while loading the text channel or webhook. + + TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} + Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} + + The Cog has been unloaded. + """ + ) + + await self.modlog.send_log_message( + title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", + text=message, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + if not await self.fetch_user_cache(): + await self.modlog.send_log_message( + title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", + text="Could not retrieve the list of watched users from the API and messages will not be relayed.", + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + async def fetch_user_cache(self) -> bool: + """ + Fetches watched users from the API and updates the watched user cache accordingly. + + This function returns `True` if the update succeeded. + """ + try: + data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) + except ResponseCodeError as err: + self.log.exception("Failed to fetch the watched users from the API", exc_info=err) + return False + + self.watched_users = defaultdict(dict) + + for entry in data: + user_id = entry.pop('user') + self.watched_users[user_id] = entry + + return True + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Queues up messages sent by watched users.""" + if msg.author.id in self.watched_users: + if not self.consuming_messages: + self._consume_task = self.bot.loop.create_task(self.consume_messages()) + + self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") + self.message_queue[msg.author.id][msg.channel.id].append(msg) + + async def consume_messages(self, delay_consumption: bool = True) -> None: + """Consumes the message queues to log watched users' messages.""" + if delay_consumption: + self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") + await asyncio.sleep(BigBrotherConfig.log_delay) + + self.log.trace("Started consuming the message queue") + + # If the previous consumption Task failed, first consume the existing comsumption_queue + if not self.consumption_queue: + self.consumption_queue = self.message_queue.copy() + self.message_queue.clear() + + for user_channel_queues in self.consumption_queue.values(): + for channel_queue in user_channel_queues.values(): + while channel_queue: + msg = channel_queue.popleft() + + self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") + await self.relay_message(msg) + + self.consumption_queue.clear() + + if self.message_queue: + self.log.trace("Channel queue not empty: Continuing consuming queues") + self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) + else: + self.log.trace("Done consuming messages.") + + async def webhook_send( + self, + content: Optional[str] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + embed: Optional[Embed] = None, + ) -> None: + """Sends a message to the webhook with the specified kwargs.""" + username = messages.sub_clyde(username) + try: + await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send a message to the webhook", + exc_info=exc + ) + + async def relay_message(self, msg: Message) -> None: + """Relays the message to the relevant watch channel.""" + limit = BigBrotherConfig.header_message_limit + + if ( + msg.author.id != self.message_history.last_author + or msg.channel.id != self.message_history.last_channel + or self.message_history.message_count >= limit + ): + self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) + + await self.send_header(msg) + + cleaned_content = msg.clean_content + + if cleaned_content: + # Put all non-media URLs in a code block to prevent embeds + media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} + for url in URL_RE.findall(cleaned_content): + if url not in media_urls: + cleaned_content = cleaned_content.replace(url, f"`{url}`") + await self.webhook_send( + cleaned_content, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + + if msg.attachments: + try: + await messages.send_attachments(msg, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.webhook_send( + embed=e, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send an attachment to the webhook", + exc_info=exc + ) + + self.message_history.message_count += 1 + + async def send_header(self, msg: Message) -> None: + """Sends a header embed with information about the relayed messages to the watch channel.""" + user_id = msg.author.id + + guild = self.bot.get_guild(GuildConfig.id) + actor = guild.get_member(self.watched_users[user_id]['actor']) + actor = actor.display_name if actor else self.watched_users[user_id]['actor'] + + inserted_at = self.watched_users[user_id]['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + + reason = self.watched_users[user_id]['reason'] + + if isinstance(msg.channel, DMChannel): + # If a watched user DMs the bot there won't be a channel name or jump URL + # This could technically include a GroupChannel but bot's can't be in those + message_jump = "via DM" + else: + message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + + footer = f"Added {time_delta} by {actor} | Reason: {reason}" + embed = Embed(description=f"{msg.author.mention} {message_jump}") + embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) + + await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) + + async def list_watched_users( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Gives an overview of the watched user list for this channel. + + The optional kwarg `oldest_first` orders the list by oldest entry. + + The optional kwarg `update_cache` specifies whether the cache should + be refreshed by polling the API. + """ + if update_cache: + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") + update_cache = False + + lines = [] + for user_id, user_data in self.watched_users.items(): + inserted_at = user_data['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + lines.append(f"• <@{user_id}> (added {time_delta})") + + if oldest_first: + lines.reverse() + + lines = lines or ("There's nothing here yet.",) + + embed = Embed( + title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", + color=Color.blue() + ) + await LinePaginator.paginate(lines, ctx, embed, empty=False) + + @staticmethod + def _get_time_delta(time_string: str) -> str: + """Returns the time in human-readable time delta format.""" + date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) + time_delta = time_since(date_time, precision="minutes", max_units=1) + + return time_delta + + def _remove_user(self, user_id: int) -> None: + """Removes a user from a watch channel.""" + self.watched_users.pop(user_id, None) + self.message_queue.pop(user_id, None) + self.consumption_queue.pop(user_id, None) + + def cog_unload(self) -> None: + """Takes care of unloading the cog and canceling the consumption task.""" + self.log.trace("Unloading the cog") + if self._consume_task and not self._consume_task.done(): + self._consume_task.cancel() + try: + self._consume_task.result() + except asyncio.CancelledError as e: + self.log.exception( + "The consume task was canceled. Messages may be lost.", + exc_info=e + ) diff --git a/bot/cogs/moderation/watchchannels/bigbrother.py b/bot/cogs/moderation/watchchannels/bigbrother.py index 0c72e88f7..7db34bcf2 100644 --- a/bot/cogs/moderation/watchchannels/bigbrother.py +++ b/bot/cogs/moderation/watchchannels/bigbrother.py @@ -5,11 +5,11 @@ from collections import ChainMap from discord.ext.commands import Cog, Context, group from bot.bot import Bot -from bot.cogs.moderation.infraction.utils import post_infraction +from bot.cogs.moderation.infraction._utils import post_infraction from bot.constants import Channels, MODERATION_ROLES, Webhooks from bot.converters import FetchedMember from bot.decorators import with_role -from .watchchannel import WatchChannel +from ._watchchannel import WatchChannel log = logging.getLogger(__name__) @@ -163,3 +163,8 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"): message = ":x: The specified user is currently not being watched." await ctx.send(message) + + +def setup(bot: Bot) -> None: + """Load the BigBrother cog.""" + bot.add_cog(BigBrother(bot)) diff --git a/bot/cogs/moderation/watchchannels/talentpool.py b/bot/cogs/moderation/watchchannels/talentpool.py index 89256e92e..2972f56e1 100644 --- a/bot/cogs/moderation/watchchannels/talentpool.py +++ b/bot/cogs/moderation/watchchannels/talentpool.py @@ -12,7 +12,7 @@ from bot.converters import FetchedMember from bot.decorators import with_role from bot.pagination import LinePaginator from bot.utils import time -from .watchchannel import WatchChannel +from ._watchchannel import WatchChannel log = logging.getLogger(__name__) @@ -262,3 +262,8 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): ) return lines.strip() + + +def setup(bot: Bot) -> None: + """Load the TalentPool cog.""" + bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/moderation/watchchannels/watchchannel.py b/bot/cogs/moderation/watchchannels/watchchannel.py deleted file mode 100644 index 044077350..000000000 --- a/bot/cogs/moderation/watchchannels/watchchannel.py +++ /dev/null @@ -1,348 +0,0 @@ -import asyncio -import logging -import re -import textwrap -from abc import abstractmethod -from collections import defaultdict, deque -from dataclasses import dataclass -from typing import Optional - -import dateutil.parser -import discord -from discord import Color, DMChannel, Embed, HTTPException, Message, errors -from discord.ext.commands import Cog, Context - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation import ModLog -from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons -from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - -URL_RE = re.compile(r"(https?://[^\s]+)") - - -@dataclass -class MessageHistory: - """Represents a watch channel's message history.""" - - last_author: Optional[int] = None - last_channel: Optional[int] = None - message_count: int = 0 - - -class WatchChannel(metaclass=CogABCMeta): - """ABC with functionality for relaying users' messages to a certain channel.""" - - @abstractmethod - def __init__( - self, - bot: Bot, - destination: int, - webhook_id: int, - api_endpoint: str, - api_default_params: dict, - logger: logging.Logger - ) -> None: - self.bot = bot - - self.destination = destination # E.g., Channels.big_brother_logs - self.webhook_id = webhook_id # E.g., Webhooks.big_brother - self.api_endpoint = api_endpoint # E.g., 'bot/infractions' - self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} - self.log = logger # Logger of the child cog for a correct name in the logs - - self._consume_task = None - self.watched_users = defaultdict(dict) - self.message_queue = defaultdict(lambda: defaultdict(deque)) - self.consumption_queue = {} - self.retries = 5 - self.retry_delay = 10 - self.channel = None - self.webhook = None - self.message_history = MessageHistory() - - self._start = self.bot.loop.create_task(self.start_watchchannel()) - - @property - def modlog(self) -> ModLog: - """Provides access to the ModLog cog for alert purposes.""" - return self.bot.get_cog("ModLog") - - @property - def consuming_messages(self) -> bool: - """Checks if a consumption task is currently running.""" - if self._consume_task is None: - return False - - if self._consume_task.done(): - exc = self._consume_task.exception() - if exc: - self.log.exception( - "The message queue consume task has failed with:", - exc_info=exc - ) - return False - - return True - - async def start_watchchannel(self) -> None: - """Starts the watch channel by getting the channel, webhook, and user cache ready.""" - await self.bot.wait_until_guild_available() - - try: - self.channel = await self.bot.fetch_channel(self.destination) - except HTTPException: - self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - if self.channel is None or self.webhook is None: - self.log.error("Failed to start the watch channel; unloading the cog.") - - message = textwrap.dedent( - f""" - An error occurred while loading the text channel or webhook. - - TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} - Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} - - The Cog has been unloaded. - """ - ) - - await self.modlog.send_log_message( - title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", - text=message, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - if not await self.fetch_user_cache(): - await self.modlog.send_log_message( - title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", - text="Could not retrieve the list of watched users from the API and messages will not be relayed.", - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - async def fetch_user_cache(self) -> bool: - """ - Fetches watched users from the API and updates the watched user cache accordingly. - - This function returns `True` if the update succeeded. - """ - try: - data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) - except ResponseCodeError as err: - self.log.exception("Failed to fetch the watched users from the API", exc_info=err) - return False - - self.watched_users = defaultdict(dict) - - for entry in data: - user_id = entry.pop('user') - self.watched_users[user_id] = entry - - return True - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Queues up messages sent by watched users.""" - if msg.author.id in self.watched_users: - if not self.consuming_messages: - self._consume_task = self.bot.loop.create_task(self.consume_messages()) - - self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") - self.message_queue[msg.author.id][msg.channel.id].append(msg) - - async def consume_messages(self, delay_consumption: bool = True) -> None: - """Consumes the message queues to log watched users' messages.""" - if delay_consumption: - self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") - await asyncio.sleep(BigBrotherConfig.log_delay) - - self.log.trace("Started consuming the message queue") - - # If the previous consumption Task failed, first consume the existing comsumption_queue - if not self.consumption_queue: - self.consumption_queue = self.message_queue.copy() - self.message_queue.clear() - - for user_channel_queues in self.consumption_queue.values(): - for channel_queue in user_channel_queues.values(): - while channel_queue: - msg = channel_queue.popleft() - - self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") - await self.relay_message(msg) - - self.consumption_queue.clear() - - if self.message_queue: - self.log.trace("Channel queue not empty: Continuing consuming queues") - self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) - else: - self.log.trace("Done consuming messages.") - - async def webhook_send( - self, - content: Optional[str] = None, - username: Optional[str] = None, - avatar_url: Optional[str] = None, - embed: Optional[Embed] = None, - ) -> None: - """Sends a message to the webhook with the specified kwargs.""" - username = messages.sub_clyde(username) - try: - await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send a message to the webhook", - exc_info=exc - ) - - async def relay_message(self, msg: Message) -> None: - """Relays the message to the relevant watch channel.""" - limit = BigBrotherConfig.header_message_limit - - if ( - msg.author.id != self.message_history.last_author - or msg.channel.id != self.message_history.last_channel - or self.message_history.message_count >= limit - ): - self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) - - await self.send_header(msg) - - cleaned_content = msg.clean_content - - if cleaned_content: - # Put all non-media URLs in a code block to prevent embeds - media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} - for url in URL_RE.findall(cleaned_content): - if url not in media_urls: - cleaned_content = cleaned_content.replace(url, f"`{url}`") - await self.webhook_send( - cleaned_content, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - - if msg.attachments: - try: - await messages.send_attachments(msg, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await self.webhook_send( - embed=e, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send an attachment to the webhook", - exc_info=exc - ) - - self.message_history.message_count += 1 - - async def send_header(self, msg: Message) -> None: - """Sends a header embed with information about the relayed messages to the watch channel.""" - user_id = msg.author.id - - guild = self.bot.get_guild(GuildConfig.id) - actor = guild.get_member(self.watched_users[user_id]['actor']) - actor = actor.display_name if actor else self.watched_users[user_id]['actor'] - - inserted_at = self.watched_users[user_id]['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - - reason = self.watched_users[user_id]['reason'] - - if isinstance(msg.channel, DMChannel): - # If a watched user DMs the bot there won't be a channel name or jump URL - # This could technically include a GroupChannel but bot's can't be in those - message_jump = "via DM" - else: - message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" - - footer = f"Added {time_delta} by {actor} | Reason: {reason}" - embed = Embed(description=f"{msg.author.mention} {message_jump}") - embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) - - await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) - - async def list_watched_users( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Gives an overview of the watched user list for this channel. - - The optional kwarg `oldest_first` orders the list by oldest entry. - - The optional kwarg `update_cache` specifies whether the cache should - be refreshed by polling the API. - """ - if update_cache: - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") - update_cache = False - - lines = [] - for user_id, user_data in self.watched_users.items(): - inserted_at = user_data['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - lines.append(f"• <@{user_id}> (added {time_delta})") - - if oldest_first: - lines.reverse() - - lines = lines or ("There's nothing here yet.",) - - embed = Embed( - title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", - color=Color.blue() - ) - await LinePaginator.paginate(lines, ctx, embed, empty=False) - - @staticmethod - def _get_time_delta(time_string: str) -> str: - """Returns the time in human-readable time delta format.""" - date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) - time_delta = time_since(date_time, precision="minutes", max_units=1) - - return time_delta - - def _remove_user(self, user_id: int) -> None: - """Removes a user from a watch channel.""" - self.watched_users.pop(user_id, None) - self.message_queue.pop(user_id, None) - self.consumption_queue.pop(user_id, None) - - def cog_unload(self) -> None: - """Takes care of unloading the cog and canceling the consumption task.""" - self.log.trace("Unloading the cog") - if self._consume_task and not self._consume_task.done(): - self._consume_task.cancel() - try: - self._consume_task.result() - except asyncio.CancelledError as e: - self.log.exception( - "The consume task was canceled. Messages may be lost.", - exc_info=e - ) diff --git a/tests/bot/cogs/backend/sync/test_base.py b/tests/bot/cogs/backend/sync/test_base.py index 0d0a8299d..3009aacb6 100644 --- a/tests/bot/cogs/backend/sync/test_base.py +++ b/tests/bot/cogs/backend/sync/test_base.py @@ -6,7 +6,7 @@ import discord from bot import constants from bot.api import ResponseCodeError -from bot.cogs.backend.sync.syncers import Syncer, _Diff +from bot.cogs.backend.sync._syncers import Syncer, _Diff from tests import helpers diff --git a/tests/bot/cogs/backend/sync/test_cog.py b/tests/bot/cogs/backend/sync/test_cog.py index 199747051..e40552817 100644 --- a/tests/bot/cogs/backend/sync/test_cog.py +++ b/tests/bot/cogs/backend/sync/test_cog.py @@ -6,7 +6,8 @@ import discord from bot import constants from bot.api import ResponseCodeError from bot.cogs.backend import sync -from bot.cogs.backend.sync.syncers import Syncer +from bot.cogs.backend.sync._cog import Sync +from bot.cogs.backend.sync._syncers import Syncer from tests import helpers from tests.base import CommandTestCase @@ -29,19 +30,19 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): self.bot = helpers.MockBot() self.role_syncer_patcher = mock.patch( - "bot.cogs.backend.sync.syncers.RoleSyncer", + "bot.cogs.backend.sync._syncers.RoleSyncer", autospec=Syncer, spec_set=True ) self.user_syncer_patcher = mock.patch( - "bot.cogs.backend.sync.syncers.UserSyncer", + "bot.cogs.backend.sync._syncers.UserSyncer", autospec=Syncer, spec_set=True ) self.RoleSyncer = self.role_syncer_patcher.start() self.UserSyncer = self.user_syncer_patcher.start() - self.cog = sync.Sync(self.bot) + self.cog = Sync(self.bot) def tearDown(self): self.role_syncer_patcher.stop() @@ -59,7 +60,7 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) + @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) def test_sync_cog_init(self, sync_guild): """Should instantiate syncers and run a sync for the guild.""" # Reset because a Sync cog was already instantiated in setUp. @@ -70,7 +71,7 @@ class SyncCogTests(SyncCogTestCase): mock_sync_guild_coro = mock.MagicMock() sync_guild.return_value = mock_sync_guild_coro - sync.Sync(self.bot) + Sync(self.bot) self.RoleSyncer.assert_called_once_with(self.bot) self.UserSyncer.assert_called_once_with(self.bot) @@ -131,7 +132,7 @@ class SyncCogListenerTests(SyncCogTestCase): super().setUp() self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - self.guild_id_patcher = mock.patch("bot.cogs.backend.sync.cog.constants.Guild.id", 5) + self.guild_id_patcher = mock.patch("bot.cogs.backend.sync._cog.constants.Guild.id", 5) self.guild_id = self.guild_id_patcher.start() self.guild = helpers.MockGuild(id=self.guild_id) diff --git a/tests/bot/cogs/backend/sync/test_roles.py b/tests/bot/cogs/backend/sync/test_roles.py index cc2e51c7f..99d682ede 100644 --- a/tests/bot/cogs/backend/sync/test_roles.py +++ b/tests/bot/cogs/backend/sync/test_roles.py @@ -3,7 +3,7 @@ from unittest import mock import discord -from bot.cogs.backend.sync.syncers import RoleSyncer, _Diff, _Role +from bot.cogs.backend.sync._syncers import RoleSyncer, _Diff, _Role from tests import helpers diff --git a/tests/bot/cogs/backend/sync/test_users.py b/tests/bot/cogs/backend/sync/test_users.py index 490ea9e06..51dcbe48a 100644 --- a/tests/bot/cogs/backend/sync/test_users.py +++ b/tests/bot/cogs/backend/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from bot.cogs.backend.sync.syncers import UserSyncer, _Diff, _User +from bot.cogs.backend.sync._syncers import UserSyncer, _Diff, _User from tests import helpers diff --git a/tests/bot/cogs/moderation/infraction/test_infractions.py b/tests/bot/cogs/moderation/infraction/test_infractions.py index a79042557..2df61d431 100644 --- a/tests/bot/cogs/moderation/infraction/test_infractions.py +++ b/tests/bot/cogs/moderation/infraction/test_infractions.py @@ -17,8 +17,8 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.guild = MockGuild(id=4567) self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) - @patch("bot.cogs.moderation.infraction.utils.get_active_infraction") - @patch("bot.cogs.moderation.infraction.utils.post_infraction") + @patch("bot.cogs.moderation.infraction._utils.get_active_infraction") + @patch("bot.cogs.moderation.infraction._utils.post_infraction") async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): """Should truncate reason for `ctx.guild.ban`.""" get_active_mock.return_value = None @@ -39,7 +39,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value ) - @patch("bot.cogs.moderation.infraction.utils.post_infraction") + @patch("bot.cogs.moderation.infraction._utils.post_infraction") async def test_apply_kick_reason_truncation(self, post_infraction_mock): """Should truncate reason for `Member.kick`.""" post_infraction_mock.return_value = {"foo": "bar"} -- cgit v1.2.3 From aaee0f86e99f8dfdc454c52516fbdf7f0030168a Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 12 Aug 2020 23:07:30 -0700 Subject: Fix ModLog imports Bunch of modules still rely on importing the cog directly from the moderation package. --- bot/cogs/filters/antispam.py | 2 +- bot/cogs/filters/filtering.py | 2 +- bot/cogs/filters/token_remover.py | 2 +- bot/cogs/moderation/defcon.py | 2 +- bot/cogs/moderation/verification.py | 2 +- bot/cogs/moderation/watchchannels/_watchchannel.py | 2 +- bot/cogs/utils/clean.py | 2 +- tests/bot/cogs/filters/test_token_remover.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/bot/cogs/filters/antispam.py b/bot/cogs/filters/antispam.py index 0bcca578d..d2dccea06 100644 --- a/bot/cogs/filters/antispam.py +++ b/bot/cogs/filters/antispam.py @@ -11,7 +11,7 @@ from discord.ext.commands import Cog from bot import rules from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import ( AntiSpam as AntiSpamConfig, Channels, Colours, DEBUG_MODE, Event, Filter, diff --git a/bot/cogs/filters/filtering.py b/bot/cogs/filters/filtering.py index 93cc1c655..556b466ef 100644 --- a/bot/cogs/filters/filtering.py +++ b/bot/cogs/filters/filtering.py @@ -12,7 +12,7 @@ from discord.ext.commands import Cog from discord.utils import escape_markdown from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import ( Channels, Colours, Filter, Icons, URLs diff --git a/bot/cogs/filters/token_remover.py b/bot/cogs/filters/token_remover.py index ef979f222..8eace07b6 100644 --- a/bot/cogs/filters/token_remover.py +++ b/bot/cogs/filters/token_remover.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog from bot import utils from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import Channels, Colours, Event, Icons log = logging.getLogger(__name__) diff --git a/bot/cogs/moderation/defcon.py b/bot/cogs/moderation/defcon.py index 4c0ad5914..e78435a7d 100644 --- a/bot/cogs/moderation/defcon.py +++ b/bot/cogs/moderation/defcon.py @@ -9,7 +9,7 @@ from discord import Colour, Embed, Member from discord.ext.commands import Cog, Context, group from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles from bot.decorators import with_role diff --git a/bot/cogs/moderation/verification.py b/bot/cogs/moderation/verification.py index ae156cf70..ba95ab5e4 100644 --- a/bot/cogs/moderation/verification.py +++ b/bot/cogs/moderation/verification.py @@ -6,7 +6,7 @@ from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.decorators import in_whitelist, without_role from bot.utils.checks import InWhitelistCheckFailure, without_role_check diff --git a/bot/cogs/moderation/watchchannels/_watchchannel.py b/bot/cogs/moderation/watchchannels/_watchchannel.py index 044077350..488ae704d 100644 --- a/bot/cogs/moderation/watchchannels/_watchchannel.py +++ b/bot/cogs/moderation/watchchannels/_watchchannel.py @@ -14,7 +14,7 @@ from discord.ext.commands import Cog, Context from bot.api import ResponseCodeError from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons from bot.pagination import LinePaginator from bot.utils import CogABCMeta, messages diff --git a/bot/cogs/utils/clean.py b/bot/cogs/utils/clean.py index f436e531a..c156ff02e 100644 --- a/bot/cogs/utils/clean.py +++ b/bot/cogs/utils/clean.py @@ -8,7 +8,7 @@ from discord.ext import commands from discord.ext.commands import Cog, Context, group from bot.bot import Bot -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from bot.constants import ( Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES ) diff --git a/tests/bot/cogs/filters/test_token_remover.py b/tests/bot/cogs/filters/test_token_remover.py index 5c527ed94..55b284ef9 100644 --- a/tests/bot/cogs/filters/test_token_remover.py +++ b/tests/bot/cogs/filters/test_token_remover.py @@ -8,7 +8,7 @@ from discord import Colour, NotFound from bot import constants from bot.cogs.filters import token_remover from bot.cogs.filters.token_remover import Token, TokenRemover -from bot.cogs.moderation import ModLog +from bot.cogs.moderation.modlog import ModLog from tests.helpers import MockBot, MockMessage, autospec -- cgit v1.2.3 From 1c2b384915f4a7ba070c95c86126746bae2f7279 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 14 Aug 2020 09:59:56 -0700 Subject: Rename "cogs" directory to "exts" The directory contains modules, which are extensions. It only indirectly contains cogs through the extensions. Therefore, a technically more accurate name is "extensions", or "exts" when abbreviated. Furthermore, "exts" is consistent with SeasonalBot. --- bot/__main__.py | 90 +- bot/cogs/__init__.py | 0 bot/cogs/alias.py | 153 ---- bot/cogs/backend/__init__.py | 0 bot/cogs/backend/config_verifier.py | 40 - bot/cogs/backend/error_handler.py | 287 ------- bot/cogs/backend/logging.py | 42 - bot/cogs/backend/sync/__init__.py | 7 - bot/cogs/backend/sync/_cog.py | 180 ---- bot/cogs/backend/sync/_syncers.py | 347 -------- bot/cogs/dm_relay.py | 124 --- bot/cogs/duck_pond.py | 166 ---- bot/cogs/filters/__init__.py | 0 bot/cogs/filters/antimalware.py | 98 --- bot/cogs/filters/antispam.py | 288 ------- bot/cogs/filters/filter_lists.py | 273 ------ bot/cogs/filters/filtering.py | 575 ------------- bot/cogs/filters/security.py | 31 - bot/cogs/filters/token_remover.py | 182 ---- bot/cogs/filters/webhook_remover.py | 84 -- bot/cogs/help_channels.py | 944 --------------------- bot/cogs/info/__init__.py | 0 bot/cogs/info/doc.py | 511 ----------- bot/cogs/info/help.py | 375 -------- bot/cogs/info/information.py | 422 --------- bot/cogs/info/python_news.py | 232 ----- bot/cogs/info/reddit.py | 304 ------- bot/cogs/info/site.py | 146 ---- bot/cogs/info/source.py | 141 --- bot/cogs/info/stats.py | 129 --- bot/cogs/info/tags.py | 277 ------ bot/cogs/info/wolfram.py | 280 ------ bot/cogs/moderation/__init__.py | 0 bot/cogs/moderation/defcon.py | 258 ------ bot/cogs/moderation/incidents.py | 412 --------- bot/cogs/moderation/infraction/__init__.py | 0 bot/cogs/moderation/infraction/_scheduler.py | 463 ---------- bot/cogs/moderation/infraction/_utils.py | 201 ----- bot/cogs/moderation/infraction/infractions.py | 375 -------- bot/cogs/moderation/infraction/management.py | 310 ------- bot/cogs/moderation/infraction/superstarify.py | 244 ------ bot/cogs/moderation/modlog.py | 837 ------------------ bot/cogs/moderation/silence.py | 170 ---- bot/cogs/moderation/slowmode.py | 97 --- bot/cogs/moderation/verification.py | 191 ----- bot/cogs/moderation/watchchannels/__init__.py | 0 bot/cogs/moderation/watchchannels/_watchchannel.py | 348 -------- bot/cogs/moderation/watchchannels/bigbrother.py | 170 ---- bot/cogs/moderation/watchchannels/talentpool.py | 269 ------ bot/cogs/off_topic_names.py | 162 ---- bot/cogs/utils/__init__.py | 0 bot/cogs/utils/bot.py | 385 --------- bot/cogs/utils/clean.py | 272 ------ bot/cogs/utils/eval.py | 202 ----- bot/cogs/utils/extensions.py | 289 ------- bot/cogs/utils/jams.py | 150 ---- bot/cogs/utils/reminders.py | 427 ---------- bot/cogs/utils/snekbox.py | 349 -------- bot/cogs/utils/utils.py | 265 ------ bot/exts/__init__.py | 0 bot/exts/alias.py | 153 ++++ bot/exts/backend/__init__.py | 0 bot/exts/backend/config_verifier.py | 40 + bot/exts/backend/error_handler.py | 287 +++++++ bot/exts/backend/logging.py | 42 + bot/exts/backend/sync/__init__.py | 7 + bot/exts/backend/sync/_cog.py | 180 ++++ bot/exts/backend/sync/_syncers.py | 347 ++++++++ bot/exts/dm_relay.py | 124 +++ bot/exts/duck_pond.py | 166 ++++ bot/exts/filters/__init__.py | 0 bot/exts/filters/antimalware.py | 98 +++ bot/exts/filters/antispam.py | 288 +++++++ bot/exts/filters/filter_lists.py | 273 ++++++ bot/exts/filters/filtering.py | 575 +++++++++++++ bot/exts/filters/security.py | 31 + bot/exts/filters/token_remover.py | 182 ++++ bot/exts/filters/webhook_remover.py | 84 ++ bot/exts/help_channels.py | 944 +++++++++++++++++++++ bot/exts/info/__init__.py | 0 bot/exts/info/doc.py | 511 +++++++++++ bot/exts/info/help.py | 375 ++++++++ bot/exts/info/information.py | 422 +++++++++ bot/exts/info/python_news.py | 232 +++++ bot/exts/info/reddit.py | 304 +++++++ bot/exts/info/site.py | 146 ++++ bot/exts/info/source.py | 141 +++ bot/exts/info/stats.py | 129 +++ bot/exts/info/tags.py | 277 ++++++ bot/exts/info/wolfram.py | 280 ++++++ bot/exts/moderation/__init__.py | 0 bot/exts/moderation/defcon.py | 258 ++++++ bot/exts/moderation/incidents.py | 412 +++++++++ bot/exts/moderation/infraction/__init__.py | 0 bot/exts/moderation/infraction/_scheduler.py | 463 ++++++++++ bot/exts/moderation/infraction/_utils.py | 201 +++++ bot/exts/moderation/infraction/infractions.py | 375 ++++++++ bot/exts/moderation/infraction/management.py | 310 +++++++ bot/exts/moderation/infraction/superstarify.py | 244 ++++++ bot/exts/moderation/modlog.py | 837 ++++++++++++++++++ bot/exts/moderation/silence.py | 170 ++++ bot/exts/moderation/slowmode.py | 97 +++ bot/exts/moderation/verification.py | 191 +++++ bot/exts/moderation/watchchannels/__init__.py | 0 bot/exts/moderation/watchchannels/_watchchannel.py | 348 ++++++++ bot/exts/moderation/watchchannels/bigbrother.py | 170 ++++ bot/exts/moderation/watchchannels/talentpool.py | 269 ++++++ bot/exts/off_topic_names.py | 162 ++++ bot/exts/utils/__init__.py | 0 bot/exts/utils/bot.py | 385 +++++++++ bot/exts/utils/clean.py | 272 ++++++ bot/exts/utils/eval.py | 202 +++++ bot/exts/utils/extensions.py | 289 +++++++ bot/exts/utils/jams.py | 150 ++++ bot/exts/utils/reminders.py | 427 ++++++++++ bot/exts/utils/snekbox.py | 349 ++++++++ bot/exts/utils/utils.py | 265 ++++++ tests/bot/cogs/__init__.py | 0 tests/bot/cogs/backend/__init__.py | 0 tests/bot/cogs/backend/sync/__init__.py | 0 tests/bot/cogs/backend/sync/test_base.py | 404 --------- tests/bot/cogs/backend/sync/test_cog.py | 416 --------- tests/bot/cogs/backend/sync/test_roles.py | 157 ---- tests/bot/cogs/backend/sync/test_users.py | 158 ---- tests/bot/cogs/backend/test_logging.py | 32 - tests/bot/cogs/filters/__init__.py | 0 tests/bot/cogs/filters/test_antimalware.py | 165 ---- tests/bot/cogs/filters/test_antispam.py | 35 - tests/bot/cogs/filters/test_security.py | 54 -- tests/bot/cogs/filters/test_token_remover.py | 310 ------- tests/bot/cogs/info/__init__.py | 0 tests/bot/cogs/info/test_information.py | 584 ------------- tests/bot/cogs/moderation/__init__.py | 0 tests/bot/cogs/moderation/infraction/__init__.py | 0 .../cogs/moderation/infraction/test_infractions.py | 55 -- tests/bot/cogs/moderation/test_incidents.py | 770 ----------------- tests/bot/cogs/moderation/test_modlog.py | 29 - tests/bot/cogs/moderation/test_silence.py | 261 ------ tests/bot/cogs/moderation/test_slowmode.py | 111 --- tests/bot/cogs/test_cogs.py | 80 -- tests/bot/cogs/test_duck_pond.py | 548 ------------ tests/bot/cogs/utils/__init__.py | 0 tests/bot/cogs/utils/test_jams.py | 173 ---- tests/bot/cogs/utils/test_snekbox.py | 409 --------- tests/bot/exts/__init__.py | 0 tests/bot/exts/backend/__init__.py | 0 tests/bot/exts/backend/sync/__init__.py | 0 tests/bot/exts/backend/sync/test_base.py | 404 +++++++++ tests/bot/exts/backend/sync/test_cog.py | 416 +++++++++ tests/bot/exts/backend/sync/test_roles.py | 157 ++++ tests/bot/exts/backend/sync/test_users.py | 158 ++++ tests/bot/exts/backend/test_logging.py | 32 + tests/bot/exts/filters/__init__.py | 0 tests/bot/exts/filters/test_antimalware.py | 165 ++++ tests/bot/exts/filters/test_antispam.py | 35 + tests/bot/exts/filters/test_security.py | 54 ++ tests/bot/exts/filters/test_token_remover.py | 310 +++++++ tests/bot/exts/info/__init__.py | 0 tests/bot/exts/info/test_information.py | 584 +++++++++++++ tests/bot/exts/moderation/__init__.py | 0 tests/bot/exts/moderation/infraction/__init__.py | 0 .../exts/moderation/infraction/test_infractions.py | 55 ++ tests/bot/exts/moderation/test_incidents.py | 770 +++++++++++++++++ tests/bot/exts/moderation/test_modlog.py | 29 + tests/bot/exts/moderation/test_silence.py | 261 ++++++ tests/bot/exts/moderation/test_slowmode.py | 111 +++ tests/bot/exts/test_cogs.py | 81 ++ tests/bot/exts/test_duck_pond.py | 548 ++++++++++++ tests/bot/exts/utils/__init__.py | 0 tests/bot/exts/utils/test_jams.py | 173 ++++ tests/bot/exts/utils/test_snekbox.py | 409 +++++++++ 171 files changed, 18281 insertions(+), 18280 deletions(-) delete mode 100644 bot/cogs/__init__.py delete mode 100644 bot/cogs/alias.py delete mode 100644 bot/cogs/backend/__init__.py delete mode 100644 bot/cogs/backend/config_verifier.py delete mode 100644 bot/cogs/backend/error_handler.py delete mode 100644 bot/cogs/backend/logging.py delete mode 100644 bot/cogs/backend/sync/__init__.py delete mode 100644 bot/cogs/backend/sync/_cog.py delete mode 100644 bot/cogs/backend/sync/_syncers.py delete mode 100644 bot/cogs/dm_relay.py delete mode 100644 bot/cogs/duck_pond.py delete mode 100644 bot/cogs/filters/__init__.py delete mode 100644 bot/cogs/filters/antimalware.py delete mode 100644 bot/cogs/filters/antispam.py delete mode 100644 bot/cogs/filters/filter_lists.py delete mode 100644 bot/cogs/filters/filtering.py delete mode 100644 bot/cogs/filters/security.py delete mode 100644 bot/cogs/filters/token_remover.py delete mode 100644 bot/cogs/filters/webhook_remover.py delete mode 100644 bot/cogs/help_channels.py delete mode 100644 bot/cogs/info/__init__.py delete mode 100644 bot/cogs/info/doc.py delete mode 100644 bot/cogs/info/help.py delete mode 100644 bot/cogs/info/information.py delete mode 100644 bot/cogs/info/python_news.py delete mode 100644 bot/cogs/info/reddit.py delete mode 100644 bot/cogs/info/site.py delete mode 100644 bot/cogs/info/source.py delete mode 100644 bot/cogs/info/stats.py delete mode 100644 bot/cogs/info/tags.py delete mode 100644 bot/cogs/info/wolfram.py delete mode 100644 bot/cogs/moderation/__init__.py delete mode 100644 bot/cogs/moderation/defcon.py delete mode 100644 bot/cogs/moderation/incidents.py delete mode 100644 bot/cogs/moderation/infraction/__init__.py delete mode 100644 bot/cogs/moderation/infraction/_scheduler.py delete mode 100644 bot/cogs/moderation/infraction/_utils.py delete mode 100644 bot/cogs/moderation/infraction/infractions.py delete mode 100644 bot/cogs/moderation/infraction/management.py delete mode 100644 bot/cogs/moderation/infraction/superstarify.py delete mode 100644 bot/cogs/moderation/modlog.py delete mode 100644 bot/cogs/moderation/silence.py delete mode 100644 bot/cogs/moderation/slowmode.py delete mode 100644 bot/cogs/moderation/verification.py delete mode 100644 bot/cogs/moderation/watchchannels/__init__.py delete mode 100644 bot/cogs/moderation/watchchannels/_watchchannel.py delete mode 100644 bot/cogs/moderation/watchchannels/bigbrother.py delete mode 100644 bot/cogs/moderation/watchchannels/talentpool.py delete mode 100644 bot/cogs/off_topic_names.py delete mode 100644 bot/cogs/utils/__init__.py delete mode 100644 bot/cogs/utils/bot.py delete mode 100644 bot/cogs/utils/clean.py delete mode 100644 bot/cogs/utils/eval.py delete mode 100644 bot/cogs/utils/extensions.py delete mode 100644 bot/cogs/utils/jams.py delete mode 100644 bot/cogs/utils/reminders.py delete mode 100644 bot/cogs/utils/snekbox.py delete mode 100644 bot/cogs/utils/utils.py create mode 100644 bot/exts/__init__.py create mode 100644 bot/exts/alias.py create mode 100644 bot/exts/backend/__init__.py create mode 100644 bot/exts/backend/config_verifier.py create mode 100644 bot/exts/backend/error_handler.py create mode 100644 bot/exts/backend/logging.py create mode 100644 bot/exts/backend/sync/__init__.py create mode 100644 bot/exts/backend/sync/_cog.py create mode 100644 bot/exts/backend/sync/_syncers.py create mode 100644 bot/exts/dm_relay.py create mode 100644 bot/exts/duck_pond.py create mode 100644 bot/exts/filters/__init__.py create mode 100644 bot/exts/filters/antimalware.py create mode 100644 bot/exts/filters/antispam.py create mode 100644 bot/exts/filters/filter_lists.py create mode 100644 bot/exts/filters/filtering.py create mode 100644 bot/exts/filters/security.py create mode 100644 bot/exts/filters/token_remover.py create mode 100644 bot/exts/filters/webhook_remover.py create mode 100644 bot/exts/help_channels.py create mode 100644 bot/exts/info/__init__.py create mode 100644 bot/exts/info/doc.py create mode 100644 bot/exts/info/help.py create mode 100644 bot/exts/info/information.py create mode 100644 bot/exts/info/python_news.py create mode 100644 bot/exts/info/reddit.py create mode 100644 bot/exts/info/site.py create mode 100644 bot/exts/info/source.py create mode 100644 bot/exts/info/stats.py create mode 100644 bot/exts/info/tags.py create mode 100644 bot/exts/info/wolfram.py create mode 100644 bot/exts/moderation/__init__.py create mode 100644 bot/exts/moderation/defcon.py create mode 100644 bot/exts/moderation/incidents.py create mode 100644 bot/exts/moderation/infraction/__init__.py create mode 100644 bot/exts/moderation/infraction/_scheduler.py create mode 100644 bot/exts/moderation/infraction/_utils.py create mode 100644 bot/exts/moderation/infraction/infractions.py create mode 100644 bot/exts/moderation/infraction/management.py create mode 100644 bot/exts/moderation/infraction/superstarify.py create mode 100644 bot/exts/moderation/modlog.py create mode 100644 bot/exts/moderation/silence.py create mode 100644 bot/exts/moderation/slowmode.py create mode 100644 bot/exts/moderation/verification.py create mode 100644 bot/exts/moderation/watchchannels/__init__.py create mode 100644 bot/exts/moderation/watchchannels/_watchchannel.py create mode 100644 bot/exts/moderation/watchchannels/bigbrother.py create mode 100644 bot/exts/moderation/watchchannels/talentpool.py create mode 100644 bot/exts/off_topic_names.py create mode 100644 bot/exts/utils/__init__.py create mode 100644 bot/exts/utils/bot.py create mode 100644 bot/exts/utils/clean.py create mode 100644 bot/exts/utils/eval.py create mode 100644 bot/exts/utils/extensions.py create mode 100644 bot/exts/utils/jams.py create mode 100644 bot/exts/utils/reminders.py create mode 100644 bot/exts/utils/snekbox.py create mode 100644 bot/exts/utils/utils.py delete mode 100644 tests/bot/cogs/__init__.py delete mode 100644 tests/bot/cogs/backend/__init__.py delete mode 100644 tests/bot/cogs/backend/sync/__init__.py delete mode 100644 tests/bot/cogs/backend/sync/test_base.py delete mode 100644 tests/bot/cogs/backend/sync/test_cog.py delete mode 100644 tests/bot/cogs/backend/sync/test_roles.py delete mode 100644 tests/bot/cogs/backend/sync/test_users.py delete mode 100644 tests/bot/cogs/backend/test_logging.py delete mode 100644 tests/bot/cogs/filters/__init__.py delete mode 100644 tests/bot/cogs/filters/test_antimalware.py delete mode 100644 tests/bot/cogs/filters/test_antispam.py delete mode 100644 tests/bot/cogs/filters/test_security.py delete mode 100644 tests/bot/cogs/filters/test_token_remover.py delete mode 100644 tests/bot/cogs/info/__init__.py delete mode 100644 tests/bot/cogs/info/test_information.py delete mode 100644 tests/bot/cogs/moderation/__init__.py delete mode 100644 tests/bot/cogs/moderation/infraction/__init__.py delete mode 100644 tests/bot/cogs/moderation/infraction/test_infractions.py delete mode 100644 tests/bot/cogs/moderation/test_incidents.py delete mode 100644 tests/bot/cogs/moderation/test_modlog.py delete mode 100644 tests/bot/cogs/moderation/test_silence.py delete mode 100644 tests/bot/cogs/moderation/test_slowmode.py delete mode 100644 tests/bot/cogs/test_cogs.py delete mode 100644 tests/bot/cogs/test_duck_pond.py delete mode 100644 tests/bot/cogs/utils/__init__.py delete mode 100644 tests/bot/cogs/utils/test_jams.py delete mode 100644 tests/bot/cogs/utils/test_snekbox.py create mode 100644 tests/bot/exts/__init__.py create mode 100644 tests/bot/exts/backend/__init__.py create mode 100644 tests/bot/exts/backend/sync/__init__.py create mode 100644 tests/bot/exts/backend/sync/test_base.py create mode 100644 tests/bot/exts/backend/sync/test_cog.py create mode 100644 tests/bot/exts/backend/sync/test_roles.py create mode 100644 tests/bot/exts/backend/sync/test_users.py create mode 100644 tests/bot/exts/backend/test_logging.py create mode 100644 tests/bot/exts/filters/__init__.py create mode 100644 tests/bot/exts/filters/test_antimalware.py create mode 100644 tests/bot/exts/filters/test_antispam.py create mode 100644 tests/bot/exts/filters/test_security.py create mode 100644 tests/bot/exts/filters/test_token_remover.py create mode 100644 tests/bot/exts/info/__init__.py create mode 100644 tests/bot/exts/info/test_information.py create mode 100644 tests/bot/exts/moderation/__init__.py create mode 100644 tests/bot/exts/moderation/infraction/__init__.py create mode 100644 tests/bot/exts/moderation/infraction/test_infractions.py create mode 100644 tests/bot/exts/moderation/test_incidents.py create mode 100644 tests/bot/exts/moderation/test_modlog.py create mode 100644 tests/bot/exts/moderation/test_silence.py create mode 100644 tests/bot/exts/moderation/test_slowmode.py create mode 100644 tests/bot/exts/test_cogs.py create mode 100644 tests/bot/exts/test_duck_pond.py create mode 100644 tests/bot/exts/utils/__init__.py create mode 100644 tests/bot/exts/utils/test_jams.py create mode 100644 tests/bot/exts/utils/test_snekbox.py (limited to 'tests') diff --git a/bot/__main__.py b/bot/__main__.py index 4b0f6dfe4..555847357 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -34,67 +34,67 @@ bot = Bot( ) # Backend -bot.load_extension("bot.cogs.backend.config_verifier") -bot.load_extension("bot.cogs.backend.error_handler") -bot.load_extension("bot.cogs.backend.logging") -bot.load_extension("bot.cogs.backend.sync") +bot.load_extension("bot.exts.backend.config_verifier") +bot.load_extension("bot.exts.backend.error_handler") +bot.load_extension("bot.exts.backend.logging") +bot.load_extension("bot.exts.backend.sync") # Filters -bot.load_extension("bot.cogs.filters.antimalware") -bot.load_extension("bot.cogs.filters.antispam") -bot.load_extension("bot.cogs.filters.filter_lists") -bot.load_extension("bot.cogs.filters.filtering") -bot.load_extension("bot.cogs.filters.security") -bot.load_extension("bot.cogs.filters.token_remover") -bot.load_extension("bot.cogs.filters.webhook_remover") +bot.load_extension("bot.exts.filters.antimalware") +bot.load_extension("bot.exts.filters.antispam") +bot.load_extension("bot.exts.filters.filter_lists") +bot.load_extension("bot.exts.filters.filtering") +bot.load_extension("bot.exts.filters.security") +bot.load_extension("bot.exts.filters.token_remover") +bot.load_extension("bot.exts.filters.webhook_remover") # Info -bot.load_extension("bot.cogs.info.doc") -bot.load_extension("bot.cogs.info.help") -bot.load_extension("bot.cogs.info.information") -bot.load_extension("bot.cogs.info.python_news") -bot.load_extension("bot.cogs.info.reddit") -bot.load_extension("bot.cogs.info.site") -bot.load_extension("bot.cogs.info.source") -bot.load_extension("bot.cogs.info.stats") -bot.load_extension("bot.cogs.info.tags") -bot.load_extension("bot.cogs.info.wolfram") +bot.load_extension("bot.exts.info.doc") +bot.load_extension("bot.exts.info.help") +bot.load_extension("bot.exts.info.information") +bot.load_extension("bot.exts.info.python_news") +bot.load_extension("bot.exts.info.reddit") +bot.load_extension("bot.exts.info.site") +bot.load_extension("bot.exts.info.source") +bot.load_extension("bot.exts.info.stats") +bot.load_extension("bot.exts.info.tags") +bot.load_extension("bot.exts.info.wolfram") # Moderation -bot.load_extension("bot.cogs.moderation.defcon") -bot.load_extension("bot.cogs.moderation.incidents") -bot.load_extension("bot.cogs.moderation.modlog") -bot.load_extension("bot.cogs.moderation.silence") -bot.load_extension("bot.cogs.moderation.slowmode") -bot.load_extension("bot.cogs.moderation.verification") +bot.load_extension("bot.exts.moderation.defcon") +bot.load_extension("bot.exts.moderation.incidents") +bot.load_extension("bot.exts.moderation.modlog") +bot.load_extension("bot.exts.moderation.silence") +bot.load_extension("bot.exts.moderation.slowmode") +bot.load_extension("bot.exts.moderation.verification") # Moderation - Infraction -bot.load_extension("bot.cogs.moderation.infraction.infractions") -bot.load_extension("bot.cogs.moderation.infraction.management") -bot.load_extension("bot.cogs.moderation.infraction.superstarify") +bot.load_extension("bot.exts.moderation.infraction.infractions") +bot.load_extension("bot.exts.moderation.infraction.management") +bot.load_extension("bot.exts.moderation.infraction.superstarify") # Moderation - Watchchannels -bot.load_extension("bot.cogs.moderation.watchchannels.bigbrother") -bot.load_extension("bot.cogs.moderation.watchchannels.talentpool") +bot.load_extension("bot.exts.moderation.watchchannels.bigbrother") +bot.load_extension("bot.exts.moderation.watchchannels.talentpool") # Utils -bot.load_extension("bot.cogs.utils.bot") -bot.load_extension("bot.cogs.utils.clean") -bot.load_extension("bot.cogs.utils.eval") -bot.load_extension("bot.cogs.utils.extensions") -bot.load_extension("bot.cogs.utils.jams") -bot.load_extension("bot.cogs.utils.reminders") -bot.load_extension("bot.cogs.utils.snekbox") -bot.load_extension("bot.cogs.utils.utils") +bot.load_extension("bot.exts.utils.bot") +bot.load_extension("bot.exts.utils.clean") +bot.load_extension("bot.exts.utils.eval") +bot.load_extension("bot.exts.utils.extensions") +bot.load_extension("bot.exts.utils.jams") +bot.load_extension("bot.exts.utils.reminders") +bot.load_extension("bot.exts.utils.snekbox") +bot.load_extension("bot.exts.utils.utils") # Misc -bot.load_extension("bot.cogs.alias") -bot.load_extension("bot.cogs.dm_relay") -bot.load_extension("bot.cogs.duck_pond") -bot.load_extension("bot.cogs.off_topic_names") +bot.load_extension("bot.exts.alias") +bot.load_extension("bot.exts.dm_relay") +bot.load_extension("bot.exts.duck_pond") +bot.load_extension("bot.exts.off_topic_names") if constants.HelpChannels.enable: - bot.load_extension("bot.cogs.help_channels") + bot.load_extension("bot.exts.help_channels") # Apply `message_edited_at` patch if discord.py did not yet release a bug fix. if not hasattr(discord.message.Message, '_handle_edited_timestamp'): diff --git a/bot/cogs/__init__.py b/bot/cogs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py deleted file mode 100644 index 3c5a35c24..000000000 --- a/bot/cogs/alias.py +++ /dev/null @@ -1,153 +0,0 @@ -import inspect -import logging - -from discord import Colour, Embed -from discord.ext.commands import ( - Cog, Command, Context, Greedy, - clean_content, command, group, -) - -from bot.bot import Bot -from bot.cogs.utils.extensions import Extension -from bot.converters import FetchedMember, TagNameConverter -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - - -class Alias (Cog): - """Aliases for commonly used commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: - """Invokes a command with args and kwargs.""" - log.debug(f"{cmd_name} was invoked through an alias") - cmd = self.bot.get_command(cmd_name) - if not cmd: - return log.info(f'Did not find command "{cmd_name}" to invoke.') - elif not await cmd.can_run(ctx): - return log.info( - f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' - ) - - await ctx.invoke(cmd, *args, **kwargs) - - @command(name='aliases') - async def aliases_command(self, ctx: Context) -> None: - """Show configured aliases on the bot.""" - embed = Embed( - title='Configured aliases', - colour=Colour.blue() - ) - await LinePaginator.paginate( - ( - f"• `{ctx.prefix}{value.name}` " - f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" - for name, value in inspect.getmembers(self) - if isinstance(value, Command) and name.endswith('_alias') - ), - ctx, embed, empty=False, max_lines=20 - ) - - @command(name="resources", aliases=("resource",), hidden=True) - async def site_resources_alias(self, ctx: Context) -> None: - """Alias for invoking site resources.""" - await self.invoke(ctx, "site resources") - - @command(name="tools", hidden=True) - async def site_tools_alias(self, ctx: Context) -> None: - """Alias for invoking site tools.""" - await self.invoke(ctx, "site tools") - - @command(name="watch", hidden=True) - async def bigbrother_watch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking bigbrother watch [user] [reason].""" - await self.invoke(ctx, "bigbrother watch", user, reason=reason) - - @command(name="unwatch", hidden=True) - async def bigbrother_unwatch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking bigbrother unwatch [user] [reason].""" - await self.invoke(ctx, "bigbrother unwatch", user, reason=reason) - - @command(name="home", hidden=True) - async def site_home_alias(self, ctx: Context) -> None: - """Alias for invoking site home.""" - await self.invoke(ctx, "site home") - - @command(name="faq", hidden=True) - async def site_faq_alias(self, ctx: Context) -> None: - """Alias for invoking site faq.""" - await self.invoke(ctx, "site faq") - - @command(name="rules", aliases=("rule",), hidden=True) - async def site_rules_alias(self, ctx: Context, rules: Greedy[int], *_: str) -> None: - """Alias for invoking site rules.""" - await self.invoke(ctx, "site rules", *rules) - - @command(name="reload", hidden=True) - async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: - """Alias for invoking extensions reload [extensions...].""" - await self.invoke(ctx, "extensions reload", *extensions) - - @command(name="defon", hidden=True) - async def defcon_enable_alias(self, ctx: Context) -> None: - """Alias for invoking defcon enable.""" - await self.invoke(ctx, "defcon enable") - - @command(name="defoff", hidden=True) - async def defcon_disable_alias(self, ctx: Context) -> None: - """Alias for invoking defcon disable.""" - await self.invoke(ctx, "defcon disable") - - @command(name="exception", hidden=True) - async def tags_get_traceback_alias(self, ctx: Context) -> None: - """Alias for invoking tags get traceback.""" - await self.invoke(ctx, "tags get", tag_name="traceback") - - @group(name="get", - aliases=("show", "g"), - hidden=True, - invoke_without_command=True) - async def get_group_alias(self, ctx: Context) -> None: - """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" - pass - - @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) - async def tags_get_alias( - self, ctx: Context, *, tag_name: TagNameConverter = None - ) -> None: - """ - Alias for invoking tags get [tag_name]. - - tag_name: str - tag to be viewed. - """ - await self.invoke(ctx, "tags get", tag_name=tag_name) - - @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) - async def docs_get_alias( - self, ctx: Context, symbol: clean_content = None - ) -> None: - """Alias for invoking docs get [symbol].""" - await self.invoke(ctx, "docs get", symbol) - - @command(name="nominate", hidden=True) - async def nomination_add_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking talentpool add [user] [reason].""" - await self.invoke(ctx, "talentpool add", user, reason=reason) - - @command(name="unnominate", hidden=True) - async def nomination_end_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking nomination end [user] [reason].""" - await self.invoke(ctx, "nomination end", user, reason=reason) - - @command(name="nominees", hidden=True) - async def nominees_alias(self, ctx: Context) -> None: - """Alias for invoking tp watched.""" - await self.invoke(ctx, "talentpool watched") - - -def setup(bot: Bot) -> None: - """Load the Alias cog.""" - bot.add_cog(Alias(bot)) diff --git a/bot/cogs/backend/__init__.py b/bot/cogs/backend/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/backend/config_verifier.py b/bot/cogs/backend/config_verifier.py deleted file mode 100644 index d72c6c22e..000000000 --- a/bot/cogs/backend/config_verifier.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging - -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot - - -log = logging.getLogger(__name__) - - -class ConfigVerifier(Cog): - """Verify config on startup.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) - - async def verify_channels(self) -> None: - """ - Verify channels. - - If any channels in config aren't present in server, log them in a warning. - """ - await self.bot.wait_until_guild_available() - server = self.bot.get_guild(constants.Guild.id) - - server_channel_ids = {channel.id for channel in server.channels} - invalid_channels = [ - channel_name for channel_name, channel_id in constants.Channels - if channel_id not in server_channel_ids - ] - - if invalid_channels: - log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") - - -def setup(bot: Bot) -> None: - """Load the ConfigVerifier cog.""" - bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/backend/error_handler.py b/bot/cogs/backend/error_handler.py deleted file mode 100644 index f9d4de638..000000000 --- a/bot/cogs/backend/error_handler.py +++ /dev/null @@ -1,287 +0,0 @@ -import contextlib -import logging -import typing as t - -from discord import Embed -from discord.ext.commands import Cog, Context, errors -from sentry_sdk import push_scope - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, Colours -from bot.converters import TagNameConverter -from bot.utils.checks import InWhitelistCheckFailure - -log = logging.getLogger(__name__) - - -class ErrorHandler(Cog): - """Handles errors emitted from commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - def _get_error_embed(self, title: str, body: str) -> Embed: - """Return an embed that contains the exception.""" - return Embed( - title=title, - colour=Colours.soft_red, - description=body - ) - - @Cog.listener() - async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: - """ - Provide generic command error handling. - - Error handling is deferred to any local error handler, if present. This is done by - checking for the presence of a `handled` attribute on the error. - - Error handling emits a single error message in the invoking context `ctx` and a log message, - prioritised as follows: - - 1. If the name fails to match a command: - * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. - Otherwise if it matches a tag, the tag is invoked - * If CommandNotFound is raised when invoking the tag (determined by the presence of the - `invoked_from_error_handler` attribute), this error is treated as being unexpected - and therefore sends an error message - * Commands in the verification channel are ignored - 2. UserInputError: see `handle_user_input_error` - 3. CheckFailure: see `handle_check_failure` - 4. CommandOnCooldown: send an error message in the invoking context - 5. ResponseCodeError: see `handle_api_error` - 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` - """ - command = ctx.command - - if hasattr(e, "handled"): - log.trace(f"Command {command} had its error already handled locally; ignoring.") - return - - if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): - if await self.try_silence(ctx): - return - if ctx.channel.id != Channels.verification: - # Try to look for a tag with the command's name - await self.try_get_tag(ctx) - return # Exit early to avoid logging. - elif isinstance(e, errors.UserInputError): - await self.handle_user_input_error(ctx, e) - elif isinstance(e, errors.CheckFailure): - await self.handle_check_failure(ctx, e) - elif isinstance(e, errors.CommandOnCooldown): - await ctx.send(e) - elif isinstance(e, errors.CommandInvokeError): - if isinstance(e.original, ResponseCodeError): - await self.handle_api_error(ctx, e.original) - else: - await self.handle_unexpected_error(ctx, e.original) - return # Exit early to avoid logging. - elif not isinstance(e, errors.DisabledCommand): - # ConversionError, MaxConcurrencyReached, ExtensionError - await self.handle_unexpected_error(ctx, e) - return # Exit early to avoid logging. - - log.debug( - f"Command {command} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" - ) - - @staticmethod - def get_help_command(ctx: Context) -> t.Coroutine: - """Return a prepared `help` command invocation coroutine.""" - if ctx.command: - return ctx.send_help(ctx.command) - - return ctx.send_help() - - async def try_silence(self, ctx: Context) -> bool: - """ - Attempt to invoke the silence or unsilence command if invoke with matches a pattern. - - Respecting the checks if: - * invoked with `shh+` silence channel for amount of h's*2 with max of 15. - * invoked with `unshh+` unsilence channel - Return bool depending on success of command. - """ - command = ctx.invoked_with.lower() - silence_command = self.bot.get_command("silence") - ctx.invoked_from_error_handler = True - try: - if not await silence_command.can_run(ctx): - log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") - return False - except errors.CommandError: - log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") - return False - if command.startswith("shh"): - await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) - return True - elif command.startswith("unshh"): - await ctx.invoke(self.bot.get_command("unsilence")) - return True - return False - - async def try_get_tag(self, ctx: Context) -> None: - """ - Attempt to display a tag by interpreting the command name as a tag name. - - The invocation of tags get respects its checks. Any CommandErrors raised will be handled - by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to - the context to prevent infinite recursion in the case of a CommandNotFound exception. - """ - tags_get_command = self.bot.get_command("tags get") - ctx.invoked_from_error_handler = True - - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." - try: - if not await tags_get_command.can_run(ctx): - log.debug(log_msg) - return - except errors.CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - - try: - tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) - except errors.BadArgument: - log.debug( - f"{ctx.author} tried to use an invalid command " - f"and the fallback tag failed validation in TagNameConverter." - ) - else: - with contextlib.suppress(ResponseCodeError): - await ctx.invoke(tags_get_command, tag_name=tag_name) - # Return to not raise the exception - return - - async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: - """ - Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. - - * MissingRequiredArgument: send an error message with arg name and the help command - * TooManyArguments: send an error message and the help command - * BadArgument: send an error message and the help command - * BadUnionArgument: send an error message including the error produced by the last converter - * ArgumentParsingError: send an error message - * Other: send an error message and the help command - """ - prepared_help_command = self.get_help_command(ctx) - - if isinstance(e, errors.MissingRequiredArgument): - embed = self._get_error_embed("Missing required argument", e.param.name) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.missing_required_argument") - elif isinstance(e, errors.TooManyArguments): - embed = self._get_error_embed("Too many arguments", str(e)) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.too_many_arguments") - elif isinstance(e, errors.BadArgument): - embed = self._get_error_embed("Bad argument", str(e)) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.bad_argument") - elif isinstance(e, errors.BadUnionArgument): - embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") - await ctx.send(embed=embed) - self.bot.stats.incr("errors.bad_union_argument") - elif isinstance(e, errors.ArgumentParsingError): - embed = self._get_error_embed("Argument parsing error", str(e)) - await ctx.send(embed=embed) - self.bot.stats.incr("errors.argument_parsing_error") - else: - embed = self._get_error_embed( - "Input error", - "Something about your input seems off. Check the arguments and try again." - ) - await ctx.send(embed=embed) - await prepared_help_command - self.bot.stats.incr("errors.other_user_input_error") - - @staticmethod - async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: - """ - Send an error message in `ctx` for certain types of CheckFailure. - - The following types are handled: - - * BotMissingPermissions - * BotMissingRole - * BotMissingAnyRole - * NoPrivateMessage - * InWhitelistCheckFailure - """ - bot_missing_errors = ( - errors.BotMissingPermissions, - errors.BotMissingRole, - errors.BotMissingAnyRole - ) - - if isinstance(e, bot_missing_errors): - ctx.bot.stats.incr("errors.bot_permission_error") - await ctx.send( - "Sorry, it looks like I don't have the permissions or roles I need to do that." - ) - elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): - ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") - await ctx.send(e) - - @staticmethod - async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: - """Send an error message in `ctx` for ResponseCodeError and log it.""" - if e.status == 404: - await ctx.send("There does not seem to be anything matching your query.") - log.debug(f"API responded with 404 for command {ctx.command}") - ctx.bot.stats.incr("errors.api_error_404") - elif e.status == 400: - content = await e.response.json() - log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) - await ctx.send("According to the API, your request is malformed.") - ctx.bot.stats.incr("errors.api_error_400") - elif 500 <= e.status < 600: - await ctx.send("Sorry, there seems to be an internal issue with the API.") - log.warning(f"API responded with {e.status} for command {ctx.command}") - ctx.bot.stats.incr("errors.api_internal_server_error") - else: - await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") - log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") - ctx.bot.stats.incr(f"errors.api_error_{e.status}") - - @staticmethod - async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: - """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" - await ctx.send( - f"Sorry, an unexpected error occurred. Please let us know!\n\n" - f"```{e.__class__.__name__}: {e}```" - ) - - ctx.bot.stats.incr("errors.unexpected") - - with push_scope() as scope: - scope.user = { - "id": ctx.author.id, - "username": str(ctx.author) - } - - scope.set_tag("command", ctx.command.qualified_name) - scope.set_tag("message_id", ctx.message.id) - scope.set_tag("channel_id", ctx.channel.id) - - scope.set_extra("full_message", ctx.message.content) - - if ctx.guild is not None: - scope.set_extra( - "jump_to", - f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" - ) - - log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) - - -def setup(bot: Bot) -> None: - """Load the ErrorHandler cog.""" - bot.add_cog(ErrorHandler(bot)) diff --git a/bot/cogs/backend/logging.py b/bot/cogs/backend/logging.py deleted file mode 100644 index 94fa2b139..000000000 --- a/bot/cogs/backend/logging.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging - -from discord import Embed -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, DEBUG_MODE - - -log = logging.getLogger(__name__) - - -class Logging(Cog): - """Debug logging module.""" - - def __init__(self, bot: Bot): - self.bot = bot - - self.bot.loop.create_task(self.startup_greeting()) - - async def startup_greeting(self) -> None: - """Announce our presence to the configured devlog channel.""" - await self.bot.wait_until_guild_available() - log.info("Bot connected!") - - embed = Embed(description="Connected!") - embed.set_author( - name="Python Bot", - url="https://github.com/python-discord/bot", - icon_url=( - "https://raw.githubusercontent.com/" - "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" - ) - ) - - if not DEBUG_MODE: - await self.bot.get_channel(Channels.dev_log).send(embed=embed) - - -def setup(bot: Bot) -> None: - """Load the Logging cog.""" - bot.add_cog(Logging(bot)) diff --git a/bot/cogs/backend/sync/__init__.py b/bot/cogs/backend/sync/__init__.py deleted file mode 100644 index 2541beaa8..000000000 --- a/bot/cogs/backend/sync/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from bot.bot import Bot - - -def setup(bot: Bot) -> None: - """Load the Sync cog.""" - from ._cog import Sync - bot.add_cog(Sync(bot)) diff --git a/bot/cogs/backend/sync/_cog.py b/bot/cogs/backend/sync/_cog.py deleted file mode 100644 index b6068f328..000000000 --- a/bot/cogs/backend/sync/_cog.py +++ /dev/null @@ -1,180 +0,0 @@ -import logging -from typing import Any, Dict - -from discord import Member, Role, User -from discord.ext import commands -from discord.ext.commands import Cog, Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from . import _syncers - -log = logging.getLogger(__name__) - - -class Sync(Cog): - """Captures relevant events and sends them to the site.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.role_syncer = _syncers.RoleSyncer(self.bot) - self.user_syncer = _syncers.UserSyncer(self.bot) - - self.bot.loop.create_task(self.sync_guild()) - - async def sync_guild(self) -> None: - """Syncs the roles/users of the guild with the database.""" - await self.bot.wait_until_guild_available() - - guild = self.bot.get_guild(constants.Guild.id) - if guild is None: - return - - for syncer in (self.role_syncer, self.user_syncer): - await syncer.sync(guild) - - async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: - """Send a PATCH request to partially update a user in the database.""" - try: - await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) - except ResponseCodeError as e: - if e.response.status != 404: - raise - if not ignore_404: - log.warning("Unable to update user, got 404. Assuming race condition from join event.") - - @Cog.listener() - async def on_guild_role_create(self, role: Role) -> None: - """Adds newly create role to the database table over the API.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.post( - 'bot/roles', - json={ - 'colour': role.colour.value, - 'id': role.id, - 'name': role.name, - 'permissions': role.permissions.value, - 'position': role.position, - } - ) - - @Cog.listener() - async def on_guild_role_delete(self, role: Role) -> None: - """Deletes role from the database when it's deleted from the guild.""" - if role.guild.id != constants.Guild.id: - return - - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - @Cog.listener() - async def on_guild_role_update(self, before: Role, after: Role) -> None: - """Syncs role with the database if any of the stored attributes were updated.""" - if after.guild.id != constants.Guild.id: - return - - was_updated = ( - before.name != after.name - or before.colour != after.colour - or before.permissions != after.permissions - or before.position != after.position - ) - - if was_updated: - await self.bot.api_client.put( - f'bot/roles/{after.id}', - json={ - 'colour': after.colour.value, - 'id': after.id, - 'name': after.name, - 'permissions': after.permissions.value, - 'position': after.position, - } - ) - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """ - Adds a new user or updates existing user to the database when a member joins the guild. - - If the joining member is a user that is already known to the database (i.e., a user that - previously left), it will update the user's information. If the user is not yet known by - the database, the user is added. - """ - if member.guild.id != constants.Guild.id: - return - - packed = { - 'discriminator': int(member.discriminator), - 'id': member.id, - 'in_guild': True, - 'name': member.name, - 'roles': sorted(role.id for role in member.roles) - } - - got_error = False - - try: - # First try an update of the user to set the `in_guild` field and other - # fields that may have changed since the last time we've seen them. - await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) - - except ResponseCodeError as e: - # If we didn't get 404, something else broke - propagate it up. - if e.response.status != 404: - raise - - got_error = True # yikes - - if got_error: - # If we got `404`, the user is new. Create them. - await self.bot.api_client.post('bot/users', json=packed) - - @Cog.listener() - async def on_member_remove(self, member: Member) -> None: - """Set the in_guild field to False when a member leaves the guild.""" - if member.guild.id != constants.Guild.id: - return - - await self.patch_user(member.id, json={"in_guild": False}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Update the roles of the member in the database if a change is detected.""" - if after.guild.id != constants.Guild.id: - return - - if before.roles != after.roles: - updated_information = {"roles": sorted(role.id for role in after.roles)} - await self.patch_user(after.id, json=updated_information) - - @Cog.listener() - async def on_user_update(self, before: User, after: User) -> None: - """Update the user information in the database if a relevant change is detected.""" - attrs = ("name", "discriminator") - if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): - updated_information = { - "name": after.name, - "discriminator": int(after.discriminator), - } - # A 404 likely means the user is in another guild. - await self.patch_user(after.id, json=updated_information, ignore_404=True) - - @commands.group(name='sync') - @commands.has_permissions(administrator=True) - async def sync_group(self, ctx: Context) -> None: - """Run synchronizations between the bot and site manually.""" - - @sync_group.command(name='roles') - @commands.has_permissions(administrator=True) - async def sync_roles_command(self, ctx: Context) -> None: - """Manually synchronise the guild's roles with the roles on the site.""" - await self.role_syncer.sync(ctx.guild, ctx) - - @sync_group.command(name='users') - @commands.has_permissions(administrator=True) - async def sync_users_command(self, ctx: Context) -> None: - """Manually synchronise the guild's users with the users on the site.""" - await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/backend/sync/_syncers.py b/bot/cogs/backend/sync/_syncers.py deleted file mode 100644 index f7ba811bc..000000000 --- a/bot/cogs/backend/sync/_syncers.py +++ /dev/null @@ -1,347 +0,0 @@ -import abc -import asyncio -import logging -import typing as t -from collections import namedtuple -from functools import partial - -import discord -from discord import Guild, HTTPException, Member, Message, Reaction, User -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot - -log = logging.getLogger(__name__) - -# These objects are declared as namedtuples because tuples are hashable, -# something that we make use of when diffing site roles against guild roles. -_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) -_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) - - -class Syncer(abc.ABC): - """Base class for synchronising the database with objects in the Discord cache.""" - - _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " - _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - @property - @abc.abstractmethod - def name(self) -> str: - """The name of the syncer; used in output messages and logging.""" - raise NotImplementedError # pragma: no cover - - async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: - """ - Send a prompt to confirm or abort a sync using reactions and return the sent message. - - If a message is given, it is edited to display the prompt and reactions. Otherwise, a new - message is sent to the dev-core channel and mentions the core developers role. If the - channel cannot be retrieved, return None. - """ - log.trace(f"Sending {self.name} sync confirmation prompt.") - - msg_content = ( - f'Possible cache issue while syncing {self.name}s. ' - f'More than {constants.Sync.max_diff} {self.name}s were changed. ' - f'React to confirm or abort the sync.' - ) - - # Send to core developers if it's an automatic sync. - if not message: - log.trace("Message not provided for confirmation; creating a new one in dev-core.") - channel = self.bot.get_channel(constants.Channels.dev_core) - - if not channel: - log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") - try: - channel = await self.bot.fetch_channel(constants.Channels.dev_core) - except HTTPException: - log.exception( - f"Failed to fetch channel for sending sync confirmation prompt; " - f"aborting {self.name} sync." - ) - return None - - allowed_roles = [discord.Object(constants.Roles.core_developers)] - message = await channel.send( - f"{self._CORE_DEV_MENTION}{msg_content}", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) - ) - else: - await message.edit(content=msg_content) - - # Add the initial reactions. - log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") - for emoji in self._REACTION_EMOJIS: - await message.add_reaction(emoji) - - return message - - def _reaction_check( - self, - author: Member, - message: Message, - reaction: Reaction, - user: t.Union[Member, User] - ) -> bool: - """ - Return True if the `reaction` is a valid confirmation or abort reaction on `message`. - - If the `author` of the prompt is a bot, then a reaction by any core developer will be - considered valid. Otherwise, the author of the reaction (`user`) will have to be the - `author` of the prompt. - """ - # For automatic syncs, check for the core dev role instead of an exact author - has_role = any(constants.Roles.core_developers == role.id for role in user.roles) - return ( - reaction.message.id == message.id - and not user.bot - and (has_role if author.bot else user == author) - and str(reaction.emoji) in self._REACTION_EMOJIS - ) - - async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: - """ - Wait for a confirmation reaction by `author` on `message` and return True if confirmed. - - Uses the `_reaction_check` function to determine if a reaction is valid. - - If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. - To acknowledge the reaction (or lack thereof), `message` will be edited. - """ - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - reaction = None - try: - log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") - reaction, _ = await self.bot.wait_for( - 'reaction_add', - check=partial(self._reaction_check, author, message), - timeout=constants.Sync.confirm_timeout - ) - except asyncio.TimeoutError: - # reaction will remain none thus sync will be aborted in the finally block below. - log.debug(f"The {self.name} syncer confirmation prompt timed out.") - - if str(reaction) == constants.Emojis.check_mark: - log.trace(f"The {self.name} syncer was confirmed.") - await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') - return True - else: - log.info(f"The {self.name} syncer was aborted or timed out!") - await message.edit( - content=f':warning: {mention}{self.name} sync aborted or timed out!' - ) - return False - - @abc.abstractmethod - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference between the cache of `guild` and the database.""" - raise NotImplementedError # pragma: no cover - - @abc.abstractmethod - async def _sync(self, diff: _Diff) -> None: - """Perform the API calls for synchronisation.""" - raise NotImplementedError # pragma: no cover - - async def _get_confirmation_result( - self, - diff_size: int, - author: Member, - message: t.Optional[Message] = None - ) -> t.Tuple[bool, t.Optional[Message]]: - """ - Prompt for confirmation and return a tuple of the result and the prompt message. - - `diff_size` is the size of the diff of the sync. If it is greater than - `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the - sync and the `message` is an extant message to edit to display the prompt. - - If confirmed or no confirmation was needed, the result is True. The returned message will - either be the given `message` or a new one which was created when sending the prompt. - """ - log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") - if diff_size > constants.Sync.max_diff: - message = await self._send_prompt(message) - if not message: - return False, None # Couldn't get channel. - - confirmed = await self._wait_for_confirmation(author, message) - if not confirmed: - return False, message # Sync aborted. - - return True, message - - async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: - """ - Synchronise the database with the cache of `guild`. - - If the differences between the cache and the database are greater than - `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core - channel. The confirmation can be optionally redirect to `ctx` instead. - """ - log.info(f"Starting {self.name} syncer.") - - message = None - author = self.bot.user - if ctx: - message = await ctx.send(f"📊 Synchronising {self.name}s.") - author = ctx.author - - diff = await self._get_diff(guild) - diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict - totals = {k: len(v) for k, v in diff_dict.items() if v is not None} - diff_size = sum(totals.values()) - - confirmed, message = await self._get_confirmation_result(diff_size, author, message) - if not confirmed: - return - - # Preserve the core-dev role mention in the message edits so users aren't confused about - # where notifications came from. - mention = self._CORE_DEV_MENTION if author.bot else "" - - try: - await self._sync(diff) - except ResponseCodeError as e: - log.exception(f"{self.name} syncer failed!") - - # Don't show response text because it's probably some really long HTML. - results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" - content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" - else: - results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) - log.info(f"{self.name} syncer finished: {results}.") - content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" - - if message: - await message.edit(content=content) - - -class RoleSyncer(Syncer): - """Synchronise the database with roles in the cache.""" - - name = "role" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of roles between the cache of `guild` and the database.""" - log.trace("Getting the diff for roles.") - roles = await self.bot.api_client.get('bot/roles') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_roles = {_Role(**role_dict) for role_dict in roles} - guild_roles = { - _Role( - id=role.id, - name=role.name, - colour=role.colour.value, - permissions=role.permissions.value, - position=role.position, - ) - for role in guild.roles - } - - guild_role_ids = {role.id for role in guild_roles} - api_role_ids = {role.id for role in db_roles} - new_role_ids = guild_role_ids - api_role_ids - deleted_role_ids = api_role_ids - guild_role_ids - - # New roles are those which are on the cached guild but not on the - # DB guild, going by the role ID. We need to send them in for creation. - roles_to_create = {role for role in guild_roles if role.id in new_role_ids} - roles_to_update = guild_roles - db_roles - roles_to_create - roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} - - return _Diff(roles_to_create, roles_to_update, roles_to_delete) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the role cache of `guild`.""" - log.trace("Syncing created roles...") - for role in diff.created: - await self.bot.api_client.post('bot/roles', json=role._asdict()) - - log.trace("Syncing updated roles...") - for role in diff.updated: - await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) - - log.trace("Syncing deleted roles...") - for role in diff.deleted: - await self.bot.api_client.delete(f'bot/roles/{role.id}') - - -class UserSyncer(Syncer): - """Synchronise the database with users in the cache.""" - - name = "user" - - async def _get_diff(self, guild: Guild) -> _Diff: - """Return the difference of users between the cache of `guild` and the database.""" - log.trace("Getting the diff for users.") - users = await self.bot.api_client.get('bot/users') - - # Pack DB roles and guild roles into one common, hashable format. - # They're hashable so that they're easily comparable with sets later. - db_users = { - user_dict['id']: _User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in users - } - guild_users = { - member.id: _User( - id=member.id, - name=member.name, - discriminator=int(member.discriminator), - roles=tuple(sorted(role.id for role in member.roles)), - in_guild=True - ) - for member in guild.members - } - - users_to_create = set() - users_to_update = set() - - for db_user in db_users.values(): - guild_user = guild_users.get(db_user.id) - if guild_user is not None: - if db_user != guild_user: - users_to_update.add(guild_user) - - elif db_user.in_guild: - # The user is known in the DB but not the guild, and the - # DB currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - new_api_user = db_user._replace(in_guild=False) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(db_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) - - return _Diff(users_to_create, users_to_update, None) - - async def _sync(self, diff: _Diff) -> None: - """Synchronise the database with the user cache of `guild`.""" - log.trace("Syncing created users...") - for user in diff.created: - await self.bot.api_client.post('bot/users', json=user._asdict()) - - log.trace("Syncing updated users...") - for user in diff.updated: - await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/dm_relay.py b/bot/cogs/dm_relay.py deleted file mode 100644 index 0d8f340b4..000000000 --- a/bot/cogs/dm_relay.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging -from typing import Optional - -import discord -from discord import Color -from discord.ext import commands -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot -from bot.converters import UserMentionOrID -from bot.utils import RedisCache -from bot.utils.checks import in_whitelist_check, with_role_check -from bot.utils.messages import send_attachments -from bot.utils.webhooks import send_webhook - -log = logging.getLogger(__name__) - - -class DMRelay(Cog): - """Relay direct messages to and from the bot.""" - - # RedisCache[str, t.Union[discord.User.id, discord.Member.id]] - dm_cache = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_id = constants.Webhooks.dm_log - self.webhook = None - self.bot.loop.create_task(self.fetch_webhook()) - - @commands.command(aliases=("reply",)) - async def send_dm(self, ctx: commands.Context, member: Optional[UserMentionOrID], *, message: str) -> None: - """ - Allows you to send a DM to a user from the bot. - - If `member` is not provided, it will send to the last user who DM'd the bot. - - This feature should be used extremely sparingly. Use ModMail if you need to have a serious - conversation with a user. This is just for responding to extraordinary DMs, having a little - fun with users, and telling people they are DMing the wrong bot. - - NOTE: This feature will be removed if it is overused. - """ - if not member: - user_id = await self.dm_cache.get("last_user") - member = ctx.guild.get_member(user_id) if user_id else None - - # If we still don't have a Member at this point, give up - if not member: - log.debug("This bot has never gotten a DM, or the RedisCache has been cleared.") - await ctx.message.add_reaction("❌") - return - - try: - await member.send(message) - except discord.errors.Forbidden: - log.debug("User has disabled DMs.") - await ctx.message.add_reaction("❌") - else: - await ctx.message.add_reaction("✅") - self.bot.stats.incr("dm_relay.dm_sent") - - async def fetch_webhook(self) -> None: - """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_guild_available() - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - @Cog.listener() - async def on_message(self, message: discord.Message) -> None: - """Relays the message's content and attachments to the dm_log channel.""" - # Only relay DMs from humans - if message.author.bot or message.guild or self.webhook is None: - return - - if message.clean_content: - await send_webhook( - webhook=self.webhook, - content=message.clean_content, - username=f"{message.author.display_name} ({message.author.id})", - avatar_url=message.author.avatar_url - ) - await self.dm_cache.set("last_user", message.author.id) - self.bot.stats.incr("dm_relay.dm_received") - - # Handle any attachments - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (discord.errors.Forbidden, discord.errors.NotFound): - e = discord.Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await send_webhook( - webhook=self.webhook, - embed=e, - username=f"{message.author.display_name} ({message.author.id})", - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception("Failed to send an attachment to the webhook") - - def cog_check(self, ctx: commands.Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - checks = [ - with_role_check(ctx, *constants.MODERATION_ROLES), - in_whitelist_check( - ctx, - channels=[constants.Channels.dm_log], - redirect=None, - fail_silently=True, - ) - ] - return all(checks) - - -def setup(bot: Bot) -> None: - """Load the DMRelay cog.""" - bot.add_cog(DMRelay(bot)) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py deleted file mode 100644 index 7021069fa..000000000 --- a/bot/cogs/duck_pond.py +++ /dev/null @@ -1,166 +0,0 @@ -import logging -from typing import Union - -import discord -from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot -from bot.utils.messages import send_attachments -from bot.utils.webhooks import send_webhook - -log = logging.getLogger(__name__) - - -class DuckPond(Cog): - """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_id = constants.Webhooks.duck_pond - self.webhook = None - self.bot.loop.create_task(self.fetch_webhook()) - - async def fetch_webhook(self) -> None: - """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_guild_available() - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - @staticmethod - def is_staff(member: Union[User, Member]) -> bool: - """Check if a specific member or user is staff.""" - if hasattr(member, "roles"): - for role in member.roles: - if role.id in constants.STAFF_ROLES: - return True - return False - - async def has_green_checkmark(self, message: Message) -> bool: - """Check if the message has a green checkmark reaction.""" - for reaction in message.reactions: - if reaction.emoji == "✅": - async for user in reaction.users(): - if user == self.bot.user: - return True - return False - - async def count_ducks(self, message: Message) -> int: - """ - Count the number of ducks in the reactions of a specific message. - - Only counts ducks added by staff members. - """ - duck_count = 0 - duck_reactors = [] - - for reaction in message.reactions: - async for user in reaction.users(): - - # Is the user a staff member and not already counted as reactor? - if not self.is_staff(user) or user.id in duck_reactors: - continue - - # Is the emoji a duck? - if hasattr(reaction.emoji, "id"): - if reaction.emoji.id in constants.DuckPond.custom_emojis: - duck_count += 1 - duck_reactors.append(user.id) - elif isinstance(reaction.emoji, str): - if reaction.emoji == "🦆": - duck_count += 1 - duck_reactors.append(user.id) - return duck_count - - async def relay_message(self, message: Message) -> None: - """Relays the message's content and attachments to the duck pond channel.""" - if message.clean_content: - await send_webhook( - webhook=self.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await send_webhook( - webhook=self.webhook, - embed=e, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception("Failed to send an attachment to the webhook") - - await message.add_reaction("✅") - - @staticmethod - def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: - """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" - if payload.emoji.is_custom_emoji(): - if payload.emoji.id in constants.DuckPond.custom_emojis: - return True - elif payload.emoji.name == "🦆": - return True - - return False - - @Cog.listener() - async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: - """ - Determine if a message should be sent to the duck pond. - - This will count the number of duck reactions on the message, and if this amount meets the - amount of ducks specified in the config under duck_pond/threshold, it will - send the message off to the duck pond. - """ - # Is the emoji in the reaction a duck? - if not self._payload_has_duckpond_emoji(payload): - return - - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - message = await channel.fetch_message(payload.message_id) - member = discord.utils.get(message.guild.members, id=payload.user_id) - - # Is the member a human and a staff member? - if not self.is_staff(member) or member.bot: - return - - # Does the message already have a green checkmark? - if await self.has_green_checkmark(message): - return - - # Time to count our ducks! - duck_count = await self.count_ducks(message) - - # If we've got more than the required amount of ducks, send the message to the duck_pond. - if duck_count >= constants.DuckPond.threshold: - await self.relay_message(message) - - @Cog.listener() - async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: - """Ensure that people don't remove the green checkmark from duck ponded messages.""" - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - - # Prevent the green checkmark from being removed - if payload.emoji.name == "✅": - message = await channel.fetch_message(payload.message_id) - duck_count = await self.count_ducks(message) - if duck_count >= constants.DuckPond.threshold: - await message.add_reaction("✅") - - -def setup(bot: Bot) -> None: - """Load the DuckPond cog.""" - bot.add_cog(DuckPond(bot)) diff --git a/bot/cogs/filters/__init__.py b/bot/cogs/filters/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/filters/antimalware.py b/bot/cogs/filters/antimalware.py deleted file mode 100644 index c76bd2c60..000000000 --- a/bot/cogs/filters/antimalware.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging -import typing as t -from os.path import splitext - -from discord import Embed, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, STAFF_ROLES, URLs - -log = logging.getLogger(__name__) - -PY_EMBED_DESCRIPTION = ( - "It looks like you tried to attach a Python file - " - f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" -) - -TXT_EMBED_DESCRIPTION = ( - "**Uh-oh!** It looks like your message got zapped by our spam filter. " - "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" - "• If you attempted to send a message longer than 2000 characters, try shortening your message " - "to fit within the character limit or use a pasting service (see below) \n\n" - "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " - "{cmd_channel_mention} for more information) or use a pasting service like: " - f"\n\n{URLs.site_schema}{URLs.site_paste}" -) - -DISALLOWED_EMBED_DESCRIPTION = ( - "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " - "We currently allow the following file types: **{joined_whitelist}**.\n\n" - "Feel free to ask in {meta_channel_mention} if you think this is a mistake." -) - - -class AntiMalware(Cog): - """Delete messages which contain attachments with non-whitelisted file extensions.""" - - def __init__(self, bot: Bot): - self.bot = bot - - def _get_whitelisted_file_formats(self) -> list: - """Get the file formats currently on the whitelist.""" - return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() - - def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: - """Get an iterable containing all the disallowed extensions of attachments.""" - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} - extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) - return extensions_blocked - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Identify messages with prohibited attachments.""" - # Return when message don't have attachment and don't moderate DMs - if not message.attachments or not message.guild: - return - - # Check if user is staff, if is, return - # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance - if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): - return - - embed = Embed() - extensions_blocked = self._get_disallowed_extensions(message) - blocked_extensions_str = ', '.join(extensions_blocked) - if ".py" in extensions_blocked: - # Short-circuit on *.py files to provide a pastebin link - embed.description = PY_EMBED_DESCRIPTION - elif ".txt" in extensions_blocked: - # Work around Discord AutoConversion of messages longer than 2000 chars to .txt - cmd_channel = self.bot.get_channel(Channels.bot_commands) - embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) - elif extensions_blocked: - meta_channel = self.bot.get_channel(Channels.meta) - embed.description = DISALLOWED_EMBED_DESCRIPTION.format( - joined_whitelist=', '.join(self._get_whitelisted_file_formats()), - blocked_extensions_str=blocked_extensions_str, - meta_channel_mention=meta_channel.mention, - ) - - if embed.description: - log.info( - f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", - extra={"attachment_list": [attachment.filename for attachment in message.attachments]} - ) - - await message.channel.send(f"Hey {message.author.mention}!", embed=embed) - - # Delete the offending message: - try: - await message.delete() - except NotFound: - log.info(f"Tried to delete message `{message.id}`, but message could not be found.") - - -def setup(bot: Bot) -> None: - """Load the AntiMalware cog.""" - bot.add_cog(AntiMalware(bot)) diff --git a/bot/cogs/filters/antispam.py b/bot/cogs/filters/antispam.py deleted file mode 100644 index d2dccea06..000000000 --- a/bot/cogs/filters/antispam.py +++ /dev/null @@ -1,288 +0,0 @@ -import asyncio -import logging -from collections.abc import Mapping -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from operator import itemgetter -from typing import Dict, Iterable, List, Set - -from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Cog - -from bot import rules -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import ( - AntiSpam as AntiSpamConfig, Channels, - Colours, DEBUG_MODE, Event, Filter, - Guild as GuildConfig, Icons, - STAFF_ROLES, -) -from bot.converters import Duration -from bot.utils.messages import send_attachments - - -log = logging.getLogger(__name__) - -RULE_FUNCTION_MAPPING = { - 'attachments': rules.apply_attachments, - 'burst': rules.apply_burst, - 'burst_shared': rules.apply_burst_shared, - 'chars': rules.apply_chars, - 'discord_emojis': rules.apply_discord_emojis, - 'duplicates': rules.apply_duplicates, - 'links': rules.apply_links, - 'mentions': rules.apply_mentions, - 'newlines': rules.apply_newlines, - 'role_mentions': rules.apply_role_mentions -} - - -@dataclass -class DeletionContext: - """Represents a Deletion Context for a single spam event.""" - - channel: TextChannel - members: Dict[int, Member] = field(default_factory=dict) - rules: Set[str] = field(default_factory=set) - messages: Dict[int, Message] = field(default_factory=dict) - attachments: List[List[str]] = field(default_factory=list) - - async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: - """Adds new rule violation events to the deletion context.""" - self.rules.add(rule_name) - - for member in members: - if member.id not in self.members: - self.members[member.id] = member - - for message in messages: - if message.id not in self.messages: - self.messages[message.id] = message - - # Re-upload attachments - destination = message.guild.get_channel(Channels.attachment_log) - urls = await send_attachments(message, destination, link_large=False) - self.attachments.append(urls) - - async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: - """Method that takes care of uploading the queue and posting modlog alert.""" - triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values()) - - mod_alert_message = ( - f"**Triggered by:** {triggered_by_users}\n" - f"**Channel:** {self.channel.mention}\n" - f"**Rules:** {', '.join(rule for rule in self.rules)}\n" - ) - - # For multiple messages or those with excessive newlines, use the logs API - if len(self.messages) > 1 or 'newlines' in self.rules: - url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) - mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" - else: - mod_alert_message += "Message:\n" - [message] = self.messages.values() - content = message.clean_content - remaining_chars = 2040 - len(mod_alert_message) - - if len(content) > remaining_chars: - content = content[:remaining_chars] + "..." - - mod_alert_message += f"{content}" - - *_, last_message = self.messages.values() - await modlog.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title="Spam detected!", - text=mod_alert_message, - thumbnail=last_message.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=AntiSpamConfig.ping_everyone - ) - - -class AntiSpam(Cog): - """Cog that controls our anti-spam measures.""" - - def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: - self.bot = bot - self.validation_errors = validation_errors - role_id = AntiSpamConfig.punishment['role_id'] - self.muted_role = Object(role_id) - self.expiration_date_converter = Duration() - - self.message_deletion_queue = dict() - - self.bot.loop.create_task(self.alert_on_validation_error()) - - @property - def mod_log(self) -> ModLog: - """Allows for easy access of the ModLog cog.""" - return self.bot.get_cog("ModLog") - - async def alert_on_validation_error(self) -> None: - """Unloads the cog and alerts admins if configuration validation failed.""" - await self.bot.wait_until_guild_available() - if self.validation_errors: - body = "**The following errors were encountered:**\n" - body += "\n".join(f"- {error}" for error in self.validation_errors.values()) - body += "\n\n**The cog has been unloaded.**" - - await self.mod_log.send_log_message( - title="Error: AntiSpam configuration validation failed!", - text=body, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Colour.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Applies the antispam rules to each received message.""" - if ( - not message.guild - or message.guild.id != GuildConfig.id - or message.author.bot - or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) - or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) - ): - return - - # Fetch the rule configuration with the highest rule interval. - max_interval_config = max( - AntiSpamConfig.rules.values(), - key=itemgetter('interval') - ) - max_interval = max_interval_config['interval'] - - # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. - earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) - relevant_messages = [ - msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) - if not msg.author.bot - ] - - for rule_name in AntiSpamConfig.rules: - rule_config = AntiSpamConfig.rules[rule_name] - rule_function = RULE_FUNCTION_MAPPING[rule_name] - - # Create a list of messages that were sent in the interval that the rule cares about. - latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) - messages_for_rule = [ - msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp - ] - result = await rule_function(message, messages_for_rule, rule_config) - - # If the rule returns `None`, that means the message didn't violate it. - # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` - # which contains the reason for why the message violated the rule and - # an iterable of all members that violated the rule. - if result is not None: - self.bot.stats.incr(f"mod_alerts.{rule_name}") - reason, members, relevant_messages = result - full_reason = f"`{rule_name}` rule: {reason}" - - # If there's no spam event going on for this channel, start a new Message Deletion Context - channel = message.channel - if channel.id not in self.message_deletion_queue: - log.trace(f"Creating queue for channel `{channel.id}`") - self.message_deletion_queue[message.channel.id] = DeletionContext(channel) - self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) - - # Add the relevant of this trigger to the Deletion Context - await self.message_deletion_queue[message.channel.id].add( - rule_name=rule_name, - members=members, - messages=relevant_messages - ) - - for member in members: - - # Fire it off as a background task to ensure - # that the sleep doesn't block further tasks - self.bot.loop.create_task( - self.punish(message, member, full_reason) - ) - - await self.maybe_delete_messages(channel, relevant_messages) - break - - async def punish(self, msg: Message, member: Member, reason: str) -> None: - """Punishes the given member for triggering an antispam rule.""" - if not any(role.id == self.muted_role.id for role in member.roles): - remove_role_after = AntiSpamConfig.punishment['remove_after'] - - # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes - context = await self.bot.get_context(msg) - context.author = self.bot.user - context.message.author = self.bot.user - - # Since we're going to invoke the tempmute command directly, we need to manually call the converter. - dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S") - await context.invoke( - self.bot.get_command('tempmute'), - member, - dt_remove_role_after, - reason=reason - ) - - async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: - """Cleans the messages if cleaning is configured.""" - if AntiSpamConfig.clean_offending: - # If we have more than one message, we can use bulk delete. - if len(messages) > 1: - message_ids = [message.id for message in messages] - self.mod_log.ignore(Event.message_delete, *message_ids) - await channel.delete_messages(messages) - - # Otherwise, the bulk delete endpoint will throw up. - # Delete the message directly instead. - else: - self.mod_log.ignore(Event.message_delete, messages[0].id) - try: - await messages[0].delete() - except NotFound: - log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") - - async def _process_deletion_context(self, context_id: int) -> None: - """Processes the Deletion Context queue.""" - log.trace("Sleeping before processing message deletion queue.") - await asyncio.sleep(10) - - if context_id not in self.message_deletion_queue: - log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") - return - - deletion_context = self.message_deletion_queue.pop(context_id) - await deletion_context.upload_messages(self.bot.user.id, self.mod_log) - - -def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: - """Validates the antispam configs.""" - validation_errors = {} - for name, config in rules_.items(): - if name not in RULE_FUNCTION_MAPPING: - log.error( - f"Unrecognized antispam rule `{name}`. " - f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" - ) - validation_errors[name] = f"`{name}` is not recognized as an antispam rule." - continue - for required_key in ('interval', 'max'): - if required_key not in config: - log.error( - f"`{required_key}` is required but was not " - f"set in rule `{name}`'s configuration." - ) - validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" - return validation_errors - - -def setup(bot: Bot) -> None: - """Validate the AntiSpam configs and load the AntiSpam cog.""" - validation_errors = validate_config() - bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/cogs/filters/filter_lists.py b/bot/cogs/filters/filter_lists.py deleted file mode 100644 index c15adc461..000000000 --- a/bot/cogs/filters/filter_lists.py +++ /dev/null @@ -1,273 +0,0 @@ -import logging -from typing import Optional - -from discord import Colour, Embed -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.converters import ValidDiscordServerInvite, ValidFilterListType -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - - -class FilterLists(Cog): - """Commands for blacklisting and whitelisting things.""" - - methods_with_filterlist_types = [ - "allow_add", - "allow_delete", - "allow_get", - "deny_add", - "deny_delete", - "deny_get", - ] - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.bot.loop.create_task(self._amend_docstrings()) - - async def _amend_docstrings(self) -> None: - """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" - await self.bot.wait_until_guild_available() - - # Add valid filterlist types to the docstrings - valid_types = await ValidFilterListType.get_valid_types(self.bot) - valid_types = [f"`{type_.lower()}`" for type_ in valid_types] - - for method_name in self.methods_with_filterlist_types: - command = getattr(self, method_name) - command.help = ( - f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." - ) - - async def _add_data( - self, - ctx: Context, - allowed: bool, - list_type: ValidFilterListType, - content: str, - comment: Optional[str] = None, - ) -> None: - """Add an item to a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we gotta validate it. - if list_type == "GUILD_INVITE": - guild_data = await self._validate_guild_invite(ctx, content) - content = guild_data.get("id") - - # Unless the user has specified another comment, let's - # use the server name as the comment so that the list - # of guild IDs will be more easily readable when we - # display it. - if not comment: - comment = guild_data.get("name") - - # If it's a file format, let's make sure it has a leading dot. - elif list_type == "FILE_FORMAT" and not content.startswith("."): - content = f".{content}" - - # Try to add the item to the database - log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") - payload = { - "allowed": allowed, - "type": list_type, - "content": content, - "comment": comment, - } - - try: - item = await self.bot.api_client.post( - "bot/filter-lists", - json=payload - ) - except ResponseCodeError as e: - if e.status == 400: - await ctx.message.add_reaction("❌") - log.debug( - f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " - "probably because the request violated the UniqueConstraint." - ) - raise BadArgument( - f"Unable to add the item to the {allow_type}. " - "The item probably already exists. Keep in mind that a " - "blacklist and a whitelist for the same item cannot co-exist, " - "and we do not permit any duplicates." - ) - raise - - # Insert the item into the cache - self.bot.insert_item_into_filter_list_cache(item) - await ctx.message.add_reaction("✅") - - async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we need to convert it. - if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): - guild_data = await self._validate_guild_invite(ctx, content) - content = guild_data.get("id") - - # If it's a file format, let's make sure it has a leading dot. - elif list_type == "FILE_FORMAT" and not content.startswith("."): - content = f".{content}" - - # Find the content and delete it. - log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") - item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) - - if item is not None: - try: - await self.bot.api_client.delete( - f"bot/filter-lists/{item['id']}" - ) - del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] - await ctx.message.add_reaction("✅") - except ResponseCodeError as e: - log.debug( - f"{ctx.author} tried to delete an item with the id {item['id']}, but " - f"the API raised an unexpected error: {e}" - ) - await ctx.message.add_reaction("❌") - else: - await ctx.message.add_reaction("❌") - - async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: - """Paginate and display all items in a filterlist.""" - allow_type = "whitelist" if allowed else "blacklist" - result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] - - # Build a list of lines we want to show in the paginator - lines = [] - for content, metadata in result.items(): - line = f"• `{content}`" - - if comment := metadata.get("comment"): - line += f" - {comment}" - - lines.append(line) - lines = sorted(lines) - - # Build the embed - list_type_plural = list_type.lower().replace("_", " ").title() + "s" - embed = Embed( - title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", - colour=Colour.blue() - ) - log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") - - if result: - await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - await ctx.message.add_reaction("❌") - - async def _sync_data(self, ctx: Context) -> None: - """Syncs the filterlists with the API.""" - try: - log.trace("Attempting to sync FilterList cache with data from the API.") - await self.bot.cache_filter_list_data() - await ctx.message.add_reaction("✅") - except ResponseCodeError as e: - log.debug( - f"{ctx.author} tried to sync FilterList cache data but " - f"the API raised an unexpected error: {e}" - ) - await ctx.message.add_reaction("❌") - - @staticmethod - async def _validate_guild_invite(ctx: Context, invite: str) -> dict: - """ - Validates a guild invite, and returns the guild info as a dict. - - Will raise a BadArgument if the guild invite is invalid. - """ - log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") - validator = ValidDiscordServerInvite() - guild_data = await validator.convert(ctx, invite) - - # If we make it this far without raising a BadArgument, the invite is - # valid. Let's return a dict of guild information. - log.trace(f"{invite} validated as server invite. Converting to ID.") - return guild_data - - @group(aliases=("allowlist", "allow", "al", "wl")) - async def whitelist(self, ctx: Context) -> None: - """Group for whitelisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @group(aliases=("denylist", "deny", "bl", "dl")) - async def blacklist(self, ctx: Context) -> None: - """Group for blacklisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @whitelist.command(name="add", aliases=("a", "set")) - async def allow_add( - self, - ctx: Context, - list_type: ValidFilterListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified allowlist.""" - await self._add_data(ctx, True, list_type, content, comment) - - @blacklist.command(name="add", aliases=("a", "set")) - async def deny_add( - self, - ctx: Context, - list_type: ValidFilterListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified denylist.""" - await self._add_data(ctx, False, list_type, content, comment) - - @whitelist.command(name="remove", aliases=("delete", "rm",)) - async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from the specified allowlist.""" - await self._delete_data(ctx, True, list_type, content) - - @blacklist.command(name="remove", aliases=("delete", "rm",)) - async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: - """Remove an item from the specified denylist.""" - await self._delete_data(ctx, False, list_type, content) - - @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: - """Get the contents of a specified allowlist.""" - await self._list_all_data(ctx, True, list_type) - - @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: - """Get the contents of a specified denylist.""" - await self._list_all_data(ctx, False, list_type) - - @whitelist.command(name="sync", aliases=("s",)) - async def allow_sync(self, ctx: Context) -> None: - """Syncs both allowlists and denylists with the API.""" - await self._sync_data(ctx) - - @blacklist.command(name="sync", aliases=("s",)) - async def deny_sync(self, ctx: Context) -> None: - """Syncs both allowlists and denylists with the API.""" - await self._sync_data(ctx) - - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the FilterLists cog.""" - bot.add_cog(FilterLists(bot)) diff --git a/bot/cogs/filters/filtering.py b/bot/cogs/filters/filtering.py deleted file mode 100644 index 556b466ef..000000000 --- a/bot/cogs/filters/filtering.py +++ /dev/null @@ -1,575 +0,0 @@ -import asyncio -import logging -import re -from datetime import datetime, timedelta -from typing import List, Mapping, Optional, Tuple, Union - -import dateutil -import discord.errors -from dateutil.relativedelta import relativedelta -from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel -from discord.ext.commands import Cog -from discord.utils import escape_markdown - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import ( - Channels, Colours, - Filter, Icons, URLs -) -from bot.utils.redis_cache import RedisCache -from bot.utils.regex import INVITE_RE -from bot.utils.scheduling import Scheduler - -log = logging.getLogger(__name__) - -# Regular expressions -SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) -URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) -ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") - -# Other constants. -DAYS_BETWEEN_ALERTS = 3 -OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) - - -class Filtering(Cog): - """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" - - # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent - name_alerts = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - self.name_lock = asyncio.Lock() - - staff_mistake_str = "If you believe this was a mistake, please let staff know!" - self.filters = { - "filter_zalgo": { - "enabled": Filter.filter_zalgo, - "function": self._has_zalgo, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_zalgo, - "notification_msg": ( - "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " - f"{staff_mistake_str}" - ), - "schedule_deletion": False - }, - "filter_invites": { - "enabled": Filter.filter_invites, - "function": self._has_invites, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_invites, - "notification_msg": ( - f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" - r"Our server rules can be found here: " - ), - "schedule_deletion": False - }, - "filter_domains": { - "enabled": Filter.filter_domains, - "function": self._has_urls, - "type": "filter", - "content_only": True, - "user_notification": Filter.notify_user_domains, - "notification_msg": ( - f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" - ), - "schedule_deletion": False - }, - "watch_regex": { - "enabled": Filter.watch_regex, - "function": self._has_watch_regex_match, - "type": "watchlist", - "content_only": True, - "schedule_deletion": True - }, - "watch_rich_embeds": { - "enabled": Filter.watch_rich_embeds, - "function": self._has_rich_embed, - "type": "watchlist", - "content_only": False, - "schedule_deletion": False - } - } - - self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: - """Fetch items from the filter_list_cache.""" - return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() - - @staticmethod - def _expand_spoilers(text: str) -> str: - """Return a string containing all interpretations of a spoilered message.""" - split_text = SPOILER_RE.split(text) - return ''.join( - split_text[0::2] + split_text[1::2] + split_text - ) - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Invoke message filter for new messages.""" - await self._filter_message(msg) - - # Ignore webhook messages. - if msg.webhook_id is None: - await self.check_bad_words_in_name(msg.author) - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """ - Invoke message filter for message edits. - - If there have been multiple edits, calculate the time delta from the previous edit. - """ - if not before.edited_at: - delta = relativedelta(after.edited_at, before.created_at).microseconds - else: - delta = relativedelta(after.edited_at, before.edited_at).microseconds - await self._filter_message(after, delta) - - def get_name_matches(self, name: str) -> List[re.Match]: - """Check bad words from passed string (name). Return list of matches.""" - matches = [] - watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) - for pattern in watchlist_patterns: - if match := re.search(pattern, name, flags=re.IGNORECASE): - matches.append(match) - return matches - - async def check_send_alert(self, member: Member) -> bool: - """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" - if last_alert := await self.name_alerts.get(member.id): - last_alert = datetime.utcfromtimestamp(last_alert) - if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: - log.trace(f"Last alert was too recent for {member}'s nickname.") - return False - - return True - - async def check_bad_words_in_name(self, member: Member) -> None: - """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" - # Use lock to avoid race conditions - async with self.name_lock: - # Check whether the users display name contains any words in our blacklist - matches = self.get_name_matches(member.display_name) - - if not matches or not await self.check_send_alert(member): - return - - log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") - - log_string = ( - f"**User:** {member.mention} (`{member.id}`)\n" - f"**Display Name:** {member.display_name}\n" - f"**Bad Matches:** {', '.join(match.group() for match in matches)}" - ) - - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colours.soft_red, - title="Username filtering alert", - text=log_string, - channel_id=Channels.mod_alerts, - thumbnail=member.avatar_url - ) - - # Update time when alert sent - await self.name_alerts.set(member.id, datetime.utcnow().timestamp()) - - async def filter_eval(self, result: str, msg: Message) -> bool: - """ - Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. - - Also requires the original message, to check whether to filter and for mod logs. - Returns whether a filter was triggered or not. - """ - filter_triggered = False - # Should we filter this message? - if self._check_filter(msg): - for filter_name, _filter in self.filters.items(): - # Is this specific filter enabled in the config? - # We also do not need to worry about filters that take the full message, - # since all we have is an arbitrary string. - if _filter["enabled"] and _filter["content_only"]: - match = await _filter["function"](result) - - if match: - # If this is a filter (not a watchlist), we set the variable so we know - # that it has been triggered - if _filter["type"] == "filter": - filter_triggered = True - - # We do not have to check against DM channels since !eval cannot be used there. - channel_str = f"in {msg.channel.mention}" - - message_content, additional_embeds, additional_embeds_msg = self._add_stats( - filter_name, match, result - ) - - message = ( - f"The {filter_name} {_filter['type']} was triggered " - f"by **{msg.author}** " - f"(`{msg.author.id}`) {channel_str} using !eval with " - f"[the following message]({msg.jump_url}):\n\n" - f"{message_content}" - ) - - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title=f"{_filter['type'].title()} triggered!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone, - additional_embeds=additional_embeds, - additional_embeds_msg=additional_embeds_msg - ) - - break # We don't want multiple filters to trigger - - return filter_triggered - - async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: - """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" - # Should we filter this message? - if self._check_filter(msg): - for filter_name, _filter in self.filters.items(): - # Is this specific filter enabled in the config? - if _filter["enabled"]: - # Double trigger check for the embeds filter - if filter_name == "watch_rich_embeds": - # If the edit delta is less than 0.001 seconds, then we're probably dealing - # with a double filter trigger. - if delta is not None and delta < 100: - continue - - # Does the filter only need the message content or the full message? - if _filter["content_only"]: - match = await _filter["function"](msg.content) - else: - match = await _filter["function"](msg) - - if match: - is_private = msg.channel.type is discord.ChannelType.private - - # If this is a filter (not a watchlist) and not in a DM, delete the message. - if _filter["type"] == "filter" and not is_private: - try: - # Embeds (can?) trigger both the `on_message` and `on_message_edit` - # event handlers, triggering filtering twice for the same message. - # - # If `on_message`-triggered filtering already deleted the message - # then `on_message_edit`-triggered filtering will raise exception - # since the message no longer exists. - # - # In addition, to avoid sending two notifications to the user, the - # logs, and mod_alert, we return if the message no longer exists. - await msg.delete() - except discord.errors.NotFound: - return - - # Notify the user if the filter specifies - if _filter["user_notification"]: - await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) - - # If the message is classed as offensive, we store it in the site db and - # it will be deleted it after one week. - if _filter["schedule_deletion"] and not is_private: - delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() - data = { - 'id': msg.id, - 'channel_id': msg.channel.id, - 'delete_date': delete_date - } - - await self.bot.api_client.post('bot/offensive-messages', json=data) - self.schedule_msg_delete(data) - log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") - - if is_private: - channel_str = "via DM" - else: - channel_str = f"in {msg.channel.mention}" - - message_content, additional_embeds, additional_embeds_msg = self._add_stats( - filter_name, match, msg.content - ) - - message = ( - f"The {filter_name} {_filter['type']} was triggered " - f"by **{msg.author}** " - f"(`{msg.author.id}`) {channel_str} with [the " - f"following message]({msg.jump_url}):\n\n" - f"{message_content}" - ) - - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.filtering, - colour=Colour(Colours.soft_red), - title=f"{_filter['type'].title()} triggered!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone if not is_private else False, - additional_embeds=additional_embeds, - additional_embeds_msg=additional_embeds_msg - ) - - break # We don't want multiple filters to trigger - - def _add_stats(self, name: str, match: Union[re.Match, dict, bool, List[discord.Embed]], content: str) -> Tuple[ - str, Optional[List[discord.Embed]], Optional[str] - ]: - """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" - # Word and match stats for watch_regex - if name == "watch_regex": - surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] - message_content = ( - f"**Match:** '{match[0]}'\n" - f"**Location:** '...{escape_markdown(surroundings)}...'\n" - f"\n**Original Message:**\n{escape_markdown(content)}" - ) - else: # Use original content - message_content = content - - additional_embeds = None - additional_embeds_msg = None - - self.bot.stats.incr(f"filters.{name}") - - # The function returns True for invalid invites. - # They have no data so additional embeds can't be created for them. - if name == "filter_invites" and match is not True: - additional_embeds = [] - for _, data in match.items(): - embed = discord.Embed(description=( - f"**Members:**\n{data['members']}\n" - f"**Active:**\n{data['active']}" - )) - embed.set_author(name=data["name"]) - embed.set_thumbnail(url=data["icon"]) - embed.set_footer(text=f"Guild ID: {data['id']}") - additional_embeds.append(embed) - additional_embeds_msg = "For the following guild(s):" - - elif name == "watch_rich_embeds": - additional_embeds = match - additional_embeds_msg = "With the following embed(s):" - - return message_content, additional_embeds, additional_embeds_msg - - @staticmethod - def _check_filter(msg: Message) -> bool: - """Check whitelists to see if we should filter this message.""" - role_whitelisted = False - - if type(msg.author) is Member: # Only Member has roles, not User. - for role in msg.author.roles: - if role.id in Filter.role_whitelist: - role_whitelisted = True - - return ( - msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist - and not role_whitelisted # Role not in whitelist - and not msg.author.bot # Author not a bot - ) - - async def _has_watch_regex_match(self, text: str) -> Union[bool, re.Match]: - """ - Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. - - `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is - matched as-is. Spoilers are expanded, if any, and URLs are ignored. - """ - if SPOILER_RE.search(text): - text = self._expand_spoilers(text) - - # Make sure it's not a URL - if URL_RE.search(text): - return False - - watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) - for pattern in watchlist_patterns: - match = re.search(pattern, text, flags=re.IGNORECASE) - if match: - return match - - async def _has_urls(self, text: str) -> bool: - """Returns True if the text contains one of the blacklisted URLs from the config file.""" - if not URL_RE.search(text): - return False - - text = text.lower() - domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) - - for url in domain_blacklist: - if url.lower() in text: - return True - - return False - - @staticmethod - async def _has_zalgo(text: str) -> bool: - """ - Returns True if the text contains zalgo characters. - - Zalgo range is \u0300 – \u036F and \u0489. - """ - return bool(ZALGO_RE.search(text)) - - async def _has_invites(self, text: str) -> Union[dict, bool]: - """ - Checks if there's any invites in the text content that aren't in the guild whitelist. - - If any are detected, a dictionary of invite data is returned, with a key per invite. - If none are detected, False is returned. - - Attempts to catch some of common ways to try to cheat the system. - """ - # Remove backslashes to prevent escape character aroundfuckery like - # discord\.gg/gdudes-pony-farm - text = text.replace("\\", "") - - invites = INVITE_RE.findall(text) - invite_data = dict() - for invite in invites: - if invite in invite_data: - continue - - response = await self.bot.http_session.get( - f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} - ) - response = await response.json() - guild = response.get("guild") - if guild is None: - # Lack of a "guild" key in the JSON response indicates either an group DM invite, an - # expired invite, or an invalid invite. The API does not currently differentiate - # between invalid and expired invites - return True - - guild_id = guild.get("id") - guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) - guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) - - # Is this invite allowed? - guild_partnered_or_verified = ( - 'PARTNERED' in guild.get("features", []) - or 'VERIFIED' in guild.get("features", []) - ) - invite_not_allowed = ( - guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted. - or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted. - and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered. - ) - - if invite_not_allowed: - guild_icon_hash = guild["icon"] - guild_icon = ( - "https://cdn.discordapp.com/icons/" - f"{guild_id}/{guild_icon_hash}.png?size=512" - ) - - invite_data[invite] = { - "name": guild["name"], - "id": guild['id'], - "icon": guild_icon, - "members": response["approximate_member_count"], - "active": response["approximate_presence_count"] - } - - return invite_data if invite_data else False - - @staticmethod - async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: - """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" - if msg.embeds: - for embed in msg.embeds: - if embed.type == "rich": - urls = URL_RE.findall(msg.content) - if not embed.url or embed.url not in urls: - # If `embed.url` does not exist or if `embed.url` is not part of the content - # of the message, it's unlikely to be an auto-generated embed by Discord. - return msg.embeds - else: - log.trace( - "Found a rich embed sent by a regular user account, " - "but it was likely just an automatic URL embed." - ) - return False - return False - - async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: - """ - Notify filtered_member about a moderation action with the reason str. - - First attempts to DM the user, fall back to in-channel notification if user has DMs disabled - """ - try: - await filtered_member.send(reason) - except discord.errors.Forbidden: - await channel.send(f"{filtered_member.mention} {reason}") - - def schedule_msg_delete(self, msg: dict) -> None: - """Delete an offensive message once its deletion date is reached.""" - delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) - - async def reschedule_offensive_msg_deletion(self) -> None: - """Get all the pending message deletion from the API and reschedule them.""" - await self.bot.wait_until_ready() - response = await self.bot.api_client.get('bot/offensive-messages',) - - now = datetime.utcnow() - - for msg in response: - delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - - if delete_at < now: - await self.delete_offensive_msg(msg) - else: - self.schedule_msg_delete(msg) - - async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: - """Delete an offensive message, and then delete it from the db.""" - try: - channel = self.bot.get_channel(msg['channel_id']) - if channel: - msg_obj = await channel.fetch_message(msg['id']) - await msg_obj.delete() - except NotFound: - log.info( - f"Tried to delete message {msg['id']}, but the message can't be found " - f"(it has been probably already deleted)." - ) - except HTTPException as e: - log.warning(f"Failed to delete message {msg['id']}: status {e.status}") - - await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') - log.info(f"Deleted the offensive message with id {msg['id']}.") - - -def setup(bot: Bot) -> None: - """Load the Filtering cog.""" - bot.add_cog(Filtering(bot)) diff --git a/bot/cogs/filters/security.py b/bot/cogs/filters/security.py deleted file mode 100644 index c680c5e27..000000000 --- a/bot/cogs/filters/security.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging - -from discord.ext.commands import Cog, Context, NoPrivateMessage - -from bot.bot import Bot - -log = logging.getLogger(__name__) - - -class Security(Cog): - """Security-related helpers.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.bot.check(self.check_not_bot) # Global commands check - no bots can run any commands at all - self.bot.check(self.check_on_guild) # Global commands check - commands can't be run in a DM - - def check_not_bot(self, ctx: Context) -> bool: - """Check if the context is a bot user.""" - return not ctx.author.bot - - def check_on_guild(self, ctx: Context) -> bool: - """Check if the context is in a guild.""" - if ctx.guild is None: - raise NoPrivateMessage("This command cannot be used in private messages.") - return True - - -def setup(bot: Bot) -> None: - """Load the Security cog.""" - bot.add_cog(Security(bot)) diff --git a/bot/cogs/filters/token_remover.py b/bot/cogs/filters/token_remover.py deleted file mode 100644 index 8eace07b6..000000000 --- a/bot/cogs/filters/token_remover.py +++ /dev/null @@ -1,182 +0,0 @@ -import base64 -import binascii -import logging -import re -import typing as t - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot import utils -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Channels, Colours, Event, Icons - -log = logging.getLogger(__name__) - -LOG_MESSAGE = ( - "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, " - "token was `{user_id}.{timestamp}.{hmac}`" -) -DELETION_MESSAGE_TEMPLATE = ( - "Hey {mention}! I noticed you posted a seemingly valid Discord API " - "token in your message and have removed your message. " - "This means that your token has been **compromised**. " - "Please change your token **immediately** at: " - "\n\n" - "Feel free to re-post it with the token removed. " - "If you believe this was a mistake, please let us know!" -) -DISCORD_EPOCH = 1_420_070_400 -TOKEN_EPOCH = 1_293_840_000 - -# Three parts delimited by dots: user ID, creation timestamp, HMAC. -# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string. -# Each part only matches base64 URL-safe characters. -# Padding has never been observed, but the padding character '=' is matched just in case. -TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII) - - -class Token(t.NamedTuple): - """A Discord Bot token.""" - - user_id: str - timestamp: str - hmac: str - - -class TokenRemover(Cog): - """Scans messages for potential discord.py bot tokens and removes them.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """ - Check each message for a string that matches Discord's token pattern. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - # Ignore DMs; can't delete messages in there anyway. - if not msg.guild or msg.author.bot: - return - - found_token = self.find_token_in_message(msg) - if found_token: - await self.take_action(msg, found_token) - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """ - Check each edit for a string that matches Discord's token pattern. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - await self.on_message(after) - - async def take_action(self, msg: Message, found_token: Token) -> None: - """Remove the `msg` containing the `found_token` and send a mod log message.""" - self.mod_log.ignore(Event.message_delete, msg.id) - - try: - await msg.delete() - except NotFound: - log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") - return - - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - log_message = self.format_log_message(msg, found_token) - log.debug(log_message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=log_message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) - - self.bot.stats.incr("tokens.removed_tokens") - - @staticmethod - def format_log_message(msg: Message, token: Token) -> str: - """Return the log message to send for `token` being censored in `msg`.""" - return LOG_MESSAGE.format( - author=msg.author, - author_id=msg.author.id, - channel=msg.channel.mention, - user_id=token.user_id, - timestamp=token.timestamp, - hmac='x' * len(token.hmac), - ) - - @classmethod - def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: - """Return a seemingly valid token found in `msg` or `None` if no token is found.""" - # Use finditer rather than search to guard against method calls prematurely returning the - # token check (e.g. `message.channel.send` also matches our token pattern) - for match in TOKEN_RE.finditer(msg.content): - token = Token(*match.groups()) - if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp): - # Short-circuit on first match - return token - - # No matching substring - return - - @staticmethod - def is_valid_user_id(b64_content: str) -> bool: - """ - Check potential token to see if it contains a valid Discord user ID. - - See: https://discordapp.com/developers/docs/reference#snowflakes - """ - b64_content = utils.pad_base64(b64_content) - - try: - decoded_bytes = base64.urlsafe_b64decode(b64_content) - string = decoded_bytes.decode('utf-8') - - # isdigit on its own would match a lot of other Unicode characters, hence the isascii. - return string.isascii() and string.isdigit() - except (binascii.Error, ValueError): - return False - - @staticmethod - def is_valid_timestamp(b64_content: str) -> bool: - """ - Return True if `b64_content` decodes to a valid timestamp. - - If the timestamp is greater than the Discord epoch, it's probably valid. - See: https://i.imgur.com/7WdehGn.png - """ - b64_content = utils.pad_base64(b64_content) - - try: - decoded_bytes = base64.urlsafe_b64decode(b64_content) - timestamp = int.from_bytes(decoded_bytes, byteorder="big") - except (binascii.Error, ValueError) as e: - log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") - return False - - # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound - # is not checked. - if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH: - return True - else: - log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch") - return False - - -def setup(bot: Bot) -> None: - """Load the TokenRemover cog.""" - bot.add_cog(TokenRemover(bot)) diff --git a/bot/cogs/filters/webhook_remover.py b/bot/cogs/filters/webhook_remover.py deleted file mode 100644 index 5812da87c..000000000 --- a/bot/cogs/filters/webhook_remover.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging -import re - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Channels, Colours, Event, Icons - -WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) - -ALERT_MESSAGE_TEMPLATE = ( - "{user}, looks like you posted a Discord webhook URL. Therefore, your " - "message has been removed. Your webhook may have been **compromised** so " - "please re-create the webhook **immediately**. If you believe this was " - "mistake, please let us know." -) - -log = logging.getLogger(__name__) - - -class WebhookRemover(Cog): - """Scan messages to detect Discord webhooks links.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get current instance of `ModLog`.""" - return self.bot.get_cog("ModLog") - - async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: - """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" - # Don't log this, due internal delete, not by user. Will make different entry. - self.mod_log.ignore(Event.message_delete, msg.id) - - try: - await msg.delete() - except NotFound: - log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") - return - - await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) - - message = ( - f"{msg.author} (`{msg.author.id}`) posted a Discord webhook URL " - f"to #{msg.channel}. Webhook URL was `{redacted_url}`" - ) - log.debug(message) - - # Send entry to moderation alerts. - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Discord webhook URL removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts - ) - - self.bot.stats.incr("tokens.removed_webhooks") - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Check if a Discord webhook URL is in `message`.""" - # Ignore DMs; can't delete messages in there anyway. - if not msg.guild or msg.author.bot: - return - - matches = WEBHOOK_URL_RE.search(msg.content) - if matches: - await self.delete_and_respond(msg, matches[1] + "xxx") - - @Cog.listener() - async def on_message_edit(self, before: Message, after: Message) -> None: - """Check if a Discord webhook URL is in the edited message `after`.""" - await self.on_message(after) - - -def setup(bot: Bot) -> None: - """Load `WebhookRemover` cog.""" - bot.add_cog(WebhookRemover(bot)) diff --git a/bot/cogs/help_channels.py b/bot/cogs/help_channels.py deleted file mode 100644 index 57094751e..000000000 --- a/bot/cogs/help_channels.py +++ /dev/null @@ -1,944 +0,0 @@ -import asyncio -import json -import logging -import random -import typing as t -from collections import deque -from datetime import datetime, timedelta, timezone -from pathlib import Path - -import discord -import discord.abc -from discord.ext import commands - -from bot import constants -from bot.bot import Bot -from bot.utils import RedisCache -from bot.utils.checks import with_role_check -from bot.utils.scheduling import Scheduler - -log = logging.getLogger(__name__) - -ASKING_GUIDE_URL = "https://pythondiscord.com/pages/asking-good-questions/" -MAX_CHANNELS_PER_CATEGORY = 50 -EXCLUDED_CHANNELS = (constants.Channels.how_to_get_help, constants.Channels.cooldown) - -HELP_CHANNEL_TOPIC = """ -This is a Python help channel. You can claim your own help channel in the Python Help: Available category. -""" - -AVAILABLE_MSG = f""" -This help channel is now **available**, which means that you can claim it by simply typing your \ -question into it. Once claimed, the channel will move into the **Python Help: Occupied** category, \ -and will be yours until it has been inactive for {constants.HelpChannels.idle_minutes} minutes or \ -is closed manually with `!close`. When that happens, it will be set to **dormant** and moved into \ -the **Help: Dormant** category. - -Try to write the best question you can by providing a detailed description and telling us what \ -you've tried already. For more information on asking a good question, \ -check out our guide on [asking good questions]({ASKING_GUIDE_URL}). -""" - -DORMANT_MSG = f""" -This help channel has been marked as **dormant**, and has been moved into the **Help: Dormant** \ -category at the bottom of the channel list. It is no longer possible to send messages in this \ -channel until it becomes available again. - -If your question wasn't answered yet, you can claim a new help channel from the \ -**Help: Available** category by simply asking your question again. Consider rephrasing the \ -question to maximize your chance of getting a good answer. If you're not sure how, have a look \ -through our guide for [asking a good question]({ASKING_GUIDE_URL}). -""" - -CoroutineFunc = t.Callable[..., t.Coroutine] - - -class HelpChannels(commands.Cog): - """ - Manage the help channel system of the guild. - - The system is based on a 3-category system: - - Available Category - - * Contains channels which are ready to be occupied by someone who needs help - * Will always contain `constants.HelpChannels.max_available` channels; refilled automatically - from the pool of dormant channels - * Prioritise using the channels which have been dormant for the longest amount of time - * If there are no more dormant channels, the bot will automatically create a new one - * If there are no dormant channels to move, helpers will be notified (see `notify()`) - * When a channel becomes available, the dormant embed will be edited to show `AVAILABLE_MSG` - * User can only claim a channel at an interval `constants.HelpChannels.claim_minutes` - * To keep track of cooldowns, user which claimed a channel will have a temporary role - - In Use Category - - * Contains all channels which are occupied by someone needing help - * Channel moves to dormant category after `constants.HelpChannels.idle_minutes` of being idle - * Command can prematurely mark a channel as dormant - * Channel claimant is allowed to use the command - * Allowed roles for the command are configurable with `constants.HelpChannels.cmd_whitelist` - * When a channel becomes dormant, an embed with `DORMANT_MSG` will be sent - - Dormant Category - - * Contains channels which aren't in use - * Channels are used to refill the Available category - - Help channels are named after the chemical elements in `bot/resources/elements.json`. - """ - - # This cache tracks which channels are claimed by which members. - # RedisCache[discord.TextChannel.id, t.Union[discord.User.id, discord.Member.id]] - help_channel_claimants = RedisCache() - - # This cache maps a help channel to whether it has had any - # activity other than the original claimant. True being no other - # activity and False being other activity. - # RedisCache[discord.TextChannel.id, bool] - unanswered = RedisCache() - - # This dictionary maps a help channel to the time it was claimed - # RedisCache[discord.TextChannel.id, UtcPosixTimestamp] - claim_times = RedisCache() - - # This cache maps a help channel to original question message in same channel. - # RedisCache[discord.TextChannel.id, discord.Message.id] - question_messages = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - # Categories - self.available_category: discord.CategoryChannel = None - self.in_use_category: discord.CategoryChannel = None - self.dormant_category: discord.CategoryChannel = None - - # Queues - self.channel_queue: asyncio.Queue[discord.TextChannel] = None - self.name_queue: t.Deque[str] = None - - self.name_positions = self.get_names() - self.last_notification: t.Optional[datetime] = None - - # Asyncio stuff - self.queue_tasks: t.List[asyncio.Task] = [] - self.ready = asyncio.Event() - self.on_message_lock = asyncio.Lock() - self.init_task = self.bot.loop.create_task(self.init_cog()) - - def cog_unload(self) -> None: - """Cancel the init task and scheduled tasks when the cog unloads.""" - log.trace("Cog unload: cancelling the init_cog task") - self.init_task.cancel() - - log.trace("Cog unload: cancelling the channel queue tasks") - for task in self.queue_tasks: - task.cancel() - - self.scheduler.cancel_all() - - def create_channel_queue(self) -> asyncio.Queue: - """ - Return a queue of dormant channels to use for getting the next available channel. - - The channels are added to the queue in a random order. - """ - log.trace("Creating the channel queue.") - - channels = list(self.get_category_channels(self.dormant_category)) - random.shuffle(channels) - - log.trace("Populating the channel queue with channels.") - queue = asyncio.Queue() - for channel in channels: - queue.put_nowait(channel) - - return queue - - async def create_dormant(self) -> t.Optional[discord.TextChannel]: - """ - Create and return a new channel in the Dormant category. - - The new channel will sync its permission overwrites with the category. - - Return None if no more channel names are available. - """ - log.trace("Getting a name for a new dormant channel.") - - try: - name = self.name_queue.popleft() - except IndexError: - log.debug("No more names available for new dormant channels.") - return None - - log.debug(f"Creating a new dormant channel named {name}.") - return await self.dormant_category.create_text_channel(name, topic=HELP_CHANNEL_TOPIC) - - def create_name_queue(self) -> deque: - """Return a queue of element names to use for creating new channels.""" - log.trace("Creating the chemical element name queue.") - - used_names = self.get_used_names() - - log.trace("Determining the available names.") - available_names = (name for name in self.name_positions if name not in used_names) - - log.trace("Populating the name queue with names.") - return deque(available_names) - - async def dormant_check(self, ctx: commands.Context) -> bool: - """Return True if the user is the help channel claimant or passes the role check.""" - if await self.help_channel_claimants.get(ctx.channel.id) == ctx.author.id: - log.trace(f"{ctx.author} is the help channel claimant, passing the check for dormant.") - self.bot.stats.incr("help.dormant_invoke.claimant") - return True - - log.trace(f"{ctx.author} is not the help channel claimant, checking roles.") - role_check = with_role_check(ctx, *constants.HelpChannels.cmd_whitelist) - - if role_check: - self.bot.stats.incr("help.dormant_invoke.staff") - - return role_check - - @commands.command(name="close", aliases=["dormant", "solved"], enabled=False) - async def close_command(self, ctx: commands.Context) -> None: - """ - Make the current in-use help channel dormant. - - Make the channel dormant if the user passes the `dormant_check`, - delete the message that invoked this, - and reset the send permissions cooldown for the user who started the session. - """ - log.trace("close command invoked; checking if the channel is in-use.") - if ctx.channel.category == self.in_use_category: - if await self.dormant_check(ctx): - await self.remove_cooldown_role(ctx.author) - - # Ignore missing task when cooldown has passed but the channel still isn't dormant. - if ctx.author.id in self.scheduler: - self.scheduler.cancel(ctx.author.id) - - await self.move_to_dormant(ctx.channel, "command") - self.scheduler.cancel(ctx.channel.id) - else: - log.debug(f"{ctx.author} invoked command 'dormant' outside an in-use help channel") - - async def get_available_candidate(self) -> discord.TextChannel: - """ - Return a dormant channel to turn into an available channel. - - If no channel is available, wait indefinitely until one becomes available. - """ - log.trace("Getting an available channel candidate.") - - try: - channel = self.channel_queue.get_nowait() - except asyncio.QueueEmpty: - log.info("No candidate channels in the queue; creating a new channel.") - channel = await self.create_dormant() - - if not channel: - log.info("Couldn't create a candidate channel; waiting to get one from the queue.") - await self.notify() - channel = await self.wait_for_dormant_channel() - - return channel - - @staticmethod - def get_clean_channel_name(channel: discord.TextChannel) -> str: - """Return a clean channel name without status emojis prefix.""" - prefix = constants.HelpChannels.name_prefix - try: - # Try to remove the status prefix using the index of the channel prefix - name = channel.name[channel.name.index(prefix):] - log.trace(f"The clean name for `{channel}` is `{name}`") - except ValueError: - # If, for some reason, the channel name does not contain "help-" fall back gracefully - log.info(f"Can't get clean name because `{channel}` isn't prefixed by `{prefix}`.") - name = channel.name - - return name - - @staticmethod - def is_excluded_channel(channel: discord.abc.GuildChannel) -> bool: - """Check if a channel should be excluded from the help channel system.""" - return not isinstance(channel, discord.TextChannel) or channel.id in EXCLUDED_CHANNELS - - def get_category_channels(self, category: discord.CategoryChannel) -> t.Iterable[discord.TextChannel]: - """Yield the text channels of the `category` in an unsorted manner.""" - log.trace(f"Getting text channels in the category '{category}' ({category.id}).") - - # This is faster than using category.channels because the latter sorts them. - for channel in self.bot.get_guild(constants.Guild.id).channels: - if channel.category_id == category.id and not self.is_excluded_channel(channel): - yield channel - - async def get_in_use_time(self, channel_id: int) -> t.Optional[timedelta]: - """Return the duration `channel_id` has been in use. Return None if it's not in use.""" - log.trace(f"Calculating in use time for channel {channel_id}.") - - claimed_timestamp = await self.claim_times.get(channel_id) - if claimed_timestamp: - claimed = datetime.utcfromtimestamp(claimed_timestamp) - return datetime.utcnow() - claimed - - @staticmethod - def get_names() -> t.List[str]: - """ - Return a truncated list of prefixed element names. - - The amount of names is configured with `HelpChannels.max_total_channels`. - The prefix is configured with `HelpChannels.name_prefix`. - """ - count = constants.HelpChannels.max_total_channels - prefix = constants.HelpChannels.name_prefix - - log.trace(f"Getting the first {count} element names from JSON.") - - with Path("bot/resources/elements.json").open(encoding="utf-8") as elements_file: - all_names = json.load(elements_file) - - if prefix: - return [prefix + name for name in all_names[:count]] - else: - return all_names[:count] - - def get_used_names(self) -> t.Set[str]: - """Return channel names which are already being used.""" - log.trace("Getting channel names which are already being used.") - - names = set() - for cat in (self.available_category, self.in_use_category, self.dormant_category): - for channel in self.get_category_channels(cat): - names.add(self.get_clean_channel_name(channel)) - - if len(names) > MAX_CHANNELS_PER_CATEGORY: - log.warning( - f"Too many help channels ({len(names)}) already exist! " - f"Discord only supports {MAX_CHANNELS_PER_CATEGORY} in a category." - ) - - log.trace(f"Got {len(names)} used names: {names}") - return names - - @classmethod - async def get_idle_time(cls, channel: discord.TextChannel) -> t.Optional[int]: - """ - Return the time elapsed, in seconds, since the last message sent in the `channel`. - - Return None if the channel has no messages. - """ - log.trace(f"Getting the idle time for #{channel} ({channel.id}).") - - msg = await cls.get_last_message(channel) - if not msg: - log.debug(f"No idle time available; #{channel} ({channel.id}) has no messages.") - return None - - idle_time = (datetime.utcnow() - msg.created_at).seconds - - log.trace(f"#{channel} ({channel.id}) has been idle for {idle_time} seconds.") - return idle_time - - @staticmethod - async def get_last_message(channel: discord.TextChannel) -> t.Optional[discord.Message]: - """Return the last message sent in the channel or None if no messages exist.""" - log.trace(f"Getting the last message in #{channel} ({channel.id}).") - - try: - return await channel.history(limit=1).next() # noqa: B305 - except discord.NoMoreItems: - log.debug(f"No last message available; #{channel} ({channel.id}) has no messages.") - return None - - async def init_available(self) -> None: - """Initialise the Available category with channels.""" - log.trace("Initialising the Available category with channels.") - - channels = list(self.get_category_channels(self.available_category)) - missing = constants.HelpChannels.max_available - len(channels) - - # If we've got less than `max_available` channel available, we should add some. - if missing > 0: - log.trace(f"Moving {missing} missing channels to the Available category.") - for _ in range(missing): - await self.move_to_available() - - # If for some reason we have more than `max_available` channels available, - # we should move the superfluous ones over to dormant. - elif missing < 0: - log.trace(f"Moving {abs(missing)} superfluous available channels over to the Dormant category.") - for channel in channels[:abs(missing)]: - await self.move_to_dormant(channel, "auto") - - async def init_categories(self) -> None: - """Get the help category objects. Remove the cog if retrieval fails.""" - log.trace("Getting the CategoryChannel objects for the help categories.") - - try: - self.available_category = await self.try_get_channel( - constants.Categories.help_available - ) - self.in_use_category = await self.try_get_channel(constants.Categories.help_in_use) - self.dormant_category = await self.try_get_channel(constants.Categories.help_dormant) - except discord.HTTPException: - log.exception("Failed to get a category; cog will be removed") - self.bot.remove_cog(self.qualified_name) - - async def init_cog(self) -> None: - """Initialise the help channel system.""" - log.trace("Waiting for the guild to be available before initialisation.") - await self.bot.wait_until_guild_available() - - log.trace("Initialising the cog.") - await self.init_categories() - await self.check_cooldowns() - - self.channel_queue = self.create_channel_queue() - self.name_queue = self.create_name_queue() - - log.trace("Moving or rescheduling in-use channels.") - for channel in self.get_category_channels(self.in_use_category): - await self.move_idle_channel(channel, has_task=False) - - # Prevent the command from being used until ready. - # The ready event wasn't used because channels could change categories between the time - # the command is invoked and the cog is ready (e.g. if move_idle_channel wasn't called yet). - # This may confuse users. So would potentially long delays for the cog to become ready. - self.close_command.enabled = True - - await self.init_available() - - log.info("Cog is ready!") - self.ready.set() - - self.report_stats() - - def report_stats(self) -> None: - """Report the channel count stats.""" - total_in_use = sum(1 for _ in self.get_category_channels(self.in_use_category)) - total_available = sum(1 for _ in self.get_category_channels(self.available_category)) - total_dormant = sum(1 for _ in self.get_category_channels(self.dormant_category)) - - self.bot.stats.gauge("help.total.in_use", total_in_use) - self.bot.stats.gauge("help.total.available", total_available) - self.bot.stats.gauge("help.total.dormant", total_dormant) - - @staticmethod - def is_claimant(member: discord.Member) -> bool: - """Return True if `member` has the 'Help Cooldown' role.""" - return any(constants.Roles.help_cooldown == role.id for role in member.roles) - - def match_bot_embed(self, message: t.Optional[discord.Message], description: str) -> bool: - """Return `True` if the bot's `message`'s embed description matches `description`.""" - if not message or not message.embeds: - return False - - bot_msg_desc = message.embeds[0].description - if bot_msg_desc is discord.Embed.Empty: - log.trace("Last message was a bot embed but it was empty.") - return False - return message.author == self.bot.user and bot_msg_desc.strip() == description.strip() - - @staticmethod - def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: - """Return True if `channel` is within a category with `category_id`.""" - actual_category = getattr(channel, "category", None) - return actual_category is not None and actual_category.id == category_id - - async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None: - """ - Make the `channel` dormant if idle or schedule the move if still active. - - If `has_task` is True and rescheduling is required, the extant task to make the channel - dormant will first be cancelled. - """ - log.trace(f"Handling in-use channel #{channel} ({channel.id}).") - - if not await self.is_empty(channel): - idle_seconds = constants.HelpChannels.idle_minutes * 60 - else: - idle_seconds = constants.HelpChannels.deleted_idle_minutes * 60 - - time_elapsed = await self.get_idle_time(channel) - - if time_elapsed is None or time_elapsed >= idle_seconds: - log.info( - f"#{channel} ({channel.id}) is idle longer than {idle_seconds} seconds " - f"and will be made dormant." - ) - - await self.move_to_dormant(channel, "auto") - else: - # Cancel the existing task, if any. - if has_task: - self.scheduler.cancel(channel.id) - - delay = idle_seconds - time_elapsed - log.info( - f"#{channel} ({channel.id}) is still active; " - f"scheduling it to be moved after {delay} seconds." - ) - - self.scheduler.schedule_later(delay, channel.id, self.move_idle_channel(channel)) - - async def move_to_bottom_position(self, channel: discord.TextChannel, category_id: int, **options) -> None: - """ - Move the `channel` to the bottom position of `category` and edit channel attributes. - - To ensure "stable sorting", we use the `bulk_channel_update` endpoint and provide the current - positions of the other channels in the category as-is. This should make sure that the channel - really ends up at the bottom of the category. - - If `options` are provided, the channel will be edited after the move is completed. This is the - same order of operations that `discord.TextChannel.edit` uses. For information on available - options, see the documention on `discord.TextChannel.edit`. While possible, position-related - options should be avoided, as it may interfere with the category move we perform. - """ - # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. - category = await self.try_get_channel(category_id) - - payload = [{"id": c.id, "position": c.position} for c in category.channels] - - # Calculate the bottom position based on the current highest position in the category. If the - # category is currently empty, we simply use the current position of the channel to avoid making - # unnecessary changes to positions in the guild. - bottom_position = payload[-1]["position"] + 1 if payload else channel.position - - payload.append( - { - "id": channel.id, - "position": bottom_position, - "parent_id": category.id, - "lock_permissions": True, - } - ) - - # We use d.py's method to ensure our request is processed by d.py's rate limit manager - await self.bot.http.bulk_channel_update(category.guild.id, payload) - - # Now that the channel is moved, we can edit the other attributes - if options: - await channel.edit(**options) - - async def move_to_available(self) -> None: - """Make a channel available.""" - log.trace("Making a channel available.") - - channel = await self.get_available_candidate() - log.info(f"Making #{channel} ({channel.id}) available.") - - await self.send_available_message(channel) - - log.trace(f"Moving #{channel} ({channel.id}) to the Available category.") - - await self.move_to_bottom_position( - channel=channel, - category_id=constants.Categories.help_available, - ) - - self.report_stats() - - async def move_to_dormant(self, channel: discord.TextChannel, caller: str) -> None: - """ - Make the `channel` dormant. - - A caller argument is provided for metrics. - """ - log.info(f"Moving #{channel} ({channel.id}) to the Dormant category.") - - await self.help_channel_claimants.delete(channel.id) - await self.move_to_bottom_position( - channel=channel, - category_id=constants.Categories.help_dormant, - ) - - self.bot.stats.incr(f"help.dormant_calls.{caller}") - - in_use_time = await self.get_in_use_time(channel.id) - if in_use_time: - self.bot.stats.timing("help.in_use_time", in_use_time) - - unanswered = await self.unanswered.get(channel.id) - if unanswered: - self.bot.stats.incr("help.sessions.unanswered") - elif unanswered is not None: - self.bot.stats.incr("help.sessions.answered") - - log.trace(f"Position of #{channel} ({channel.id}) is actually {channel.position}.") - log.trace(f"Sending dormant message for #{channel} ({channel.id}).") - embed = discord.Embed(description=DORMANT_MSG) - await channel.send(embed=embed) - - await self.unpin(channel) - - log.trace(f"Pushing #{channel} ({channel.id}) into the channel queue.") - self.channel_queue.put_nowait(channel) - self.report_stats() - - async def move_to_in_use(self, channel: discord.TextChannel) -> None: - """Make a channel in-use and schedule it to be made dormant.""" - log.info(f"Moving #{channel} ({channel.id}) to the In Use category.") - - await self.move_to_bottom_position( - channel=channel, - category_id=constants.Categories.help_in_use, - ) - - timeout = constants.HelpChannels.idle_minutes * 60 - - log.trace(f"Scheduling #{channel} ({channel.id}) to become dormant in {timeout} sec.") - self.scheduler.schedule_later(timeout, channel.id, self.move_idle_channel(channel)) - self.report_stats() - - async def notify(self) -> None: - """ - Send a message notifying about a lack of available help channels. - - Configuration: - - * `HelpChannels.notify` - toggle notifications - * `HelpChannels.notify_channel` - destination channel for notifications - * `HelpChannels.notify_minutes` - minimum interval between notifications - * `HelpChannels.notify_roles` - roles mentioned in notifications - """ - if not constants.HelpChannels.notify: - return - - log.trace("Notifying about lack of channels.") - - if self.last_notification: - elapsed = (datetime.utcnow() - self.last_notification).seconds - minimum_interval = constants.HelpChannels.notify_minutes * 60 - should_send = elapsed >= minimum_interval - else: - should_send = True - - if not should_send: - log.trace("Notification not sent because it's too recent since the previous one.") - return - - try: - log.trace("Sending notification message.") - - channel = self.bot.get_channel(constants.HelpChannels.notify_channel) - mentions = " ".join(f"<@&{role}>" for role in constants.HelpChannels.notify_roles) - allowed_roles = [discord.Object(id_) for id_ in constants.HelpChannels.notify_roles] - - message = await channel.send( - f"{mentions} A new available help channel is needed but there " - f"are no more dormant ones. Consider freeing up some in-use channels manually by " - f"using the `{constants.Bot.prefix}dormant` command within the channels.", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) - ) - - self.bot.stats.incr("help.out_of_channel_alerts") - - self.last_notification = message.created_at - except Exception: - # Handle it here cause this feature isn't critical for the functionality of the system. - log.exception("Failed to send notification about lack of dormant channels!") - - async def check_for_answer(self, message: discord.Message) -> None: - """Checks for whether new content in a help channel comes from non-claimants.""" - channel = message.channel - - # Confirm the channel is an in use help channel - if self.is_in_category(channel, constants.Categories.help_in_use): - log.trace(f"Checking if #{channel} ({channel.id}) has been answered.") - - # Check if there is an entry in unanswered - if await self.unanswered.contains(channel.id): - claimant_id = await self.help_channel_claimants.get(channel.id) - if not claimant_id: - # The mapping for this channel doesn't exist, we can't do anything. - return - - # Check the message did not come from the claimant - if claimant_id != message.author.id: - # Mark the channel as answered - await self.unanswered.set(channel.id, False) - - @commands.Cog.listener() - async def on_message(self, message: discord.Message) -> None: - """Move an available channel to the In Use category and replace it with a dormant one.""" - if message.author.bot: - return # Ignore messages sent by bots. - - channel = message.channel - - await self.check_for_answer(message) - - if not self.is_in_category(channel, constants.Categories.help_available) or self.is_excluded_channel(channel): - return # Ignore messages outside the Available category or in excluded channels. - - log.trace("Waiting for the cog to be ready before processing messages.") - await self.ready.wait() - - log.trace("Acquiring lock to prevent a channel from being processed twice...") - async with self.on_message_lock: - log.trace(f"on_message lock acquired for {message.id}.") - - if not self.is_in_category(channel, constants.Categories.help_available): - log.debug( - f"Message {message.id} will not make #{channel} ({channel.id}) in-use " - f"because another message in the channel already triggered that." - ) - return - - log.info(f"Channel #{channel} was claimed by `{message.author.id}`.") - await self.move_to_in_use(channel) - await self.revoke_send_permissions(message.author) - - await self.pin(message) - - # Add user with channel for dormant check. - await self.help_channel_claimants.set(channel.id, message.author.id) - - self.bot.stats.incr("help.claimed") - - # Must use a timezone-aware datetime to ensure a correct POSIX timestamp. - timestamp = datetime.now(timezone.utc).timestamp() - await self.claim_times.set(channel.id, timestamp) - - await self.unanswered.set(channel.id, True) - - log.trace(f"Releasing on_message lock for {message.id}.") - - # Move a dormant channel to the Available category to fill in the gap. - # This is done last and outside the lock because it may wait indefinitely for a channel to - # be put in the queue. - await self.move_to_available() - - @commands.Cog.listener() - async def on_message_delete(self, msg: discord.Message) -> None: - """ - Reschedule an in-use channel to become dormant sooner if the channel is empty. - - The new time for the dormant task is configured with `HelpChannels.deleted_idle_minutes`. - """ - if not self.is_in_category(msg.channel, constants.Categories.help_in_use): - return - - if not await self.is_empty(msg.channel): - return - - log.info(f"Claimant of #{msg.channel} ({msg.author}) deleted message, channel is empty now. Rescheduling task.") - - # Cancel existing dormant task before scheduling new. - self.scheduler.cancel(msg.channel.id) - - delay = constants.HelpChannels.deleted_idle_minutes * 60 - self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel)) - - async def is_empty(self, channel: discord.TextChannel) -> bool: - """Return True if there's an AVAILABLE_MSG and the messages leading up are bot messages.""" - log.trace(f"Checking if #{channel} ({channel.id}) is empty.") - - # A limit of 100 results in a single API call. - # If AVAILABLE_MSG isn't found within 100 messages, then assume the channel is not empty. - # Not gonna do an extensive search for it cause it's too expensive. - async for msg in channel.history(limit=100): - if not msg.author.bot: - log.trace(f"#{channel} ({channel.id}) has a non-bot message.") - return False - - if self.match_bot_embed(msg, AVAILABLE_MSG): - log.trace(f"#{channel} ({channel.id}) has the available message embed.") - return True - - return False - - async def check_cooldowns(self) -> None: - """Remove expired cooldowns and re-schedule active ones.""" - log.trace("Checking all cooldowns to remove or re-schedule them.") - guild = self.bot.get_guild(constants.Guild.id) - cooldown = constants.HelpChannels.claim_minutes * 60 - - for channel_id, member_id in await self.help_channel_claimants.items(): - member = guild.get_member(member_id) - if not member: - continue # Member probably left the guild. - - in_use_time = await self.get_in_use_time(channel_id) - - if not in_use_time or in_use_time.seconds > cooldown: - # Remove the role if no claim time could be retrieved or if the cooldown expired. - # Since the channel is in the claimants cache, it is definitely strange for a time - # to not exist. However, it isn't a reason to keep the user stuck with a cooldown. - await self.remove_cooldown_role(member) - else: - # The member is still on a cooldown; re-schedule it for the remaining time. - delay = cooldown - in_use_time.seconds - self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) - - async def add_cooldown_role(self, member: discord.Member) -> None: - """Add the help cooldown role to `member`.""" - log.trace(f"Adding cooldown role for {member} ({member.id}).") - await self._change_cooldown_role(member, member.add_roles) - - async def remove_cooldown_role(self, member: discord.Member) -> None: - """Remove the help cooldown role from `member`.""" - log.trace(f"Removing cooldown role for {member} ({member.id}).") - await self._change_cooldown_role(member, member.remove_roles) - - async def _change_cooldown_role(self, member: discord.Member, coro_func: CoroutineFunc) -> None: - """ - Change `member`'s cooldown role via awaiting `coro_func` and handle errors. - - `coro_func` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. - """ - guild = self.bot.get_guild(constants.Guild.id) - role = guild.get_role(constants.Roles.help_cooldown) - if role is None: - log.warning(f"Help cooldown role ({constants.Roles.help_cooldown}) could not be found!") - return - - try: - await coro_func(role) - except discord.NotFound: - log.debug(f"Failed to change role for {member} ({member.id}): member not found") - except discord.Forbidden: - log.debug( - f"Forbidden to change role for {member} ({member.id}); " - f"possibly due to role hierarchy" - ) - except discord.HTTPException as e: - log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") - - async def revoke_send_permissions(self, member: discord.Member) -> None: - """ - Disallow `member` to send messages in the Available category for a certain time. - - The time until permissions are reinstated can be configured with - `HelpChannels.claim_minutes`. - """ - log.trace( - f"Revoking {member}'s ({member.id}) send message permissions in the Available category." - ) - - await self.add_cooldown_role(member) - - # Cancel the existing task, if any. - # Would mean the user somehow bypassed the lack of permissions (e.g. user is guild owner). - if member.id in self.scheduler: - self.scheduler.cancel(member.id) - - delay = constants.HelpChannels.claim_minutes * 60 - self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) - - async def send_available_message(self, channel: discord.TextChannel) -> None: - """Send the available message by editing a dormant message or sending a new message.""" - channel_info = f"#{channel} ({channel.id})" - log.trace(f"Sending available message in {channel_info}.") - - embed = discord.Embed(description=AVAILABLE_MSG) - - msg = await self.get_last_message(channel) - if self.match_bot_embed(msg, DORMANT_MSG): - log.trace(f"Found dormant message {msg.id} in {channel_info}; editing it.") - await msg.edit(embed=embed) - else: - log.trace(f"Dormant message not found in {channel_info}; sending a new message.") - await channel.send(embed=embed) - - async def try_get_channel(self, channel_id: int) -> discord.abc.GuildChannel: - """Attempt to get or fetch a channel and return it.""" - log.trace(f"Getting the channel {channel_id}.") - - channel = self.bot.get_channel(channel_id) - if not channel: - log.debug(f"Channel {channel_id} is not in cache; fetching from API.") - channel = await self.bot.fetch_channel(channel_id) - - log.trace(f"Channel #{channel} ({channel_id}) retrieved.") - return channel - - async def pin_wrapper(self, msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool: - """ - Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. - - Return True if successful and False otherwise. - """ - channel_str = f"#{channel} ({channel.id})" - if pin: - func = self.bot.http.pin_message - verb = "pin" - else: - func = self.bot.http.unpin_message - verb = "unpin" - - try: - await func(channel.id, msg_id) - except discord.HTTPException as e: - if e.code == 10008: - log.debug(f"Message {msg_id} in {channel_str} doesn't exist; can't {verb}.") - else: - log.exception( - f"Error {verb}ning message {msg_id} in {channel_str}: {e.status} ({e.code})" - ) - return False - else: - log.trace(f"{verb.capitalize()}ned message {msg_id} in {channel_str}.") - return True - - async def pin(self, message: discord.Message) -> None: - """Pin an initial question `message` and store it in a cache.""" - if await self.pin_wrapper(message.id, message.channel, pin=True): - await self.question_messages.set(message.channel.id, message.id) - - async def unpin(self, channel: discord.TextChannel) -> None: - """Unpin the initial question message sent in `channel`.""" - msg_id = await self.question_messages.pop(channel.id) - if msg_id is None: - log.debug(f"#{channel} ({channel.id}) doesn't have a message pinned.") - else: - await self.pin_wrapper(msg_id, channel, pin=False) - - async def wait_for_dormant_channel(self) -> discord.TextChannel: - """Wait for a dormant channel to become available in the queue and return it.""" - log.trace("Waiting for a dormant channel.") - - task = asyncio.create_task(self.channel_queue.get()) - self.queue_tasks.append(task) - channel = await task - - log.trace(f"Channel #{channel} ({channel.id}) finally retrieved from the queue.") - self.queue_tasks.remove(task) - - return channel - - -def validate_config() -> None: - """Raise a ValueError if the cog's config is invalid.""" - log.trace("Validating config.") - total = constants.HelpChannels.max_total_channels - available = constants.HelpChannels.max_available - - if total == 0 or available == 0: - raise ValueError("max_total_channels and max_available and must be greater than 0.") - - if total < available: - raise ValueError( - f"max_total_channels ({total}) must be greater than or equal to max_available " - f"({available})." - ) - - if total > MAX_CHANNELS_PER_CATEGORY: - raise ValueError( - f"max_total_channels ({total}) must be less than or equal to " - f"{MAX_CHANNELS_PER_CATEGORY} due to Discord's limit on channels per category." - ) - - -def setup(bot: Bot) -> None: - """Load the HelpChannels cog.""" - try: - validate_config() - except ValueError as e: - log.error(f"HelpChannels cog will not be loaded due to misconfiguration: {e}") - else: - bot.add_cog(HelpChannels(bot)) diff --git a/bot/cogs/info/__init__.py b/bot/cogs/info/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/info/doc.py b/bot/cogs/info/doc.py deleted file mode 100644 index 204cffb37..000000000 --- a/bot/cogs/info/doc.py +++ /dev/null @@ -1,511 +0,0 @@ -import asyncio -import functools -import logging -import re -import textwrap -from collections import OrderedDict -from contextlib import suppress -from types import SimpleNamespace -from typing import Any, Callable, Optional, Tuple - -import discord -from bs4 import BeautifulSoup -from bs4.element import PageElement, Tag -from discord.errors import NotFound -from discord.ext import commands -from markdownify import MarkdownConverter -from requests import ConnectTimeout, ConnectionError, HTTPError -from sphinx.ext import intersphinx -from urllib3.exceptions import ProtocolError - -from bot.bot import Bot -from bot.constants import MODERATION_ROLES, RedirectOutput -from bot.converters import ValidPythonIdentifier, ValidURL -from bot.decorators import with_role -from bot.pagination import LinePaginator - - -log = logging.getLogger(__name__) -logging.getLogger('urllib3').setLevel(logging.WARNING) - -# Since Intersphinx is intended to be used with Sphinx, -# we need to mock its configuration. -SPHINX_MOCK_APP = SimpleNamespace( - config=SimpleNamespace( - intersphinx_timeout=3, - tls_verify=True, - user_agent="python3:python-discord/bot:1.0.0" - ) -) - -NO_OVERRIDE_GROUPS = ( - "2to3fixer", - "token", - "label", - "pdbcommand", - "term", -) -NO_OVERRIDE_PACKAGES = ( - "python", -) - -SEARCH_END_TAG_ATTRS = ( - "data", - "function", - "class", - "exception", - "seealso", - "section", - "rubric", - "sphinxsidebar", -) -UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") -WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") - -FAILED_REQUEST_RETRY_AMOUNT = 3 -NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay - - -def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: - """ - LRU cache implementation for coroutines. - - Once the cache exceeds the maximum size, keys are deleted in FIFO order. - - An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. - """ - # Assign the cache to the function itself so we can clear it from outside. - async_cache.cache = OrderedDict() - - def decorator(function: Callable) -> Callable: - """Define the async_cache decorator.""" - @functools.wraps(function) - async def wrapper(*args) -> Any: - """Decorator wrapper for the caching logic.""" - key = ':'.join(args[arg_offset:]) - - value = async_cache.cache.get(key) - if value is None: - if len(async_cache.cache) > max_size: - async_cache.cache.popitem(last=False) - - async_cache.cache[key] = await function(*args) - return async_cache.cache[key] - return wrapper - return decorator - - -class DocMarkdownConverter(MarkdownConverter): - """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" - - def convert_code(self, el: PageElement, text: str) -> str: - """Undo `markdownify`s underscore escaping.""" - return f"`{text}`".replace('\\', '') - - def convert_pre(self, el: PageElement, text: str) -> str: - """Wrap any codeblocks in `py` for syntax highlighting.""" - code = ''.join(el.strings) - return f"```py\n{code}```" - - -def markdownify(html: str) -> DocMarkdownConverter: - """Create a DocMarkdownConverter object from the input html.""" - return DocMarkdownConverter(bullets='•').convert(html) - - -class InventoryURL(commands.Converter): - """ - Represents an Intersphinx inventory URL. - - This converter checks whether intersphinx accepts the given inventory URL, and raises - `BadArgument` if that is not the case. - - Otherwise, it simply passes through the given URL. - """ - - @staticmethod - async def convert(ctx: commands.Context, url: str) -> str: - """Convert url to Intersphinx inventory URL.""" - try: - intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) - except AttributeError: - raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") - except ConnectionError: - if url.startswith('https'): - raise commands.BadArgument( - f"Cannot establish a connection to `{url}`. Does it support HTTPS?" - ) - raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") - except ValueError: - raise commands.BadArgument( - f"Failed to read Intersphinx inventory from URL `{url}`. " - "Are you sure that it's a valid inventory file?" - ) - return url - - -class Doc(commands.Cog): - """A set of commands for querying & displaying documentation.""" - - def __init__(self, bot: Bot): - self.base_urls = {} - self.bot = bot - self.inventories = {} - self.renamed_symbols = set() - - self.bot.loop.create_task(self.init_refresh_inventory()) - - async def init_refresh_inventory(self) -> None: - """Refresh documentation inventory on cog initialization.""" - await self.bot.wait_until_guild_available() - await self.refresh_inventory() - - async def update_single( - self, package_name: str, base_url: str, inventory_url: str - ) -> None: - """ - Rebuild the inventory for a single package. - - Where: - * `package_name` is the package name to use, appears in the log - * `base_url` is the root documentation URL for the specified package, used to build - absolute paths that link to specific symbols - * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running - `intersphinx.fetch_inventory` in an executor on the bot's event loop - """ - self.base_urls[package_name] = base_url - - package = await self._fetch_inventory(inventory_url) - if not package: - return None - - for group, value in package.items(): - for symbol, (package_name, _version, relative_doc_url, _) in value.items(): - absolute_doc_url = base_url + relative_doc_url - - if symbol in self.inventories: - group_name = group.split(":")[1] - symbol_base_url = self.inventories[symbol].split("/", 3)[2] - if ( - group_name in NO_OVERRIDE_GROUPS - or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) - ): - - symbol = f"{group_name}.{symbol}" - # If renamed `symbol` already exists, add library name in front to differentiate between them. - if symbol in self.renamed_symbols: - # Split `package_name` because of packages like Pillow that have spaces in them. - symbol = f"{package_name.split()[0]}.{symbol}" - - self.inventories[symbol] = absolute_doc_url - self.renamed_symbols.add(symbol) - continue - - self.inventories[symbol] = absolute_doc_url - - log.trace(f"Fetched inventory for {package_name}.") - - async def refresh_inventory(self) -> None: - """Refresh internal documentation inventory.""" - log.debug("Refreshing documentation inventory...") - - # Clear the old base URLS and inventories to ensure - # that we start from a fresh local dataset. - # Also, reset the cache used for fetching documentation. - self.base_urls.clear() - self.inventories.clear() - self.renamed_symbols.clear() - async_cache.cache = OrderedDict() - - # Run all coroutines concurrently - since each of them performs a HTTP - # request, this speeds up fetching the inventory data heavily. - coros = [ - self.update_single( - package["package"], package["base_url"], package["inventory_url"] - ) for package in await self.bot.api_client.get('bot/documentation-links') - ] - await asyncio.gather(*coros) - - async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: - """ - Given a Python symbol, return its signature and description. - - The first tuple element is the signature of the given symbol as a markup-free string, and - the second tuple element is the description of the given symbol with HTML markup included. - - If the given symbol is a module, returns a tuple `(None, str)` - else if the symbol could not be found, returns `None`. - """ - url = self.inventories.get(symbol) - if url is None: - return None - - async with self.bot.http_session.get(url) as response: - html = await response.text(encoding='utf-8') - - # Find the signature header and parse the relevant parts. - symbol_id = url.split('#')[-1] - soup = BeautifulSoup(html, 'lxml') - symbol_heading = soup.find(id=symbol_id) - search_html = str(soup) - - if symbol_heading is None: - return None - - if symbol_id == f"module-{symbol}": - # Get page content from the module headerlink to the - # first tag that has its class in `SEARCH_END_TAG_ATTRS` - start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) - if start_tag is None: - return [], "" - - end_tag = start_tag.find_next(self._match_end_tag) - if end_tag is None: - return [], "" - - description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) - description_end_index = search_html.find(str(end_tag)) - description = search_html[description_start_index:description_end_index] - signatures = None - - else: - signatures = [] - description = str(symbol_heading.find_next_sibling("dd")) - description_pos = search_html.find(description) - # Get text of up to 3 signatures, remove unwanted symbols - for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): - signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) - if signature and search_html.find(str(element)) < description_pos: - signatures.append(signature) - - return signatures, description.replace('¶', '') - - @async_cache(arg_offset=1) - async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: - """ - Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. - - If the symbol is known, an Embed with documentation about it is returned. - """ - scraped_html = await self.get_symbol_html(symbol) - if scraped_html is None: - return None - - signatures = scraped_html[0] - permalink = self.inventories[symbol] - description = markdownify(scraped_html[1]) - - # Truncate the description of the embed to the last occurrence - # of a double newline (interpreted as a paragraph) before index 1000. - if len(description) > 1000: - shortened = description[:1000] - description_cutoff = shortened.rfind('\n\n', 100) - if description_cutoff == -1: - # Search the shortened version for cutoff points in decreasing desirability, - # cutoff at 1000 if none are found. - for string in (". ", ", ", ",", " "): - description_cutoff = shortened.rfind(string) - if description_cutoff != -1: - break - else: - description_cutoff = 1000 - description = description[:description_cutoff] - - # If there is an incomplete code block, cut it out - if description.count("```") % 2: - codeblock_start = description.rfind('```py') - description = description[:codeblock_start].rstrip() - description += f"... [read more]({permalink})" - - description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) - if signatures is None: - # If symbol is a module, don't show signature. - embed_description = description - - elif not signatures: - # It's some "meta-page", for example: - # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views - embed_description = "This appears to be a generic page not tied to a specific symbol." - - else: - embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) - embed_description += f"\n{description}" - - embed = discord.Embed( - title=f'`{symbol}`', - url=permalink, - description=embed_description - ) - # Show all symbols with the same name that were renamed in the footer. - embed.set_footer( - text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) - ) - return embed - - @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) - async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: - """Lookup documentation for Python symbols.""" - await ctx.invoke(self.get_command, symbol) - - @docs_group.command(name='get', aliases=('g',)) - async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: - """ - Return a documentation embed for a given symbol. - - If no symbol is given, return a list of all available inventories. - - Examples: - !docs - !docs aiohttp - !docs aiohttp.ClientSession - !docs get aiohttp.ClientSession - """ - if symbol is None: - inventory_embed = discord.Embed( - title=f"All inventories (`{len(self.base_urls)}` total)", - colour=discord.Colour.blue() - ) - - lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) - if self.base_urls: - await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) - - else: - inventory_embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=inventory_embed) - - else: - # Fetching documentation for a symbol (at least for the first time, since - # caching is used) takes quite some time, so let's send typing to indicate - # that we got the command, but are still working on it. - async with ctx.typing(): - doc_embed = await self.get_symbol_embed(symbol) - - if doc_embed is None: - error_embed = discord.Embed( - description=f"Sorry, I could not find any documentation for `{symbol}`.", - colour=discord.Colour.red() - ) - error_message = await ctx.send(embed=error_embed) - with suppress(NotFound): - await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) - await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) - else: - await ctx.send(embed=doc_embed) - - @docs_group.command(name='set', aliases=('s',)) - @with_role(*MODERATION_ROLES) - async def set_command( - self, ctx: commands.Context, package_name: ValidPythonIdentifier, - base_url: ValidURL, inventory_url: InventoryURL - ) -> None: - """ - Adds a new documentation metadata object to the site's database. - - The database will update the object, should an existing item with the specified `package_name` already exist. - - Example: - !docs set \ - python \ - https://docs.python.org/3/ \ - https://docs.python.org/3/objects.inv - """ - body = { - 'package': package_name, - 'base_url': base_url, - 'inventory_url': inventory_url - } - await self.bot.api_client.post('bot/documentation-links', json=body) - - log.info( - f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" - f"Package name: {package_name}\n" - f"Base url: {base_url}\n" - f"Inventory URL: {inventory_url}" - ) - - # Rebuilding the inventory can take some time, so lets send out a - # typing event to show that the Bot is still working. - async with ctx.typing(): - await self.refresh_inventory() - await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") - - @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) - @with_role(*MODERATION_ROLES) - async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: - """ - Removes the specified package from the database. - - Examples: - !docs delete aiohttp - """ - await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') - - async with ctx.typing(): - # Rebuild the inventory to ensure that everything - # that was from this package is properly deleted. - await self.refresh_inventory() - await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") - - @docs_group.command(name="refresh", aliases=("rfsh", "r")) - @with_role(*MODERATION_ROLES) - async def refresh_command(self, ctx: commands.Context) -> None: - """Refresh inventories and send differences to channel.""" - old_inventories = set(self.base_urls) - with ctx.typing(): - await self.refresh_inventory() - # Get differences of added and removed inventories - added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) - if added: - added = f"+ {added}" - - removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) - if removed: - removed = f"- {removed}" - - embed = discord.Embed( - title="Inventories refreshed", - description=f"```diff\n{added}\n{removed}```" if added or removed else "" - ) - await ctx.send(embed=embed) - - async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: - """Get and return inventory from `inventory_url`. If fetching fails, return None.""" - fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) - for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): - try: - package = await self.bot.loop.run_in_executor(None, fetch_func) - except ConnectTimeout: - log.error( - f"Fetching of inventory {inventory_url} timed out," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" - ) - except ProtocolError: - log.error( - f"Connection lost while fetching inventory {inventory_url}," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" - ) - except HTTPError as e: - log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") - return None - except ConnectionError: - log.error(f"Couldn't establish connection to inventory {inventory_url}.") - return None - else: - return package - log.error(f"Fetching of inventory {inventory_url} failed.") - return None - - @staticmethod - def _match_end_tag(tag: Tag) -> bool: - """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" - for attr in SEARCH_END_TAG_ATTRS: - if attr in tag.get("class", ()): - return True - - return tag.name == "table" - - -def setup(bot: Bot) -> None: - """Load the Doc cog.""" - bot.add_cog(Doc(bot)) diff --git a/bot/cogs/info/help.py b/bot/cogs/info/help.py deleted file mode 100644 index 3d1d6fd10..000000000 --- a/bot/cogs/info/help.py +++ /dev/null @@ -1,375 +0,0 @@ -import itertools -import logging -from asyncio import TimeoutError -from collections import namedtuple -from contextlib import suppress -from typing import List, Union - -from discord import Colour, Embed, Member, Message, NotFound, Reaction, User -from discord.ext.commands import Bot, Cog, Command, Context, Group, HelpCommand -from fuzzywuzzy import fuzz, process -from fuzzywuzzy.utils import full_process - -from bot import constants -from bot.constants import Channels, Emojis, STAFF_ROLES -from bot.decorators import redirect_output -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -COMMANDS_PER_PAGE = 8 -DELETE_EMOJI = Emojis.trashcan -PREFIX = constants.Bot.prefix - -Category = namedtuple("Category", ["name", "description", "cogs"]) - - -async def help_cleanup(bot: Bot, author: Member, message: Message) -> None: - """ - Runs the cleanup for the help command. - - Adds the :trashcan: reaction that, when clicked, will delete the help message. - After a 300 second timeout, the reaction will be removed. - """ - def check(reaction: Reaction, user: User) -> bool: - """Checks the reaction is :trashcan:, the author is original author and messages are the same.""" - return str(reaction) == DELETE_EMOJI and user.id == author.id and reaction.message.id == message.id - - await message.add_reaction(DELETE_EMOJI) - - with suppress(NotFound): - try: - await bot.wait_for("reaction_add", check=check, timeout=300) - await message.delete() - except TimeoutError: - await message.remove_reaction(DELETE_EMOJI, bot.user) - - -class HelpQueryNotFound(ValueError): - """ - Raised when a HelpSession Query doesn't match a command or cog. - - Contains the custom attribute of ``possible_matches``. - - Instances of this object contain a dictionary of any command(s) that were close to matching the - query, where keys are the possible matched command names and values are the likeness match scores. - """ - - def __init__(self, arg: str, possible_matches: dict = None): - super().__init__(arg) - self.possible_matches = possible_matches - - -class CustomHelpCommand(HelpCommand): - """ - An interactive instance for the bot help command. - - Cogs can be grouped into custom categories. All cogs with the same category will be displayed - under a single category name in the help output. Custom categories are defined inside the cogs - as a class attribute named `category`. A description can also be specified with the attribute - `category_description`. If a description is not found in at least one cog, the default will be - the regular description (class docstring) of the first cog found in the category. - """ - - def __init__(self): - super().__init__(command_attrs={"help": "Shows help for bot commands"}) - - @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) - async def command_callback(self, ctx: Context, *, command: str = None) -> None: - """Attempts to match the provided query with a valid command or cog.""" - # the only reason we need to tamper with this is because d.py does not support "categories", - # so we need to deal with them ourselves. - - bot = ctx.bot - - if command is None: - # quick and easy, send bot help if command is none - mapping = self.get_bot_mapping() - await self.send_bot_help(mapping) - return - - cog_matches = [] - description = None - for cog in bot.cogs.values(): - if hasattr(cog, "category") and cog.category == command: - cog_matches.append(cog) - if hasattr(cog, "category_description"): - description = cog.category_description - - if cog_matches: - category = Category(name=command, description=description, cogs=cog_matches) - await self.send_category_help(category) - return - - # it's either a cog, group, command or subcommand; let the parent class deal with it - await super().command_callback(ctx, command=command) - - async def get_all_help_choices(self) -> set: - """ - Get all the possible options for getting help in the bot. - - This will only display commands the author has permission to run. - - These include: - - Category names - - Cog names - - Group command names (and aliases) - - Command names (and aliases) - - Subcommand names (with parent group and aliases for subcommand, but not including aliases for group) - - Options and choices are case sensitive. - """ - # first get all commands including subcommands and full command name aliases - choices = set() - for command in await self.filter_commands(self.context.bot.walk_commands()): - # the the command or group name - choices.add(str(command)) - - if isinstance(command, Command): - # all aliases if it's just a command - choices.update(command.aliases) - else: - # otherwise we need to add the parent name in - choices.update(f"{command.full_parent_name} {alias}" for alias in command.aliases) - - # all cog names - choices.update(self.context.bot.cogs) - - # all category names - choices.update(cog.category for cog in self.context.bot.cogs.values() if hasattr(cog, "category")) - return choices - - async def command_not_found(self, string: str) -> "HelpQueryNotFound": - """ - Handles when a query does not match a valid command, group, cog or category. - - Will return an instance of the `HelpQueryNotFound` exception with the error message and possible matches. - """ - choices = await self.get_all_help_choices() - - # Run fuzzywuzzy's processor beforehand, and avoid matching if processed string is empty - # This avoids fuzzywuzzy from raising a warning on inputs with only non-alphanumeric characters - if (processed := full_process(string)): - result = process.extractBests(processed, choices, scorer=fuzz.ratio, score_cutoff=60, processor=None) - else: - result = [] - - return HelpQueryNotFound(f'Query "{string}" not found.', dict(result)) - - async def subcommand_not_found(self, command: Command, string: str) -> "HelpQueryNotFound": - """ - Redirects the error to `command_not_found`. - - `command_not_found` deals with searching and getting best choices for both commands and subcommands. - """ - return await self.command_not_found(f"{command.qualified_name} {string}") - - async def send_error_message(self, error: HelpQueryNotFound) -> None: - """Send the error message to the channel.""" - embed = Embed(colour=Colour.red(), title=str(error)) - - if getattr(error, "possible_matches", None): - matches = "\n".join(f"`{match}`" for match in error.possible_matches) - embed.description = f"**Did you mean:**\n{matches}" - - await self.context.send(embed=embed) - - async def command_formatting(self, command: Command) -> Embed: - """ - Takes a command and turns it into an embed. - - It will add an author, command signature + help, aliases and a note if the user can't run the command. - """ - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - parent = command.full_parent_name - - name = str(command) if not parent else f"{parent} {command.name}" - command_details = f"**```{PREFIX}{name} {command.signature}```**\n" - - # show command aliases - aliases = ", ".join(f"`{alias}`" if not parent else f"`{parent} {alias}`" for alias in command.aliases) - if aliases: - command_details += f"**Can also use:** {aliases}\n\n" - - # check if the user is allowed to run this command - if not await command.can_run(self.context): - command_details += "***You cannot run this command.***\n\n" - - command_details += f"*{command.help or 'No details provided.'}*\n" - embed.description = command_details - - return embed - - async def send_command_help(self, command: Command) -> None: - """Send help for a single command.""" - embed = await self.command_formatting(command) - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - @staticmethod - def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: - """ - Formats the prefix, command name and signature, and short doc for an iterable of commands. - - return_as_list is helpful for passing these command details into the paginator as a list of command details. - """ - details = [] - for command in commands_: - signature = f" {command.signature}" if command.signature else "" - details.append( - f"\n**`{PREFIX}{command.qualified_name}{signature}`**\n*{command.short_doc or 'No details provided'}*" - ) - if return_as_list: - return details - else: - return "".join(details) - - async def send_group_help(self, group: Group) -> None: - """Sends help for a group command.""" - subcommands = group.commands - - if len(subcommands) == 0: - # no subcommands, just treat it like a regular command - await self.send_command_help(group) - return - - # remove commands that the user can't run and are hidden, and sort by name - commands_ = await self.filter_commands(subcommands, sort=True) - - embed = await self.command_formatting(group) - - command_details = self.get_commands_brief_details(commands_) - if command_details: - embed.description += f"\n**Subcommands:**\n{command_details}" - - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - async def send_cog_help(self, cog: Cog) -> None: - """Send help for a cog.""" - # sort commands by name, and remove any the user cant run or are hidden. - commands_ = await self.filter_commands(cog.get_commands(), sort=True) - - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - embed.description = f"**{cog.qualified_name}**\n*{cog.description}*" - - command_details = self.get_commands_brief_details(commands_) - if command_details: - embed.description += f"\n\n**Commands:**\n{command_details}" - - message = await self.context.send(embed=embed) - await help_cleanup(self.context.bot, self.context.author, message) - - @staticmethod - def _category_key(command: Command) -> str: - """ - Returns a cog name of a given command for use as a key for `sorted` and `groupby`. - - A zero width space is used as a prefix for results with no cogs to force them last in ordering. - """ - if command.cog: - with suppress(AttributeError): - if command.cog.category: - return f"**{command.cog.category}**" - return f"**{command.cog_name}**" - else: - return "**\u200bNo Category:**" - - async def send_category_help(self, category: Category) -> None: - """ - Sends help for a bot category. - - This sends a brief help for all commands in all cogs registered to the category. - """ - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - all_commands = [] - for cog in category.cogs: - all_commands.extend(cog.get_commands()) - - filtered_commands = await self.filter_commands(all_commands, sort=True) - - command_detail_lines = self.get_commands_brief_details(filtered_commands, return_as_list=True) - description = f"**{category.name}**\n*{category.description}*" - - if command_detail_lines: - description += "\n\n**Commands:**" - - await LinePaginator.paginate( - command_detail_lines, - self.context, - embed, - prefix=description, - max_lines=COMMANDS_PER_PAGE, - max_size=2000, - ) - - async def send_bot_help(self, mapping: dict) -> None: - """Sends help for all bot commands and cogs.""" - bot = self.context.bot - - embed = Embed() - embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) - - filter_commands = await self.filter_commands(bot.commands, sort=True, key=self._category_key) - - cog_or_category_pages = [] - - for cog_or_category, _commands in itertools.groupby(filter_commands, key=self._category_key): - sorted_commands = sorted(_commands, key=lambda c: c.name) - - if len(sorted_commands) == 0: - continue - - command_detail_lines = self.get_commands_brief_details(sorted_commands, return_as_list=True) - - # Split cogs or categories which have too many commands to fit in one page. - # The length of commands is included for later use when aggregating into pages for the paginator. - for index in range(0, len(sorted_commands), COMMANDS_PER_PAGE): - truncated_lines = command_detail_lines[index:index + COMMANDS_PER_PAGE] - joined_lines = "".join(truncated_lines) - cog_or_category_pages.append((f"**{cog_or_category}**{joined_lines}", len(truncated_lines))) - - pages = [] - counter = 0 - page = "" - for page_details, length in cog_or_category_pages: - counter += length - if counter > COMMANDS_PER_PAGE: - # force a new page on paginator even if it falls short of the max pages - # since we still want to group categories/cogs. - counter = length - pages.append(page) - page = f"{page_details}\n\n" - else: - page += f"{page_details}\n\n" - - if page: - # add any remaining command help that didn't get added in the last iteration above. - pages.append(page) - - await LinePaginator.paginate(pages, self.context, embed=embed, max_lines=1, max_size=2000) - - -class Help(Cog): - """Custom Embed Pagination Help feature.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - self.old_help_command = bot.help_command - bot.help_command = CustomHelpCommand() - bot.help_command.cog = self - - def cog_unload(self) -> None: - """Reset the help command when the cog is unloaded.""" - self.bot.help_command = self.old_help_command - - -def setup(bot: Bot) -> None: - """Load the Help cog.""" - bot.add_cog(Help(bot)) - log.info("Cog loaded: Help") diff --git a/bot/cogs/info/information.py b/bot/cogs/info/information.py deleted file mode 100644 index 8982196d1..000000000 --- a/bot/cogs/info/information.py +++ /dev/null @@ -1,422 +0,0 @@ -import colorsys -import logging -import pprint -import textwrap -from collections import Counter, defaultdict -from string import Template -from typing import Any, Mapping, Optional, Union - -from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils -from discord.abc import GuildChannel -from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group -from discord.utils import escape_markdown - -from bot import constants -from bot.bot import Bot -from bot.decorators import in_whitelist, with_role -from bot.pagination import LinePaginator -from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - - -class Information(Cog): - """A cog with commands for generating embeds with server info, such as server stats and user info.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @staticmethod - def role_can_read(channel: GuildChannel, role: Role) -> bool: - """Return True if `role` can read messages in `channel`.""" - overwrites = channel.overwrites_for(role) - return overwrites.read_messages is True - - def get_staff_channel_count(self, guild: Guild) -> int: - """ - Get the number of channels that are staff-only. - - We need to know two things about a channel: - - Does the @everyone role have explicit read deny permissions? - - Do staff roles have explicit read allow permissions? - - If the answer to both of these questions is yes, it's a staff channel. - """ - channel_ids = set() - for channel in guild.channels: - if channel.type is ChannelType.category: - continue - - everyone_can_read = self.role_can_read(channel, guild.default_role) - - for role in constants.STAFF_ROLES: - role_can_read = self.role_can_read(channel, guild.get_role(role)) - if role_can_read and not everyone_can_read: - channel_ids.add(channel.id) - break - - return len(channel_ids) - - @staticmethod - def get_channel_type_counts(guild: Guild) -> str: - """Return the total amounts of the various types of channels in `guild`.""" - channel_counter = Counter(c.type for c in guild.channels) - channel_type_list = [] - for channel, count in channel_counter.items(): - channel_type = str(channel).title() - channel_type_list.append(f"{channel_type} channels: {count}") - - channel_type_list = sorted(channel_type_list) - return "\n".join(channel_type_list) - - @with_role(*constants.MODERATION_ROLES) - @command(name="roles") - async def roles_info(self, ctx: Context) -> None: - """Returns a list of all roles and their corresponding IDs.""" - # Sort the roles alphabetically and remove the @everyone role - roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) - - # Build a list - role_list = [] - for role in roles: - role_list.append(f"`{role.id}` - {role.mention}") - - # Build an embed - embed = Embed( - title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", - colour=Colour.blurple() - ) - - await LinePaginator.paginate(role_list, ctx, embed, empty=False) - - @with_role(*constants.MODERATION_ROLES) - @command(name="role") - async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: - """ - Return information on a role or list of roles. - - To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. - """ - parsed_roles = [] - failed_roles = [] - - for role_name in roles: - if isinstance(role_name, Role): - # Role conversion has already succeeded - parsed_roles.append(role_name) - continue - - role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) - - if not role: - failed_roles.append(role_name) - continue - - parsed_roles.append(role) - - if failed_roles: - await ctx.send(f":x: Could not retrieve the following roles: {', '.join(failed_roles)}") - - for role in parsed_roles: - h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) - - embed = Embed( - title=f"{role.name} info", - colour=role.colour, - ) - embed.add_field(name="ID", value=role.id, inline=True) - embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) - embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) - embed.add_field(name="Member count", value=len(role.members), inline=True) - embed.add_field(name="Position", value=role.position) - embed.add_field(name="Permission code", value=role.permissions.value, inline=True) - - await ctx.send(embed=embed) - - @command(name="server", aliases=["server_info", "guild", "guild_info"]) - async def server_info(self, ctx: Context) -> None: - """Returns an embed full of server information.""" - created = time_since(ctx.guild.created_at, precision="days") - features = ", ".join(ctx.guild.features) - region = ctx.guild.region - - roles = len(ctx.guild.roles) - member_count = ctx.guild.member_count - channel_counts = self.get_channel_type_counts(ctx.guild) - - # How many of each user status? - statuses = Counter(member.status for member in ctx.guild.members) - embed = Embed(colour=Colour.blurple()) - - # How many staff members and staff channels do we have? - staff_member_count = len(ctx.guild.get_role(constants.Roles.helpers).members) - staff_channel_count = self.get_staff_channel_count(ctx.guild) - - # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the - # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting - # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts - # after the dedent is made. - embed.description = Template( - textwrap.dedent(f""" - **Server information** - Created: {created} - Voice region: {region} - Features: {features} - - **Channel counts** - $channel_counts - Staff channels: {staff_channel_count} - - **Member counts** - Members: {member_count:,} - Staff members: {staff_member_count} - Roles: {roles} - - **Member statuses** - {constants.Emojis.status_online} {statuses[Status.online]:,} - {constants.Emojis.status_idle} {statuses[Status.idle]:,} - {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} - {constants.Emojis.status_offline} {statuses[Status.offline]:,} - """) - ).substitute({"channel_counts": channel_counts}) - embed.set_thumbnail(url=ctx.guild.icon_url) - - await ctx.send(embed=embed) - - @command(name="user", aliases=["user_info", "member", "member_info"]) - async def user_info(self, ctx: Context, user: Member = None) -> None: - """Returns info about a user.""" - if user is None: - user = ctx.author - - # Do a role check if this is being executed on someone other than the caller - elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): - await ctx.send("You may not use this command on users other than yourself.") - return - - # Non-staff may only do this in #bot-commands - if not with_role_check(ctx, *constants.STAFF_ROLES): - if not ctx.channel.id == constants.Channels.bot_commands: - raise InWhitelistCheckFailure(constants.Channels.bot_commands) - - embed = await self.create_user_embed(ctx, user) - - await ctx.send(embed=embed) - - async def create_user_embed(self, ctx: Context, user: Member) -> Embed: - """Creates an embed containing information on the `user`.""" - created = time_since(user.created_at, max_units=3) - - # Custom status - custom_status = '' - for activity in user.activities: - # Check activity.state for None value if user has a custom status set - # This guards against a custom status with an emoji but no text, which will cause - # escape_markdown to raise an exception - # This can be reworked after a move to d.py 1.3.0+, which adds a CustomActivity class - if activity.name == 'Custom Status' and activity.state: - state = escape_markdown(activity.state) - custom_status = f'Status: {state}\n' - - name = str(user) - if user.nick: - name = f"{user.nick} ({name})" - - joined = time_since(user.joined_at, max_units=3) - roles = ", ".join(role.mention for role in user.roles[1:]) - - description = [ - textwrap.dedent(f""" - **User Information** - Created: {created} - Profile: {user.mention} - ID: {user.id} - {custom_status} - **Member Information** - Joined: {joined} - Roles: {roles or None} - """).strip() - ] - - # Show more verbose output in moderation channels for infractions and nominations - if ctx.channel.id in constants.MODERATION_CHANNELS: - description.append(await self.expanded_user_infraction_counts(user)) - description.append(await self.user_nomination_counts(user)) - else: - description.append(await self.basic_user_infraction_counts(user)) - - # Let's build the embed now - embed = Embed( - title=name, - description="\n\n".join(description) - ) - - embed.set_thumbnail(url=user.avatar_url_as(static_format="png")) - embed.colour = user.top_role.colour if roles else Colour.blurple() - - return embed - - async def basic_user_infraction_counts(self, member: Member) -> str: - """Gets the total and active infraction counts for the given `member`.""" - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'hidden': 'False', - 'user__id': str(member.id) - } - ) - - total_infractions = len(infractions) - active_infractions = sum(infraction['active'] for infraction in infractions) - - infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}" - - return infraction_output - - async def expanded_user_infraction_counts(self, member: Member) -> str: - """ - Gets expanded infraction counts for the given `member`. - - The counts will be split by infraction type and the number of active infractions for each type will indicated - in the output as well. - """ - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'user__id': str(member.id) - } - ) - - infraction_output = ["**Infractions**"] - if not infractions: - infraction_output.append("This user has never received an infraction.") - else: - # Count infractions split by `type` and `active` status for this user - infraction_types = set() - infraction_counter = defaultdict(int) - for infraction in infractions: - infraction_type = infraction["type"] - infraction_active = 'active' if infraction["active"] else 'inactive' - - infraction_types.add(infraction_type) - infraction_counter[f"{infraction_active} {infraction_type}"] += 1 - - # Format the output of the infraction counts - for infraction_type in sorted(infraction_types): - active_count = infraction_counter[f"active {infraction_type}"] - total_count = active_count + infraction_counter[f"inactive {infraction_type}"] - - line = f"{infraction_type.capitalize()}s: {total_count}" - if active_count: - line += f" ({active_count} active)" - - infraction_output.append(line) - - return "\n".join(infraction_output) - - async def user_nomination_counts(self, member: Member) -> str: - """Gets the active and historical nomination counts for the given `member`.""" - nominations = await self.bot.api_client.get( - 'bot/nominations', - params={ - 'user__id': str(member.id) - } - ) - - output = ["**Nominations**"] - - if not nominations: - output.append("This user has never been nominated.") - else: - count = len(nominations) - is_currently_nominated = any(nomination["active"] for nomination in nominations) - nomination_noun = "nomination" if count == 1 else "nominations" - - if is_currently_nominated: - output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).") - else: - output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.") - - return "\n".join(output) - - def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str: - """Format a mapping to be readable to a human.""" - # sorting is technically superfluous but nice if you want to look for a specific field - fields = sorted(mapping.items(), key=lambda item: item[0]) - - if field_width is None: - field_width = len(max(mapping.keys(), key=len)) - - out = '' - - for key, val in fields: - if isinstance(val, dict): - # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries - inner_width = int(field_width * 1.6) - val = '\n' + self.format_fields(val, field_width=inner_width) - - elif isinstance(val, str): - # split up text since it might be long - text = textwrap.fill(val, width=100, replace_whitespace=False) - - # indent it, I guess you could do this with `wrap` and `join` but this is nicer - val = textwrap.indent(text, ' ' * (field_width + len(': '))) - - # the first line is already indented so we `str.lstrip` it - val = val.lstrip() - - if key == 'color': - # makes the base 10 representation of a hex number readable to humans - val = hex(val) - - out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width) - - # remove trailing whitespace - return out.rstrip() - - @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) - @group(invoke_without_command=True) - @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) - async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: - """Shows information about the raw API response.""" - # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling - # doing this extra request is also much easier than trying to convert everything back into a dictionary again - raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) - - paginator = Paginator() - - def add_content(title: str, content: str) -> None: - paginator.add_line(f'== {title} ==\n') - # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution. - # we hope it's not close to 2000 - paginator.add_line(content.replace('```', '`` `')) - paginator.close_page() - - if message.content: - add_content('Raw message', message.content) - - transformer = pprint.pformat if json else self.format_fields - for field_name in ('embeds', 'attachments'): - data = raw_data[field_name] - - if not data: - continue - - total = len(data) - for current, item in enumerate(data, start=1): - title = f'Raw {field_name} ({current}/{total})' - add_content(title, transformer(item)) - - for page in paginator.pages: - await ctx.send(page) - - @raw.command() - async def json(self, ctx: Context, message: Message) -> None: - """Shows information about the raw API response in a copy-pasteable Python format.""" - await ctx.invoke(self.raw, message=message, json=True) - - -def setup(bot: Bot) -> None: - """Load the Information cog.""" - bot.add_cog(Information(bot)) diff --git a/bot/cogs/info/python_news.py b/bot/cogs/info/python_news.py deleted file mode 100644 index 0ab5738a4..000000000 --- a/bot/cogs/info/python_news.py +++ /dev/null @@ -1,232 +0,0 @@ -import logging -import typing as t -from datetime import date, datetime - -import discord -import feedparser -from bs4 import BeautifulSoup -from discord.ext.commands import Cog -from discord.ext.tasks import loop - -from bot import constants -from bot.bot import Bot -from bot.utils.webhooks import send_webhook - -PEPS_RSS_URL = "https://www.python.org/dev/peps/peps.rss/" - -RECENT_THREADS_TEMPLATE = "https://mail.python.org/archives/list/{name}@python.org/recent-threads" -THREAD_TEMPLATE_URL = "https://mail.python.org/archives/api/list/{name}@python.org/thread/{id}/" -MAILMAN_PROFILE_URL = "https://mail.python.org/archives/users/{id}/" -THREAD_URL = "https://mail.python.org/archives/list/{list}@python.org/thread/{id}/" - -AVATAR_URL = "https://www.python.org/static/opengraph-icon-200x200.png" - -log = logging.getLogger(__name__) - - -class PythonNews(Cog): - """Post new PEPs and Python News to `#python-news`.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_names = {} - self.webhook: t.Optional[discord.Webhook] = None - - self.bot.loop.create_task(self.get_webhook_names()) - self.bot.loop.create_task(self.get_webhook_and_channel()) - - async def start_tasks(self) -> None: - """Start the tasks for fetching new PEPs and mailing list messages.""" - self.fetch_new_media.start() - - @loop(minutes=20) - async def fetch_new_media(self) -> None: - """Fetch new mailing list messages and then new PEPs.""" - await self.post_maillist_news() - await self.post_pep_news() - - async def sync_maillists(self) -> None: - """Sync currently in-use maillists with API.""" - # Wait until guild is available to avoid running before everything is ready - await self.bot.wait_until_guild_available() - - response = await self.bot.api_client.get("bot/bot-settings/news") - for mail in constants.PythonNews.mail_lists: - if mail not in response["data"]: - response["data"][mail] = [] - - # Because we are handling PEPs differently, we don't include it to mail lists - if "pep" not in response["data"]: - response["data"]["pep"] = [] - - await self.bot.api_client.put("bot/bot-settings/news", json=response) - - async def get_webhook_names(self) -> None: - """Get webhook author names from maillist API.""" - await self.bot.wait_until_guild_available() - - async with self.bot.http_session.get("https://mail.python.org/archives/api/lists") as resp: - lists = await resp.json() - - for mail in lists: - if mail["name"].split("@")[0] in constants.PythonNews.mail_lists: - self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] - - async def post_pep_news(self) -> None: - """Fetch new PEPs and when they don't have announcement in #python-news, create it.""" - # Wait until everything is ready and http_session available - await self.bot.wait_until_guild_available() - await self.sync_maillists() - - async with self.bot.http_session.get(PEPS_RSS_URL) as resp: - data = feedparser.parse(await resp.text("utf-8")) - - news_listing = await self.bot.api_client.get("bot/bot-settings/news") - payload = news_listing.copy() - pep_numbers = news_listing["data"]["pep"] - - # Reverse entries to send oldest first - data["entries"].reverse() - for new in data["entries"]: - try: - new_datetime = datetime.strptime(new["published"], "%a, %d %b %Y %X %Z") - except ValueError: - log.warning(f"Wrong datetime format passed in PEP new: {new['published']}") - continue - pep_nr = new["title"].split(":")[0].split()[1] - if ( - pep_nr in pep_numbers - or new_datetime.date() < date.today() - ): - continue - - # Build an embed and send a webhook - embed = discord.Embed( - title=new["title"], - description=new["summary"], - timestamp=new_datetime, - url=new["link"], - colour=constants.Colours.soft_green - ) - embed.set_footer(text=data["feed"]["title"], icon_url=AVATAR_URL) - msg = await send_webhook( - webhook=self.webhook, - username=data["feed"]["title"], - embed=embed, - avatar_url=AVATAR_URL, - wait=True, - ) - payload["data"]["pep"].append(pep_nr) - - # Increase overall PEP new stat - self.bot.stats.incr("python_news.posted.pep") - - if msg.channel.is_news(): - log.trace("Publishing PEP annnouncement because it was in a news channel") - await msg.publish() - - # Apply new sent news to DB to avoid duplicate sending - await self.bot.api_client.put("bot/bot-settings/news", json=payload) - - async def post_maillist_news(self) -> None: - """Send new maillist threads to #python-news that is listed in configuration.""" - await self.bot.wait_until_guild_available() - await self.sync_maillists() - existing_news = await self.bot.api_client.get("bot/bot-settings/news") - payload = existing_news.copy() - - for maillist in constants.PythonNews.mail_lists: - async with self.bot.http_session.get(RECENT_THREADS_TEMPLATE.format(name=maillist)) as resp: - recents = BeautifulSoup(await resp.text(), features="lxml") - - # When a

element is present in the response then the mailing list - # has not had any activity during the current month, so therefore it - # can be ignored. - if recents.p: - continue - - for thread in recents.html.body.div.find_all("a", href=True): - # We want only these threads that have identifiers - if "latest" in thread["href"]: - continue - - thread_information, email_information = await self.get_thread_and_first_mail( - maillist, thread["href"].split("/")[-2] - ) - - try: - new_date = datetime.strptime(email_information["date"], "%Y-%m-%dT%X%z") - except ValueError: - log.warning(f"Invalid datetime from Thread email: {email_information['date']}") - continue - - if ( - thread_information["thread_id"] in existing_news["data"][maillist] - or 'Re: ' in thread_information["subject"] - or new_date.date() < date.today() - ): - continue - - content = email_information["content"] - link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) - - # Build an embed and send a message to the webhook - embed = discord.Embed( - title=thread_information["subject"], - description=content[:500] + f"... [continue reading]({link})" if len(content) > 500 else content, - timestamp=new_date, - url=link, - colour=constants.Colours.soft_green - ) - embed.set_author( - name=f"{email_information['sender_name']} ({email_information['sender']['address']})", - url=MAILMAN_PROFILE_URL.format(id=email_information["sender"]["mailman_id"]), - ) - embed.set_footer( - text=f"Posted to {self.webhook_names[maillist]}", - icon_url=AVATAR_URL, - ) - msg = await send_webhook( - webhook=self.webhook, - username=self.webhook_names[maillist], - embed=embed, - avatar_url=AVATAR_URL, - wait=True, - ) - payload["data"][maillist].append(thread_information["thread_id"]) - - # Increase this specific maillist counter in stats - self.bot.stats.incr(f"python_news.posted.{maillist.replace('-', '_')}") - - if msg.channel.is_news(): - log.trace("Publishing mailing list message because it was in a news channel") - await msg.publish() - - await self.bot.api_client.put("bot/bot-settings/news", json=payload) - - async def get_thread_and_first_mail(self, maillist: str, thread_identifier: str) -> t.Tuple[t.Any, t.Any]: - """Get mail thread and first mail from mail.python.org based on `maillist` and `thread_identifier`.""" - async with self.bot.http_session.get( - THREAD_TEMPLATE_URL.format(name=maillist, id=thread_identifier) - ) as resp: - thread_information = await resp.json() - - async with self.bot.http_session.get(thread_information["starting_email"]) as resp: - email_information = await resp.json() - return thread_information, email_information - - async def get_webhook_and_channel(self) -> None: - """Storage #python-news channel Webhook and `TextChannel` to `News.webhook` and `channel`.""" - await self.bot.wait_until_guild_available() - self.webhook = await self.bot.fetch_webhook(constants.PythonNews.webhook) - - await self.start_tasks() - - def cog_unload(self) -> None: - """Stop news posting tasks on cog unload.""" - self.fetch_new_media.cancel() - - -def setup(bot: Bot) -> None: - """Add `News` cog.""" - bot.add_cog(PythonNews(bot)) diff --git a/bot/cogs/info/reddit.py b/bot/cogs/info/reddit.py deleted file mode 100644 index d853ab2ea..000000000 --- a/bot/cogs/info/reddit.py +++ /dev/null @@ -1,304 +0,0 @@ -import asyncio -import logging -import random -import textwrap -from collections import namedtuple -from datetime import datetime, timedelta -from typing import List - -from aiohttp import BasicAuth, ClientError -from discord import Colour, Embed, TextChannel -from discord.ext.commands import Cog, Context, group -from discord.ext.tasks import loop - -from bot.bot import Bot -from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks -from bot.converters import Subreddit -from bot.decorators import with_role -from bot.pagination import LinePaginator -from bot.utils.messages import sub_clyde - -log = logging.getLogger(__name__) - -AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) - - -class Reddit(Cog): - """Track subreddit posts and show detailed statistics about them.""" - - HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} - URL = "https://www.reddit.com" - OAUTH_URL = "https://oauth.reddit.com" - MAX_RETRIES = 3 - - def __init__(self, bot: Bot): - self.bot = bot - - self.webhook = None - self.access_token = None - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - - bot.loop.create_task(self.init_reddit_ready()) - self.auto_poster_loop.start() - - def cog_unload(self) -> None: - """Stop the loop task and revoke the access token when the cog is unloaded.""" - self.auto_poster_loop.cancel() - if self.access_token and self.access_token.expires_at > datetime.utcnow(): - asyncio.create_task(self.revoke_access_token()) - - async def init_reddit_ready(self) -> None: - """Sets the reddit webhook when the cog is loaded.""" - await self.bot.wait_until_guild_available() - if not self.webhook: - self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) - - @property - def channel(self) -> TextChannel: - """Get the #reddit channel object from the bot's cache.""" - return self.bot.get_channel(Channels.reddit) - - async def get_access_token(self) -> None: - """ - Get a Reddit API OAuth2 access token and assign it to self.access_token. - - A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded and a ClientError raised if retrieval was still unsuccessful. - """ - for i in range(1, self.MAX_RETRIES + 1): - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "grant_type": "client_credentials", - "duration": "temporary" - } - ) - - if response.status == 200 and response.content_type == "application/json": - content = await response.json() - expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - self.access_token = AccessToken( - token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(seconds=expiration) - ) - - log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") - return - else: - log.debug( - f"Failed to get an access token: " - f"status {response.status} & content type {response.content_type}; " - f"retrying ({i}/{self.MAX_RETRIES})" - ) - - await asyncio.sleep(3) - - self.bot.remove_cog(self.qualified_name) - raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") - - async def revoke_access_token(self) -> None: - """ - Revoke the OAuth2 access token for the Reddit API. - - For security reasons, it's good practice to revoke the token when it's no longer being used. - """ - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/revoke_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "token": self.access_token.token, - "token_type_hint": "access_token" - } - ) - - if response.status == 204 and response.content_type == "application/json": - self.access_token = None - else: - log.warning(f"Unable to revoke access token: status {response.status}.") - - async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: - """A helper method to fetch a certain amount of Reddit posts at a given route.""" - # Reddit's JSON responses only provide 25 posts at most. - if not 25 >= amount > 0: - raise ValueError("Invalid amount of subreddit posts requested.") - - # Renew the token if necessary. - if not self.access_token or self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() - - url = f"{self.OAUTH_URL}/{route}" - for _ in range(self.MAX_RETRIES): - response = await self.bot.http_session.get( - url=url, - headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, - params=params - ) - if response.status == 200 and response.content_type == 'application/json': - # Got appropriate response - process and return. - content = await response.json() - posts = content["data"]["children"] - return posts[:amount] - - await asyncio.sleep(3) - - log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") - return list() # Failed to get appropriate response within allowed number of retries. - - async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: - """ - Get the top amount of posts for a given subreddit within a specified timeframe. - - A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top - weekly posts. - - The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. - """ - embed = Embed(description="") - - posts = await self.fetch_posts( - route=f"{subreddit}/top", - amount=amount, - params={"t": time} - ) - - if not posts: - embed.title = random.choice(ERROR_REPLIES) - embed.colour = Colour.red() - embed.description = ( - "Sorry! We couldn't find any posts from that subreddit. " - "If this problem persists, please let us know." - ) - - return embed - - for post in posts: - data = post["data"] - - text = data["selftext"] - if text: - text = textwrap.shorten(text, width=128, placeholder="...") - text += "\n" # Add newline to separate embed info - - ups = data["ups"] - comments = data["num_comments"] - author = data["author"] - - title = textwrap.shorten(data["title"], width=64, placeholder="...") - link = self.URL + data["permalink"] - - embed.description += ( - f"**[{title}]({link})**\n" - f"{text}" - f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" - ) - - embed.colour = Colour.blurple() - return embed - - @loop() - async def auto_poster_loop(self) -> None: - """Post the top 5 posts daily, and the top 5 posts weekly.""" - # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter - now = datetime.utcnow() - tomorrow = now + timedelta(days=1) - midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) - seconds_until = (midnight_tomorrow - now).total_seconds() - - await asyncio.sleep(seconds_until) - - await self.bot.wait_until_guild_available() - if not self.webhook: - await self.bot.fetch_webhook(Webhooks.reddit) - - if datetime.utcnow().weekday() == 0: - await self.top_weekly_posts() - # if it's a monday send the top weekly posts - - for subreddit in RedditConfig.subreddits: - top_posts = await self.get_top_posts(subreddit=subreddit, time="day") - username = sub_clyde(f"{subreddit} Top Daily Posts") - message = await self.webhook.send(username=username, embed=top_posts, wait=True) - - if message.channel.is_news(): - await message.publish() - - async def top_weekly_posts(self) -> None: - """Post a summary of the top posts.""" - for subreddit in RedditConfig.subreddits: - # Send and pin the new weekly posts. - top_posts = await self.get_top_posts(subreddit=subreddit, time="week") - username = sub_clyde(f"{subreddit} Top Weekly Posts") - message = await self.webhook.send(wait=True, username=username, embed=top_posts) - - if subreddit.lower() == "r/python": - if not self.channel: - log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") - return - - # Remove the oldest pins so that only 12 remain at most. - pins = await self.channel.pins() - - while len(pins) >= 12: - await pins[-1].unpin() - del pins[-1] - - await message.pin() - - if message.channel.is_news(): - await message.publish() - - @group(name="reddit", invoke_without_command=True) - async def reddit_group(self, ctx: Context) -> None: - """View the top posts from various subreddits.""" - await ctx.send_help(ctx.command) - - @reddit_group.command(name="top") - async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of all time from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="all") - - await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - - @reddit_group.command(name="daily") - async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of today from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="day") - - await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - - @reddit_group.command(name="weekly") - async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of this week from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="week") - - await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - - @with_role(*STAFF_ROLES) - @reddit_group.command(name="subreddits", aliases=("subs",)) - async def subreddits_command(self, ctx: Context) -> None: - """Send a paginated embed of all the subreddits we're relaying.""" - embed = Embed() - embed.title = "Relayed subreddits." - embed.colour = Colour.blurple() - - await LinePaginator.paginate( - RedditConfig.subreddits, - ctx, embed, - footer_text="Use the reddit commands along with these to view their posts.", - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Reddit cog.""" - if not RedditConfig.secret or not RedditConfig.client_id: - log.error("Credentials not provided, cog not loaded.") - return - bot.add_cog(Reddit(bot)) diff --git a/bot/cogs/info/site.py b/bot/cogs/info/site.py deleted file mode 100644 index ac29daa1d..000000000 --- a/bot/cogs/info/site.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging - -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import URLs -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" - - -class Site(Cog): - """Commands for linking to different parts of the site.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="site", aliases=("s",), invoke_without_command=True) - async def site_group(self, ctx: Context) -> None: - """Commands for getting info about our website.""" - await ctx.send_help(ctx.command) - - @site_group.command(name="home", aliases=("about",)) - async def site_main(self, ctx: Context) -> None: - """Info about the website itself.""" - url = f"{URLs.site_schema}{URLs.site}/" - - embed = Embed(title="Python Discord website") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - f"[Our official website]({url}) is an open-source community project " - "created with Python and Django. It contains information about the server " - "itself, lets you sign up for upcoming events, has its own wiki, contains " - "a list of valuable learning resources, and much more." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="resources") - async def site_resources(self, ctx: Context) -> None: - """Info about the site's Resources page.""" - learning_url = f"{PAGES_URL}/resources" - - embed = Embed(title="Resources") - embed.set_footer(text=f"{learning_url}") - embed.colour = Colour.blurple() - embed.description = ( - f"The [Resources page]({learning_url}) on our website contains a " - "list of hand-selected learning resources that we regularly recommend " - f"to both beginners and experts." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="tools") - async def site_tools(self, ctx: Context) -> None: - """Info about the site's Tools page.""" - tools_url = f"{PAGES_URL}/resources/tools" - - embed = Embed(title="Tools") - embed.set_footer(text=f"{tools_url}") - embed.colour = Colour.blurple() - embed.description = ( - f"The [Tools page]({tools_url}) on our website contains a " - f"couple of the most popular tools for programming in Python." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="help") - async def site_help(self, ctx: Context) -> None: - """Info about the site's Getting Help page.""" - url = f"{PAGES_URL}/resources/guides/asking-good-questions" - - embed = Embed(title="Asking Good Questions") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - "Asking the right question about something that's new to you can sometimes be tricky. " - f"To help with this, we've created a [guide to asking good questions]({url}) on our website. " - "It contains everything you need to get the very best help from our community." - ) - - await ctx.send(embed=embed) - - @site_group.command(name="faq") - async def site_faq(self, ctx: Context) -> None: - """Info about the site's FAQ page.""" - url = f"{PAGES_URL}/frequently-asked-questions" - - embed = Embed(title="FAQ") - embed.set_footer(text=url) - embed.colour = Colour.blurple() - embed.description = ( - "As the largest Python community on Discord, we get hundreds of questions every day. " - "Many of these questions have been asked before. We've compiled a list of the most " - "frequently asked questions along with their answers, which can be found on " - f"our [FAQ page]({url})." - ) - - await ctx.send(embed=embed) - - @site_group.command(aliases=['r', 'rule'], name='rules') - async def site_rules(self, ctx: Context, *rules: int) -> None: - """Provides a link to all rules or, if specified, displays specific rule(s).""" - rules_embed = Embed(title='Rules', color=Colour.blurple()) - rules_embed.url = f"{PAGES_URL}/rules" - - if not rules: - # Rules were not submitted. Return the default description. - rules_embed.description = ( - "The rules and guidelines that apply to this community can be found on" - f" our [rules page]({PAGES_URL}/rules). We expect" - " all members of the community to have read and understood these." - ) - - await ctx.send(embed=rules_embed) - return - - full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) - invalid_indices = tuple( - pick - for pick in rules - if pick < 1 or pick > len(full_rules) - ) - - if invalid_indices: - indices = ', '.join(map(str, invalid_indices)) - await ctx.send(f":x: Invalid rule indices: {indices}") - return - - for rule in rules: - self.bot.stats.incr(f"rule_uses.{rule}") - - final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) - - await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) - - -def setup(bot: Bot) -> None: - """Load the Site cog.""" - bot.add_cog(Site(bot)) diff --git a/bot/cogs/info/source.py b/bot/cogs/info/source.py deleted file mode 100644 index 205e0ba81..000000000 --- a/bot/cogs/info/source.py +++ /dev/null @@ -1,141 +0,0 @@ -import inspect -from pathlib import Path -from typing import Optional, Tuple, Union - -from discord import Embed -from discord.ext import commands - -from bot.bot import Bot -from bot.constants import URLs - -SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] - - -class SourceConverter(commands.Converter): - """Convert an argument into a help command, tag, command, or cog.""" - - async def convert(self, ctx: commands.Context, argument: str) -> SourceType: - """Convert argument into source object.""" - if argument.lower().startswith("help"): - return ctx.bot.help_command - - cog = ctx.bot.get_cog(argument) - if cog: - return cog - - cmd = ctx.bot.get_command(argument) - if cmd: - return cmd - - tags_cog = ctx.bot.get_cog("Tags") - show_tag = True - - if not tags_cog: - show_tag = False - elif argument.lower() in tags_cog._cache: - return argument.lower() - - raise commands.BadArgument( - f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog." - ) - - -class BotSource(commands.Cog): - """Displays information about the bot's source code.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @commands.command(name="source", aliases=("src",)) - async def source_command(self, ctx: commands.Context, *, source_item: SourceConverter = None) -> None: - """Display information and a GitHub link to the source code of a command, tag, or cog.""" - if not source_item: - embed = Embed(title="Bot's GitHub Repository") - embed.add_field(name="Repository", value=f"[Go to GitHub]({URLs.github_bot_repo})") - embed.set_thumbnail(url="https://avatars1.githubusercontent.com/u/9919") - await ctx.send(embed=embed) - return - - embed = await self.build_embed(source_item) - await ctx.send(embed=embed) - - def get_source_link(self, source_item: SourceType) -> Tuple[str, str, Optional[int]]: - """ - Build GitHub link of source item, return this link, file location and first line number. - - Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). - """ - if isinstance(source_item, commands.Command): - if source_item.cog_name == "Alias": - cmd_name = source_item.callback.__name__.replace("_alias", "") - cmd = self.bot.get_command(cmd_name.replace("_", " ")) - src = cmd.callback.__code__ - filename = src.co_filename - else: - src = source_item.callback.__code__ - filename = src.co_filename - elif isinstance(source_item, str): - tags_cog = self.bot.get_cog("Tags") - filename = tags_cog._cache[source_item]["location"] - else: - src = type(source_item) - try: - filename = inspect.getsourcefile(src) - except TypeError: - raise commands.BadArgument("Cannot get source for a dynamically-created object.") - - if not isinstance(source_item, str): - try: - lines, first_line_no = inspect.getsourcelines(src) - except OSError: - raise commands.BadArgument("Cannot get source for a dynamically-created object.") - - lines_extension = f"#L{first_line_no}-L{first_line_no+len(lines)-1}" - else: - first_line_no = None - lines_extension = "" - - # Handle tag file location differently than others to avoid errors in some cases - if not first_line_no: - file_location = Path(filename).relative_to("/bot/") - else: - file_location = Path(filename).relative_to(Path.cwd()).as_posix() - - url = f"{URLs.github_bot_repo}/blob/master/{file_location}{lines_extension}" - - return url, file_location, first_line_no or None - - async def build_embed(self, source_object: SourceType) -> Optional[Embed]: - """Build embed based on source object.""" - url, location, first_line = self.get_source_link(source_object) - - if isinstance(source_object, commands.HelpCommand): - title = "Help Command" - description = source_object.__doc__.splitlines()[1] - elif isinstance(source_object, commands.Command): - if source_object.cog_name == "Alias": - cmd_name = source_object.callback.__name__.replace("_alias", "") - cmd = self.bot.get_command(cmd_name.replace("_", " ")) - description = cmd.short_doc - else: - description = source_object.short_doc - - title = f"Command: {source_object.qualified_name}" - elif isinstance(source_object, str): - title = f"Tag: {source_object}" - description = "" - else: - title = f"Cog: {source_object.qualified_name}" - description = source_object.description.splitlines()[0] - - embed = Embed(title=title, description=description) - embed.add_field(name="Source Code", value=f"[Go to GitHub]({url})") - line_text = f":{first_line}" if first_line else "" - embed.set_footer(text=f"{location}{line_text}") - - return embed - - -def setup(bot: Bot) -> None: - """Load the BotSource cog.""" - bot.add_cog(BotSource(bot)) diff --git a/bot/cogs/info/stats.py b/bot/cogs/info/stats.py deleted file mode 100644 index d42f55466..000000000 --- a/bot/cogs/info/stats.py +++ /dev/null @@ -1,129 +0,0 @@ -import string -from datetime import datetime - -from discord import Member, Message, Status -from discord.ext.commands import Cog, Context -from discord.ext.tasks import loop - -from bot.bot import Bot -from bot.constants import Categories, Channels, Guild, Stats as StatConf - - -CHANNEL_NAME_OVERRIDES = { - Channels.off_topic_0: "off_topic_0", - Channels.off_topic_1: "off_topic_1", - Channels.off_topic_2: "off_topic_2", - Channels.staff_lounge: "staff_lounge" -} - -ALLOWED_CHARS = string.ascii_letters + string.digits + "_" - - -class Stats(Cog): - """A cog which provides a way to hook onto Discord events and forward to stats.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.last_presence_update = None - self.update_guild_boost.start() - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Report message events in the server to statsd.""" - if message.guild is None: - return - - if message.guild.id != Guild.id: - return - - cat = getattr(message.channel, "category", None) - if cat is not None and cat.id == Categories.modmail: - if message.channel.id != Channels.incidents: - # Do not report modmail channels to stats, there are too many - # of them for interesting statistics to be drawn out of this. - return - - reformatted_name = message.channel.name.replace('-', '_') - - if CHANNEL_NAME_OVERRIDES.get(message.channel.id): - reformatted_name = CHANNEL_NAME_OVERRIDES.get(message.channel.id) - - reformatted_name = "".join(char for char in reformatted_name if char in ALLOWED_CHARS) - - stat_name = f"channels.{reformatted_name}" - self.bot.stats.incr(stat_name) - - # Increment the total message count - self.bot.stats.incr("messages") - - @Cog.listener() - async def on_command_completion(self, ctx: Context) -> None: - """Report completed commands to statsd.""" - command_name = ctx.command.qualified_name.replace(" ", "_") - - self.bot.stats.incr(f"commands.{command_name}") - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Update member count stat on member join.""" - if member.guild.id != Guild.id: - return - - self.bot.stats.gauge("guild.total_members", len(member.guild.members)) - - @Cog.listener() - async def on_member_leave(self, member: Member) -> None: - """Update member count stat on member leave.""" - if member.guild.id != Guild.id: - return - - self.bot.stats.gauge("guild.total_members", len(member.guild.members)) - - @Cog.listener() - async def on_member_update(self, _before: Member, after: Member) -> None: - """Update presence estimates on member update.""" - if after.guild.id != Guild.id: - return - - if self.last_presence_update: - if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: - return - - self.last_presence_update = datetime.now() - - online = 0 - idle = 0 - dnd = 0 - offline = 0 - - for member in after.guild.members: - if member.status is Status.online: - online += 1 - elif member.status is Status.dnd: - dnd += 1 - elif member.status is Status.idle: - idle += 1 - elif member.status is Status.offline: - offline += 1 - - self.bot.stats.gauge("guild.status.online", online) - self.bot.stats.gauge("guild.status.idle", idle) - self.bot.stats.gauge("guild.status.do_not_disturb", dnd) - self.bot.stats.gauge("guild.status.offline", offline) - - @loop(hours=1) - async def update_guild_boost(self) -> None: - """Post the server boost level and tier every hour.""" - await self.bot.wait_until_guild_available() - g = self.bot.get_guild(Guild.id) - self.bot.stats.gauge("boost.amount", g.premium_subscription_count) - self.bot.stats.gauge("boost.tier", g.premium_tier) - - def cog_unload(self) -> None: - """Stop the boost statistic task on unload of the Cog.""" - self.update_guild_boost.stop() - - -def setup(bot: Bot) -> None: - """Load the stats cog.""" - bot.add_cog(Stats(bot)) diff --git a/bot/cogs/info/tags.py b/bot/cogs/info/tags.py deleted file mode 100644 index 3d76c5c08..000000000 --- a/bot/cogs/info/tags.py +++ /dev/null @@ -1,277 +0,0 @@ -import logging -import re -import time -from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot import constants -from bot.bot import Bot -from bot.converters import TagNameConverter -from bot.pagination import LinePaginator -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -TEST_CHANNELS = ( - constants.Channels.bot_commands, - constants.Channels.helpers -) - -REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) -FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." - - -class Tags(Cog): - """Save new tags and fetch existing tags.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.tag_cooldowns = {} - self._cache = self.get_tags() - - @staticmethod - def get_tags() -> dict: - """Get all tags.""" - cache = {} - - base_path = Path("bot", "resources", "tags") - for file in base_path.glob("**/*"): - if file.is_file(): - tag_title = file.stem - tag = { - "title": tag_title, - "embed": { - "description": file.read_text(encoding="utf8"), - }, - "restricted_to": "developers", - "location": f"/bot/{file}" - } - - # Convert to a list to allow negative indexing. - parents = list(file.relative_to(base_path).parents) - if len(parents) > 1: - # -1 would be '.' hence -2 is used as the index. - tag["restricted_to"] = parents[-2].name - - cache[tag_title] = tag - - return cache - - @staticmethod - def check_accessibility(user: Member, tag: dict) -> bool: - """Check if user can access a tag.""" - return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] - - @staticmethod - def _fuzzy_search(search: str, target: str) -> float: - """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" - current, index = 0, 0 - _search = REGEX_NON_ALPHABET.sub('', search.lower()) - _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) - _target = next(_targets) - try: - while True: - while index < len(_target) and _search[current] == _target[index]: - current += 1 - index += 1 - index, _target = 0, next(_targets) - except (StopIteration, IndexError): - pass - return current / len(_search) * 100 - - def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: - """Return a list of suggested tags.""" - scores: Dict[str, int] = { - tag_title: Tags._fuzzy_search(tag_name, tag['title']) - for tag_title, tag in self._cache.items() - } - - thresholds = thresholds or [100, 90, 80, 70, 60] - - for threshold in thresholds: - suggestions = [ - self._cache[tag_title] - for tag_title, matching_score in scores.items() - if matching_score >= threshold - ] - if suggestions: - return suggestions - - return [] - - def _get_tag(self, tag_name: str) -> list: - """Get a specific tag.""" - found = [self._cache.get(tag_name.lower(), None)] - if not found[0]: - return self._get_suggestions(tag_name) - return found - - def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: - """ - Search for tags via contents. - - `predicate` will be the built-in any, all, or a custom callable. Must return a bool. - """ - keywords_processed: List[str] = [] - for keyword in keywords.split(','): - keyword_sanitized = keyword.strip().casefold() - if not keyword_sanitized: - # this happens when there are leading / trailing / consecutive comma. - continue - keywords_processed.append(keyword_sanitized) - - if not keywords_processed: - # after sanitizing, we can end up with an empty list, for example when keywords is ',' - # in that case, we simply want to search for such keywords directly instead. - keywords_processed = [keywords] - - matching_tags = [] - for tag in self._cache.values(): - matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) - if self.check_accessibility(user, tag) and check(matches): - matching_tags.append(tag) - - return matching_tags - - async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: - """Send the result of matching tags to user.""" - if not matching_tags: - pass - elif len(matching_tags) == 1: - await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) - else: - is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 - embed = Embed( - title=f"Here are the tags containing the given keyword{'s' * is_plural}:", - description='\n'.join(tag['title'] for tag in matching_tags[:10]) - ) - await LinePaginator.paginate( - sorted(f"**»** {tag['title']}" for tag in matching_tags), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - - @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) - async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: - """Show all known tags, a single tag, or run a subcommand.""" - await ctx.invoke(self.get_command, tag_name=tag_name) - - @tags_group.group(name='search', invoke_without_command=True) - async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Only search for tags that has ALL the keywords. - """ - matching_tags = self._get_tags_via_content(all, keywords, ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @search_tag_content.command(name='any') - async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Search for tags that has ANY of the keywords. - """ - matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @tags_group.command(name='get', aliases=('show', 'g')) - async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: - """Get a specified tag, or a list of all tags if no tag is specified.""" - - def _command_on_cooldown(tag_name: str) -> bool: - """ - Check if the command is currently on cooldown, on a per-tag, per-channel basis. - - The cooldown duration is set in constants.py. - """ - now = time.time() - - cooldown_conditions = ( - tag_name - and tag_name in self.tag_cooldowns - and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags - and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id - ) - - if cooldown_conditions: - return True - return False - - if _command_on_cooldown(tag_name): - time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] - time_left = constants.Cooldowns.tags - time_elapsed - log.info( - f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " - f"Cooldown ends in {time_left:.1f} seconds." - ) - return - - if tag_name is not None: - temp_founds = self._get_tag(tag_name) - - founds = [] - - for found_tag in temp_founds: - if self.check_accessibility(ctx.author, found_tag): - founds.append(found_tag) - - if len(founds) == 1: - tag = founds[0] - if ctx.channel.id not in TEST_CHANNELS: - self.tag_cooldowns[tag_name] = { - "time": time.time(), - "channel": ctx.channel.id - } - - self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") - - await wait_for_deletion( - await ctx.send(embed=Embed.from_dict(tag['embed'])), - [ctx.author.id], - client=self.bot - ) - elif founds and len(tag_name) >= 3: - await wait_for_deletion( - await ctx.send( - embed=Embed( - title='Did you mean ...', - description='\n'.join(tag['title'] for tag in founds[:10]) - ) - ), - [ctx.author.id], - client=self.bot - ) - - else: - tags = self._cache.values() - if not tags: - await ctx.send(embed=Embed( - description="**There are no tags in the database!**", - colour=Colour.red() - )) - else: - embed: Embed = Embed(title="**Current tags**") - await LinePaginator.paginate( - sorted( - f"**»** {tag['title']}" for tag in tags - if self.check_accessibility(ctx.author, tag) - ), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Tags cog.""" - bot.add_cog(Tags(bot)) diff --git a/bot/cogs/info/wolfram.py b/bot/cogs/info/wolfram.py deleted file mode 100644 index e6cae3bb8..000000000 --- a/bot/cogs/info/wolfram.py +++ /dev/null @@ -1,280 +0,0 @@ -import logging -from io import BytesIO -from typing import Callable, List, Optional, Tuple -from urllib import parse - -import discord -from dateutil.relativedelta import relativedelta -from discord import Embed -from discord.ext import commands -from discord.ext.commands import BucketType, Cog, Context, check, group - -from bot.bot import Bot -from bot.constants import Colours, STAFF_ROLES, Wolfram -from bot.pagination import ImagePaginator -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -APPID = Wolfram.key -DEFAULT_OUTPUT_FORMAT = "JSON" -QUERY = "http://api.wolframalpha.com/v2/{request}?{data}" -WOLF_IMAGE = "https://www.symbols.com/gi.php?type=1&id=2886&i=1" - -MAX_PODS = 20 - -# Allows for 10 wolfram calls pr user pr day -usercd = commands.CooldownMapping.from_cooldown(Wolfram.user_limit_day, 60*60*24, BucketType.user) - -# Allows for max api requests / days in month per day for the entire guild (Temporary) -guildcd = commands.CooldownMapping.from_cooldown(Wolfram.guild_limit_day, 60*60*24, BucketType.guild) - - -async def send_embed( - ctx: Context, - message_txt: str, - colour: int = Colours.soft_red, - footer: str = None, - img_url: str = None, - f: discord.File = None -) -> None: - """Generate & send a response embed with Wolfram as the author.""" - embed = Embed(colour=colour) - embed.description = message_txt - embed.set_author(name="Wolfram Alpha", - icon_url=WOLF_IMAGE, - url="https://www.wolframalpha.com/") - if footer: - embed.set_footer(text=footer) - - if img_url: - embed.set_image(url=img_url) - - await ctx.send(embed=embed, file=f) - - -def custom_cooldown(*ignore: List[int]) -> Callable: - """ - Implement per-user and per-guild cooldowns for requests to the Wolfram API. - - A list of roles may be provided to ignore the per-user cooldown - """ - async def predicate(ctx: Context) -> bool: - if ctx.invoked_with == 'help': - # if the invoked command is help we don't want to increase the ratelimits since it's not actually - # invoking the command/making a request, so instead just check if the user/guild are on cooldown. - guild_cooldown = not guildcd.get_bucket(ctx.message).get_tokens() == 0 # if guild is on cooldown - if not any(r.id in ignore for r in ctx.author.roles): # check user bucket if user is not ignored - return guild_cooldown and not usercd.get_bucket(ctx.message).get_tokens() == 0 - return guild_cooldown - - user_bucket = usercd.get_bucket(ctx.message) - - if all(role.id not in ignore for role in ctx.author.roles): - user_rate = user_bucket.update_rate_limit() - - if user_rate: - # Can't use api; cause: member limit - delta = relativedelta(seconds=int(user_rate)) - cooldown = humanize_delta(delta) - message = ( - "You've used up your limit for Wolfram|Alpha requests.\n" - f"Cooldown: {cooldown}" - ) - await send_embed(ctx, message) - return False - - guild_bucket = guildcd.get_bucket(ctx.message) - guild_rate = guild_bucket.update_rate_limit() - - # Repr has a token attribute to read requests left - log.debug(guild_bucket) - - if guild_rate: - # Can't use api; cause: guild limit - message = ( - "The max limit of requests for the server has been reached for today.\n" - f"Cooldown: {int(guild_rate)}" - ) - await send_embed(ctx, message) - return False - - return True - return check(predicate) - - -async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tuple]]: - """Get the Wolfram API pod pages for the provided query.""" - async with ctx.channel.typing(): - url_str = parse.urlencode({ - "input": query, - "appid": APPID, - "output": DEFAULT_OUTPUT_FORMAT, - "format": "image,plaintext" - }) - request_url = QUERY.format(request="query", data=url_str) - - async with bot.http_session.get(request_url) as response: - json = await response.json(content_type='text/plain') - - result = json["queryresult"] - - if result["error"]: - # API key not set up correctly - if result["error"]["msg"] == "Invalid appid": - message = "Wolfram API key is invalid or missing." - log.warning( - "API key seems to be missing, or invalid when " - f"processing a wolfram request: {url_str}, Response: {json}" - ) - await send_embed(ctx, message) - return - - message = "Something went wrong internally with your request, please notify staff!" - log.warning(f"Something went wrong getting a response from wolfram: {url_str}, Response: {json}") - await send_embed(ctx, message) - return - - if not result["success"]: - message = f"I couldn't find anything for {query}." - await send_embed(ctx, message) - return - - if not result["numpods"]: - message = "Could not find any results." - await send_embed(ctx, message) - return - - pods = result["pods"] - pages = [] - for pod in pods[:MAX_PODS]: - subs = pod.get("subpods") - - for sub in subs: - title = sub.get("title") or sub.get("plaintext") or sub.get("id", "") - img = sub["img"]["src"] - pages.append((title, img)) - return pages - - -class Wolfram(Cog): - """Commands for interacting with the Wolfram|Alpha API.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_command(self, ctx: Context, *, query: str) -> None: - """Requests all answers on a single image, sends an image of all related pods.""" - url_str = parse.urlencode({ - "i": query, - "appid": APPID, - }) - query = QUERY.format(request="simple", data=url_str) - - # Give feedback that the bot is working. - async with ctx.channel.typing(): - async with self.bot.http_session.get(query) as response: - status = response.status - image_bytes = await response.read() - - f = discord.File(BytesIO(image_bytes), filename="image.png") - image_url = "attachment://image.png" - - if status == 501: - message = "Failed to get response" - footer = "" - color = Colours.soft_red - elif status == 400: - message = "No input found" - footer = "" - color = Colours.soft_red - elif status == 403: - message = "Wolfram API key is invalid or missing." - footer = "" - color = Colours.soft_red - else: - message = "" - footer = "View original for a bigger picture." - color = Colours.soft_orange - - # Sends a "blank" embed if no request is received, unsure how to fix - await send_embed(ctx, message, color, footer=footer, img_url=image_url, f=f) - - @wolfram_command.command(name="page", aliases=("pa", "p")) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_page_command(self, ctx: Context, *, query: str) -> None: - """ - Requests a drawn image of given query. - - Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. - """ - pages = await get_pod_pages(ctx, self.bot, query) - - if not pages: - return - - embed = Embed() - embed.set_author(name="Wolfram Alpha", - icon_url=WOLF_IMAGE, - url="https://www.wolframalpha.com/") - embed.colour = Colours.soft_orange - - await ImagePaginator.paginate(pages, ctx, embed) - - @wolfram_command.command(name="cut", aliases=("c",)) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_cut_command(self, ctx: Context, *, query: str) -> None: - """ - Requests a drawn image of given query. - - Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. - """ - pages = await get_pod_pages(ctx, self.bot, query) - - if not pages: - return - - if len(pages) >= 2: - page = pages[1] - else: - page = pages[0] - - await send_embed(ctx, page[0], colour=Colours.soft_orange, img_url=page[1]) - - @wolfram_command.command(name="short", aliases=("sh", "s")) - @custom_cooldown(*STAFF_ROLES) - async def wolfram_short_command(self, ctx: Context, *, query: str) -> None: - """Requests an answer to a simple question.""" - url_str = parse.urlencode({ - "i": query, - "appid": APPID, - }) - query = QUERY.format(request="result", data=url_str) - - # Give feedback that the bot is working. - async with ctx.channel.typing(): - async with self.bot.http_session.get(query) as response: - status = response.status - response_text = await response.text() - - if status == 501: - message = "Failed to get response" - color = Colours.soft_red - elif status == 400: - message = "No input found" - color = Colours.soft_red - elif response_text == "Error 1: Invalid appid": - message = "Wolfram API key is invalid or missing." - color = Colours.soft_red - else: - message = response_text - color = Colours.soft_orange - - await send_embed(ctx, message, color) - - -def setup(bot: Bot) -> None: - """Load the Wolfram cog.""" - bot.add_cog(Wolfram(bot)) diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/moderation/defcon.py b/bot/cogs/moderation/defcon.py deleted file mode 100644 index e78435a7d..000000000 --- a/bot/cogs/moderation/defcon.py +++ /dev/null @@ -1,258 +0,0 @@ -from __future__ import annotations - -import logging -from collections import namedtuple -from datetime import datetime, timedelta -from enum import Enum - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles -from bot.decorators import with_role - -log = logging.getLogger(__name__) - -REJECTION_MESSAGE = """ -Hi, {user} - Thanks for your interest in our server! - -Due to a current (or detected) cyberattack on our community, we've limited access to the server for new accounts. Since -your account is relatively new, we're unable to provide access to the server at this time. - -Even so, thanks for joining! We're very excited at the possibility of having you here, and we hope that this situation -will be resolved soon. In the meantime, please feel free to peruse the resources on our site at -, and have a nice day! -""" - -BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism" - - -class Action(Enum): - """Defcon Action.""" - - ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template']) - - ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n") - DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "") - UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n") - - -class Defcon(Cog): - """Time-sensitive server defense mechanisms.""" - - days = None # type: timedelta - enabled = False # type: bool - - def __init__(self, bot: Bot): - self.bot = bot - self.channel = None - self.days = timedelta(days=0) - - self.bot.loop.create_task(self.sync_settings()) - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def sync_settings(self) -> None: - """On cog load, try to synchronize DEFCON settings to the API.""" - await self.bot.wait_until_guild_available() - self.channel = await self.bot.fetch_channel(Channels.defcon) - - try: - response = await self.bot.api_client.get('bot/bot-settings/defcon') - data = response['data'] - - except Exception: # Yikes! - log.exception("Unable to get DEFCON settings!") - await self.bot.get_channel(Channels.dev_log).send( - f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" - ) - - else: - if data["enabled"]: - self.enabled = True - self.days = timedelta(days=data["days"]) - log.info(f"DEFCON enabled: {self.days.days} days") - - else: - self.enabled = False - self.days = timedelta(days=0) - log.info("DEFCON disabled") - - await self.update_channel_topic() - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """If DEFCON is enabled, check newly joining users to see if they meet the account age threshold.""" - if self.enabled and self.days.days > 0: - now = datetime.utcnow() - - if now - member.created_at < self.days: - log.info(f"Rejecting user {member}: Account is too new and DEFCON is enabled") - - message_sent = False - - try: - await member.send(REJECTION_MESSAGE.format(user=member.mention)) - - message_sent = True - except Exception: - log.exception(f"Unable to send rejection message to user: {member}") - - await member.kick(reason="DEFCON active, user is too new") - self.bot.stats.incr("defcon.leaves") - - message = ( - f"{member} (`{member.id}`) was denied entry because their account is too new." - ) - - if not message_sent: - message = f"{message}\n\nUnable to send rejection message via DM; they probably have DMs disabled." - - await self.mod_log.send_log_message( - Icons.defcon_denied, Colours.soft_red, "Entry denied", - message, member.avatar_url_as(static_format="png") - ) - - @group(name='defcon', aliases=('dc',), invoke_without_command=True) - @with_role(Roles.admins, Roles.owners) - async def defcon_group(self, ctx: Context) -> None: - """Check the DEFCON status or run a subcommand.""" - await ctx.send_help(ctx.command) - - async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: - """Providing a structured way to do an defcon action.""" - try: - response = await self.bot.api_client.get('bot/bot-settings/defcon') - data = response['data'] - - if "enable_date" in data and action is Action.DISABLED: - enabled = datetime.fromisoformat(data["enable_date"]) - - delta = datetime.now() - enabled - - self.bot.stats.timing("defcon.enabled", delta) - except Exception: - pass - - error = None - try: - await self.bot.api_client.put( - 'bot/bot-settings/defcon', - json={ - 'name': 'defcon', - 'data': { - # TODO: retrieve old days count - 'days': days, - 'enabled': action is not Action.DISABLED, - 'enable_date': datetime.now().isoformat() - } - } - ) - except Exception as err: - log.exception("Unable to update DEFCON settings.") - error = err - finally: - await ctx.send(self.build_defcon_msg(action, error)) - await self.send_defcon_log(action, ctx.author, error) - - self.bot.stats.gauge("defcon.threshold", days) - - @defcon_group.command(name='enable', aliases=('on', 'e')) - @with_role(Roles.admins, Roles.owners) - async def enable_command(self, ctx: Context) -> None: - """ - Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! - - Currently, this just adds an account age requirement. Use !defcon days to set how old an account must be, - in days. - """ - self.enabled = True - await self._defcon_action(ctx, days=0, action=Action.ENABLED) - await self.update_channel_topic() - - @defcon_group.command(name='disable', aliases=('off', 'd')) - @with_role(Roles.admins, Roles.owners) - async def disable_command(self, ctx: Context) -> None: - """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" - self.enabled = False - await self._defcon_action(ctx, days=0, action=Action.DISABLED) - await self.update_channel_topic() - - @defcon_group.command(name='status', aliases=('s',)) - @with_role(Roles.admins, Roles.owners) - async def status_command(self, ctx: Context) -> None: - """Check the current status of DEFCON mode.""" - embed = Embed( - colour=Colour.blurple(), title="DEFCON Status", - description=f"**Enabled:** {self.enabled}\n" - f"**Days:** {self.days.days}" - ) - - await ctx.send(embed=embed) - - @defcon_group.command(name='days') - @with_role(Roles.admins, Roles.owners) - async def days_command(self, ctx: Context, days: int) -> None: - """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" - self.days = timedelta(days=days) - self.enabled = True - await self._defcon_action(ctx, days=days, action=Action.UPDATED) - await self.update_channel_topic() - - async def update_channel_topic(self) -> None: - """Update the #defcon channel topic with the current DEFCON status.""" - if self.enabled: - day_str = "days" if self.days.days > 1 else "day" - new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Enabled, Threshold: {self.days.days} {day_str})" - else: - new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Disabled)" - - self.mod_log.ignore(Event.guild_channel_update, Channels.defcon) - await self.channel.edit(topic=new_topic) - - def build_defcon_msg(self, action: Action, e: Exception = None) -> str: - """Build in-channel response string for DEFCON action.""" - if action is Action.ENABLED: - msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n" - elif action is Action.DISABLED: - msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n" - elif action is Action.UPDATED: - msg = ( - f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} " - f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n" - ) - - if e: - msg += ( - "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" - f"```py\n{e}\n```" - ) - - return msg - - async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None: - """Send log message for DEFCON action.""" - info = action.value - log_msg: str = ( - f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n" - f"{info.template.format(days=self.days.days)}" - ) - status_msg = f"DEFCON {action.name.lower()}" - - if e: - log_msg += ( - "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" - f"```py\n{e}\n```" - ) - - await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg) - - -def setup(bot: Bot) -> None: - """Load the Defcon cog.""" - bot.add_cog(Defcon(bot)) diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py deleted file mode 100644 index e49913552..000000000 --- a/bot/cogs/moderation/incidents.py +++ /dev/null @@ -1,412 +0,0 @@ -import asyncio -import logging -import typing as t -from datetime import datetime -from enum import Enum - -import discord -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, Colours, Emojis, Guild, Webhooks -from bot.utils.messages import sub_clyde - -log = logging.getLogger(__name__) - -# Amount of messages for `crawl_task` to process at most on start-up - limited to 50 -# as in practice, there should never be this many messages, and if there are, -# something has likely gone very wrong -CRAWL_LIMIT = 50 - -# Seconds for `crawl_task` to sleep after adding reactions to a message -CRAWL_SLEEP = 2 - - -class Signal(Enum): - """ - Recognized incident status signals. - - This binds emoji to actions. The bot will only react to emoji linked here. - All other signals are seen as invalid. - """ - - ACTIONED = Emojis.incident_actioned - NOT_ACTIONED = Emojis.incident_unactioned - INVESTIGATING = Emojis.incident_investigating - - -# Reactions from non-mod roles will be removed -ALLOWED_ROLES: t.Set[int] = set(Guild.moderation_roles) - -# Message must have all of these emoji to pass the `has_signals` check -ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} - -# An embed coupled with an optional file to be dispatched -# If the file is not None, the embed attempts to show it in its body -FileEmbed = t.Tuple[discord.Embed, t.Optional[discord.File]] - - -async def download_file(attachment: discord.Attachment) -> t.Optional[discord.File]: - """ - Download & return `attachment` file. - - If the download fails, the reason is logged and None will be returned. - 404 and 403 errors are only logged at debug level. - """ - log.debug(f"Attempting to download attachment: {attachment.filename}") - try: - return await attachment.to_file() - except (discord.NotFound, discord.Forbidden) as exc: - log.debug(f"Failed to download attachment: {exc}") - except Exception: - log.exception("Failed to download attachment") - - -async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> FileEmbed: - """ - Create an embed representation of `incident` for the #incidents-archive channel. - - The name & discriminator of `actioned_by` and `outcome` will be presented in the - embed footer. Additionally, the embed is coloured based on `outcome`. - - The author of `incident` is not shown in the embed. It is assumed that this piece - of information will be relayed in other ways, e.g. webhook username. - - As mentions in embeds do not ping, we do not need to use `incident.clean_content`. - - If `incident` contains attachments, the first attachment will be downloaded and - returned alongside the embed. The embed attempts to display the attachment. - Should the download fail, we fallback on linking the `proxy_url`, which should - remain functional for some time after the original message is deleted. - """ - log.trace(f"Creating embed for {incident.id=}") - - if outcome is Signal.ACTIONED: - colour = Colours.soft_green - footer = f"Actioned by {actioned_by}" - else: - colour = Colours.soft_red - footer = f"Rejected by {actioned_by}" - - embed = discord.Embed( - description=incident.content, - timestamp=datetime.utcnow(), - colour=colour, - ) - embed.set_footer(text=footer, icon_url=actioned_by.avatar_url) - - if incident.attachments: - attachment = incident.attachments[0] # User-sent messages can only contain one attachment - file = await download_file(attachment) - - if file is not None: - embed.set_image(url=f"attachment://{attachment.filename}") # Embed displays the attached file - else: - embed.set_author(name="[Failed to relay attachment]", url=attachment.proxy_url) # Embed links the file - else: - file = None - - return embed, file - - -def is_incident(message: discord.Message) -> bool: - """True if `message` qualifies as an incident, False otherwise.""" - conditions = ( - message.channel.id == Channels.incidents, # Message sent in #incidents - not message.author.bot, # Not by a bot - not message.content.startswith("#"), # Doesn't start with a hash - not message.pinned, # And isn't header - ) - return all(conditions) - - -def own_reactions(message: discord.Message) -> t.Set[str]: - """Get the set of reactions placed on `message` by the bot itself.""" - return {str(reaction.emoji) for reaction in message.reactions if reaction.me} - - -def has_signals(message: discord.Message) -> bool: - """True if `message` already has all `Signal` reactions, False otherwise.""" - return ALL_SIGNALS.issubset(own_reactions(message)) - - -async def add_signals(incident: discord.Message) -> None: - """ - Add `Signal` member emoji to `incident` as reactions. - - If the emoji has already been placed on `incident` by the bot, it will be skipped. - """ - existing_reacts = own_reactions(incident) - - for signal_emoji in Signal: - if signal_emoji.value in existing_reacts: # This would not raise, but it is a superfluous API call - log.trace(f"Skipping emoji as it's already been placed: {signal_emoji}") - else: - log.trace(f"Adding reaction: {signal_emoji}") - await incident.add_reaction(signal_emoji.value) - - -class Incidents(Cog): - """ - Automation for the #incidents channel. - - This cog does not provide a command API, it only reacts to the following events. - - On start-up: - * Crawl #incidents and add missing `Signal` emoji where appropriate - * This is to retro-actively add the available options for messages which - were sent while the bot wasn't listening - * Pinned messages and message starting with # do not qualify as incidents - * See: `crawl_incidents` - - On message: - * Add `Signal` member emoji if message qualifies as an incident - * Ignore messages starting with # - * Use this if verbal communication is necessary - * Each such message must be deleted manually once appropriate - * See: `on_message` - - On reaction: - * Remove reaction if not permitted - * User does not have any of the roles in `ALLOWED_ROLES` - * Used emoji is not a `Signal` member - * If `Signal.ACTIONED` or `Signal.NOT_ACTIONED` were chosen, attempt to - relay the incident message to #incidents-archive - * If relay successful, delete original message - * See: `on_raw_reaction_add` - - Please refer to function docstrings for implementation details. - """ - - def __init__(self, bot: Bot) -> None: - """Prepare `event_lock` and schedule `crawl_task` on start-up.""" - self.bot = bot - - self.event_lock = asyncio.Lock() - self.crawl_task = self.bot.loop.create_task(self.crawl_incidents()) - - async def crawl_incidents(self) -> None: - """ - Crawl #incidents and add missing emoji where necessary. - - This is to catch-up should an incident be reported while the bot wasn't listening. - After adding each reaction, we take a short break to avoid drowning in ratelimits. - - Once this task is scheduled, listeners that change messages should await it. - The crawl assumes that the channel history doesn't change as we go over it. - - Behaviour is configured by: `CRAWL_LIMIT`, `CRAWL_SLEEP`. - """ - await self.bot.wait_until_guild_available() - incidents: discord.TextChannel = self.bot.get_channel(Channels.incidents) - - log.debug(f"Crawling messages in #incidents: {CRAWL_LIMIT=}, {CRAWL_SLEEP=}") - async for message in incidents.history(limit=CRAWL_LIMIT): - - if not is_incident(message): - log.trace(f"Skipping message {message.id}: not an incident") - continue - - if has_signals(message): - log.trace(f"Skipping message {message.id}: already has all signals") - continue - - await add_signals(message) - await asyncio.sleep(CRAWL_SLEEP) - - log.debug("Crawl task finished!") - - async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: - """ - Relay an embed representation of `incident` to the #incidents-archive channel. - - The following pieces of information are relayed: - * Incident message content (as embed description) - * Incident attachment (if image, shown in archive embed) - * Incident author name (as webhook author) - * Incident author avatar (as webhook avatar) - * Resolution signal `outcome` (as embed colour & footer) - * Moderator `actioned_by` (name & discriminator shown in footer) - - If `incident` contains an attachment, we try to add it to the archive embed. There is - no handing of extensions / file types - we simply dispatch the attachment file with the - webhook, and try to display it in the embed. Testing indicates that if the attachment - cannot be displayed (e.g. a text file), it's invisible in the embed, with no error. - - Return True if the relay finishes successfully. If anything goes wrong, meaning - not all information was relayed, return False. This signals that the original - message is not safe to be deleted, as we will lose some information. - """ - log.debug(f"Archiving incident: {incident.id} (outcome: {outcome}, actioned by: {actioned_by})") - embed, attachment_file = await make_embed(incident, outcome, actioned_by) - - try: - webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive) - await webhook.send( - embed=embed, - username=sub_clyde(incident.author.name), - avatar_url=incident.author.avatar_url, - file=attachment_file, - ) - except Exception: - log.exception(f"Failed to archive incident {incident.id} to #incidents-archive") - return False - else: - log.trace("Message archived successfully!") - return True - - def make_confirmation_task(self, incident: discord.Message, timeout: int = 5) -> asyncio.Task: - """ - Create a task to wait `timeout` seconds for `incident` to be deleted. - - If `timeout` passes, this will raise `asyncio.TimeoutError`, signaling that we haven't - been able to confirm that the message was deleted. - """ - log.trace(f"Confirmation task will wait {timeout=} seconds for {incident.id=} to be deleted") - - def check(payload: discord.RawReactionActionEvent) -> bool: - return payload.message_id == incident.id - - coroutine = self.bot.wait_for(event="raw_message_delete", check=check, timeout=timeout) - return self.bot.loop.create_task(coroutine) - - async def process_event(self, reaction: str, incident: discord.Message, member: discord.Member) -> None: - """ - Process a `reaction_add` event in #incidents. - - First, we check that the reaction is a recognized `Signal` member, and that it was sent by - a permitted user (at least one role in `ALLOWED_ROLES`). If not, the reaction is removed. - - If the reaction was either `Signal.ACTIONED` or `Signal.NOT_ACTIONED`, we attempt to relay - the report to #incidents-archive. If successful, the original message is deleted. - - We do not release `event_lock` until we receive the corresponding `message_delete` event. - This ensures that if there is a racing event awaiting the lock, it will fail to find the - message, and will abort. There is a `timeout` to ensure that this doesn't hold the lock - forever should something go wrong. - """ - members_roles: t.Set[int] = {role.id for role in member.roles} - if not members_roles & ALLOWED_ROLES: # Intersection is truthy on at least 1 common element - log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals") - await incident.remove_reaction(reaction, member) - return - - try: - signal = Signal(reaction) - except ValueError: - log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") - await incident.remove_reaction(reaction, member) - return - - log.trace(f"Received signal: {signal}") - - if signal not in (Signal.ACTIONED, Signal.NOT_ACTIONED): - log.debug("Reaction was valid, but no action is currently defined for it") - return - - relay_successful = await self.archive(incident, signal, actioned_by=member) - if not relay_successful: - log.trace("Original message will not be deleted as we failed to relay it to the archive") - return - - timeout = 5 # Seconds - confirmation_task = self.make_confirmation_task(incident, timeout) - - log.trace("Deleting original message") - await incident.delete() - - log.trace(f"Awaiting deletion confirmation: {timeout=} seconds") - try: - await confirmation_task - except asyncio.TimeoutError: - log.warning(f"Did not receive incident deletion confirmation within {timeout} seconds!") - else: - log.trace("Deletion was confirmed") - - async def resolve_message(self, message_id: int) -> t.Optional[discord.Message]: - """ - Get `discord.Message` for `message_id` from cache, or API. - - We first look into the local cache to see if the message is present. - - If not, we try to fetch the message from the API. This is necessary for messages - which were sent before the bot's current session. - - In an edge-case, it is also possible that the message was already deleted, and - the API will respond with a 404. In such a case, None will be returned. - This signals that the event for `message_id` should be ignored. - """ - await self.bot.wait_until_guild_available() # First make sure that the cache is ready - log.trace(f"Resolving message for: {message_id=}") - message: t.Optional[discord.Message] = self.bot._connection._get_message(message_id) - - if message is not None: - log.trace("Message was found in cache") - return message - - log.trace("Message not found, attempting to fetch") - try: - message = await self.bot.get_channel(Channels.incidents).fetch_message(message_id) - except discord.NotFound: - log.trace("Message doesn't exist, it was likely already relayed") - except Exception: - log.exception(f"Failed to fetch message {message_id}!") - else: - log.trace("Message fetched successfully!") - return message - - @Cog.listener() - async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None: - """ - Pre-process `payload` and pass it to `process_event` if appropriate. - - We abort instantly if `payload` doesn't relate to a message sent in #incidents, - or if it was sent by a bot. - - If `payload` relates to a message in #incidents, we first ensure that `crawl_task` has - finished, to make sure we don't mutate channel state as we're crawling it. - - Next, we acquire `event_lock` - to prevent racing, events are processed one at a time. - - Once we have the lock, the `discord.Message` object for this event must be resolved. - If the lock was previously held by an event which successfully relayed the incident, - this will fail and we abort the current event. - - Finally, with both the lock and the `discord.Message` instance in our hands, we delegate - to `process_event` to handle the event. - - The justification for using a raw listener is the need to receive events for messages - which were not cached in the current session. As a result, a certain amount of - complexity is introduced, but at the moment this doesn't appear to be avoidable. - """ - if payload.channel_id != Channels.incidents or payload.member.bot: - return - - log.trace(f"Received reaction add event in #incidents, waiting for crawler: {self.crawl_task.done()=}") - await self.crawl_task - - log.trace(f"Acquiring event lock: {self.event_lock.locked()=}") - async with self.event_lock: - message = await self.resolve_message(payload.message_id) - - if message is None: - log.debug("Listener will abort as related message does not exist!") - return - - if not is_incident(message): - log.debug("Ignoring event for a non-incident message") - return - - await self.process_event(str(payload.emoji), message, payload.member) - log.trace("Releasing event lock") - - @Cog.listener() - async def on_message(self, message: discord.Message) -> None: - """Pass `message` to `add_signals` if and only if it satisfies `is_incident`.""" - if is_incident(message): - await add_signals(message) - - -def setup(bot: Bot) -> None: - """Load the Incidents cog.""" - bot.add_cog(Incidents(bot)) diff --git a/bot/cogs/moderation/infraction/__init__.py b/bot/cogs/moderation/infraction/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/moderation/infraction/_scheduler.py b/bot/cogs/moderation/infraction/_scheduler.py deleted file mode 100644 index 33944a8db..000000000 --- a/bot/cogs/moderation/infraction/_scheduler.py +++ /dev/null @@ -1,463 +0,0 @@ -import logging -import textwrap -import typing as t -from abc import abstractmethod -from datetime import datetime -from gettext import ngettext - -import dateutil.parser -import discord -from discord.ext.commands import Context - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import Colours, STAFF_CHANNELS -from bot.utils import time -from bot.utils.scheduling import Scheduler -from . import _utils -from ._utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class InfractionScheduler: - """Handles the application, pardoning, and expiration of infractions.""" - - def __init__(self, bot: Bot, supported_infractions: t.Container[str]): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - @property - def mod_log(self) -> ModLog: - """Get the currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: - """Schedule expiration for previous infractions.""" - await self.bot.wait_until_guild_available() - - log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") - - infractions = await self.bot.api_client.get( - 'bot/infractions', - params={'active': 'true'} - ) - for infraction in infractions: - if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_expiration(infraction) - - async def reapply_infraction( - self, - infraction: _utils.Infraction, - apply_coro: t.Optional[t.Awaitable] - ) -> None: - """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" - # Calculate the time remaining, in seconds, for the mute. - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - delta = (expiry - datetime.utcnow()).total_seconds() - - # Mark as inactive if less than a minute remains. - if delta < 60: - log.info( - "Infraction will be deactivated instead of re-applied " - "because less than 1 minute remains." - ) - await self.deactivate_infraction(infraction) - return - - # Allowing mod log since this is a passive action that should be logged. - await apply_coro - log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") - - async def apply_infraction( - self, - ctx: Context, - infraction: _utils.Infraction, - user: UserSnowflake, - action_coro: t.Optional[t.Awaitable] = None - ) -> None: - """Apply an infraction to the user, log the infraction, and optionally notify the user.""" - infr_type = infraction["type"] - icon = _utils.INFRACTION_ICONS[infr_type][0] - reason = infraction["reason"] - expiry = time.format_infraction_with_duration(infraction["expires_at"]) - id_ = infraction['id'] - - log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") - - # Default values for the confirmation message and mod log. - confirm_msg = ":ok_hand: applied" - - # Specifying an expiry for a note or warning makes no sense. - if infr_type in ("note", "warning"): - expiry_msg = "" - else: - expiry_msg = f" until {expiry}" if expiry else " permanently" - - dm_result = "" - dm_log_text = "" - expiry_log_text = f"\nExpires: {expiry}" if expiry else "" - log_title = "applied" - log_content = None - failed = False - - # DM the user about the infraction if it's not a shadow/hidden infraction. - # This needs to happen before we apply the infraction, as the bot cannot - # send DMs to user that it doesn't share a guild with. If we were to - # apply kick/ban infractions first, this would mean that we'd make it - # impossible for us to deliver a DM. See python-discord/bot#982. - if not infraction["hidden"]: - dm_result = f"{constants.Emojis.failmail} " - dm_log_text = "\nDM: **Failed**" - - # Sometimes user is a discord.Object; make it a proper user. - try: - if not isinstance(user, (discord.Member, discord.User)): - user = await self.bot.fetch_user(user.id) - except discord.HTTPException as e: - log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") - else: - # Accordingly display whether the user was successfully notified via DM. - if await _utils.notify_infraction(user, infr_type, expiry, reason, icon): - dm_result = ":incoming_envelope: " - dm_log_text = "\nDM: Sent" - - end_msg = "" - if infraction["actor"] == self.bot.user.id: - log.trace( - f"Infraction #{id_} actor is bot; including the reason in the confirmation message." - ) - if reason: - end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" - elif ctx.channel.id not in STAFF_CHANNELS: - log.trace( - f"Infraction #{id_} context is not in a staff channel; omitting infraction count." - ) - else: - log.trace(f"Fetching total infraction count for {user}.") - - infractions = await self.bot.api_client.get( - "bot/infractions", - params={"user__id": str(user.id)} - ) - total = len(infractions) - end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" - - # Execute the necessary actions to apply the infraction on Discord. - if action_coro: - log.trace(f"Awaiting the infraction #{id_} application action coroutine.") - try: - await action_coro - if expiry: - # Schedule the expiration of the infraction. - self.schedule_expiration(infraction) - except discord.HTTPException as e: - # Accordingly display that applying the infraction failed. - confirm_msg = ":x: failed to apply" - expiry_msg = "" - log_content = ctx.author.mention - log_title = "failed to apply" - - log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" - if isinstance(e, discord.Forbidden): - log.warning(f"{log_msg}: bot lacks permissions.") - else: - log.exception(log_msg) - failed = True - - if failed: - log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") - try: - await self.bot.api_client.delete(f"bot/infractions/{id_}") - except ResponseCodeError as e: - confirm_msg += " and failed to delete" - log_title += " and failed to delete" - log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") - infr_message = "" - else: - infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" - - # Send a confirmation message to the invoking context. - log.trace(f"Sending infraction #{id_} confirmation message.") - await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") - - # Send a log message to the mod log. - log.trace(f"Sending apply mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=icon, - colour=Colours.soft_red, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {user.mention} (`{user.id}`) - Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} - Reason: {reason} - """), - content=log_content, - footer=f"ID {infraction['id']}" - ) - - log.info(f"Applied {infr_type} infraction #{id_} to {user}.") - - async def pardon_infraction( - self, - ctx: Context, - infr_type: str, - user: UserSnowflake, - send_msg: bool = True - ) -> None: - """ - Prematurely end an infraction for a user and log the action in the mod log. - - If `send_msg` is True, then a pardoning confirmation message will be sent to - the context channel. Otherwise, no such message will be sent. - """ - log.trace(f"Pardoning {infr_type} infraction for {user}.") - - # Check the current active infraction - log.trace(f"Fetching active {infr_type} infractions for {user}.") - response = await self.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': user.id - } - ) - - if not response: - log.debug(f"No active {infr_type} infraction found for {user}.") - await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") - return - - # Deactivate the infraction and cancel its scheduled expiration task. - log_text = await self.deactivate_infraction(response[0], send_log=False) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["Actor"] = str(ctx.message.author) - log_content = None - id_ = response[0]['id'] - footer = f"ID: {id_}" - - # If multiple active infractions were found, mark them as inactive in the database - # and cancel their expiration tasks. - if len(response) > 1: - log.info( - f"Found more than one active {infr_type} infraction for user {user.id}; " - "deactivating the extra active infractions too." - ) - - footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" - - log_note = f"Found multiple **active** {infr_type} infractions in the database." - if "Note" in log_text: - log_text["Note"] = f" {log_note}" - else: - log_text["Note"] = log_note - - # deactivate_infraction() is not called again because: - # 1. Discord cannot store multiple active bans or assign multiples of the same role - # 2. It would send a pardon DM for each active infraction, which is redundant - for infraction in response[1:]: - id_ = infraction['id'] - try: - # Mark infraction as inactive in the database. - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError: - log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") - # This is simpler and cleaner than trying to concatenate all the errors. - log_text["Failure"] = "See bot's logs for details." - - # Cancel pending expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Accordingly display whether the user was successfully notified via DM. - dm_emoji = "" - if log_text.get("DM") == "Sent": - dm_emoji = ":incoming_envelope: " - elif "DM" in log_text: - dm_emoji = f"{constants.Emojis.failmail} " - - # Accordingly display whether the pardon failed. - if "Failure" in log_text: - confirm_msg = ":x: failed to pardon" - log_title = "pardon failed" - log_content = ctx.author.mention - - log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") - else: - confirm_msg = ":ok_hand: pardoned" - log_title = "pardoned" - - log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") - - # Send a confirmation message to the invoking context. - if send_msg: - log.trace(f"Sending infraction #{id_} pardon confirmation message.") - await ctx.send( - f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " - f"{log_text.get('Failure', '')}" - ) - - # Move reason to end of entry to avoid cutting out some keys - log_text["Reason"] = log_text.pop("Reason") - - # Send a log message to the mod log. - await self.mod_log.send_log_message( - icon_url=_utils.INFRACTION_ICONS[infr_type][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {infr_type}", - thumbnail=user.avatar_url_as(static_format="png"), - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=footer, - content=log_content, - ) - - async def deactivate_infraction( - self, - infraction: _utils.Infraction, - send_log: bool = True - ) -> t.Dict[str, str]: - """ - Deactivate an active infraction and return a dictionary of lines to send in a mod log. - - The infraction is removed from Discord, marked as inactive in the database, and has its - expiration task cancelled. If `send_log` is True, a mod log is sent for the - deactivation of the infraction. - - Infractions of unsupported types will raise a ValueError. - """ - guild = self.bot.get_guild(constants.Guild.id) - mod_role = guild.get_role(constants.Roles.moderators) - user_id = infraction["user"] - actor = infraction["actor"] - type_ = infraction["type"] - id_ = infraction["id"] - inserted_at = infraction["inserted_at"] - expiry = infraction["expires_at"] - - log.info(f"Marking infraction #{id_} as inactive (expired).") - - expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None - created = time.format_infraction_with_duration(inserted_at, expiry) - - log_content = None - log_text = { - "Member": f"<@{user_id}>", - "Actor": str(self.bot.get_user(actor) or actor), - "Reason": infraction["reason"], - "Created": created, - } - - try: - log.trace("Awaiting the pardon action coroutine.") - returned_log = await self._pardon_action(infraction) - - if returned_log is not None: - log_text = {**log_text, **returned_log} # Merge the logs together - else: - raise ValueError( - f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" - ) - except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") - log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" - log_content = mod_role.mention - except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." - log_content = mod_role.mention - - # Check if the user is currently being watched by Big Brother. - try: - log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") - - active_watch = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "watch", - "user__id": user_id - } - ) - - log_text["Watching"] = "Yes" if active_watch else "No" - except ResponseCodeError: - log.exception(f"Failed to fetch watch status for user {user_id}") - log_text["Watching"] = "Unknown - failed to fetch watch status." - - try: - # Mark infraction as inactive in the database. - log.trace(f"Marking infraction #{id_} as inactive in the database.") - await self.bot.api_client.patch( - f"bot/infractions/{id_}", - json={"active": False} - ) - except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_line = f"API request failed with code {e.status}." - log_content = mod_role.mention - - # Append to an existing failure message if possible - if "Failure" in log_text: - log_text["Failure"] += f" {log_line}" - else: - log_text["Failure"] = log_line - - # Cancel the expiration task. - if infraction["expires_at"] is not None: - self.scheduler.cancel(infraction["id"]) - - # Send a log message to the mod log. - if send_log: - log_title = "expiration failed" if "Failure" in log_text else "expired" - - user = self.bot.get_user(user_id) - avatar = user.avatar_url_as(static_format="png") if user else None - - # Move reason to end so when reason is too long, this is not gonna cut out required items. - log_text["Reason"] = log_text.pop("Reason") - - log.trace(f"Sending deactivation mod log for infraction #{id_}.") - await self.mod_log.send_log_message( - icon_url=_utils.INFRACTION_ICONS[type_][1], - colour=Colours.soft_green, - title=f"Infraction {log_title}: {type_}", - thumbnail=avatar, - text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=f"ID: {id_}", - content=log_content, - ) - - return log_text - - @abstractmethod - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - raise NotImplementedError - - def schedule_expiration(self, infraction: _utils.Infraction) -> None: - """ - Marks an infraction expired after the delay from time of scheduling to time of expiration. - - At the time of expiration, the infraction is marked as inactive on the website and the - expiration task is cancelled. - """ - expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/infraction/_utils.py b/bot/cogs/moderation/infraction/_utils.py deleted file mode 100644 index fb55287b6..000000000 --- a/bot/cogs/moderation/infraction/_utils.py +++ /dev/null @@ -1,201 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext.commands import Context - -from bot.api import ResponseCodeError -from bot.constants import Colours, Icons - -log = logging.getLogger(__name__) - -# apply icon, pardon icon -INFRACTION_ICONS = { - "ban": (Icons.user_ban, Icons.user_unban), - "kick": (Icons.sign_out, None), - "mute": (Icons.user_mute, Icons.user_unmute), - "note": (Icons.user_warn, None), - "superstar": (Icons.superstarify, Icons.unsuperstarify), - "warning": (Icons.user_warn, None), -} -RULES_URL = "https://pythondiscord.com/pages/rules" -APPEALABLE_INFRACTIONS = ("ban", "mute") - -# Type aliases -UserObject = t.Union[discord.Member, discord.User] -UserSnowflake = t.Union[UserObject, discord.Object] -Infraction = t.Dict[str, t.Union[str, int, bool]] - - -async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: - """ - Create a new user in the database. - - Used when an infraction needs to be applied on a user absent in the guild. - """ - log.trace(f"Attempting to add user {user.id} to the database.") - - if not isinstance(user, (discord.Member, discord.User)): - log.debug("The user being added to the DB is not a Member or User object.") - - payload = { - 'discriminator': int(getattr(user, 'discriminator', 0)), - 'id': user.id, - 'in_guild': False, - 'name': getattr(user, 'name', 'Name unknown'), - 'roles': [] - } - - try: - response = await ctx.bot.api_client.post('bot/users', json=payload) - log.info(f"User {user.id} added to the DB.") - return response - except ResponseCodeError as e: - log.error(f"Failed to add user {user.id} to the DB. {e}") - await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") - - -async def post_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - reason: str, - expires_at: datetime = None, - hidden: bool = False, - active: bool = True -) -> t.Optional[dict]: - """Posts an infraction to the API.""" - log.trace(f"Posting {infr_type} infraction for {user} to the API.") - - payload = { - "actor": ctx.message.author.id, - "hidden": hidden, - "reason": reason, - "type": infr_type, - "user": user.id, - "active": active - } - if expires_at: - payload['expires_at'] = expires_at.isoformat() - - # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. - for should_post_user in (True, False): - try: - response = await ctx.bot.api_client.post('bot/infractions', json=payload) - return response - except ResponseCodeError as e: - if e.status == 400 and 'user' in e.response_json: - # Only one attempt to add the user to the database, not two: - if not should_post_user or await post_user(ctx, user) is None: - return - else: - log.exception(f"Unexpected error while adding an infraction for {user}:") - await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") - return - - -async def get_active_infraction( - ctx: Context, - user: UserSnowflake, - infr_type: str, - send_msg: bool = True -) -> t.Optional[dict]: - """ - Retrieves an active infraction of the given type for the user. - - If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, - then a message for the moderator will be sent to the context channel letting them know. - Otherwise, no message will be sent. - """ - log.trace(f"Checking if {user} has active infractions of type {infr_type}.") - - active_infractions = await ctx.bot.api_client.get( - 'bot/infractions', - params={ - 'active': 'true', - 'type': infr_type, - 'user__id': str(user.id) - } - ) - if active_infractions: - # Checks to see if the moderator should be told there is an active infraction - if send_msg: - log.trace(f"{user} has active infractions of type {infr_type}.") - await ctx.send( - f":x: According to my records, this user already has a {infr_type} infraction. " - f"See infraction **#{active_infractions[0]['id']}**." - ) - return active_infractions[0] - else: - log.trace(f"{user} does not have active infractions of type {infr_type}.") - - -async def notify_infraction( - user: UserObject, - infr_type: str, - expires_at: t.Optional[str] = None, - reason: t.Optional[str] = None, - icon_url: str = Icons.token_removed -) -> bool: - """DM a user about their new infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their {infr_type} infraction.") - - text = textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """) - - embed = discord.Embed( - description=textwrap.shorten(text, width=2048, placeholder="..."), - colour=Colours.soft_red - ) - - embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) - embed.title = f"Please review our rules over at {RULES_URL}" - embed.url = RULES_URL - - if infr_type in APPEALABLE_INFRACTIONS: - embed.set_footer( - text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" - ) - - return await send_private_embed(user, embed) - - -async def notify_pardon( - user: UserObject, - title: str, - content: str, - icon_url: str = Icons.user_verified -) -> bool: - """DM a user about their pardoned infraction and return True if the DM is successful.""" - log.trace(f"Sending {user} a DM about their pardoned infraction.") - - embed = discord.Embed( - description=content, - colour=Colours.soft_green - ) - - embed.set_author(name=title, icon_url=icon_url) - - return await send_private_embed(user, embed) - - -async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: - """ - A helper method for sending an embed to a user's DMs. - - Returns a boolean indicator of DM success. - """ - try: - await user.send(embed=embed) - return True - except (discord.HTTPException, discord.Forbidden, discord.NotFound): - log.debug( - f"Infraction-related information could not be sent to user {user} ({user.id}). " - "The user either could not be retrieved or probably disabled their DMs." - ) - return False diff --git a/bot/cogs/moderation/infraction/infractions.py b/bot/cogs/moderation/infraction/infractions.py deleted file mode 100644 index cb459b447..000000000 --- a/bot/cogs/moderation/infraction/infractions.py +++ /dev/null @@ -1,375 +0,0 @@ -import logging -import textwrap -import typing as t - -import discord -from discord import Member -from discord.ext import commands -from discord.ext.commands import Context, command - -from bot import constants -from bot.bot import Bot -from bot.constants import Event -from bot.converters import Expiry, FetchedMember -from bot.decorators import respect_role_hierarchy -from bot.utils.checks import with_role_check -from . import _utils -from ._scheduler import InfractionScheduler -from ._utils import UserSnowflake - -log = logging.getLogger(__name__) - - -class Infractions(InfractionScheduler, commands.Cog): - """Apply and pardon infractions on users for moderation purposes.""" - - category = "Moderation" - category_description = "Server moderation tools." - - def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) - - self.category = "Moderation" - self._muted_role = discord.Object(constants.Roles.muted) - - @commands.Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Reapply active mute infractions for returning members.""" - active_mutes = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "mute", - "user__id": member.id - } - ) - - if active_mutes: - reason = f"Re-applying active mute: {active_mutes[0]['id']}" - action = member.add_roles(self._muted_role, reason=reason) - - await self.reapply_infraction(active_mutes[0], action) - - # region: Permanent infractions - - @command() - async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Warn a user for the given reason.""" - infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False) - if infraction is None: - return - - await self.apply_infraction(ctx, infraction, user) - - @command() - async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Kick a user for the given reason.""" - await self.apply_kick(ctx, user, reason) - - @command() - async def ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Permanently ban a user for the given reason and stop watching them with Big Brother.""" - await self.apply_ban(ctx, user, reason) - - # endregion - # region: Temporary infractions - - @command(aliases=["mute"]) - async def tempmute(self, ctx: Context, user: Member, duration: Expiry, *, reason: t.Optional[str] = None) -> None: - """ - Temporarily mute a user for the given reason and duration. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_mute(ctx, user, reason, expires_at=duration) - - @command() - async def tempban( - self, - ctx: Context, - user: FetchedMember, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily ban a user for the given reason and duration. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_ban(ctx, user, reason, expires_at=duration) - - # endregion - # region: Permanent shadow infractions - - @command(hidden=True) - async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Create a private note for a user with the given reason without notifying the user.""" - infraction = await _utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) - if infraction is None: - return - - await self.apply_infraction(ctx, infraction, user) - - @command(hidden=True, aliases=['shadowkick', 'skick']) - async def shadow_kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: - """Kick a user for the given reason without notifying the user.""" - await self.apply_kick(ctx, user, reason, hidden=True) - - @command(hidden=True, aliases=['shadowban', 'sban']) - async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: - """Permanently ban a user for the given reason without notifying the user.""" - await self.apply_ban(ctx, user, reason, hidden=True) - - # endregion - # region: Temporary shadow infractions - - @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) - async def shadow_tempmute( - self, ctx: Context, - user: Member, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily mute a user for the given reason and duration without notifying the user. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True) - - @command(hidden=True, aliases=["shadowtempban, stempban"]) - async def shadow_tempban( - self, - ctx: Context, - user: FetchedMember, - duration: Expiry, - *, - reason: t.Optional[str] = None - ) -> None: - """ - Temporarily ban a user for the given reason and duration without notifying the user. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - """ - await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) - - # endregion - # region: Remove infractions (un- commands) - - @command() - async def unmute(self, ctx: Context, user: FetchedMember) -> None: - """Prematurely end the active mute infraction for the user.""" - await self.pardon_infraction(ctx, "mute", user) - - @command() - async def unban(self, ctx: Context, user: FetchedMember) -> None: - """Prematurely end the active ban infraction for the user.""" - await self.pardon_infraction(ctx, "ban", user) - - # endregion - # region: Base apply functions - - async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: - """Apply a mute infraction with kwargs passed to `post_infraction`.""" - if await _utils.get_active_infraction(ctx, user, "mute"): - return - - infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_update, user.id) - - async def action() -> None: - await user.add_roles(self._muted_role, reason=reason) - - log.trace(f"Attempting to kick {user} from voice because they've been muted.") - await user.move_to(None, reason=reason) - - await self.apply_infraction(ctx, infraction, user, action()) - - @respect_role_hierarchy() - async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: - """Apply a kick infraction with kwargs passed to `post_infraction`.""" - infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_remove, user.id) - - if reason: - reason = textwrap.shorten(reason, width=512, placeholder="...") - - action = user.kick(reason=reason) - await self.apply_infraction(ctx, infraction, user, action) - - @respect_role_hierarchy() - async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: - """ - Apply a ban infraction with kwargs passed to `post_infraction`. - - Will also remove the banned user from the Big Brother watch list if applicable. - """ - # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active - is_temporary = kwargs.get("expires_at") is not None - active_infraction = await _utils.get_active_infraction(ctx, user, "ban", is_temporary) - - if active_infraction: - if is_temporary: - log.trace("Tempban ignored as it cannot overwrite an active ban.") - return - - if active_infraction.get('expires_at') is None: - log.trace("Permaban already exists, notify.") - await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).") - return - - log.trace("Old tempban is being replaced by new permaban.") - await self.pardon_infraction(ctx, "ban", user, is_temporary) - - infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) - if infraction is None: - return - - self.mod_log.ignore(Event.member_remove, user.id) - - if reason: - reason = textwrap.shorten(reason, width=512, placeholder="...") - - action = ctx.guild.ban(user, reason=reason, delete_message_days=0) - await self.apply_infraction(ctx, infraction, user, action) - - if infraction.get('expires_at') is not None: - log.trace(f"Ban isn't permanent; user {user} won't be unwatched by Big Brother.") - return - - bb_cog = self.bot.get_cog("Big Brother") - if not bb_cog: - log.error(f"Big Brother cog not loaded; perma-banned user {user} won't be unwatched.") - return - - log.trace(f"Big Brother cog loaded; attempting to unwatch perma-banned user {user}.") - - bb_reason = "User has been permanently banned from the server. Automatically removed." - await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) - - # endregion - # region: Base pardon functions - - async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's muted role, DM them a notification, and return a log dict.""" - user = guild.get_member(user_id) - log_text = {} - - if user: - # Remove the muted role. - self.mod_log.ignore(Event.member_update, user.id) - await user.remove_roles(self._muted_role, reason=reason) - - # DM the user about the expiration. - notified = await _utils.notify_pardon( - user=user, - title="You have been unmuted", - content="You may now send messages in the server.", - icon_url=_utils.INFRACTION_ICONS["mute"][1] - ) - - log_text["Member"] = f"{user.mention}(`{user.id}`)" - log_text["DM"] = "Sent" if notified else "**Failed**" - else: - log.info(f"Failed to unmute user {user_id}: user not found") - log_text["Failure"] = "User was not found in the guild." - - return log_text - - async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's ban on the Discord guild and return a log dict.""" - user = discord.Object(user_id) - log_text = {} - - self.mod_log.ignore(Event.member_unban, user_id) - - try: - await guild.unban(user, reason=reason) - except discord.NotFound: - log.info(f"Failed to unban user {user_id}: no active ban found on Discord") - log_text["Note"] = "No active ban found on Discord." - - return log_text - - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """ - Execute deactivation steps specific to the infraction's type and return a log dict. - - If an infraction type is unsupported, return None instead. - """ - guild = self.bot.get_guild(constants.Guild.id) - user_id = infraction["user"] - reason = f"Infraction #{infraction['id']} expired or was pardoned." - - if infraction["type"] == "mute": - return await self.pardon_mute(user_id, guild, reason) - elif infraction["type"] == "ban": - return await self.pardon_ban(user_id, guild, reason) - - # endregion - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Send a notification to the invoking context on a Union failure.""" - if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters or discord.Member in error.converters: - await ctx.send(str(error.errors[0])) - error.handled = True - - -def setup(bot: Bot) -> None: - """Load the Infractions cog.""" - bot.add_cog(Infractions(bot)) diff --git a/bot/cogs/moderation/infraction/management.py b/bot/cogs/moderation/infraction/management.py deleted file mode 100644 index 9e7ae8113..000000000 --- a/bot/cogs/moderation/infraction/management.py +++ /dev/null @@ -1,310 +0,0 @@ -import logging -import textwrap -import typing as t -from datetime import datetime - -import discord -from discord.ext import commands -from discord.ext.commands import Context - -from bot import constants -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user -from bot.pagination import LinePaginator -from bot.utils import time -from bot.utils.checks import in_whitelist_check, with_role_check -from . import _utils -from .infractions import Infractions - -log = logging.getLogger(__name__) - - -class ModManagement(commands.Cog): - """Management of infractions.""" - - category = "Moderation" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @property - def infractions_cog(self) -> Infractions: - """Get currently loaded Infractions cog instance.""" - return self.bot.get_cog("Infractions") - - # region: Edit infraction commands - - @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) - async def infraction_group(self, ctx: Context) -> None: - """Infraction manipulation commands.""" - await ctx.send_help(ctx.command) - - @infraction_group.command(name='edit') - async def infraction_edit( - self, - ctx: Context, - infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821 - duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821 - *, - reason: str = None - ) -> None: - """ - Edit the duration and/or the reason of an infraction. - - Durations are relative to the time of updating and should be appended with a unit of time. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction - authored by the command invoker should be edited. - - Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 - timestamp can be provided for the duration. - """ - if duration is None and reason is None: - # Unlike UserInputError, the error handler will show a specified message for BadArgument - raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") - - # Retrieve the previous infraction for its information. - if isinstance(infraction_id, str): - params = { - "actor__id": ctx.author.id, - "ordering": "-inserted_at" - } - infractions = await self.bot.api_client.get("bot/infractions", params=params) - - if infractions: - old_infraction = infractions[0] - infraction_id = old_infraction["id"] - else: - await ctx.send( - ":x: Couldn't find most recent infraction; you have never given an infraction." - ) - return - else: - old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") - - request_data = {} - confirm_messages = [] - log_text = "" - - if duration is not None and not old_infraction['active']: - if reason is None: - await ctx.send(":x: Cannot edit the expiration of an expired infraction.") - return - confirm_messages.append("expiry unchanged (infraction already expired)") - elif isinstance(duration, str): - request_data['expires_at'] = None - confirm_messages.append("marked as permanent") - elif duration is not None: - request_data['expires_at'] = duration.isoformat() - expiry = time.format_infraction_with_duration(request_data['expires_at']) - confirm_messages.append(f"set to expire on {expiry}") - else: - confirm_messages.append("expiry unchanged") - - if reason: - request_data['reason'] = reason - confirm_messages.append("set a new reason") - log_text += f""" - Previous reason: {old_infraction['reason']} - New reason: {reason} - """.rstrip() - else: - confirm_messages.append("reason unchanged") - - # Update the infraction - new_infraction = await self.bot.api_client.patch( - f'bot/infractions/{infraction_id}', - json=request_data, - ) - - # Re-schedule infraction if the expiration has been updated - if 'expires_at' in request_data: - # A scheduled task should only exist if the old infraction wasn't permanent - if old_infraction['expires_at']: - self.infractions_cog.scheduler.cancel(new_infraction['id']) - - # If the infraction was not marked as permanent, schedule a new expiration task - if request_data['expires_at']: - self.infractions_cog.schedule_expiration(new_infraction) - - log_text += f""" - Previous expiry: {old_infraction['expires_at'] or "Permanent"} - New expiry: {new_infraction['expires_at'] or "Permanent"} - """.rstrip() - - changes = ' & '.join(confirm_messages) - await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") - - # Get information about the infraction's user - user_id = new_infraction['user'] - user = ctx.guild.get_member(user_id) - - if user: - user_text = f"{user.mention} (`{user.id}`)" - thumbnail = user.avatar_url_as(static_format="png") - else: - user_text = f"`{user_id}`" - thumbnail = None - - # The infraction's actor - actor_id = new_infraction['actor'] - actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`" - - await self.mod_log.send_log_message( - icon_url=constants.Icons.pencil, - colour=discord.Colour.blurple(), - title="Infraction edited", - thumbnail=thumbnail, - text=textwrap.dedent(f""" - Member: {user_text} - Actor: {actor} - Edited by: {ctx.message.author}{log_text} - """) - ) - - # endregion - # region: Search infractions - - @infraction_group.group(name="search", invoke_without_command=True) - async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None: - """Searches for infractions in the database.""" - if isinstance(query, discord.User): - await ctx.invoke(self.search_user, query) - else: - await ctx.invoke(self.search_reason, query) - - @infraction_search_group.command(name="user", aliases=("member", "id")) - async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: - """Search for infractions by member.""" - infraction_list = await self.bot.api_client.get( - 'bot/infractions', - params={'user__id': str(user.id)} - ) - embed = discord.Embed( - title=f"Infractions for {user} ({len(infraction_list)} total)", - colour=discord.Colour.orange() - ) - await self.send_infraction_list(ctx, embed, infraction_list) - - @infraction_search_group.command(name="reason", aliases=("match", "regex", "re")) - async def search_reason(self, ctx: Context, reason: str) -> None: - """Search for infractions by their reason. Use Re2 for matching.""" - infraction_list = await self.bot.api_client.get( - 'bot/infractions', - params={'search': reason} - ) - embed = discord.Embed( - title=f"Infractions matching `{reason}` ({len(infraction_list)} total)", - colour=discord.Colour.orange() - ) - await self.send_infraction_list(ctx, embed, infraction_list) - - # endregion - # region: Utility functions - - async def send_infraction_list( - self, - ctx: Context, - embed: discord.Embed, - infractions: t.Iterable[_utils.Infraction] - ) -> None: - """Send a paginated embed of infractions for the specified user.""" - if not infractions: - await ctx.send(":warning: No infractions could be found for that query.") - return - - lines = tuple( - self.infraction_to_string(infraction) - for infraction in infractions - ) - - await LinePaginator.paginate( - lines, - ctx=ctx, - embed=embed, - empty=True, - max_lines=3, - max_size=1000 - ) - - def infraction_to_string(self, infraction: _utils.Infraction) -> str: - """Convert the infraction object to a string representation.""" - actor_id = infraction["actor"] - guild = self.bot.get_guild(constants.Guild.id) - actor = guild.get_member(actor_id) - active = infraction["active"] - user_id = infraction["user"] - hidden = infraction["hidden"] - created = time.format_infraction(infraction["inserted_at"]) - - if active: - remaining = time.until_expiration(infraction["expires_at"]) or "Expired" - else: - remaining = "Inactive" - - if infraction["expires_at"] is None: - expires = "*Permanent*" - else: - date_from = datetime.strptime(created, time.INFRACTION_FORMAT) - expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) - - lines = textwrap.dedent(f""" - {"**===============**" if active else "==============="} - Status: {"__**Active**__" if active else "Inactive"} - User: {self.bot.get_user(user_id)} (`{user_id}`) - Type: **{infraction["type"]}** - Shadow: {hidden} - Created: {created} - Expires: {expires} - Remaining: {remaining} - Actor: {actor.mention if actor else actor_id} - ID: `{infraction["id"]}` - Reason: {infraction["reason"] or "*None*"} - {"**===============**" if active else "==============="} - """) - - return lines.strip() - - # endregion - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators inside moderator channels to invoke the commands in this cog.""" - checks = [ - with_role_check(ctx, *constants.MODERATION_ROLES), - in_whitelist_check( - ctx, - channels=constants.MODERATION_CHANNELS, - categories=[constants.Categories.modmail], - redirect=None, - fail_silently=True, - ) - ] - return all(checks) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Send a notification to the invoking context on a Union failure.""" - if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters: - await ctx.send(str(error.errors[0])) - error.handled = True - - -def setup(bot: Bot) -> None: - """Load the ModManagement cog.""" - bot.add_cog(ModManagement(bot)) diff --git a/bot/cogs/moderation/infraction/superstarify.py b/bot/cogs/moderation/infraction/superstarify.py deleted file mode 100644 index 7dc5b4691..000000000 --- a/bot/cogs/moderation/infraction/superstarify.py +++ /dev/null @@ -1,244 +0,0 @@ -import json -import logging -import random -import textwrap -import typing as t -from pathlib import Path - -from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, command - -from bot import constants -from bot.bot import Bot -from bot.converters import Expiry -from bot.utils.checks import with_role_check -from bot.utils.time import format_infraction -from . import _utils -from ._scheduler import InfractionScheduler - -log = logging.getLogger(__name__) -NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" - -with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file: - STAR_NAMES = json.load(stars_file) - - -class Superstarify(InfractionScheduler, Cog): - """A set of commands to moderate terrible nicknames.""" - - def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"superstar"}) - - @Cog.listener() - async def on_member_update(self, before: Member, after: Member) -> None: - """Revert nickname edits if the user has an active superstarify infraction.""" - if before.display_name == after.display_name: - return # User didn't change their nickname. Abort! - - log.trace( - f"{before} ({before.display_name}) is trying to change their nickname to " - f"{after.display_name}. Checking if the user is in superstar-prison..." - ) - - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "superstar", - "user__id": str(before.id) - } - ) - - if not active_superstarifies: - log.trace(f"{before} has no active superstar infractions.") - return - - infraction = active_superstarifies[0] - forced_nick = self.get_nick(infraction["id"], before.id) - if after.display_name == forced_nick: - return # Nick change was triggered by this event. Ignore. - - log.info( - f"{after.display_name} ({after.id}) tried to escape superstar prison. " - f"Changing the nick back to {before.display_name}." - ) - await after.edit( - nick=forced_nick, - reason=f"Superstarified member tried to escape the prison: {infraction['id']}" - ) - - notified = await _utils.notify_infraction( - user=after, - infr_type="Superstarify", - expires_at=format_infraction(infraction["expires_at"]), - reason=( - "You have tried to change your nickname on the **Python Discord** server " - f"from **{before.display_name}** to **{after.display_name}**, but as you " - "are currently in superstar-prison, you do not have permission to do so." - ), - icon_url=_utils.INFRACTION_ICONS["superstar"][0] - ) - - if not notified: - log.info("Failed to DM user about why they cannot change their nickname.") - - @Cog.listener() - async def on_member_join(self, member: Member) -> None: - """Reapply active superstar infractions for returning members.""" - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ - "active": "true", - "type": "superstar", - "user__id": member.id - } - ) - - if active_superstarifies: - infraction = active_superstarifies[0] - action = member.edit( - nick=self.get_nick(infraction["id"], member.id), - reason=f"Superstarified member tried to escape the prison: {infraction['id']}" - ) - - await self.reapply_infraction(infraction, action) - - @command(name="superstarify", aliases=("force_nick", "star")) - async def superstarify( - self, - ctx: Context, - member: Member, - duration: Expiry, - *, - reason: str = None, - ) -> None: - """ - Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. - - A unit of time should be appended to the duration. - Units (∗case-sensitive): - \u2003`y` - years - \u2003`m` - months∗ - \u2003`w` - weeks - \u2003`d` - days - \u2003`h` - hours - \u2003`M` - minutes∗ - \u2003`s` - seconds - - Alternatively, an ISO 8601 timestamp can be provided for the duration. - - An optional reason can be provided. If no reason is given, the original name will be shown - in a generated reason. - """ - if await _utils.get_active_infraction(ctx, member, "superstar"): - return - - # Post the infraction to the API - reason = reason or f"old nick: {member.display_name}" - infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) - id_ = infraction["id"] - - old_nick = member.display_name - forced_nick = self.get_nick(id_, member.id) - expiry_str = format_infraction(infraction["expires_at"]) - - # Apply the infraction and schedule the expiration task. - log.debug(f"Changing nickname of {member} to {forced_nick}.") - self.mod_log.ignore(constants.Event.member_update, member.id) - await member.edit(nick=forced_nick, reason=reason) - self.schedule_expiration(infraction) - - # Send a DM to the user to notify them of their new infraction. - await _utils.notify_infraction( - user=member, - infr_type="Superstarify", - expires_at=expiry_str, - icon_url=_utils.INFRACTION_ICONS["superstar"][0], - reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." - ) - - # Send an embed with the infraction information to the invoking context. - log.trace(f"Sending superstar #{id_} embed.") - embed = Embed( - title="Congratulations!", - colour=constants.Colours.soft_orange, - description=( - f"Your previous nickname, **{old_nick}**, " - f"was so bad that we have decided to change it. " - f"Your new nickname will be **{forced_nick}**.\n\n" - f"You will be unable to change your nickname until **{expiry_str}**.\n\n" - "If you're confused by this, please read our " - f"[official nickname policy]({NICKNAME_POLICY_URL})." - ) - ) - await ctx.send(embed=embed) - - # Log to the mod log channel. - log.trace(f"Sending apply mod log for superstar #{id_}.") - await self.mod_log.send_log_message( - icon_url=_utils.INFRACTION_ICONS["superstar"][0], - colour=Colour.gold(), - title="Member achieved superstardom", - thumbnail=member.avatar_url_as(static_format="png"), - text=textwrap.dedent(f""" - Member: {member.mention} (`{member.id}`) - Actor: {ctx.message.author} - Expires: {expiry_str} - Old nickname: `{old_nick}` - New nickname: `{forced_nick}` - Reason: {reason} - """), - footer=f"ID {id_}" - ) - - @command(name="unsuperstarify", aliases=("release_nick", "unstar")) - async def unsuperstarify(self, ctx: Context, member: Member) -> None: - """Remove the superstarify infraction and allow the user to change their nickname.""" - await self.pardon_infraction(ctx, "superstar", member) - - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """Pardon a superstar infraction and return a log dict.""" - if infraction["type"] != "superstar": - return - - guild = self.bot.get_guild(constants.Guild.id) - user = guild.get_member(infraction["user"]) - - # Don't bother sending a notification if the user left the guild. - if not user: - log.debug( - "User left the guild and therefore won't be notified about superstar " - f"{infraction['id']} pardon." - ) - return {} - - # DM the user about the expiration. - notified = await _utils.notify_pardon( - user=user, - title="You are no longer superstarified", - content="You may now change your nickname on the server.", - icon_url=_utils.INFRACTION_ICONS["superstar"][1] - ) - - return { - "Member": f"{user.mention}(`{user.id}`)", - "DM": "Sent" if notified else "**Failed**" - } - - @staticmethod - def get_nick(infraction_id: int, member_id: int) -> str: - """Randomly select a nickname from the Superstarify nickname list.""" - log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") - - rng = random.Random(str(infraction_id) + str(member_id)) - return rng.choice(STAR_NAMES) - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the Superstarify cog.""" - bot.add_cog(Superstarify(bot)) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py deleted file mode 100644 index c86f04b9d..000000000 --- a/bot/cogs/moderation/modlog.py +++ /dev/null @@ -1,837 +0,0 @@ -import asyncio -import difflib -import itertools -import logging -import typing as t -from datetime import datetime -from itertools import zip_longest - -import discord -from dateutil.relativedelta import relativedelta -from deepdiff import DeepDiff -from discord import Colour -from discord.abc import GuildChannel -from discord.ext.commands import Cog, Context -from discord.utils import escape_markdown - -from bot.bot import Bot -from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -GUILD_CHANNEL = t.Union[discord.CategoryChannel, discord.TextChannel, discord.VoiceChannel] - -CHANNEL_CHANGES_UNSUPPORTED = ("permissions",) -CHANNEL_CHANGES_SUPPRESSED = ("_overwrites", "position") -ROLE_CHANGES_UNSUPPORTED = ("colour", "permissions") - -VOICE_STATE_ATTRIBUTES = { - "channel.name": "Channel", - "self_stream": "Streaming", - "self_video": "Broadcasting", -} - - -class ModLog(Cog, name="ModLog"): - """Logging for server events and staff actions.""" - - def __init__(self, bot: Bot): - self.bot = bot - self._ignored = {event: [] for event in Event} - - self._cached_deletes = [] - self._cached_edits = [] - - async def upload_log( - self, - messages: t.Iterable[discord.Message], - actor_id: int, - attachments: t.Iterable[t.List[str]] = None - ) -> str: - """Upload message logs to the database and return a URL to a page for viewing the logs.""" - if attachments is None: - attachments = [] - - response = await self.bot.api_client.post( - 'bot/deleted-messages', - json={ - 'actor': actor_id, - 'creation': datetime.utcnow().isoformat(), - 'deletedmessage_set': [ - { - 'id': message.id, - 'author': message.author.id, - 'channel_id': message.channel.id, - 'content': message.content, - 'embeds': [embed.to_dict() for embed in message.embeds], - 'attachments': attachment, - } - for message, attachment in zip_longest(messages, attachments, fillvalue=[]) - ] - } - ) - - return f"{URLs.site_logs_view}/{response['id']}" - - def ignore(self, event: Event, *items: int) -> None: - """Add event to ignored events to suppress log emission.""" - for item in items: - if item not in self._ignored[event]: - self._ignored[event].append(item) - - async def send_log_message( - self, - icon_url: t.Optional[str], - colour: t.Union[discord.Colour, int], - title: t.Optional[str], - text: str, - thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, - channel_id: int = Channels.mod_log, - ping_everyone: bool = False, - files: t.Optional[t.List[discord.File]] = None, - content: t.Optional[str] = None, - additional_embeds: t.Optional[t.List[discord.Embed]] = None, - additional_embeds_msg: t.Optional[str] = None, - timestamp_override: t.Optional[datetime] = None, - footer: t.Optional[str] = None, - ) -> Context: - """Generate log embed and send to logging channel.""" - # Truncate string directly here to avoid removing newlines - embed = discord.Embed( - description=text[:2045] + "..." if len(text) > 2048 else text - ) - - if title and icon_url: - embed.set_author(name=title, icon_url=icon_url) - - embed.colour = colour - embed.timestamp = timestamp_override or datetime.utcnow() - - if footer: - embed.set_footer(text=footer) - - if thumbnail: - embed.set_thumbnail(url=thumbnail) - - if ping_everyone: - if content: - content = f"@everyone\n{content}" - else: - content = "@everyone" - - channel = self.bot.get_channel(channel_id) - log_message = await channel.send( - content=content, - embed=embed, - files=files, - allowed_mentions=discord.AllowedMentions(everyone=True) - ) - - if additional_embeds: - if additional_embeds_msg: - await channel.send(additional_embeds_msg) - for additional_embed in additional_embeds: - await channel.send(embed=additional_embed) - - return await self.bot.get_context(log_message) # Optionally return for use with antispam - - @Cog.listener() - async def on_guild_channel_create(self, channel: GUILD_CHANNEL) -> None: - """Log channel create event to mod log.""" - if channel.guild.id != GuildConstant.id: - return - - if isinstance(channel, discord.CategoryChannel): - title = "Category created" - message = f"{channel.name} (`{channel.id}`)" - elif isinstance(channel, discord.VoiceChannel): - title = "Voice channel created" - - if channel.category: - message = f"{channel.category}/{channel.name} (`{channel.id}`)" - else: - message = f"{channel.name} (`{channel.id}`)" - else: - title = "Text channel created" - - if channel.category: - message = f"{channel.category}/{channel.name} (`{channel.id}`)" - else: - message = f"{channel.name} (`{channel.id}`)" - - await self.send_log_message(Icons.hash_green, Colours.soft_green, title, message) - - @Cog.listener() - async def on_guild_channel_delete(self, channel: GUILD_CHANNEL) -> None: - """Log channel delete event to mod log.""" - if channel.guild.id != GuildConstant.id: - return - - if isinstance(channel, discord.CategoryChannel): - title = "Category deleted" - elif isinstance(channel, discord.VoiceChannel): - title = "Voice channel deleted" - else: - title = "Text channel deleted" - - if channel.category and not isinstance(channel, discord.CategoryChannel): - message = f"{channel.category}/{channel.name} (`{channel.id}`)" - else: - message = f"{channel.name} (`{channel.id}`)" - - await self.send_log_message( - Icons.hash_red, Colours.soft_red, - title, message - ) - - @Cog.listener() - async def on_guild_channel_update(self, before: GUILD_CHANNEL, after: GuildChannel) -> None: - """Log channel update event to mod log.""" - if before.guild.id != GuildConstant.id: - return - - if before.id in self._ignored[Event.guild_channel_update]: - self._ignored[Event.guild_channel_update].remove(before.id) - return - - # Two channel updates are sent for a single edit: 1 for topic and 1 for category change. - # TODO: remove once support is added for ignoring multiple occurrences for the same channel. - help_categories = (Categories.help_available, Categories.help_dormant, Categories.help_in_use) - if after.category and after.category.id in help_categories: - return - - diff = DeepDiff(before, after) - changes = [] - done = [] - - diff_values = diff.get("values_changed", {}) - diff_values.update(diff.get("type_changes", {})) - - for key, value in diff_values.items(): - if not key: # Not sure why, but it happens - continue - - key = key[5:] # Remove "root." prefix - - if "[" in key: - key = key.split("[", 1)[0] - - if "." in key: - key = key.split(".", 1)[0] - - if key in done or key in CHANNEL_CHANGES_SUPPRESSED: - continue - - if key in CHANNEL_CHANGES_UNSUPPORTED: - changes.append(f"**{key.title()}** updated") - else: - new = value["new_value"] - old = value["old_value"] - - # Discord does not treat consecutive backticks ("``") as an empty inline code block, so the markdown - # formatting is broken when `new` and/or `old` are empty values. "None" is used for these cases so - # formatting is preserved. - changes.append(f"**{key.title()}:** `{old or 'None'}` **→** `{new or 'None'}`") - - done.append(key) - - if not changes: - return - - message = "" - - for item in sorted(changes): - message += f"{Emojis.bullet} {item}\n" - - if after.category: - message = f"**{after.category}/#{after.name} (`{after.id}`)**\n{message}" - else: - message = f"**#{after.name}** (`{after.id}`)\n{message}" - - await self.send_log_message( - Icons.hash_blurple, Colour.blurple(), - "Channel updated", message - ) - - @Cog.listener() - async def on_guild_role_create(self, role: discord.Role) -> None: - """Log role create event to mod log.""" - if role.guild.id != GuildConstant.id: - return - - await self.send_log_message( - Icons.crown_green, Colours.soft_green, - "Role created", f"`{role.id}`" - ) - - @Cog.listener() - async def on_guild_role_delete(self, role: discord.Role) -> None: - """Log role delete event to mod log.""" - if role.guild.id != GuildConstant.id: - return - - await self.send_log_message( - Icons.crown_red, Colours.soft_red, - "Role removed", f"{role.name} (`{role.id}`)" - ) - - @Cog.listener() - async def on_guild_role_update(self, before: discord.Role, after: discord.Role) -> None: - """Log role update event to mod log.""" - if before.guild.id != GuildConstant.id: - return - - diff = DeepDiff(before, after) - changes = [] - done = [] - - diff_values = diff.get("values_changed", {}) - diff_values.update(diff.get("type_changes", {})) - - for key, value in diff_values.items(): - if not key: # Not sure why, but it happens - continue - - key = key[5:] # Remove "root." prefix - - if "[" in key: - key = key.split("[", 1)[0] - - if "." in key: - key = key.split(".", 1)[0] - - if key in done or key == "color": - continue - - if key in ROLE_CHANGES_UNSUPPORTED: - changes.append(f"**{key.title()}** updated") - else: - new = value["new_value"] - old = value["old_value"] - - changes.append(f"**{key.title()}:** `{old}` **→** `{new}`") - - done.append(key) - - if not changes: - return - - message = "" - - for item in sorted(changes): - message += f"{Emojis.bullet} {item}\n" - - message = f"**{after.name}** (`{after.id}`)\n{message}" - - await self.send_log_message( - Icons.crown_blurple, Colour.blurple(), - "Role updated", message - ) - - @Cog.listener() - async def on_guild_update(self, before: discord.Guild, after: discord.Guild) -> None: - """Log guild update event to mod log.""" - if before.id != GuildConstant.id: - return - - diff = DeepDiff(before, after) - changes = [] - done = [] - - diff_values = diff.get("values_changed", {}) - diff_values.update(diff.get("type_changes", {})) - - for key, value in diff_values.items(): - if not key: # Not sure why, but it happens - continue - - key = key[5:] # Remove "root." prefix - - if "[" in key: - key = key.split("[", 1)[0] - - if "." in key: - key = key.split(".", 1)[0] - - if key in done: - continue - - new = value["new_value"] - old = value["old_value"] - - changes.append(f"**{key.title()}:** `{old}` **→** `{new}`") - - done.append(key) - - if not changes: - return - - message = "" - - for item in sorted(changes): - message += f"{Emojis.bullet} {item}\n" - - message = f"**{after.name}** (`{after.id}`)\n{message}" - - await self.send_log_message( - Icons.guild_update, Colour.blurple(), - "Guild updated", message, - thumbnail=after.icon_url_as(format="png") - ) - - @Cog.listener() - async def on_member_ban(self, guild: discord.Guild, member: discord.Member) -> None: - """Log ban event to user log.""" - if guild.id != GuildConstant.id: - return - - if member.id in self._ignored[Event.member_ban]: - self._ignored[Event.member_ban].remove(member.id) - return - - await self.send_log_message( - Icons.user_ban, Colours.soft_red, - "User banned", f"{member} (`{member.id}`)", - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.user_log - ) - - @Cog.listener() - async def on_member_join(self, member: discord.Member) -> None: - """Log member join event to user log.""" - if member.guild.id != GuildConstant.id: - return - - member_str = escape_markdown(str(member)) - message = f"{member_str} (`{member.id}`)" - now = datetime.utcnow() - difference = abs(relativedelta(now, member.created_at)) - - message += "\n\n**Account age:** " + humanize_delta(difference) - - if difference.days < 1 and difference.months < 1 and difference.years < 1: # New user account! - message = f"{Emojis.new} {message}" - - await self.send_log_message( - Icons.sign_in, Colours.soft_green, - "User joined", message, - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.user_log - ) - - @Cog.listener() - async def on_member_remove(self, member: discord.Member) -> None: - """Log member leave event to user log.""" - if member.guild.id != GuildConstant.id: - return - - if member.id in self._ignored[Event.member_remove]: - self._ignored[Event.member_remove].remove(member.id) - return - - member_str = escape_markdown(str(member)) - await self.send_log_message( - Icons.sign_out, Colours.soft_red, - "User left", f"{member_str} (`{member.id}`)", - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.user_log - ) - - @Cog.listener() - async def on_member_unban(self, guild: discord.Guild, member: discord.User) -> None: - """Log member unban event to mod log.""" - if guild.id != GuildConstant.id: - return - - if member.id in self._ignored[Event.member_unban]: - self._ignored[Event.member_unban].remove(member.id) - return - - member_str = escape_markdown(str(member)) - await self.send_log_message( - Icons.user_unban, Colour.blurple(), - "User unbanned", f"{member_str} (`{member.id}`)", - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.mod_log - ) - - @staticmethod - def get_role_diff(before: t.List[discord.Role], after: t.List[discord.Role]) -> t.List[str]: - """Return a list of strings describing the roles added and removed.""" - changes = [] - before_roles = set(before) - after_roles = set(after) - - for role in (before_roles - after_roles): - changes.append(f"**Role removed:** {role.name} (`{role.id}`)") - - for role in (after_roles - before_roles): - changes.append(f"**Role added:** {role.name} (`{role.id}`)") - - return changes - - @Cog.listener() - async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: - """Log member update event to user log.""" - if before.guild.id != GuildConstant.id: - return - - if before.id in self._ignored[Event.member_update]: - self._ignored[Event.member_update].remove(before.id) - return - - changes = self.get_role_diff(before.roles, after.roles) - - # The regex is a simple way to exclude all sequence and mapping types. - diff = DeepDiff(before, after, exclude_regex_paths=r".*\[.*") - - # A type change seems to always take precedent over a value change. Furthermore, it will - # include the value change along with the type change anyway. Therefore, it's OK to - # "overwrite" values_changed; in practice there will never even be anything to overwrite. - diff_values = {**diff.get("values_changed", {}), **diff.get("type_changes", {})} - - for attr, value in diff_values.items(): - if not attr: # Not sure why, but it happens. - continue - - attr = attr[5:] # Remove "root." prefix. - attr = attr.replace("_", " ").replace(".", " ").capitalize() - - new = value.get("new_value") - old = value.get("old_value") - - changes.append(f"**{attr}:** `{old}` **→** `{new}`") - - if not changes: - return - - message = "" - - for item in sorted(changes): - message += f"{Emojis.bullet} {item}\n" - - member_str = escape_markdown(str(after)) - message = f"**{member_str}** (`{after.id}`)\n{message}" - - await self.send_log_message( - icon_url=Icons.user_update, - colour=Colour.blurple(), - title="Member updated", - text=message, - thumbnail=after.avatar_url_as(static_format="png"), - channel_id=Channels.user_log - ) - - @Cog.listener() - async def on_message_delete(self, message: discord.Message) -> None: - """Log message delete event to message change log.""" - channel = message.channel - author = message.author - - # Ignore DMs. - if not message.guild: - return - - if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist: - return - - self._cached_deletes.append(message.id) - - if message.id in self._ignored[Event.message_delete]: - self._ignored[Event.message_delete].remove(message.id) - return - - if author.bot: - return - - author_str = escape_markdown(str(author)) - if channel.category: - response = ( - f"**Author:** {author_str} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - ) - else: - response = ( - f"**Author:** {author_str} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - ) - - if message.attachments: - # Prepend the message metadata with the number of attachments - response = f"**Attachments:** {len(message.attachments)}\n" + response - - # Shorten the message content if necessary - content = message.clean_content - remaining_chars = 2040 - len(response) - - if len(content) > remaining_chars: - botlog_url = await self.upload_log(messages=[message], actor_id=message.author.id) - ending = f"\n\nMessage truncated, [full message here]({botlog_url})." - truncation_point = remaining_chars - len(ending) - content = f"{content[:truncation_point]}...{ending}" - - response += f"{content}" - - await self.send_log_message( - Icons.message_delete, Colours.soft_red, - "Message deleted", - response, - channel_id=Channels.message_log - ) - - @Cog.listener() - async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: - """Log raw message delete event to message change log.""" - if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.modlog_blacklist: - return - - await asyncio.sleep(1) # Wait here in case the normal event was fired - - if event.message_id in self._cached_deletes: - # It was in the cache and the normal event was fired, so we can just ignore it - self._cached_deletes.remove(event.message_id) - return - - if event.message_id in self._ignored[Event.message_delete]: - self._ignored[Event.message_delete].remove(event.message_id) - return - - channel = self.bot.get_channel(event.channel_id) - - if channel.category: - response = ( - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{event.message_id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - else: - response = ( - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{event.message_id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - - await self.send_log_message( - Icons.message_delete, Colours.soft_red, - "Message deleted", - response, - channel_id=Channels.message_log - ) - - @Cog.listener() - async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: - """Log message edit event to message change log.""" - if ( - not msg_before.guild - or msg_before.guild.id != GuildConstant.id - or msg_before.channel.id in GuildConstant.modlog_blacklist - or msg_before.author.bot - ): - return - - self._cached_edits.append(msg_before.id) - - if msg_before.content == msg_after.content: - return - - author = msg_before.author - author_str = escape_markdown(str(author)) - - channel = msg_before.channel - channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" - - # Getting the difference per words and group them by type - add, remove, same - # Note that this is intended grouping without sorting - diff = difflib.ndiff(msg_before.clean_content.split(), msg_after.clean_content.split()) - diff_groups = tuple( - (diff_type, tuple(s[2:] for s in diff_words)) - for diff_type, diff_words in itertools.groupby(diff, key=lambda s: s[0]) - ) - - content_before: t.List[str] = [] - content_after: t.List[str] = [] - - for index, (diff_type, words) in enumerate(diff_groups): - sub = ' '.join(words) - if diff_type == '-': - content_before.append(f"[{sub}](http://o.hi)") - elif diff_type == '+': - content_after.append(f"[{sub}](http://o.hi)") - elif diff_type == ' ': - if len(words) > 2: - sub = ( - f"{words[0] if index > 0 else ''}" - " ... " - f"{words[-1] if index < len(diff_groups) - 1 else ''}" - ) - content_before.append(sub) - content_after.append(sub) - - response = ( - f"**Author:** {author_str} (`{author.id}`)\n" - f"**Channel:** {channel_name} (`{channel.id}`)\n" - f"**Message ID:** `{msg_before.id}`\n" - "\n" - f"**Before**:\n{' '.join(content_before)}\n" - f"**After**:\n{' '.join(content_after)}\n" - "\n" - f"[Jump to message]({msg_after.jump_url})" - ) - - if msg_before.edited_at: - # Message was previously edited, to assist with self-bot detection, use the edited_at - # datetime as the baseline and create a human-readable delta between this edit event - # and the last time the message was edited - timestamp = msg_before.edited_at - delta = humanize_delta(relativedelta(msg_after.edited_at, msg_before.edited_at)) - footer = f"Last edited {delta} ago" - else: - # Message was not previously edited, use the created_at datetime as the baseline, no - # delta calculation needed - timestamp = msg_before.created_at - footer = None - - await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited", response, - channel_id=Channels.message_log, timestamp_override=timestamp, footer=footer - ) - - @Cog.listener() - async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: - """Log raw message edit event to message change log.""" - try: - channel = self.bot.get_channel(int(event.data["channel_id"])) - message = await channel.fetch_message(event.message_id) - except discord.NotFound: # Was deleted before we got the event - return - - if ( - not message.guild - or message.guild.id != GuildConstant.id - or message.channel.id in GuildConstant.modlog_blacklist - or message.author.bot - ): - return - - await asyncio.sleep(1) # Wait here in case the normal event was fired - - if event.message_id in self._cached_edits: - # It was in the cache and the normal event was fired, so we can just ignore it - self._cached_edits.remove(event.message_id) - return - - author = message.author - channel = message.channel - channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" - - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel_name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel_name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - f"{message.clean_content}" - ) - - await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited (Before)", - before_response, channel_id=Channels.message_log - ) - - await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited (After)", - after_response, channel_id=Channels.message_log - ) - - @Cog.listener() - async def on_voice_state_update( - self, - member: discord.Member, - before: discord.VoiceState, - after: discord.VoiceState - ) -> None: - """Log member voice state changes to the voice log channel.""" - if ( - member.guild.id != GuildConstant.id - or (before.channel and before.channel.id in GuildConstant.modlog_blacklist) - ): - return - - if member.id in self._ignored[Event.voice_state_update]: - self._ignored[Event.voice_state_update].remove(member.id) - return - - # Exclude all channel attributes except the name. - diff = DeepDiff( - before, - after, - exclude_paths=("root.session_id", "root.afk"), - exclude_regex_paths=r"root\.channel\.(?!name)", - ) - - # A type change seems to always take precedent over a value change. Furthermore, it will - # include the value change along with the type change anyway. Therefore, it's OK to - # "overwrite" values_changed; in practice there will never even be anything to overwrite. - diff_values = {**diff.get("values_changed", {}), **diff.get("type_changes", {})} - - icon = Icons.voice_state_blue - colour = Colour.blurple() - changes = [] - - for attr, values in diff_values.items(): - if not attr: # Not sure why, but it happens. - continue - - old = values["old_value"] - new = values["new_value"] - - attr = attr[5:] # Remove "root." prefix. - attr = VOICE_STATE_ATTRIBUTES.get(attr, attr.replace("_", " ").capitalize()) - - changes.append(f"**{attr}:** `{old}` **→** `{new}`") - - # Set the embed icon and colour depending on which attribute changed. - if any(name in attr for name in ("Channel", "deaf", "mute")): - if new is None or new is True: - # Left a channel or was muted/deafened. - icon = Icons.voice_state_red - colour = Colours.soft_red - elif old is None or old is True: - # Joined a channel or was unmuted/undeafened. - icon = Icons.voice_state_green - colour = Colours.soft_green - - if not changes: - return - - member_str = escape_markdown(str(member)) - message = "\n".join(f"{Emojis.bullet} {item}" for item in sorted(changes)) - message = f"**{member_str}** (`{member.id}`)\n{message}" - - await self.send_log_message( - icon_url=icon, - colour=colour, - title="Voice state updated", - text=message, - thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.voice_log - ) - - -def setup(bot: Bot) -> None: - """Load the ModLog cog.""" - bot.add_cog(ModLog(bot)) diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py deleted file mode 100644 index 4af87c724..000000000 --- a/bot/cogs/moderation/silence.py +++ /dev/null @@ -1,170 +0,0 @@ -import asyncio -import logging -from contextlib import suppress -from typing import Optional - -from discord import TextChannel -from discord.ext import commands, tasks -from discord.ext.commands import Context - -from bot.bot import Bot -from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles -from bot.converters import HushDurationConverter -from bot.utils.checks import with_role_check -from bot.utils.scheduling import Scheduler - -log = logging.getLogger(__name__) - - -class SilenceNotifier(tasks.Loop): - """Loop notifier for posting notices to `alert_channel` containing added channels.""" - - def __init__(self, alert_channel: TextChannel): - super().__init__(self._notifier, seconds=1, minutes=0, hours=0, count=None, reconnect=True, loop=None) - self._silenced_channels = {} - self._alert_channel = alert_channel - - def add_channel(self, channel: TextChannel) -> None: - """Add channel to `_silenced_channels` and start loop if not launched.""" - if not self._silenced_channels: - self.start() - log.info("Starting notifier loop.") - self._silenced_channels[channel] = self._current_loop - - def remove_channel(self, channel: TextChannel) -> None: - """Remove channel from `_silenced_channels` and stop loop if no channels remain.""" - with suppress(KeyError): - del self._silenced_channels[channel] - if not self._silenced_channels: - self.stop() - log.info("Stopping notifier loop.") - - async def _notifier(self) -> None: - """Post notice of `_silenced_channels` with their silenced duration to `_alert_channel` periodically.""" - # Wait for 15 minutes between notices with pause at start of loop. - if self._current_loop and not self._current_loop/60 % 15: - log.debug( - f"Sending notice with channels: " - f"{', '.join(f'#{channel} ({channel.id})' for channel in self._silenced_channels)}." - ) - channels_text = ', '.join( - f"{channel.mention} for {(self._current_loop-start)//60} min" - for channel, start in self._silenced_channels.items() - ) - await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") - - -class Silence(commands.Cog): - """Commands for stopping channel messages for `verified` role in a channel.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - self.muted_channels = set() - - self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars()) - self._get_instance_vars_event = asyncio.Event() - - async def _get_instance_vars(self) -> None: - """Get instance variables after they're available to get from the guild.""" - await self.bot.wait_until_guild_available() - guild = self.bot.get_guild(Guild.id) - self._verified_role = guild.get_role(Roles.verified) - self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) - self._mod_log_channel = self.bot.get_channel(Channels.mod_log) - self.notifier = SilenceNotifier(self._mod_log_channel) - self._get_instance_vars_event.set() - - @commands.command(aliases=("hush",)) - async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None: - """ - Silence the current channel for `duration` minutes or `forever`. - - Duration is capped at 15 minutes, passing forever makes the silence indefinite. - Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start. - """ - await self._get_instance_vars_event.wait() - log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.") - if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration): - await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.") - return - if duration is None: - await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.") - return - - await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") - - self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) - - @commands.command(aliases=("unhush",)) - async def unsilence(self, ctx: Context) -> None: - """ - Unsilence the current channel. - - If the channel was silenced indefinitely, notifications for the channel will stop. - """ - await self._get_instance_vars_event.wait() - log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") - if not await self._unsilence(ctx.channel): - await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.") - else: - await ctx.send(f"{Emojis.check_mark} unsilenced current channel.") - - async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: - """ - Silence `channel` for `self._verified_role`. - - If `persistent` is `True` add `channel` to notifier. - `duration` is only used for logging; if None is passed `persistent` should be True to not log None. - Return `True` if channel permissions were changed, `False` otherwise. - """ - current_overwrite = channel.overwrites_for(self._verified_role) - if current_overwrite.send_messages is False: - log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.") - return False - await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False)) - self.muted_channels.add(channel) - if persistent: - log.info(f"Silenced #{channel} ({channel.id}) indefinitely.") - self.notifier.add_channel(channel) - return True - - log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).") - return True - - async def _unsilence(self, channel: TextChannel) -> bool: - """ - Unsilence `channel`. - - Check if `channel` is silenced through a `PermissionOverwrite`, - if it is unsilence it and remove it from the notifier. - Return `True` if channel permissions were changed, `False` otherwise. - """ - current_overwrite = channel.overwrites_for(self._verified_role) - if current_overwrite.send_messages is False: - await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) - log.info(f"Unsilenced channel #{channel} ({channel.id}).") - self.scheduler.cancel(channel.id) - self.notifier.remove_channel(channel) - self.muted_channels.discard(channel) - return True - log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") - return False - - def cog_unload(self) -> None: - """Send alert with silenced channels and cancel scheduled tasks on unload.""" - self.scheduler.cancel_all() - if self.muted_channels: - channels_string = ''.join(channel.mention for channel in self.muted_channels) - message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}" - asyncio.create_task(self._mod_alerts_channel.send(message)) - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the Silence cog.""" - bot.add_cog(Silence(bot)) diff --git a/bot/cogs/moderation/slowmode.py b/bot/cogs/moderation/slowmode.py deleted file mode 100644 index 1d055afac..000000000 --- a/bot/cogs/moderation/slowmode.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -from datetime import datetime -from typing import Optional - -from dateutil.relativedelta import relativedelta -from discord import TextChannel -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import Emojis, MODERATION_ROLES -from bot.converters import DurationDelta -from bot.decorators import with_role_check -from bot.utils import time - -log = logging.getLogger(__name__) - -SLOWMODE_MAX_DELAY = 21600 # seconds - - -class Slowmode(Cog): - """Commands for getting and setting slowmode delays of text channels.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - @group(name='slowmode', aliases=['sm'], invoke_without_command=True) - async def slowmode_group(self, ctx: Context) -> None: - """Get or set the slowmode delay for the text channel this was invoked in or a given text channel.""" - await ctx.send_help(ctx.command) - - @slowmode_group.command(name='get', aliases=['g']) - async def get_slowmode(self, ctx: Context, channel: Optional[TextChannel]) -> None: - """Get the slowmode delay for a text channel.""" - # Use the channel this command was invoked in if one was not given - if channel is None: - channel = ctx.channel - - delay = relativedelta(seconds=channel.slowmode_delay) - humanized_delay = time.humanize_delta(delay) - - await ctx.send(f'The slowmode delay for {channel.mention} is {humanized_delay}.') - - @slowmode_group.command(name='set', aliases=['s']) - async def set_slowmode(self, ctx: Context, channel: Optional[TextChannel], delay: DurationDelta) -> None: - """Set the slowmode delay for a text channel.""" - # Use the channel this command was invoked in if one was not given - if channel is None: - channel = ctx.channel - - # Convert `dateutil.relativedelta.relativedelta` to `datetime.timedelta` - # Must do this to get the delta in a particular unit of time - utcnow = datetime.utcnow() - slowmode_delay = (utcnow + delay - utcnow).total_seconds() - - humanized_delay = time.humanize_delta(delay) - - # Ensure the delay is within discord's limits - if slowmode_delay <= SLOWMODE_MAX_DELAY: - log.info(f'{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.') - - await channel.edit(slowmode_delay=slowmode_delay) - await ctx.send( - f'{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}.' - ) - - else: - log.info( - f'{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, ' - 'which is not between 0 and 6 hours.' - ) - - await ctx.send( - f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.' - ) - - @slowmode_group.command(name='reset', aliases=['r']) - async def reset_slowmode(self, ctx: Context, channel: Optional[TextChannel]) -> None: - """Reset the slowmode delay for a text channel to 0 seconds.""" - # Use the channel this command was invoked in if one was not given - if channel is None: - channel = ctx.channel - - log.info(f'{ctx.author} reset the slowmode delay for #{channel} to 0 seconds.') - - await channel.edit(slowmode_delay=0) - await ctx.send( - f'{Emojis.check_mark} The slowmode delay for {channel.mention} has been reset to 0 seconds.' - ) - - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the Slowmode cog.""" - bot.add_cog(Slowmode(bot)) diff --git a/bot/cogs/moderation/verification.py b/bot/cogs/moderation/verification.py deleted file mode 100644 index ba95ab5e4..000000000 --- a/bot/cogs/moderation/verification.py +++ /dev/null @@ -1,191 +0,0 @@ -import logging -from contextlib import suppress - -from discord import Colour, Forbidden, Message, NotFound, Object -from discord.ext.commands import Cog, Context, command - -from bot import constants -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.decorators import in_whitelist, without_role -from bot.utils.checks import InWhitelistCheckFailure, without_role_check - -log = logging.getLogger(__name__) - -WELCOME_MESSAGE = f""" -Hello! Welcome to the server, and thanks for verifying yourself! - -For your records, these are the documents you accepted: - -`1)` Our rules, here: -`2)` Our privacy policy, here: - you can find information on how to have \ -your information removed here as well. - -Feel free to review them at any point! - -Additionally, if you'd like to receive notifications for the announcements \ -we post in <#{constants.Channels.announcements}> -from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ -to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. - -If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ -<#{constants.Channels.bot_commands}>. -""" - -BOT_MESSAGE_DELETE_DELAY = 10 - - -class Verification(Cog): - """User verification and role self-management.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - @Cog.listener() - async def on_message(self, message: Message) -> None: - """Check new message event for messages to the checkpoint channel & process.""" - if message.channel.id != constants.Channels.verification: - return # Only listen for #checkpoint messages - - if message.author.bot: - # They're a bot, delete their message after the delay. - await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) - return - - # if a user mentions a role or guild member - # alert the mods in mod-alerts channel - if message.mentions or message.role_mentions: - log.debug( - f"{message.author} mentioned one or more users " - f"and/or roles in {message.channel.name}" - ) - - embed_text = ( - f"{message.author.mention} sent a message in " - f"{message.channel.mention} that contained user and/or role mentions." - f"\n\n**Original message:**\n>>> {message.content}" - ) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=constants.Icons.filtering, - colour=Colour(constants.Colours.soft_red), - title=f"User/Role mentioned in {message.channel.name}", - text=embed_text, - thumbnail=message.author.avatar_url_as(static_format="png"), - channel_id=constants.Channels.mod_alerts, - ) - - ctx: Context = await self.bot.get_context(message) - if ctx.command is not None and ctx.command.name == "accept": - return - - if any(r.id == constants.Roles.verified for r in ctx.author.roles): - log.info( - f"{ctx.author} posted '{ctx.message.content}' " - "in the verification channel, but is already verified." - ) - return - - log.debug( - f"{ctx.author} posted '{ctx.message.content}' in the verification " - "channel. We are providing instructions how to verify." - ) - await ctx.send( - f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " - f"and gain access to the rest of the server.", - delete_after=20 - ) - - log.trace(f"Deleting the message posted by {ctx.author}") - with suppress(NotFound): - await ctx.message.delete() - - @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) - @without_role(constants.Roles.verified) - @in_whitelist(channels=(constants.Channels.verification,)) - async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Accept our rules and gain access to the rest of the server.""" - log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") - await ctx.author.add_roles(Object(constants.Roles.verified), reason="Accepted the rules") - try: - await ctx.author.send(WELCOME_MESSAGE) - except Forbidden: - log.info(f"Sending welcome message failed for {ctx.author}.") - finally: - log.trace(f"Deleting accept message by {ctx.author}.") - with suppress(NotFound): - self.mod_log.ignore(constants.Event.message_delete, ctx.message.id) - await ctx.message.delete() - - @command(name='subscribe') - @in_whitelist(channels=(constants.Channels.bot_commands,)) - async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Subscribe to announcement notifications by assigning yourself the role.""" - has_role = False - - for role in ctx.author.roles: - if role.id == constants.Roles.announcements: - has_role = True - break - - if has_role: - await ctx.send(f"{ctx.author.mention} You're already subscribed!") - return - - log.debug(f"{ctx.author} called !subscribe. Assigning the 'Announcements' role.") - await ctx.author.add_roles(Object(constants.Roles.announcements), reason="Subscribed to announcements") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - await ctx.send( - f"{ctx.author.mention} Subscribed to <#{constants.Channels.announcements}> notifications.", - ) - - @command(name='unsubscribe') - @in_whitelist(channels=(constants.Channels.bot_commands,)) - async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args - """Unsubscribe from announcement notifications by removing the role from yourself.""" - has_role = False - - for role in ctx.author.roles: - if role.id == constants.Roles.announcements: - has_role = True - break - - if not has_role: - await ctx.send(f"{ctx.author.mention} You're already unsubscribed!") - return - - log.debug(f"{ctx.author} called !unsubscribe. Removing the 'Announcements' role.") - await ctx.author.remove_roles(Object(constants.Roles.announcements), reason="Unsubscribed from announcements") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - await ctx.send( - f"{ctx.author.mention} Unsubscribed from <#{constants.Channels.announcements}> notifications." - ) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Check for & ignore any InWhitelistCheckFailure.""" - if isinstance(error, InWhitelistCheckFailure): - error.handled = True - - @staticmethod - def bot_check(ctx: Context) -> bool: - """Block any command within the verification channel that is not !accept.""" - if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): - return ctx.command.name == "accept" - else: - return True - - -def setup(bot: Bot) -> None: - """Load the Verification cog.""" - bot.add_cog(Verification(bot)) diff --git a/bot/cogs/moderation/watchchannels/__init__.py b/bot/cogs/moderation/watchchannels/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/moderation/watchchannels/_watchchannel.py b/bot/cogs/moderation/watchchannels/_watchchannel.py deleted file mode 100644 index 488ae704d..000000000 --- a/bot/cogs/moderation/watchchannels/_watchchannel.py +++ /dev/null @@ -1,348 +0,0 @@ -import asyncio -import logging -import re -import textwrap -from abc import abstractmethod -from collections import defaultdict, deque -from dataclasses import dataclass -from typing import Optional - -import dateutil.parser -import discord -from discord import Color, DMChannel, Embed, HTTPException, Message, errors -from discord.ext.commands import Cog, Context - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons -from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages -from bot.utils.time import time_since - -log = logging.getLogger(__name__) - -URL_RE = re.compile(r"(https?://[^\s]+)") - - -@dataclass -class MessageHistory: - """Represents a watch channel's message history.""" - - last_author: Optional[int] = None - last_channel: Optional[int] = None - message_count: int = 0 - - -class WatchChannel(metaclass=CogABCMeta): - """ABC with functionality for relaying users' messages to a certain channel.""" - - @abstractmethod - def __init__( - self, - bot: Bot, - destination: int, - webhook_id: int, - api_endpoint: str, - api_default_params: dict, - logger: logging.Logger - ) -> None: - self.bot = bot - - self.destination = destination # E.g., Channels.big_brother_logs - self.webhook_id = webhook_id # E.g., Webhooks.big_brother - self.api_endpoint = api_endpoint # E.g., 'bot/infractions' - self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} - self.log = logger # Logger of the child cog for a correct name in the logs - - self._consume_task = None - self.watched_users = defaultdict(dict) - self.message_queue = defaultdict(lambda: defaultdict(deque)) - self.consumption_queue = {} - self.retries = 5 - self.retry_delay = 10 - self.channel = None - self.webhook = None - self.message_history = MessageHistory() - - self._start = self.bot.loop.create_task(self.start_watchchannel()) - - @property - def modlog(self) -> ModLog: - """Provides access to the ModLog cog for alert purposes.""" - return self.bot.get_cog("ModLog") - - @property - def consuming_messages(self) -> bool: - """Checks if a consumption task is currently running.""" - if self._consume_task is None: - return False - - if self._consume_task.done(): - exc = self._consume_task.exception() - if exc: - self.log.exception( - "The message queue consume task has failed with:", - exc_info=exc - ) - return False - - return True - - async def start_watchchannel(self) -> None: - """Starts the watch channel by getting the channel, webhook, and user cache ready.""" - await self.bot.wait_until_guild_available() - - try: - self.channel = await self.bot.fetch_channel(self.destination) - except HTTPException: - self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - if self.channel is None or self.webhook is None: - self.log.error("Failed to start the watch channel; unloading the cog.") - - message = textwrap.dedent( - f""" - An error occurred while loading the text channel or webhook. - - TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} - Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} - - The Cog has been unloaded. - """ - ) - - await self.modlog.send_log_message( - title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", - text=message, - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - self.bot.remove_cog(self.__class__.__name__) - return - - if not await self.fetch_user_cache(): - await self.modlog.send_log_message( - title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", - text="Could not retrieve the list of watched users from the API and messages will not be relayed.", - ping_everyone=True, - icon_url=Icons.token_removed, - colour=Color.red() - ) - - async def fetch_user_cache(self) -> bool: - """ - Fetches watched users from the API and updates the watched user cache accordingly. - - This function returns `True` if the update succeeded. - """ - try: - data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) - except ResponseCodeError as err: - self.log.exception("Failed to fetch the watched users from the API", exc_info=err) - return False - - self.watched_users = defaultdict(dict) - - for entry in data: - user_id = entry.pop('user') - self.watched_users[user_id] = entry - - return True - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """Queues up messages sent by watched users.""" - if msg.author.id in self.watched_users: - if not self.consuming_messages: - self._consume_task = self.bot.loop.create_task(self.consume_messages()) - - self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") - self.message_queue[msg.author.id][msg.channel.id].append(msg) - - async def consume_messages(self, delay_consumption: bool = True) -> None: - """Consumes the message queues to log watched users' messages.""" - if delay_consumption: - self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") - await asyncio.sleep(BigBrotherConfig.log_delay) - - self.log.trace("Started consuming the message queue") - - # If the previous consumption Task failed, first consume the existing comsumption_queue - if not self.consumption_queue: - self.consumption_queue = self.message_queue.copy() - self.message_queue.clear() - - for user_channel_queues in self.consumption_queue.values(): - for channel_queue in user_channel_queues.values(): - while channel_queue: - msg = channel_queue.popleft() - - self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") - await self.relay_message(msg) - - self.consumption_queue.clear() - - if self.message_queue: - self.log.trace("Channel queue not empty: Continuing consuming queues") - self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) - else: - self.log.trace("Done consuming messages.") - - async def webhook_send( - self, - content: Optional[str] = None, - username: Optional[str] = None, - avatar_url: Optional[str] = None, - embed: Optional[Embed] = None, - ) -> None: - """Sends a message to the webhook with the specified kwargs.""" - username = messages.sub_clyde(username) - try: - await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send a message to the webhook", - exc_info=exc - ) - - async def relay_message(self, msg: Message) -> None: - """Relays the message to the relevant watch channel.""" - limit = BigBrotherConfig.header_message_limit - - if ( - msg.author.id != self.message_history.last_author - or msg.channel.id != self.message_history.last_channel - or self.message_history.message_count >= limit - ): - self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) - - await self.send_header(msg) - - cleaned_content = msg.clean_content - - if cleaned_content: - # Put all non-media URLs in a code block to prevent embeds - media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} - for url in URL_RE.findall(cleaned_content): - if url not in media_urls: - cleaned_content = cleaned_content.replace(url, f"`{url}`") - await self.webhook_send( - cleaned_content, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - - if msg.attachments: - try: - await messages.send_attachments(msg, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await self.webhook_send( - embed=e, - username=msg.author.display_name, - avatar_url=msg.author.avatar_url - ) - except discord.HTTPException as exc: - self.log.exception( - "Failed to send an attachment to the webhook", - exc_info=exc - ) - - self.message_history.message_count += 1 - - async def send_header(self, msg: Message) -> None: - """Sends a header embed with information about the relayed messages to the watch channel.""" - user_id = msg.author.id - - guild = self.bot.get_guild(GuildConfig.id) - actor = guild.get_member(self.watched_users[user_id]['actor']) - actor = actor.display_name if actor else self.watched_users[user_id]['actor'] - - inserted_at = self.watched_users[user_id]['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - - reason = self.watched_users[user_id]['reason'] - - if isinstance(msg.channel, DMChannel): - # If a watched user DMs the bot there won't be a channel name or jump URL - # This could technically include a GroupChannel but bot's can't be in those - message_jump = "via DM" - else: - message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" - - footer = f"Added {time_delta} by {actor} | Reason: {reason}" - embed = Embed(description=f"{msg.author.mention} {message_jump}") - embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) - - await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) - - async def list_watched_users( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Gives an overview of the watched user list for this channel. - - The optional kwarg `oldest_first` orders the list by oldest entry. - - The optional kwarg `update_cache` specifies whether the cache should - be refreshed by polling the API. - """ - if update_cache: - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") - update_cache = False - - lines = [] - for user_id, user_data in self.watched_users.items(): - inserted_at = user_data['inserted_at'] - time_delta = self._get_time_delta(inserted_at) - lines.append(f"• <@{user_id}> (added {time_delta})") - - if oldest_first: - lines.reverse() - - lines = lines or ("There's nothing here yet.",) - - embed = Embed( - title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", - color=Color.blue() - ) - await LinePaginator.paginate(lines, ctx, embed, empty=False) - - @staticmethod - def _get_time_delta(time_string: str) -> str: - """Returns the time in human-readable time delta format.""" - date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) - time_delta = time_since(date_time, precision="minutes", max_units=1) - - return time_delta - - def _remove_user(self, user_id: int) -> None: - """Removes a user from a watch channel.""" - self.watched_users.pop(user_id, None) - self.message_queue.pop(user_id, None) - self.consumption_queue.pop(user_id, None) - - def cog_unload(self) -> None: - """Takes care of unloading the cog and canceling the consumption task.""" - self.log.trace("Unloading the cog") - if self._consume_task and not self._consume_task.done(): - self._consume_task.cancel() - try: - self._consume_task.result() - except asyncio.CancelledError as e: - self.log.exception( - "The consume task was canceled. Messages may be lost.", - exc_info=e - ) diff --git a/bot/cogs/moderation/watchchannels/bigbrother.py b/bot/cogs/moderation/watchchannels/bigbrother.py deleted file mode 100644 index 7db34bcf2..000000000 --- a/bot/cogs/moderation/watchchannels/bigbrother.py +++ /dev/null @@ -1,170 +0,0 @@ -import logging -import textwrap -from collections import ChainMap - -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation.infraction._utils import post_infraction -from bot.constants import Channels, MODERATION_ROLES, Webhooks -from bot.converters import FetchedMember -from bot.decorators import with_role -from ._watchchannel import WatchChannel - -log = logging.getLogger(__name__) - - -class BigBrother(WatchChannel, Cog, name="Big Brother"): - """Monitors users by relaying their messages to a watch channel to assist with moderation.""" - - def __init__(self, bot: Bot) -> None: - super().__init__( - bot, - destination=Channels.big_brother_logs, - webhook_id=Webhooks.big_brother, - api_endpoint='bot/infractions', - api_default_params={'active': 'true', 'type': 'watch', 'ordering': '-inserted_at'}, - logger=log - ) - - @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def bigbrother_group(self, ctx: Context) -> None: - """Monitors users by relaying their messages to the Big Brother watch channel.""" - await ctx.send_help(ctx.command) - - @bigbrother_group.command(name='watched', aliases=('all', 'list')) - @with_role(*MODERATION_ROLES) - async def watched_command( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Shows the users that are currently being monitored by Big Brother. - - The optional kwarg `oldest_first` can be used to order the list by oldest watched. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) - - @bigbrother_group.command(name='oldest') - @with_role(*MODERATION_ROLES) - async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: - """ - Shows Big Brother monitored users ordered by oldest watched. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) - - @bigbrother_group.command(name='watch', aliases=('w',)) - @with_role(*MODERATION_ROLES) - async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Relay messages sent by the given `user` to the `#big-brother` channel. - - A `reason` for adding the user to Big Brother is required and will be displayed - in the header when relaying messages of this user to the watchchannel. - """ - await self.apply_watch(ctx, user, reason) - - @bigbrother_group.command(name='unwatch', aliases=('uw',)) - @with_role(*MODERATION_ROLES) - async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Stop relaying messages by the given `user`.""" - await self.apply_unwatch(ctx, user, reason) - - async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: - """ - Add `user` to watched users and apply a watch infraction with `reason`. - - A message indicating the result of the operation is sent to `ctx`. - The message will include `user`'s previous watch infraction history, if it exists. - """ - if user.bot: - await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") - return - - if not await self.fetch_user_cache(): - await ctx.send(f":x: Updating the user cache failed, can't watch user {user}") - return - - if user.id in self.watched_users: - await ctx.send(f":x: {user} is already being watched.") - return - - response = await post_infraction(ctx, user, 'watch', reason, hidden=True, active=True) - - if response is not None: - self.watched_users[user.id] = response - msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother." - - history = await self.bot.api_client.get( - self.api_endpoint, - params={ - "user__id": str(user.id), - "active": "false", - 'type': 'watch', - 'ordering': '-inserted_at' - } - ) - - if len(history) > 1: - total = f"({len(history) // 2} previous infractions in total)" - end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") - start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" - msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" - else: - msg = ":x: Failed to post the infraction: response was empty." - - await ctx.send(msg) - - async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: - """ - Remove `user` from watched users and mark their infraction as inactive with `reason`. - - If `send_message` is True, a message indicating the result of the operation is sent to - `ctx`. - """ - active_watches = await self.bot.api_client.get( - self.api_endpoint, - params=ChainMap( - self.api_default_params, - {"user__id": str(user.id)} - ) - ) - if active_watches: - log.trace("Active watches for user found. Attempting to remove.") - [infraction] = active_watches - - await self.bot.api_client.patch( - f"{self.api_endpoint}/{infraction['id']}", - json={'active': False} - ) - - await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False) - - self._remove_user(user.id) - - if not send_message: # Prevents a message being sent to the channel if part of a permanent ban - log.debug(f"Perma-banned user {user} was unwatched.") - return - log.trace("User is not banned. Sending message to channel") - message = f":white_check_mark: Messages sent by {user} will no longer be relayed." - - else: - log.trace("No active watches found for user.") - if not send_message: # Prevents a message being sent to the channel if part of a permanent ban - log.debug(f"{user} was not on the watch list; no removal necessary.") - return - log.trace("User is not perma banned. Send the error message.") - message = ":x: The specified user is currently not being watched." - - await ctx.send(message) - - -def setup(bot: Bot) -> None: - """Load the BigBrother cog.""" - bot.add_cog(BigBrother(bot)) diff --git a/bot/cogs/moderation/watchchannels/talentpool.py b/bot/cogs/moderation/watchchannels/talentpool.py deleted file mode 100644 index 2972f56e1..000000000 --- a/bot/cogs/moderation/watchchannels/talentpool.py +++ /dev/null @@ -1,269 +0,0 @@ -import logging -import textwrap -from collections import ChainMap - -from discord import Color, Embed, Member -from discord.ext.commands import Cog, Context, group - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks -from bot.converters import FetchedMember -from bot.decorators import with_role -from bot.pagination import LinePaginator -from bot.utils import time -from ._watchchannel import WatchChannel - -log = logging.getLogger(__name__) - - -class TalentPool(WatchChannel, Cog, name="Talentpool"): - """Relays messages of helper candidates to a watch channel to observe them.""" - - def __init__(self, bot: Bot) -> None: - super().__init__( - bot, - destination=Channels.talent_pool, - webhook_id=Webhooks.talent_pool, - api_endpoint='bot/nominations', - api_default_params={'active': 'true', 'ordering': '-inserted_at'}, - logger=log, - ) - - @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def nomination_group(self, ctx: Context) -> None: - """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" - await ctx.send_help(ctx.command) - - @nomination_group.command(name='watched', aliases=('all', 'list')) - @with_role(*MODERATION_ROLES) - async def watched_command( - self, ctx: Context, oldest_first: bool = False, update_cache: bool = True - ) -> None: - """ - Shows the users that are currently being monitored in the talent pool. - - The optional kwarg `oldest_first` can be used to order the list by oldest nomination. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) - - @nomination_group.command(name='oldest') - @with_role(*MODERATION_ROLES) - async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: - """ - Shows talent pool monitored users ordered by oldest nomination. - - The optional kwarg `update_cache` can be used to update the user - cache using the API before listing the users. - """ - await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) - - @nomination_group.command(name='watch', aliases=('w', 'add', 'a')) - @with_role(*STAFF_ROLES) - async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Relay messages sent by the given `user` to the `#talent-pool` channel. - - A `reason` for adding the user to the talent pool is required and will be displayed - in the header when relaying messages of this user to the channel. - """ - if user.bot: - await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") - return - - if isinstance(user, Member) and any(role.id in STAFF_ROLES for role in user.roles): - await ctx.send(":x: Nominating staff members, eh? Here's a cookie :cookie:") - return - - if not await self.fetch_user_cache(): - await ctx.send(f":x: Failed to update the user cache; can't add {user}") - return - - if user.id in self.watched_users: - await ctx.send(f":x: {user} is already being watched in the talent pool") - return - - # Manual request with `raise_for_status` as False because we want the actual response - session = self.bot.api_client.session - url = self.bot.api_client._url_for(self.api_endpoint) - kwargs = { - 'json': { - 'actor': ctx.author.id, - 'reason': reason, - 'user': user.id - }, - 'raise_for_status': False, - } - async with session.post(url, **kwargs) as resp: - response_data = await resp.json() - - if resp.status == 400 and response_data.get('user', False): - await ctx.send(":x: The specified user can't be found in the database tables") - return - else: - resp.raise_for_status() - - self.watched_users[user.id] = response_data - msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel" - - history = await self.bot.api_client.get( - self.api_endpoint, - params={ - "user__id": str(user.id), - "active": "false", - "ordering": "-inserted_at" - } - ) - - if history: - total = f"({len(history)} previous nominations in total)" - start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" - end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" - msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" - - await ctx.send(msg) - - @nomination_group.command(name='history', aliases=('info', 'search')) - @with_role(*MODERATION_ROLES) - async def history_command(self, ctx: Context, user: FetchedMember) -> None: - """Shows the specified user's nomination history.""" - result = await self.bot.api_client.get( - self.api_endpoint, - params={ - 'user__id': str(user.id), - 'ordering': "-active,-inserted_at" - } - ) - if not result: - await ctx.send(":warning: This user has never been nominated") - return - - embed = Embed( - title=f"Nominations for {user.display_name} `({user.id})`", - color=Color.blue() - ) - lines = [self._nomination_to_string(nomination) for nomination in result] - await LinePaginator.paginate( - lines, - ctx=ctx, - embed=embed, - empty=True, - max_lines=3, - max_size=1000 - ) - - @nomination_group.command(name='unwatch', aliases=('end', )) - @with_role(*MODERATION_ROLES) - async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """ - Ends the active nomination of the specified user with the given reason. - - Providing a `reason` is required. - """ - active_nomination = await self.bot.api_client.get( - self.api_endpoint, - params=ChainMap( - self.api_default_params, - {"user__id": str(user.id)} - ) - ) - - if not active_nomination: - await ctx.send(":x: The specified user does not have an active nomination") - return - - [nomination] = active_nomination - await self.bot.api_client.patch( - f"{self.api_endpoint}/{nomination['id']}", - json={'end_reason': reason, 'active': False} - ) - await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") - self._remove_user(user.id) - - @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def nomination_edit_group(self, ctx: Context) -> None: - """Commands to edit nominations.""" - await ctx.send_help(ctx.command) - - @nomination_edit_group.command(name='reason') - @with_role(*MODERATION_ROLES) - async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None: - """ - Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. - - If the nomination is active, the reason for nominating the user will be edited; - If the nomination is no longer active, the reason for ending the nomination will be edited instead. - """ - try: - nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") - except ResponseCodeError as e: - if e.response.status == 404: - self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") - await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") - return - else: - raise - - field = "reason" if nomination["active"] else "end_reason" - - self.log.trace(f"Changing {field} for nomination with id {nomination_id} to {reason}") - - await self.bot.api_client.patch( - f"{self.api_endpoint}/{nomination_id}", - json={field: reason} - ) - - await ctx.send(f":white_check_mark: Updated the {field} of the nomination!") - - def _nomination_to_string(self, nomination_object: dict) -> str: - """Creates a string representation of a nomination.""" - guild = self.bot.get_guild(Guild.id) - - actor_id = nomination_object["actor"] - actor = guild.get_member(actor_id) - - active = nomination_object["active"] - log.debug(active) - log.debug(type(nomination_object["inserted_at"])) - - start_date = time.format_infraction(nomination_object["inserted_at"]) - if active: - lines = textwrap.dedent( - f""" - =============== - Status: **Active** - Date: {start_date} - Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} - Nomination ID: `{nomination_object["id"]}` - =============== - """ - ) - else: - end_date = time.format_infraction(nomination_object["ended_at"]) - lines = textwrap.dedent( - f""" - =============== - Status: Inactive - Date: {start_date} - Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} - - End date: {end_date} - Unwatch reason: {nomination_object["end_reason"]} - Nomination ID: `{nomination_object["id"]}` - =============== - """ - ) - - return lines.strip() - - -def setup(bot: Bot) -> None: - """Load the TalentPool cog.""" - bot.add_cog(TalentPool(bot)) diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py deleted file mode 100644 index ce95450e0..000000000 --- a/bot/cogs/off_topic_names.py +++ /dev/null @@ -1,162 +0,0 @@ -import asyncio -import difflib -import logging -from datetime import datetime, timedelta - -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES -from bot.converters import OffTopicName -from bot.decorators import with_role -from bot.pagination import LinePaginator - -CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) -log = logging.getLogger(__name__) - - -async def update_names(bot: Bot) -> None: - """Background updater task that performs the daily channel name update.""" - while True: - # Since we truncate the compute timedelta to seconds, we add one second to ensure - # we go past midnight in the `seconds_to_sleep` set below. - today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0) - next_midnight = today_at_midnight + timedelta(days=1) - seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1 - await asyncio.sleep(seconds_to_sleep) - - try: - channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get( - 'bot/off-topic-channel-names', params={'random_items': 3} - ) - except ResponseCodeError as e: - log.error(f"Failed to get new off topic channel names: code {e.response.status}") - continue - channel_0, channel_1, channel_2 = (bot.get_channel(channel_id) for channel_id in CHANNELS) - - await channel_0.edit(name=f'ot0-{channel_0_name}') - await channel_1.edit(name=f'ot1-{channel_1_name}') - await channel_2.edit(name=f'ot2-{channel_2_name}') - log.debug( - "Updated off-topic channel names to" - f" {channel_0_name}, {channel_1_name} and {channel_2_name}" - ) - - -class OffTopicNames(Cog): - """Commands related to managing the off-topic category channel names.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.updater_task = None - - self.bot.loop.create_task(self.init_offtopic_updater()) - - def cog_unload(self) -> None: - """Cancel any running updater tasks on cog unload.""" - if self.updater_task is not None: - self.updater_task.cancel() - - async def init_offtopic_updater(self) -> None: - """Start off-topic channel updating event loop if it hasn't already started.""" - await self.bot.wait_until_guild_available() - if self.updater_task is None: - coro = update_names(self.bot) - self.updater_task = self.bot.loop.create_task(coro) - - @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def otname_group(self, ctx: Context) -> None: - """Add or list items from the off-topic channel name rotation.""" - await ctx.send_help(ctx.command) - - @otname_group.command(name='add', aliases=('a',)) - @with_role(*MODERATION_ROLES) - async def add_command(self, ctx: Context, *, name: OffTopicName) -> None: - """ - Adds a new off-topic name to the rotation. - - The name is not added if it is too similar to an existing name. - """ - existing_names = await self.bot.api_client.get('bot/off-topic-channel-names') - close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) - - if close_match: - match = close_match[0] - log.info( - f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'" - ) - await ctx.send( - f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. " - "Use `!otn forceadd` to override this check." - ) - else: - await self._add_name(ctx, name) - - @otname_group.command(name='forceadd', aliases=('fa',)) - @with_role(*MODERATION_ROLES) - async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None: - """Forcefully adds a new off-topic name to the rotation.""" - await self._add_name(ctx, name) - - async def _add_name(self, ctx: Context, name: str) -> None: - """Adds an off-topic channel name to the site storage.""" - await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name}) - - log.info(f"{ctx.author} added the off-topic channel name '{name}'") - await ctx.send(f":ok_hand: Added `{name}` to the names list.") - - @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) - @with_role(*MODERATION_ROLES) - async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None: - """Removes a off-topic name from the rotation.""" - await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') - - log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") - await ctx.send(f":ok_hand: Removed `{name}` from the names list.") - - @otname_group.command(name='list', aliases=('l',)) - @with_role(*MODERATION_ROLES) - async def list_command(self, ctx: Context) -> None: - """ - Lists all currently known off-topic channel names in a paginator. - - Restricted to Moderator and above to not spoil the surprise. - """ - result = await self.bot.api_client.get('bot/off-topic-channel-names') - lines = sorted(f"• {name}" for name in result) - embed = Embed( - title=f"Known off-topic names (`{len(result)}` total)", - colour=Colour.blue() - ) - if result: - await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - - @otname_group.command(name='search', aliases=('s',)) - @with_role(*MODERATION_ROLES) - async def search_command(self, ctx: Context, *, query: OffTopicName) -> None: - """Search for an off-topic name.""" - result = await self.bot.api_client.get('bot/off-topic-channel-names') - in_matches = {name for name in result if query in name} - close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70) - lines = sorted(f"• {name}" for name in in_matches.union(close_matches)) - embed = Embed( - title="Query results", - colour=Colour.blue() - ) - - if lines: - await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) - else: - embed.description = "Nothing found." - await ctx.send(embed=embed) - - -def setup(bot: Bot) -> None: - """Load the OffTopicNames cog.""" - bot.add_cog(OffTopicNames(bot)) diff --git a/bot/cogs/utils/__init__.py b/bot/cogs/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bot/cogs/utils/bot.py b/bot/cogs/utils/bot.py deleted file mode 100644 index 71ed54f60..000000000 --- a/bot/cogs/utils/bot.py +++ /dev/null @@ -1,385 +0,0 @@ -import ast -import logging -import re -import time -from typing import Optional, Tuple - -from discord import Embed, Message, RawMessageUpdateEvent, TextChannel -from discord.ext.commands import Cog, Context, command, group - -from bot.bot import Bot -from bot.cogs.filters.token_remover import TokenRemover -from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs -from bot.decorators import with_role -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -RE_MARKDOWN = re.compile(r'([*_~`|>])') - - -class BotCog(Cog, name="Bot"): - """Bot information commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - # Stores allowed channels plus epoch time since last call. - self.channel_cooldowns = { - Channels.python_discussion: 0, - } - - # These channels will also work, but will not be subject to cooldown - self.channel_whitelist = ( - Channels.bot_commands, - ) - - # Stores improperly formatted Python codeblock message ids and the corresponding bot message - self.codeblock_message_ids = {} - - @group(invoke_without_command=True, name="bot", hidden=True) - @with_role(Roles.verified) - async def botinfo_group(self, ctx: Context) -> None: - """Bot informational commands.""" - await ctx.send_help(ctx.command) - - @botinfo_group.command(name='about', aliases=('info',), hidden=True) - @with_role(Roles.verified) - async def about_command(self, ctx: Context) -> None: - """Get information about the bot.""" - embed = Embed( - description="A utility bot designed just for the Python server! Try `!help` for more info.", - url="https://github.com/python-discord/bot" - ) - - embed.add_field(name="Total Users", value=str(len(self.bot.get_guild(Guild.id).members))) - embed.set_author( - name="Python Bot", - url="https://github.com/python-discord/bot", - icon_url=URLs.bot_avatar - ) - - await ctx.send(embed=embed) - - @command(name='echo', aliases=('print',)) - @with_role(*MODERATION_ROLES) - async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: - """Repeat the given message in either a specified channel or the current channel.""" - if channel is None: - await ctx.send(text) - else: - await channel.send(text) - - @command(name='embed') - @with_role(*MODERATION_ROLES) - async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: - """Send the input within an embed to either a specified channel or the current channel.""" - embed = Embed(description=text) - - if channel is None: - await ctx.send(embed=embed) - else: - await channel.send(embed=embed) - - def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: - """ - Strip msg in order to find Python code. - - Tries to strip out Python code out of msg and returns the stripped block or - None if the block is a valid Python codeblock. - """ - if msg.count("\n") >= 3: - # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. - if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: - log.trace( - "Someone wrote a message that was already a " - "valid Python syntax highlighted code block. No action taken." - ) - return None - - else: - # Stripping backticks from every line of the message. - log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") - content = "" - for line in msg.splitlines(keepends=True): - content += line.strip("`") - - content = content.strip() - - # Remove "Python" or "Py" from start of the message if it exists. - log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") - pycode = False - if content.lower().startswith("python"): - content = content[6:] - pycode = True - elif content.lower().startswith("py"): - content = content[2:] - pycode = True - - if pycode: - content = content.splitlines(keepends=True) - - # Check if there might be code in the first line, and preserve it. - first_line = content[0] - if " " in content[0]: - first_space = first_line.index(" ") - content[0] = first_line[first_space:] - content = "".join(content) - - # If there's no code we can just get rid of the first line. - else: - content = "".join(content[1:]) - - # Strip it again to remove any leading whitespace. This is neccessary - # if the first line of the message looked like ```python - old = content.strip() - - # Strips REPL code out of the message if there is any. - content, repl_code = self.repl_stripping(old) - if old != content: - return (content, old), repl_code - - # Try to apply indentation fixes to the code. - content = self.fix_indentation(content) - - # Check if the code contains backticks, if it does ignore the message. - if "`" in content: - log.trace("Detected ` inside the code, won't reply") - return None - else: - log.trace(f"Returning message.\n\n{content}\n\n") - return (content,), repl_code - - def fix_indentation(self, msg: str) -> str: - """Attempts to fix badly indented code.""" - def unindent(code: str, skip_spaces: int = 0) -> str: - """Unindents all code down to the number of spaces given in skip_spaces.""" - final = "" - current = code[0] - leading_spaces = 0 - - # Get numbers of spaces before code in the first line. - while current == " ": - current = code[leading_spaces + 1] - leading_spaces += 1 - leading_spaces -= skip_spaces - - # If there are any, remove that number of spaces from every line. - if leading_spaces > 0: - for line in code.splitlines(keepends=True): - line = line[leading_spaces:] - final += line - return final - else: - return code - - # Apply fix for "all lines are overindented" case. - msg = unindent(msg) - - # If the first line does not end with a colon, we can be - # certain the next line will be on the same indentation level. - # - # If it does end with a colon, we will need to indent all successive - # lines one additional level. - first_line = msg.splitlines()[0] - code = "".join(msg.splitlines(keepends=True)[1:]) - if not first_line.endswith(":"): - msg = f"{first_line}\n{unindent(code)}" - else: - msg = f"{first_line}\n{unindent(code, 4)}" - return msg - - def repl_stripping(self, msg: str) -> Tuple[str, bool]: - """ - Strip msg in order to extract Python code out of REPL output. - - Tries to strip out REPL Python code out of msg and returns the stripped msg. - - Returns True for the boolean if REPL code was found in the input msg. - """ - final = "" - for line in msg.splitlines(keepends=True): - if line.startswith(">>>") or line.startswith("..."): - final += line[4:] - log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") - if not final: - log.trace(f"Found no REPL code in \n\n{msg}\n\n") - return msg, False - else: - log.trace(f"Found REPL code in \n\n{msg}\n\n") - return final.rstrip(), True - - def has_bad_ticks(self, msg: Message) -> bool: - """Check to see if msg contains ticks that aren't '`'.""" - not_backticks = [ - "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", - "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", - "\u3003\u3003\u3003" - ] - - return msg.content[:3] in not_backticks - - @Cog.listener() - async def on_message(self, msg: Message) -> None: - """ - Detect poorly formatted Python code in new messages. - - If poorly formatted code is detected, send the user a helpful message explaining how to do - properly formatted Python syntax highlighting codeblocks. - """ - is_help_channel = ( - getattr(msg.channel, "category", None) - and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) - ) - parse_codeblock = ( - ( - is_help_channel - or msg.channel.id in self.channel_cooldowns - or msg.channel.id in self.channel_whitelist - ) - and not msg.author.bot - and len(msg.content.splitlines()) > 3 - and not TokenRemover.find_token_in_message(msg) - ) - - if parse_codeblock: # no token in the msg - on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 - if not on_cooldown or DEBUG_MODE: - try: - if self.has_bad_ticks(msg): - ticks = msg.content[:3] - content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) - if content is None: - return - - content, repl_code = content - - if len(content) == 2: - content = content[1] - else: - content = content[0] - - space_left = 204 - if len(content) >= space_left: - current_length = 0 - lines_walked = 0 - for line in content.splitlines(keepends=True): - if current_length + len(line) > space_left or lines_walked == 10: - break - current_length += len(line) - lines_walked += 1 - content = content[:current_length] + "#..." - content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) - howto = ( - "It looks like you are trying to paste code into this channel.\n\n" - "You seem to be using the wrong symbols to indicate where the codeblock should start. " - f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" - "**Here is an example of how it should look:**\n" - f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" - "**This will result in the following:**\n" - f"```python\n{content}\n```" - ) - - else: - howto = "" - content = self.codeblock_stripping(msg.content, False) - if content is None: - return - - content, repl_code = content - # Attempts to parse the message into an AST node. - # Invalid Python code will raise a SyntaxError. - tree = ast.parse(content[0]) - - # Multiple lines of single words could be interpreted as expressions. - # This check is to avoid all nodes being parsed as expressions. - # (e.g. words over multiple lines) - if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: - # Shorten the code to 10 lines and/or 204 characters. - space_left = 204 - if content and repl_code: - content = content[1] - else: - content = content[0] - - if len(content) >= space_left: - current_length = 0 - lines_walked = 0 - for line in content.splitlines(keepends=True): - if current_length + len(line) > space_left or lines_walked == 10: - break - current_length += len(line) - lines_walked += 1 - content = content[:current_length] + "#..." - - content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) - howto += ( - "It looks like you're trying to paste code into this channel.\n\n" - "Discord has support for Markdown, which allows you to post code with full " - "syntax highlighting. Please use these whenever you paste code, as this " - "helps improve the legibility and makes it easier for us to help you.\n\n" - f"**To do this, use the following method:**\n" - f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" - "**This will result in the following:**\n" - f"```python\n{content}\n```" - ) - - log.debug(f"{msg.author} posted something that needed to be put inside python code " - "blocks. Sending the user some instructions.") - else: - log.trace("The code consists only of expressions, not sending instructions") - - if howto != "": - # Increase amount of codeblock correction in stats - self.bot.stats.incr("codeblock_corrections") - howto_embed = Embed(description=howto) - bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) - self.codeblock_message_ids[msg.id] = bot_message.id - - self.bot.loop.create_task( - wait_for_deletion(bot_message, user_ids=(msg.author.id,), client=self.bot) - ) - else: - return - - if msg.channel.id not in self.channel_whitelist: - self.channel_cooldowns[msg.channel.id] = time.time() - - except SyntaxError: - log.trace( - f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " - "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " - f"The message that was posted was:\n\n{msg.content}\n\n" - ) - - @Cog.listener() - async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: - """Check to see if an edited message (previously called out) still contains poorly formatted code.""" - if ( - # Checks to see if the message was called out by the bot - payload.message_id not in self.codeblock_message_ids - # Makes sure that there is content in the message - or payload.data.get("content") is None - # Makes sure there's a channel id in the message payload - or payload.data.get("channel_id") is None - ): - return - - # Retrieve channel and message objects for use later - channel = self.bot.get_channel(int(payload.data.get("channel_id"))) - user_message = await channel.fetch_message(payload.message_id) - - # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None - has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) - - # If the message is fixed, delete the bot message and the entry from the id dictionary - if has_fixed_codeblock is None: - bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) - await bot_message.delete() - del self.codeblock_message_ids[payload.message_id] - log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") - - -def setup(bot: Bot) -> None: - """Load the Bot cog.""" - bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/utils/clean.py b/bot/cogs/utils/clean.py deleted file mode 100644 index c156ff02e..000000000 --- a/bot/cogs/utils/clean.py +++ /dev/null @@ -1,272 +0,0 @@ -import logging -import random -import re -from typing import Iterable, Optional - -from discord import Colour, Embed, Message, TextChannel, User -from discord.ext import commands -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.cogs.moderation.modlog import ModLog -from bot.constants import ( - Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES -) -from bot.decorators import with_role - -log = logging.getLogger(__name__) - - -class Clean(Cog): - """ - A cog that allows messages to be deleted in bulk, while applying various filters. - - You can delete messages sent by a specific user, messages sent by bots, all messages, or messages that match a - specific regular expression. - - The deleted messages are saved and uploaded to the database via an API endpoint, and a URL is returned which can be - used to view the messages in the Discord dark theme style. - """ - - def __init__(self, bot: Bot): - self.bot = bot - self.cleaning = False - - @property - def mod_log(self) -> ModLog: - """Get currently loaded ModLog cog instance.""" - return self.bot.get_cog("ModLog") - - async def _clean_messages( - self, - amount: int, - ctx: Context, - channels: Iterable[TextChannel], - bots_only: bool = False, - user: User = None, - regex: Optional[str] = None, - until_message: Optional[Message] = None, - ) -> None: - """A helper function that does the actual message cleaning.""" - def predicate_bots_only(message: Message) -> bool: - """Return True if the message was sent by a bot.""" - return message.author.bot - - def predicate_specific_user(message: Message) -> bool: - """Return True if the message was sent by the user provided in the _clean_messages call.""" - return message.author == user - - def predicate_regex(message: Message) -> bool: - """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" - content = [message.content] - - # Add the content for all embed attributes - for embed in message.embeds: - content.append(embed.title) - content.append(embed.description) - content.append(embed.footer.text) - content.append(embed.author.name) - for field in embed.fields: - content.append(field.name) - content.append(field.value) - - # Get rid of empty attributes and turn it into a string - content = [attr for attr in content if attr] - content = "\n".join(content) - - # Now let's see if there's a regex match - if not content: - return False - else: - return bool(re.search(regex.lower(), content.lower())) - - # Is this an acceptable amount of messages to clean? - if amount > CleanMessages.message_limit: - embed = Embed( - color=Colour(Colours.soft_red), - title=random.choice(NEGATIVE_REPLIES), - description=f"You cannot clean more than {CleanMessages.message_limit} messages." - ) - await ctx.send(embed=embed) - return - - # Are we already performing a clean? - if self.cleaning: - embed = Embed( - color=Colour(Colours.soft_red), - title=random.choice(NEGATIVE_REPLIES), - description="Please wait for the currently ongoing clean operation to complete." - ) - await ctx.send(embed=embed) - return - - # Set up the correct predicate - if bots_only: - predicate = predicate_bots_only # Delete messages from bots - elif user: - predicate = predicate_specific_user # Delete messages from specific user - elif regex: - predicate = predicate_regex # Delete messages that match regex - else: - predicate = None # Delete all messages - - # Default to using the invoking context's channel - if not channels: - channels = [ctx.channel] - - # Delete the invocation first - self.mod_log.ignore(Event.message_delete, ctx.message.id) - await ctx.message.delete() - - messages = [] - message_ids = [] - self.cleaning = True - - # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. - for channel in channels: - async for message in channel.history(limit=amount): - - # If at any point the cancel command is invoked, we should stop. - if not self.cleaning: - return - - # If we are looking for specific message. - if until_message: - - # we could use ID's here however in case if the message we are looking for gets deleted, - # we won't have a way to figure that out thus checking for datetime should be more reliable - if message.created_at < until_message.created_at: - # means we have found the message until which we were supposed to be deleting. - break - - # Since we will be using `delete_messages` method of a TextChannel and we need message objects to - # use it as well as to send logs we will start appending messages here instead adding them from - # purge. - messages.append(message) - - # If the message passes predicate, let's save it. - if predicate is None or predicate(message): - message_ids.append(message.id) - - self.cleaning = False - - # Now let's delete the actual messages with purge. - self.mod_log.ignore(Event.message_delete, *message_ids) - for channel in channels: - if until_message: - for i in range(0, len(messages), 100): - # while purge automatically handles the amount of messages - # delete_messages only allows for up to 100 messages at once - # thus we need to paginate the amount to always be <= 100 - await channel.delete_messages(messages[i:i + 100]) - else: - messages += await channel.purge(limit=amount, check=predicate) - - # Reverse the list to restore chronological order - if messages: - messages = reversed(messages) - log_url = await self.mod_log.upload_log(messages, ctx.author.id) - else: - # Can't build an embed, nothing to clean! - embed = Embed( - color=Colour(Colours.soft_red), - description="No matching messages could be found." - ) - await ctx.send(embed=embed, delete_after=10) - return - - # Build the embed and send it - target_channels = ", ".join(channel.mention for channel in channels) - - message = ( - f"**{len(message_ids)}** messages deleted in {target_channels} by **{ctx.author.name}**\n\n" - f"A log of the deleted messages can be found [here]({log_url})." - ) - - await self.mod_log.send_log_message( - icon_url=Icons.message_bulk_delete, - colour=Colour(Colours.soft_red), - title="Bulk message delete", - text=message, - channel_id=Channels.mod_log, - ) - - @group(invoke_without_command=True, name="clean", aliases=["purge"]) - @with_role(*MODERATION_ROLES) - async def clean_group(self, ctx: Context) -> None: - """Commands for cleaning messages in channels.""" - await ctx.send_help(ctx.command) - - @clean_group.command(name="user", aliases=["users"]) - @with_role(*MODERATION_ROLES) - async def clean_user( - self, - ctx: Context, - user: User, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, user=user, channels=channels) - - @clean_group.command(name="all", aliases=["everything"]) - @with_role(*MODERATION_ROLES) - async def clean_all( - self, - ctx: Context, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, channels=channels) - - @clean_group.command(name="bots", aliases=["bot"]) - @with_role(*MODERATION_ROLES) - async def clean_bots( - self, - ctx: Context, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, bots_only=True, channels=channels) - - @clean_group.command(name="regex", aliases=["word", "expression"]) - @with_role(*MODERATION_ROLES) - async def clean_regex( - self, - ctx: Context, - regex: str, - amount: Optional[int] = 10, - channels: commands.Greedy[TextChannel] = None - ) -> None: - """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, regex=regex, channels=channels) - - @clean_group.command(name="message", aliases=["messages"]) - @with_role(*MODERATION_ROLES) - async def clean_message(self, ctx: Context, message: Message) -> None: - """Delete all messages until certain message, stop cleaning after hitting the `message`.""" - await self._clean_messages( - CleanMessages.message_limit, - ctx, - channels=[message.channel], - until_message=message - ) - - @clean_group.command(name="stop", aliases=["cancel", "abort"]) - @with_role(*MODERATION_ROLES) - async def clean_cancel(self, ctx: Context) -> None: - """If there is an ongoing cleaning process, attempt to immediately cancel it.""" - self.cleaning = False - - embed = Embed( - color=Colour.blurple(), - description="Clean interrupted." - ) - await ctx.send(embed=embed, delete_after=10) - - -def setup(bot: Bot) -> None: - """Load the Clean cog.""" - bot.add_cog(Clean(bot)) diff --git a/bot/cogs/utils/eval.py b/bot/cogs/utils/eval.py deleted file mode 100644 index eb8bfb1cf..000000000 --- a/bot/cogs/utils/eval.py +++ /dev/null @@ -1,202 +0,0 @@ -import contextlib -import inspect -import logging -import pprint -import re -import textwrap -import traceback -from io import StringIO -from typing import Any, Optional, Tuple - -import discord -from discord.ext.commands import Cog, Context, group - -from bot.bot import Bot -from bot.constants import Roles -from bot.decorators import with_role -from bot.interpreter import Interpreter - -log = logging.getLogger(__name__) - - -class CodeEval(Cog): - """Owner and admin feature that evaluates code and returns the result to the channel.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.env = {} - self.ln = 0 - self.stdout = StringIO() - - self.interpreter = Interpreter(bot) - - def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: - """Format the eval output into a string & attempt to format it into an Embed.""" - self._ = out - - res = "" - - # Erase temp input we made - if inp.startswith("_ = "): - inp = inp[4:] - - # Get all non-empty lines - lines = [line for line in inp.split("\n") if line.strip()] - if len(lines) != 1: - lines += [""] - - # Create the input dialog - for i, line in enumerate(lines): - if i == 0: - # Start dialog - start = f"In [{self.ln}]: " - - else: - # Indent the 3 dots correctly; - # Normally, it's something like - # In [X]: - # ...: - # - # But if it's - # In [XX]: - # ...: - # - # You can see it doesn't look right. - # This code simply indents the dots - # far enough to align them. - # we first `str()` the line number - # then we get the length - # and use `str.rjust()` - # to indent it. - start = "...: ".rjust(len(str(self.ln)) + 7) - - if i == len(lines) - 2: - if line.startswith("return"): - line = line[6:].strip() - - # Combine everything - res += (start + line + "\n") - - self.stdout.seek(0) - text = self.stdout.read() - self.stdout.close() - self.stdout = StringIO() - - if text: - res += (text + "\n") - - if out is None: - # No output, return the input statement - return (res, None) - - res += f"Out[{self.ln}]: " - - if isinstance(out, discord.Embed): - # We made an embed? Send that as embed - res += "" - res = (res, out) - - else: - if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): - # Leave out the traceback message - out = "\n" + "\n".join(out.split("\n")[1:]) - - if isinstance(out, str): - pretty = out - else: - pretty = pprint.pformat(out, compact=True, width=60) - - if pretty != str(out): - # We're using the pretty version, start on the next line - res += "\n" - - if pretty.count("\n") > 20: - # Text too long, shorten - li = pretty.split("\n") - - pretty = ("\n".join(li[:3]) # First 3 lines - + "\n ...\n" # Ellipsis to indicate removed lines - + "\n".join(li[-3:])) # last 3 lines - - # Add the output - res += pretty - res = (res, None) - - return res # Return (text, embed) - - async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: - """Eval the input code string & send an embed to the invoking context.""" - self.ln += 1 - - if code.startswith("exit"): - self.ln = 0 - self.env = {} - return await ctx.send("```Reset history!```") - - env = { - "message": ctx.message, - "author": ctx.message.author, - "channel": ctx.channel, - "guild": ctx.guild, - "ctx": ctx, - "self": self, - "bot": self.bot, - "inspect": inspect, - "discord": discord, - "contextlib": contextlib - } - - self.env.update(env) - - # Ignore this code, it works - code_ = """ -async def func(): # (None,) -> Any - try: - with contextlib.redirect_stdout(self.stdout): -{0} - if '_' in locals(): - if inspect.isawaitable(_): - _ = await _ - return _ - finally: - self.env.update(locals()) -""".format(textwrap.indent(code, ' ')) - - try: - exec(code_, self.env) # noqa: B102,S102 - func = self.env['func'] - res = await func() - - except Exception: - res = traceback.format_exc() - - out, embed = self._format(code, res) - await ctx.send(f"```py\n{out}```", embed=embed) - - @group(name='internal', aliases=('int',)) - @with_role(Roles.owners, Roles.admins) - async def internal_group(self, ctx: Context) -> None: - """Internal commands. Top secret!""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @internal_group.command(name='eval', aliases=('e',)) - @with_role(Roles.admins, Roles.owners) - async def eval(self, ctx: Context, *, code: str) -> None: - """Run eval in a REPL-like format.""" - code = code.strip("`") - if re.match('py(thon)?\n', code): - code = "\n".join(code.split("\n")[1:]) - - if not re.search( # Check if it's an expression - r"^(return|import|for|while|def|class|" - r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( - code.split("\n")) == 1: - code = "_ = " + code - - await self._eval(ctx, code) - - -def setup(bot: Bot) -> None: - """Load the CodeEval cog.""" - bot.add_cog(CodeEval(bot)) diff --git a/bot/cogs/utils/extensions.py b/bot/cogs/utils/extensions.py deleted file mode 100644 index 2cde07035..000000000 --- a/bot/cogs/utils/extensions.py +++ /dev/null @@ -1,289 +0,0 @@ -import functools -import importlib -import inspect -import logging -import pkgutil -import typing as t -from enum import Enum - -from discord import Colour, Embed -from discord.ext import commands -from discord.ext.commands import Context, group - -from bot import cogs -from bot.bot import Bot -from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - - -def walk_extensions() -> t.Iterator[str]: - """Yield extension names from the bot.cogs subpackage.""" - - def on_error(name: str) -> t.NoReturn: - raise ImportError(name=name) # pragma: no cover - - for module in pkgutil.walk_packages(cogs.__path__, f"{cogs.__name__}.", onerror=on_error): - if module.name.rsplit(".", maxsplit=1)[-1].startswith("_"): - # Ignore module/package names starting with an underscore. - continue - - if module.ispkg: - imported = importlib.import_module(module.name) - if not inspect.isfunction(getattr(imported, "setup", None)): - # If it lacks a setup function, it's not an extension. - continue - - yield module.name - - -UNLOAD_BLACKLIST = {f"{cogs.__name__}.utils.extensions", f"{cogs.__name__}.moderation.modlog"} -EXTENSIONS = frozenset(walk_extensions()) -COG_PATH_LEN = len(cogs.__name__.split(".")) - - -class Action(Enum): - """Represents an action to perform on an extension.""" - - # Need to be partial otherwise they are considered to be function definitions. - LOAD = functools.partial(Bot.load_extension) - UNLOAD = functools.partial(Bot.unload_extension) - RELOAD = functools.partial(Bot.reload_extension) - - -class Extension(commands.Converter): - """ - Fully qualify the name of an extension and ensure it exists. - - The * and ** values bypass this when used with the reload command. - """ - - async def convert(self, ctx: Context, argument: str) -> str: - """Fully qualify the name of an extension and ensure it exists.""" - # Special values to reload all extensions - if argument == "*" or argument == "**": - return argument - - argument = argument.lower() - - if argument in EXTENSIONS: - return argument - elif (qualified_arg := f"{cogs.__name__}.{argument}") in EXTENSIONS: - return qualified_arg - - matches = [] - for ext in EXTENSIONS: - name = ext.rsplit(".", maxsplit=1)[-1] - if argument == name: - matches.append(ext) - - if len(matches) > 1: - matches.sort() - names = "\n".join(matches) - raise commands.BadArgument( - f":x: `{argument}` is an ambiguous extension name. " - f"Please use one of the following fully-qualified names.```\n{names}```" - ) - elif matches: - return matches[0] - else: - raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") - - -class Extensions(commands.Cog): - """Extension management commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) - async def extensions_group(self, ctx: Context) -> None: - """Load, unload, reload, and list loaded extensions.""" - await ctx.send_help(ctx.command) - - @extensions_group.command(name="load", aliases=("l",)) - async def load_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Load extensions given their fully qualified or unqualified names. - - If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - if "*" in extensions or "**" in extensions: - extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) - - msg = self.batch_manage(Action.LOAD, *extensions) - await ctx.send(msg) - - @extensions_group.command(name="unload", aliases=("ul",)) - async def unload_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Unload currently loaded extensions given their fully qualified or unqualified names. - - If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) - - if blacklisted: - msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" - else: - if "*" in extensions or "**" in extensions: - extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST - - msg = self.batch_manage(Action.UNLOAD, *extensions) - - await ctx.send(msg) - - @extensions_group.command(name="reload", aliases=("r",)) - async def reload_command(self, ctx: Context, *extensions: Extension) -> None: - r""" - Reload extensions given their fully qualified or unqualified names. - - If an extension fails to be reloaded, it will be rolled-back to the prior working state. - - If '\*' is given as the name, all currently loaded extensions will be reloaded. - If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. - """ # noqa: W605 - if not extensions: - await ctx.send_help(ctx.command) - return - - if "**" in extensions: - extensions = EXTENSIONS - elif "*" in extensions: - extensions = set(self.bot.extensions.keys()) | set(extensions) - extensions.remove("*") - - msg = self.batch_manage(Action.RELOAD, *extensions) - - await ctx.send(msg) - - @extensions_group.command(name="list", aliases=("all",)) - async def list_command(self, ctx: Context) -> None: - """ - Get a list of all extensions, including their loaded status. - - Grey indicates that the extension is unloaded. - Green indicates that the extension is currently loaded. - """ - embed = Embed(colour=Colour.blurple()) - embed.set_author( - name="Extensions List", - url=URLs.github_bot_repo, - icon_url=URLs.bot_avatar - ) - - lines = [] - categories = self.group_extension_statuses() - for category, extensions in sorted(categories.items()): - # Treat each category as a single line by concatenating everything. - # This ensures the paginator will not cut off a page in the middle of a category. - category = category.replace("_", " ").title() - extensions = "\n".join(sorted(extensions)) - lines.append(f"**{category}**\n{extensions}\n") - - log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") - await LinePaginator.paginate(lines, ctx, embed, scale_to_size=700, empty=False) - - def group_extension_statuses(self) -> t.Mapping[str, str]: - """Return a mapping of extension names and statuses to their categories.""" - categories = {} - - for ext in EXTENSIONS: - if ext in self.bot.extensions: - status = Emojis.status_online - else: - status = Emojis.status_offline - - path = ext.split(".") - if len(path) > COG_PATH_LEN + 1: - category = " - ".join(path[COG_PATH_LEN:-1]) - else: - category = "uncategorised" - - categories.setdefault(category, []).append(f"{status} {path[-1]}") - - return categories - - def batch_manage(self, action: Action, *extensions: str) -> str: - """ - Apply an action to multiple extensions and return a message with the results. - - If only one extension is given, it is deferred to `manage()`. - """ - if len(extensions) == 1: - msg, _ = self.manage(action, extensions[0]) - return msg - - verb = action.name.lower() - failures = {} - - for extension in extensions: - _, error = self.manage(action, extension) - if error: - failures[extension] = error - - emoji = ":x:" if failures else ":ok_hand:" - msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." - - if failures: - failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) - msg += f"\nFailures:```{failures}```" - - log.debug(f"Batch {verb}ed extensions.") - - return msg - - def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: - """Apply an action to an extension and return the status message and any error message.""" - verb = action.name.lower() - error_msg = None - - try: - action.value(self.bot, ext) - except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): - if action is Action.RELOAD: - # When reloading, just load the extension if it was not loaded. - return self.manage(Action.LOAD, ext) - - msg = f":x: Extension `{ext}` is already {verb}ed." - log.debug(msg[4:]) - except Exception as e: - if hasattr(e, "original"): - e = e.original - - log.exception(f"Extension '{ext}' failed to {verb}.") - - error_msg = f"{e.__class__.__name__}: {e}" - msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" - else: - msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." - log.debug(msg[10:]) - - return msg, error_msg - - # This cannot be static (must have a __func__ attribute). - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators and core developers to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) - - # This cannot be static (must have a __func__ attribute). - async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Handle BadArgument errors locally to prevent the help command from showing.""" - if isinstance(error, commands.BadArgument): - await ctx.send(str(error)) - error.handled = True - - -def setup(bot: Bot) -> None: - """Load the Extensions cog.""" - bot.add_cog(Extensions(bot)) diff --git a/bot/cogs/utils/jams.py b/bot/cogs/utils/jams.py deleted file mode 100644 index b3102db2f..000000000 --- a/bot/cogs/utils/jams.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -import typing as t - -from discord import CategoryChannel, Guild, Member, PermissionOverwrite, Role -from discord.ext import commands -from more_itertools import unique_everseen - -from bot.bot import Bot -from bot.constants import Roles -from bot.decorators import with_role - -log = logging.getLogger(__name__) - -MAX_CHANNELS = 50 -CATEGORY_NAME = "Code Jam" - - -class CodeJams(commands.Cog): - """Manages the code-jam related parts of our server.""" - - def __init__(self, bot: Bot): - self.bot = bot - - @commands.command() - @with_role(Roles.admins) - async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: - """ - Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. - - The first user passed will always be the team leader. - """ - # Ignore duplicate members - members = list(unique_everseen(members)) - - # We had a little issue during Code Jam 4 here, the greedy converter did it's job - # and ignored anything which wasn't a valid argument which left us with teams of - # two members or at some times even 1 member. This fixes that by checking that there - # are always 3 members in the members list. - if len(members) < 3: - await ctx.send( - ":no_entry_sign: One of your arguments was invalid\n" - f"There must be a minimum of 3 valid members in your team. Found: {len(members)}" - " members" - ) - return - - team_channel = await self.create_channels(ctx.guild, team_name, members) - await self.add_roles(ctx.guild, members) - - await ctx.send( - f":ok_hand: Team created: {team_channel}\n" - f"**Team Leader:** {members[0].mention}\n" - f"**Team Members:** {' '.join(member.mention for member in members[1:])}" - ) - - async def get_category(self, guild: Guild) -> CategoryChannel: - """ - Return a code jam category. - - If all categories are full or none exist, create a new category. - """ - for category in guild.categories: - # Need 2 available spaces: one for the text channel and one for voice. - if category.name == CATEGORY_NAME and MAX_CHANNELS - len(category.channels) >= 2: - return category - - return await self.create_category(guild) - - @staticmethod - async def create_category(guild: Guild) -> CategoryChannel: - """Create a new code jam category and return it.""" - log.info("Creating a new code jam category.") - - category_overwrites = { - guild.default_role: PermissionOverwrite(read_messages=False), - guild.me: PermissionOverwrite(read_messages=True) - } - - return await guild.create_category_channel( - CATEGORY_NAME, - overwrites=category_overwrites, - reason="It's code jam time!" - ) - - @staticmethod - def get_overwrites(members: t.List[Member], guild: Guild) -> t.Dict[t.Union[Member, Role], PermissionOverwrite]: - """Get code jam team channels permission overwrites.""" - # First member is always the team leader - team_channel_overwrites = { - members[0]: PermissionOverwrite( - manage_messages=True, - read_messages=True, - manage_webhooks=True, - connect=True - ), - guild.default_role: PermissionOverwrite(read_messages=False, connect=False), - guild.get_role(Roles.verified): PermissionOverwrite( - read_messages=False, - connect=False - ) - } - - # Rest of members should just have read_messages - for member in members[1:]: - team_channel_overwrites[member] = PermissionOverwrite( - read_messages=True, - connect=True - ) - - return team_channel_overwrites - - async def create_channels(self, guild: Guild, team_name: str, members: t.List[Member]) -> str: - """Create team text and voice channels. Return the mention for the text channel.""" - # Get permission overwrites and category - team_channel_overwrites = self.get_overwrites(members, guild) - code_jam_category = await self.get_category(guild) - - # Create a text channel for the team - team_channel = await guild.create_text_channel( - team_name, - overwrites=team_channel_overwrites, - category=code_jam_category - ) - - # Create a voice channel for the team - team_voice_name = " ".join(team_name.split("-")).title() - - await guild.create_voice_channel( - team_voice_name, - overwrites=team_channel_overwrites, - category=code_jam_category - ) - - return team_channel.mention - - @staticmethod - async def add_roles(guild: Guild, members: t.List[Member]) -> None: - """Assign team leader and jammer roles.""" - # Assign team leader role - await members[0].add_roles(guild.get_role(Roles.team_leaders)) - - # Assign rest of roles - jammer_role = guild.get_role(Roles.jammers) - for member in members: - await member.add_roles(jammer_role) - - -def setup(bot: Bot) -> None: - """Load the CodeJams cog.""" - bot.add_cog(CodeJams(bot)) diff --git a/bot/cogs/utils/reminders.py b/bot/cogs/utils/reminders.py deleted file mode 100644 index 670493bcf..000000000 --- a/bot/cogs/utils/reminders.py +++ /dev/null @@ -1,427 +0,0 @@ -import asyncio -import logging -import random -import textwrap -import typing as t -from datetime import datetime, timedelta -from operator import itemgetter - -import discord -from dateutil.parser import isoparse -from dateutil.relativedelta import relativedelta -from discord.ext.commands import Cog, Context, Greedy, group - -from bot.bot import Bot -from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES -from bot.converters import Duration -from bot.pagination import LinePaginator -from bot.utils.checks import without_role_check -from bot.utils.messages import send_denial -from bot.utils.scheduling import Scheduler -from bot.utils.time import humanize_delta - -log = logging.getLogger(__name__) - -WHITELISTED_CHANNELS = Guild.reminder_whitelist -MAXIMUM_REMINDERS = 5 - -Mentionable = t.Union[discord.Member, discord.Role] - - -class Reminders(Cog): - """Provide in-channel reminder functionality.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.scheduler = Scheduler(self.__class__.__name__) - - self.bot.loop.create_task(self.reschedule_reminders()) - - def cog_unload(self) -> None: - """Cancel scheduled tasks.""" - self.scheduler.cancel_all() - - async def reschedule_reminders(self) -> None: - """Get all current reminders from the API and reschedule them.""" - await self.bot.wait_until_guild_available() - response = await self.bot.api_client.get( - 'bot/reminders', - params={'active': 'true'} - ) - - now = datetime.utcnow() - - for reminder in response: - is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) - if not is_valid: - continue - - remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) - - # If the reminder is already overdue ... - if remind_at < now: - late = relativedelta(now, remind_at) - await self.send_reminder(reminder, late) - else: - self.schedule_reminder(reminder) - - def ensure_valid_reminder( - self, - reminder: dict, - cancel_task: bool = True - ) -> t.Tuple[bool, discord.User, discord.TextChannel]: - """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" - user = self.bot.get_user(reminder['author']) - channel = self.bot.get_channel(reminder['channel_id']) - is_valid = True - if not user or not channel: - is_valid = False - log.info( - f"Reminder {reminder['id']} invalid: " - f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." - ) - asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) - - return is_valid, user, channel - - @staticmethod - async def _send_confirmation( - ctx: Context, - on_success: str, - reminder_id: str, - delivery_dt: t.Optional[datetime], - ) -> None: - """Send an embed confirming the reminder change was made successfully.""" - embed = discord.Embed() - embed.colour = discord.Colour.green() - embed.title = random.choice(POSITIVE_REPLIES) - embed.description = on_success - - footer_str = f"ID: {reminder_id}" - if delivery_dt: - # Reminder deletion will have a `None` `delivery_dt` - footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" - - embed.set_footer(text=footer_str) - - await ctx.send(embed=embed) - - @staticmethod - async def _check_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> t.Tuple[bool, str]: - """ - Returns whether or not the list of mentions is allowed. - - Conditions: - - Role reminders are Mods+ - - Reminders for other users are Helpers+ - - If mentions aren't allowed, also return the type of mention(s) disallowed. - """ - if without_role_check(ctx, *STAFF_ROLES): - return False, "members/roles" - elif without_role_check(ctx, *MODERATION_ROLES): - return all(isinstance(mention, discord.Member) for mention in mentions), "roles" - else: - return True, "" - - @staticmethod - async def validate_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> bool: - """ - Filter mentions to see if the user can mention, and sends a denial if not allowed. - - Returns whether or not the validation is successful. - """ - mentions_allowed, disallowed_mentions = await Reminders._check_mentions(ctx, mentions) - - if not mentions or mentions_allowed: - return True - else: - await send_denial(ctx, f"You can't mention other {disallowed_mentions} in your reminder!") - return False - - def get_mentionables(self, mention_ids: t.List[int]) -> t.Iterator[Mentionable]: - """Converts Role and Member ids to their corresponding objects if possible.""" - guild = self.bot.get_guild(Guild.id) - for mention_id in mention_ids: - if (mentionable := (guild.get_member(mention_id) or guild.get_role(mention_id))): - yield mentionable - - def schedule_reminder(self, reminder: dict) -> None: - """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" - reminder_id = reminder["id"] - reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) - - async def _remind() -> None: - await self.send_reminder(reminder) - - log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") - await self._delete_reminder(reminder_id) - - self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) - - async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: - """Delete a reminder from the database, given its ID, and cancel the running task.""" - await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) - - if cancel_task: - # Now we can remove it from the schedule list - self.scheduler.cancel(reminder_id) - - async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: - """ - Edits a reminder in the database given the ID and payload. - - Returns the edited reminder. - """ - # Send the request to update the reminder in the database - reminder = await self.bot.api_client.patch( - 'bot/reminders/' + str(reminder_id), - json=payload - ) - return reminder - - async def _reschedule_reminder(self, reminder: dict) -> None: - """Reschedule a reminder object.""" - log.trace(f"Cancelling old task #{reminder['id']}") - self.scheduler.cancel(reminder["id"]) - - log.trace(f"Scheduling new task #{reminder['id']}") - self.schedule_reminder(reminder) - - async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: - """Send the reminder.""" - is_valid, user, channel = self.ensure_valid_reminder(reminder) - if not is_valid: - return - - embed = discord.Embed() - embed.colour = discord.Colour.blurple() - embed.set_author( - icon_url=Icons.remind_blurple, - name="It has arrived!" - ) - - embed.description = f"Here's your reminder: `{reminder['content']}`." - - if reminder.get("jump_url"): # keep backward compatibility - embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" - - if late: - embed.colour = discord.Colour.red() - embed.set_author( - icon_url=Icons.remind_red, - name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" - ) - - additional_mentions = ' '.join( - mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) - ) - - await channel.send( - content=f"{user.mention} {additional_mentions}", - embed=embed - ) - await self._delete_reminder(reminder["id"]) - - @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) - async def remind_group( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str - ) -> None: - """Commands for managing your reminders.""" - await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content) - - @remind_group.command(name="new", aliases=("add", "create")) - async def new_reminder( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str - ) -> None: - """ - Set yourself a simple reminder. - - Expiration is parsed per: http://strftime.org/ - """ - # If the user is not staff, we need to verify whether or not to make a reminder at all. - if without_role_check(ctx, *STAFF_ROLES): - - # If they don't have permission to set a reminder in this channel - if ctx.channel.id not in WHITELISTED_CHANNELS: - await send_denial(ctx, "Sorry, you can't do that here!") - return - - # Get their current active reminders - active_reminders = await self.bot.api_client.get( - 'bot/reminders', - params={ - 'author__id': str(ctx.author.id) - } - ) - - # Let's limit this, so we don't get 10 000 - # reminders from kip or something like that :P - if len(active_reminders) > MAXIMUM_REMINDERS: - await send_denial(ctx, "You have too many active reminders!") - return - - # Remove duplicate mentions - mentions = set(mentions) - mentions.discard(ctx.author) - - # Filter mentions to see if the user can mention members/roles - if not await self.validate_mentions(ctx, mentions): - return - - mention_ids = [mention.id for mention in mentions] - - # Now we can attempt to actually set the reminder. - reminder = await self.bot.api_client.post( - 'bot/reminders', - json={ - 'author': ctx.author.id, - 'channel_id': ctx.message.channel.id, - 'jump_url': ctx.message.jump_url, - 'content': content, - 'expiration': expiration.isoformat(), - 'mentions': mention_ids, - } - ) - - now = datetime.utcnow() - timedelta(seconds=1) - humanized_delta = humanize_delta(relativedelta(expiration, now)) - mention_string = ( - f"Your reminder will arrive in {humanized_delta} " - f"and will mention {len(mentions)} other(s)!" - ) - - # Confirm to the user that it worked. - await self._send_confirmation( - ctx, - on_success=mention_string, - reminder_id=reminder["id"], - delivery_dt=expiration, - ) - - self.schedule_reminder(reminder) - - @remind_group.command(name="list") - async def list_reminders(self, ctx: Context) -> None: - """View a paginated embed of all reminders for your user.""" - # Get all the user's reminders from the database. - data = await self.bot.api_client.get( - 'bot/reminders', - params={'author__id': str(ctx.author.id)} - ) - - now = datetime.utcnow() - - # Make a list of tuples so it can be sorted by time. - reminders = sorted( - ( - (rem['content'], rem['expiration'], rem['id'], rem['mentions']) - for rem in data - ), - key=itemgetter(1) - ) - - lines = [] - - for content, remind_at, id_, mentions in reminders: - # Parse and humanize the time, make it pretty :D - remind_datetime = isoparse(remind_at).replace(tzinfo=None) - time = humanize_delta(relativedelta(remind_datetime, now)) - - mentions = ", ".join( - # Both Role and User objects have the `name` attribute - mention.name for mention in self.get_mentionables(mentions) - ) - mention_string = f"\n**Mentions:** {mentions}" if mentions else "" - - text = textwrap.dedent(f""" - **Reminder #{id_}:** *expires in {time}* (ID: {id_}){mention_string} - {content} - """).strip() - - lines.append(text) - - embed = discord.Embed() - embed.colour = discord.Colour.blurple() - embed.title = f"Reminders for {ctx.author}" - - # Remind the user that they have no reminders :^) - if not lines: - embed.description = "No active reminders could be found." - await ctx.send(embed=embed) - return - - # Construct the embed and paginate it. - embed.colour = discord.Colour.blurple() - - await LinePaginator.paginate( - lines, - ctx, embed, - max_lines=3, - empty=True - ) - - @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) - async def edit_reminder_group(self, ctx: Context) -> None: - """Commands for modifying your current reminders.""" - await ctx.send_help(ctx.command) - - @edit_reminder_group.command(name="duration", aliases=("time",)) - async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: - """ - Edit one of your reminder's expiration. - - Expiration is parsed per: http://strftime.org/ - """ - await self.edit_reminder(ctx, id_, {'expiration': expiration.isoformat()}) - - @edit_reminder_group.command(name="content", aliases=("reason",)) - async def edit_reminder_content(self, ctx: Context, id_: int, *, content: str) -> None: - """Edit one of your reminder's content.""" - await self.edit_reminder(ctx, id_, {"content": content}) - - @edit_reminder_group.command(name="mentions", aliases=("pings",)) - async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: - """Edit one of your reminder's mentions.""" - # Remove duplicate mentions - mentions = set(mentions) - mentions.discard(ctx.author) - - # Filter mentions to see if the user can mention members/roles - if not await self.validate_mentions(ctx, mentions): - return - - mention_ids = [mention.id for mention in mentions] - await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) - - async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: - """Edits a reminder with the given payload, then sends a confirmation message.""" - reminder = await self._edit_reminder(id_, payload) - - # Parse the reminder expiration back into a datetime - expiration = isoparse(reminder["expiration"]).replace(tzinfo=None) - - # Send a confirmation message to the channel - await self._send_confirmation( - ctx, - on_success="That reminder has been edited successfully!", - reminder_id=id_, - delivery_dt=expiration, - ) - await self._reschedule_reminder(reminder) - - @remind_group.command("delete", aliases=("remove", "cancel")) - async def delete_reminder(self, ctx: Context, id_: int) -> None: - """Delete one of your active reminders.""" - await self._delete_reminder(id_) - await self._send_confirmation( - ctx, - on_success="That reminder has been deleted successfully!", - reminder_id=id_, - delivery_dt=None, - ) - - -def setup(bot: Bot) -> None: - """Load the Reminders cog.""" - bot.add_cog(Reminders(bot)) diff --git a/bot/cogs/utils/snekbox.py b/bot/cogs/utils/snekbox.py deleted file mode 100644 index 52c8b6f88..000000000 --- a/bot/cogs/utils/snekbox.py +++ /dev/null @@ -1,349 +0,0 @@ -import asyncio -import contextlib -import datetime -import logging -import re -import textwrap -from functools import partial -from signal import Signals -from typing import Optional, Tuple - -from discord import HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only - -from bot.bot import Bot -from bot.constants import Categories, Channels, Roles, URLs -from bot.decorators import in_whitelist -from bot.utils.messages import wait_for_deletion - -log = logging.getLogger(__name__) - -ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") -FORMATTED_CODE_REGEX = re.compile( - r"^\s*" # any leading whitespace from the beginning of the string - r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block - r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) - r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all code inside the markup - r"\s*" # any more whitespace before the end of the code markup - r"(?P=delim)" # match the exact same delimiter from the start again - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive -) -RAW_CODE_REGEX = re.compile( - r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all the rest as code - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL # "." also matches newlines -) - -MAX_PASTE_LEN = 1000 - -# `!eval` command whitelists -EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) -EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) -EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) - -SIGKILL = 9 - -REEVAL_EMOJI = '\U0001f501' # :repeat: -REEVAL_TIMEOUT = 30 - - -class Snekbox(Cog): - """Safe evaluation of Python code using Snekbox.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.jobs = {} - - async def post_eval(self, code: str) -> dict: - """Send a POST request to the Snekbox API to evaluate code and return the results.""" - url = URLs.snekbox_eval_api - data = {"input": code} - async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return await resp.json() - - async def upload_output(self, output: str) -> Optional[str]: - """Upload the eval output to a paste service and return a URL to it if successful.""" - log.trace("Uploading full output to paste service...") - - if len(output) > MAX_PASTE_LEN: - log.info("Full output is too long to upload") - return "too long to upload" - - url = URLs.paste_service.format(key="documents") - try: - async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: - data = await resp.json() - - if "key" in data: - return URLs.paste_service.format(key=data["key"]) - except Exception: - # 400 (Bad Request) means there are too many characters - log.exception("Failed to upload full output to paste service!") - - @staticmethod - def prepare_input(code: str) -> str: - """Extract code from the Markdown, format it, and insert it into the code template.""" - match = FORMATTED_CODE_REGEX.fullmatch(code) - if match: - code, block, lang, delim = match.group("code", "block", "lang", "delim") - code = textwrap.dedent(code) - if block: - info = (f"'{lang}' highlighted" if lang else "plain") + " code block" - else: - info = f"{delim}-enclosed inline code" - log.trace(f"Extracted {info} for evaluation:\n{code}") - else: - code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) - log.trace( - f"Eval message contains unformatted or badly formatted code, " - f"stripping whitespace only:\n{code}" - ) - - return code - - @staticmethod - def get_results_message(results: dict) -> Tuple[str, str]: - """Return a user-friendly message and error corresponding to the process's return code.""" - stdout, returncode = results["stdout"], results["returncode"] - msg = f"Your eval job has completed with return code {returncode}" - error = "" - - if returncode is None: - msg = "Your eval job has failed" - error = stdout.strip() - elif returncode == 128 + SIGKILL: - msg = "Your eval job timed out or ran out of memory" - elif returncode == 255: - msg = "Your eval job has failed" - error = "A fatal NsJail error occurred" - else: - # Try to append signal's name if one exists - try: - name = Signals(returncode - 128).name - msg = f"{msg} ({name})" - except ValueError: - pass - - return msg, error - - @staticmethod - def get_status_emoji(results: dict) -> str: - """Return an emoji corresponding to the status code or lack of output in result.""" - if not results["stdout"].strip(): # No output - return ":warning:" - elif results["returncode"] == 0: # No error - return ":white_check_mark:" - else: # Exception - return ":x:" - - async def format_output(self, output: str) -> Tuple[str, Optional[str]]: - """ - Format the output and return a tuple of the formatted output and a URL to the full output. - - Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters - and upload the full output to a paste service. - """ - log.trace("Formatting output...") - - output = output.rstrip("\n") - original_output = output # To be uploaded to a pasting service if needed - paste_link = None - - if "<@" in output: - output = output.replace("<@", "<@\u200B") # Zero-width space - - if " 0: - output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] - output = output[:11] # Limiting to only 11 lines - output = "\n".join(output) - - if lines > 10: - truncated = True - if len(output) >= 1000: - output = f"{output[:1000]}\n... (truncated - too long, too many lines)" - else: - output = f"{output}\n... (truncated - too many lines)" - elif len(output) >= 1000: - truncated = True - output = f"{output[:1000]}\n... (truncated - too long)" - - if truncated: - paste_link = await self.upload_output(original_output) - - output = output or "[No output]" - - return output, paste_link - - async def send_eval(self, ctx: Context, code: str) -> Message: - """ - Evaluate code, format it, and send the output to the corresponding channel. - - Return the bot response. - """ - async with ctx.typing(): - results = await self.post_eval(code) - msg, error = self.get_results_message(results) - - if error: - output, paste_link = error, None - else: - output, paste_link = await self.format_output(results["stdout"]) - - icon = self.get_status_emoji(results) - msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" - if paste_link: - msg = f"{msg}\nFull output: {paste_link}" - - # Collect stats of eval fails + successes - if icon == ":x:": - self.bot.stats.incr("snekbox.python.fail") - else: - self.bot.stats.incr("snekbox.python.success") - - filter_cog = self.bot.get_cog("Filtering") - filter_triggered = False - if filter_cog: - filter_triggered = await filter_cog.filter_eval(msg, ctx.message) - if filter_triggered: - response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") - else: - response = await ctx.send(msg) - self.bot.loop.create_task( - wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) - ) - - log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") - return response - - async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: - """ - Check if the eval session should continue. - - Return the new code to evaluate or None if the eval session should be terminated. - """ - _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) - _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) - - with contextlib.suppress(NotFound): - try: - _, new_message = await self.bot.wait_for( - 'message_edit', - check=_predicate_eval_message_edit, - timeout=REEVAL_TIMEOUT - ) - await ctx.message.add_reaction(REEVAL_EMOJI) - await self.bot.wait_for( - 'reaction_add', - check=_predicate_emoji_reaction, - timeout=10 - ) - - code = await self.get_code(new_message) - await ctx.message.clear_reactions() - with contextlib.suppress(HTTPException): - await response.delete() - - except asyncio.TimeoutError: - await ctx.message.clear_reactions() - return None - - return code - - async def get_code(self, message: Message) -> Optional[str]: - """ - Return the code from `message` to be evaluated. - - If the message is an invocation of the eval command, return the first argument or None if it - doesn't exist. Otherwise, return the full content of the message. - """ - log.trace(f"Getting context for message {message.id}.") - new_ctx = await self.bot.get_context(message) - - if new_ctx.command is self.eval_command: - log.trace(f"Message {message.id} invokes eval command.") - split = message.content.split(maxsplit=1) - code = split[1] if len(split) > 1 else None - else: - log.trace(f"Message {message.id} does not invoke eval command.") - code = message.content - - return code - - @command(name="eval", aliases=("e",)) - @guild_only() - @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) - async def eval_command(self, ctx: Context, *, code: str = None) -> None: - """ - Run Python code and get the results. - - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! - """ - if ctx.author.id in self.jobs: - await ctx.send( - f"{ctx.author.mention} You've already got a job running - " - "please wait for it to finish!" - ) - return - - if not code: # None or empty string - await ctx.send_help(ctx.command) - return - - if Roles.helpers in (role.id for role in ctx.author.roles): - self.bot.stats.incr("snekbox_usages.roles.helpers") - else: - self.bot.stats.incr("snekbox_usages.roles.developers") - - if ctx.channel.category_id == Categories.help_in_use: - self.bot.stats.incr("snekbox_usages.channels.help") - elif ctx.channel.id == Channels.bot_commands: - self.bot.stats.incr("snekbox_usages.channels.bot_commands") - else: - self.bot.stats.incr("snekbox_usages.channels.topical") - - log.info(f"Received code from {ctx.author} for evaluation:\n{code}") - - while True: - self.jobs[ctx.author.id] = datetime.datetime.now() - code = self.prepare_input(code) - try: - response = await self.send_eval(ctx, code) - finally: - del self.jobs[ctx.author.id] - - code = await self.continue_eval(ctx, response) - if not code: - break - log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") - - -def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: - """Return True if the edited message is the context message and the content was indeed modified.""" - return new_msg.id == ctx.message.id and old_msg.content != new_msg.content - - -def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: - """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI - - -def setup(bot: Bot) -> None: - """Load the Snekbox cog.""" - bot.add_cog(Snekbox(bot)) diff --git a/bot/cogs/utils/utils.py b/bot/cogs/utils/utils.py deleted file mode 100644 index d96abbd5a..000000000 --- a/bot/cogs/utils/utils.py +++ /dev/null @@ -1,265 +0,0 @@ -import difflib -import logging -import re -import unicodedata -from email.parser import HeaderParser -from io import StringIO -from typing import Tuple, Union - -from discord import Colour, Embed, utils -from discord.ext.commands import BadArgument, Cog, Context, clean_content, command - -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES -from bot.decorators import in_whitelist, with_role -from bot.pagination import LinePaginator -from bot.utils import messages - -log = logging.getLogger(__name__) - -ZEN_OF_PYTHON = """\ -Beautiful is better than ugly. -Explicit is better than implicit. -Simple is better than complex. -Complex is better than complicated. -Flat is better than nested. -Sparse is better than dense. -Readability counts. -Special cases aren't special enough to break the rules. -Although practicality beats purity. -Errors should never pass silently. -Unless explicitly silenced. -In the face of ambiguity, refuse the temptation to guess. -There should be one-- and preferably only one --obvious way to do it. -Although that way may not be obvious at first unless you're Dutch. -Now is better than never. -Although never is often better than *right* now. -If the implementation is hard to explain, it's a bad idea. -If the implementation is easy to explain, it may be a good idea. -Namespaces are one honking great idea -- let's do more of those! -""" - -ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" - - -class Utils(Cog): - """A selection of utilities which don't have a clear category.""" - - def __init__(self, bot: Bot): - self.bot = bot - - self.base_pep_url = "http://www.python.org/dev/peps/pep-" - self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" - - @command(name='pep', aliases=('get_pep', 'p')) - async def pep_command(self, ctx: Context, pep_number: str) -> None: - """Fetches information about a PEP and sends it to the channel.""" - if pep_number.isdigit(): - pep_number = int(pep_number) - else: - await ctx.send_help(ctx.command) - return - - # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. - if pep_number == 0: - return await self.send_pep_zero(ctx) - - possible_extensions = ['.txt', '.rst'] - found_pep = False - for extension in possible_extensions: - # Attempt to fetch the PEP - pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" - log.trace(f"Requesting PEP {pep_number} with {pep_url}") - response = await self.bot.http_session.get(pep_url) - - if response.status == 200: - log.trace("PEP found") - found_pep = True - - pep_content = await response.text() - - # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 - pep_header = HeaderParser().parse(StringIO(pep_content)) - - # Assemble the embed - pep_embed = Embed( - title=f"**PEP {pep_number} - {pep_header['Title']}**", - description=f"[Link]({self.base_pep_url}{pep_number:04})", - ) - - pep_embed.set_thumbnail(url=ICON_URL) - - # Add the interesting information - fields_to_check = ("Status", "Python-Version", "Created", "Type") - for field in fields_to_check: - # Check for a PEP metadata field that is present but has an empty value - # embed field values can't contain an empty string - if pep_header.get(field, ""): - pep_embed.add_field(name=field, value=pep_header[field]) - - elif response.status != 404: - # any response except 200 and 404 is expected - found_pep = True # actually not, but it's easier to display this way - log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " - f"{response.status}.\n{response.text}") - - error_message = "Unexpected HTTP error during PEP search. Please let us know." - pep_embed = Embed(title="Unexpected error", description=error_message) - pep_embed.colour = Colour.red() - break - - if not found_pep: - log.trace("PEP was not found") - not_found = f"PEP {pep_number} does not exist." - pep_embed = Embed(title="PEP not found", description=not_found) - pep_embed.colour = Colour.red() - - await ctx.message.channel.send(embed=pep_embed) - - @command() - @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) - async def charinfo(self, ctx: Context, *, characters: str) -> None: - """Shows you information on up to 50 unicode characters.""" - match = re.match(r"<(a?):(\w+):(\d+)>", characters) - if match: - return await messages.send_denial( - ctx, - "**Non-Character Detected**\n" - "Only unicode characters can be processed, but a custom Discord emoji " - "was found. Please remove it and try again." - ) - - if len(characters) > 50: - return await messages.send_denial(ctx, f"Too many characters ({len(characters)}/50)") - - def get_info(char: str) -> Tuple[str, str]: - digit = f"{ord(char):x}" - if len(digit) <= 4: - u_code = f"\\u{digit:>04}" - else: - u_code = f"\\U{digit:>08}" - url = f"https://www.compart.com/en/unicode/U+{digit:>04}" - name = f"[{unicodedata.name(char, '')}]({url})" - info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}" - return info, u_code - - char_list, raw_list = zip(*(get_info(c) for c in characters)) - embed = Embed().set_author(name="Character Info") - - if len(characters) > 1: - # Maximum length possible is 502 out of 1024, so there's no need to truncate. - embed.add_field(name='Full Raw Text', value=f"`{''.join(raw_list)}`", inline=False) - - await LinePaginator.paginate(char_list, ctx, embed, max_lines=10, max_size=2000, empty=False) - - @command() - async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: - """ - Show the Zen of Python. - - Without any arguments, the full Zen will be produced. - If an integer is provided, the line with that index will be produced. - If a string is provided, the line which matches best will be produced. - """ - embed = Embed( - colour=Colour.blurple(), - title="The Zen of Python", - description=ZEN_OF_PYTHON - ) - - if search_value is None: - embed.title += ", by Tim Peters" - await ctx.send(embed=embed) - return - - zen_lines = ZEN_OF_PYTHON.splitlines() - - # handle if it's an index int - if isinstance(search_value, int): - upper_bound = len(zen_lines) - 1 - lower_bound = -1 * upper_bound - if not (lower_bound <= search_value <= upper_bound): - raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") - - embed.title += f" (line {search_value % len(zen_lines)}):" - embed.description = zen_lines[search_value] - await ctx.send(embed=embed) - return - - # Try to handle first exact word due difflib.SequenceMatched may use some other similar word instead - # exact word. - for i, line in enumerate(zen_lines): - for word in line.split(): - if word.lower() == search_value.lower(): - embed.title += f" (line {i}):" - embed.description = line - await ctx.send(embed=embed) - return - - # handle if it's a search string and not exact word - matcher = difflib.SequenceMatcher(None, search_value.lower()) - - best_match = "" - match_index = 0 - best_ratio = 0 - - for index, line in enumerate(zen_lines): - matcher.set_seq2(line.lower()) - - # the match ratio needs to be adjusted because, naturally, - # longer lines will have worse ratios than shorter lines when - # fuzzy searching for keywords. this seems to work okay. - adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() - - if adjusted_ratio > best_ratio: - best_ratio = adjusted_ratio - best_match = line - match_index = index - - if not best_match: - raise BadArgument("I didn't get a match! Please try again with a different search term.") - - embed.title += f" (line {match_index}):" - embed.description = best_match - await ctx.send(embed=embed) - - @command(aliases=("poll",)) - @with_role(*MODERATION_ROLES) - async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None: - """ - Build a quick voting poll with matching reactions with the provided options. - - A maximum of 20 options can be provided, as Discord supports a max of 20 - reactions on a single message. - """ - if len(title) > 256: - raise BadArgument("The title cannot be longer than 256 characters.") - if len(options) < 2: - raise BadArgument("Please provide at least 2 options.") - if len(options) > 20: - raise BadArgument("I can only handle 20 options!") - - codepoint_start = 127462 # represents "regional_indicator_a" unicode value - options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} - embed = Embed(title=title, description="\n".join(options.values())) - message = await ctx.send(embed=embed) - for reaction in options: - await message.add_reaction(reaction) - - async def send_pep_zero(self, ctx: Context) -> None: - """Send information about PEP 0.""" - pep_embed = Embed( - title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**", - description="[Link](https://www.python.org/dev/peps/)" - ) - pep_embed.set_thumbnail(url=ICON_URL) - pep_embed.add_field(name="Status", value="Active") - pep_embed.add_field(name="Created", value="13-Jul-2000") - pep_embed.add_field(name="Type", value="Informational") - - await ctx.send(embed=pep_embed) - - -def setup(bot: Bot) -> None: - """Load the Utils cog.""" - bot.add_cog(Utils(bot)) diff --git a/bot/exts/__init__.py b/bot/exts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/alias.py b/bot/exts/alias.py new file mode 100644 index 000000000..77867b933 --- /dev/null +++ b/bot/exts/alias.py @@ -0,0 +1,153 @@ +import inspect +import logging + +from discord import Colour, Embed +from discord.ext.commands import ( + Cog, Command, Context, Greedy, + clean_content, command, group, +) + +from bot.bot import Bot +from bot.converters import FetchedMember, TagNameConverter +from bot.exts.utils.extensions import Extension +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + + +class Alias (Cog): + """Aliases for commonly used commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: + """Invokes a command with args and kwargs.""" + log.debug(f"{cmd_name} was invoked through an alias") + cmd = self.bot.get_command(cmd_name) + if not cmd: + return log.info(f'Did not find command "{cmd_name}" to invoke.') + elif not await cmd.can_run(ctx): + return log.info( + f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' + ) + + await ctx.invoke(cmd, *args, **kwargs) + + @command(name='aliases') + async def aliases_command(self, ctx: Context) -> None: + """Show configured aliases on the bot.""" + embed = Embed( + title='Configured aliases', + colour=Colour.blue() + ) + await LinePaginator.paginate( + ( + f"• `{ctx.prefix}{value.name}` " + f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" + for name, value in inspect.getmembers(self) + if isinstance(value, Command) and name.endswith('_alias') + ), + ctx, embed, empty=False, max_lines=20 + ) + + @command(name="resources", aliases=("resource",), hidden=True) + async def site_resources_alias(self, ctx: Context) -> None: + """Alias for invoking site resources.""" + await self.invoke(ctx, "site resources") + + @command(name="tools", hidden=True) + async def site_tools_alias(self, ctx: Context) -> None: + """Alias for invoking site tools.""" + await self.invoke(ctx, "site tools") + + @command(name="watch", hidden=True) + async def bigbrother_watch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking bigbrother watch [user] [reason].""" + await self.invoke(ctx, "bigbrother watch", user, reason=reason) + + @command(name="unwatch", hidden=True) + async def bigbrother_unwatch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking bigbrother unwatch [user] [reason].""" + await self.invoke(ctx, "bigbrother unwatch", user, reason=reason) + + @command(name="home", hidden=True) + async def site_home_alias(self, ctx: Context) -> None: + """Alias for invoking site home.""" + await self.invoke(ctx, "site home") + + @command(name="faq", hidden=True) + async def site_faq_alias(self, ctx: Context) -> None: + """Alias for invoking site faq.""" + await self.invoke(ctx, "site faq") + + @command(name="rules", aliases=("rule",), hidden=True) + async def site_rules_alias(self, ctx: Context, rules: Greedy[int], *_: str) -> None: + """Alias for invoking site rules.""" + await self.invoke(ctx, "site rules", *rules) + + @command(name="reload", hidden=True) + async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: + """Alias for invoking extensions reload [extensions...].""" + await self.invoke(ctx, "extensions reload", *extensions) + + @command(name="defon", hidden=True) + async def defcon_enable_alias(self, ctx: Context) -> None: + """Alias for invoking defcon enable.""" + await self.invoke(ctx, "defcon enable") + + @command(name="defoff", hidden=True) + async def defcon_disable_alias(self, ctx: Context) -> None: + """Alias for invoking defcon disable.""" + await self.invoke(ctx, "defcon disable") + + @command(name="exception", hidden=True) + async def tags_get_traceback_alias(self, ctx: Context) -> None: + """Alias for invoking tags get traceback.""" + await self.invoke(ctx, "tags get", tag_name="traceback") + + @group(name="get", + aliases=("show", "g"), + hidden=True, + invoke_without_command=True) + async def get_group_alias(self, ctx: Context) -> None: + """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" + pass + + @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) + async def tags_get_alias( + self, ctx: Context, *, tag_name: TagNameConverter = None + ) -> None: + """ + Alias for invoking tags get [tag_name]. + + tag_name: str - tag to be viewed. + """ + await self.invoke(ctx, "tags get", tag_name=tag_name) + + @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) + async def docs_get_alias( + self, ctx: Context, symbol: clean_content = None + ) -> None: + """Alias for invoking docs get [symbol].""" + await self.invoke(ctx, "docs get", symbol) + + @command(name="nominate", hidden=True) + async def nomination_add_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking talentpool add [user] [reason].""" + await self.invoke(ctx, "talentpool add", user, reason=reason) + + @command(name="unnominate", hidden=True) + async def nomination_end_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking nomination end [user] [reason].""" + await self.invoke(ctx, "nomination end", user, reason=reason) + + @command(name="nominees", hidden=True) + async def nominees_alias(self, ctx: Context) -> None: + """Alias for invoking tp watched.""" + await self.invoke(ctx, "talentpool watched") + + +def setup(bot: Bot) -> None: + """Load the Alias cog.""" + bot.add_cog(Alias(bot)) diff --git a/bot/exts/backend/__init__.py b/bot/exts/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/backend/config_verifier.py b/bot/exts/backend/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/exts/backend/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): + """Verify config on startup.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + + async def verify_channels(self) -> None: + """ + Verify channels. + + If any channels in config aren't present in server, log them in a warning. + """ + await self.bot.wait_until_guild_available() + server = self.bot.get_guild(constants.Guild.id) + + server_channel_ids = {channel.id for channel in server.channels} + invalid_channels = [ + channel_name for channel_name, channel_id in constants.Channels + if channel_id not in server_channel_ids + ] + + if invalid_channels: + log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: + """Load the ConfigVerifier cog.""" + bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py new file mode 100644 index 000000000..f9d4de638 --- /dev/null +++ b/bot/exts/backend/error_handler.py @@ -0,0 +1,287 @@ +import contextlib +import logging +import typing as t + +from discord import Embed +from discord.ext.commands import Cog, Context, errors +from sentry_sdk import push_scope + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Colours +from bot.converters import TagNameConverter +from bot.utils.checks import InWhitelistCheckFailure + +log = logging.getLogger(__name__) + + +class ErrorHandler(Cog): + """Handles errors emitted from commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_error_embed(self, title: str, body: str) -> Embed: + """Return an embed that contains the exception.""" + return Embed( + title=title, + colour=Colours.soft_red, + description=body + ) + + @Cog.listener() + async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: + """ + Provide generic command error handling. + + Error handling is deferred to any local error handler, if present. This is done by + checking for the presence of a `handled` attribute on the error. + + Error handling emits a single error message in the invoking context `ctx` and a log message, + prioritised as follows: + + 1. If the name fails to match a command: + * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. + Otherwise if it matches a tag, the tag is invoked + * If CommandNotFound is raised when invoking the tag (determined by the presence of the + `invoked_from_error_handler` attribute), this error is treated as being unexpected + and therefore sends an error message + * Commands in the verification channel are ignored + 2. UserInputError: see `handle_user_input_error` + 3. CheckFailure: see `handle_check_failure` + 4. CommandOnCooldown: send an error message in the invoking context + 5. ResponseCodeError: see `handle_api_error` + 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` + """ + command = ctx.command + + if hasattr(e, "handled"): + log.trace(f"Command {command} had its error already handled locally; ignoring.") + return + + if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + if await self.try_silence(ctx): + return + if ctx.channel.id != Channels.verification: + # Try to look for a tag with the command's name + await self.try_get_tag(ctx) + return # Exit early to avoid logging. + elif isinstance(e, errors.UserInputError): + await self.handle_user_input_error(ctx, e) + elif isinstance(e, errors.CheckFailure): + await self.handle_check_failure(ctx, e) + elif isinstance(e, errors.CommandOnCooldown): + await ctx.send(e) + elif isinstance(e, errors.CommandInvokeError): + if isinstance(e.original, ResponseCodeError): + await self.handle_api_error(ctx, e.original) + else: + await self.handle_unexpected_error(ctx, e.original) + return # Exit early to avoid logging. + elif not isinstance(e, errors.DisabledCommand): + # ConversionError, MaxConcurrencyReached, ExtensionError + await self.handle_unexpected_error(ctx, e) + return # Exit early to avoid logging. + + log.debug( + f"Command {command} invoked by {ctx.message.author} with error " + f"{e.__class__.__name__}: {e}" + ) + + @staticmethod + def get_help_command(ctx: Context) -> t.Coroutine: + """Return a prepared `help` command invocation coroutine.""" + if ctx.command: + return ctx.send_help(ctx.command) + + return ctx.send_help() + + async def try_silence(self, ctx: Context) -> bool: + """ + Attempt to invoke the silence or unsilence command if invoke with matches a pattern. + + Respecting the checks if: + * invoked with `shh+` silence channel for amount of h's*2 with max of 15. + * invoked with `unshh+` unsilence channel + Return bool depending on success of command. + """ + command = ctx.invoked_with.lower() + silence_command = self.bot.get_command("silence") + ctx.invoked_from_error_handler = True + try: + if not await silence_command.can_run(ctx): + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + except errors.CommandError: + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + if command.startswith("shh"): + await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) + return True + elif command.startswith("unshh"): + await ctx.invoke(self.bot.get_command("unsilence")) + return True + return False + + async def try_get_tag(self, ctx: Context) -> None: + """ + Attempt to display a tag by interpreting the command name as a tag name. + + The invocation of tags get respects its checks. Any CommandErrors raised will be handled + by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to + the context to prevent infinite recursion in the case of a CommandNotFound exception. + """ + tags_get_command = self.bot.get_command("tags get") + ctx.invoked_from_error_handler = True + + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except errors.CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) + return + + try: + tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) + except errors.BadArgument: + log.debug( + f"{ctx.author} tried to use an invalid command " + f"and the fallback tag failed validation in TagNameConverter." + ) + else: + with contextlib.suppress(ResponseCodeError): + await ctx.invoke(tags_get_command, tag_name=tag_name) + # Return to not raise the exception + return + + async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: + """ + Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. + + * MissingRequiredArgument: send an error message with arg name and the help command + * TooManyArguments: send an error message and the help command + * BadArgument: send an error message and the help command + * BadUnionArgument: send an error message including the error produced by the last converter + * ArgumentParsingError: send an error message + * Other: send an error message and the help command + """ + prepared_help_command = self.get_help_command(ctx) + + if isinstance(e, errors.MissingRequiredArgument): + embed = self._get_error_embed("Missing required argument", e.param.name) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.missing_required_argument") + elif isinstance(e, errors.TooManyArguments): + embed = self._get_error_embed("Too many arguments", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.too_many_arguments") + elif isinstance(e, errors.BadArgument): + embed = self._get_error_embed("Bad argument", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.bad_argument") + elif isinstance(e, errors.BadUnionArgument): + embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") + await ctx.send(embed=embed) + self.bot.stats.incr("errors.bad_union_argument") + elif isinstance(e, errors.ArgumentParsingError): + embed = self._get_error_embed("Argument parsing error", str(e)) + await ctx.send(embed=embed) + self.bot.stats.incr("errors.argument_parsing_error") + else: + embed = self._get_error_embed( + "Input error", + "Something about your input seems off. Check the arguments and try again." + ) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.other_user_input_error") + + @staticmethod + async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: + """ + Send an error message in `ctx` for certain types of CheckFailure. + + The following types are handled: + + * BotMissingPermissions + * BotMissingRole + * BotMissingAnyRole + * NoPrivateMessage + * InWhitelistCheckFailure + """ + bot_missing_errors = ( + errors.BotMissingPermissions, + errors.BotMissingRole, + errors.BotMissingAnyRole + ) + + if isinstance(e, bot_missing_errors): + ctx.bot.stats.incr("errors.bot_permission_error") + await ctx.send( + "Sorry, it looks like I don't have the permissions or roles I need to do that." + ) + elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): + ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") + await ctx.send(e) + + @staticmethod + async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: + """Send an error message in `ctx` for ResponseCodeError and log it.""" + if e.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + log.debug(f"API responded with 404 for command {ctx.command}") + ctx.bot.stats.incr("errors.api_error_404") + elif e.status == 400: + content = await e.response.json() + log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + await ctx.send("According to the API, your request is malformed.") + ctx.bot.stats.incr("errors.api_error_400") + elif 500 <= e.status < 600: + await ctx.send("Sorry, there seems to be an internal issue with the API.") + log.warning(f"API responded with {e.status} for command {ctx.command}") + ctx.bot.stats.incr("errors.api_internal_server_error") + else: + await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") + log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") + ctx.bot.stats.incr(f"errors.api_error_{e.status}") + + @staticmethod + async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: + """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" + await ctx.send( + f"Sorry, an unexpected error occurred. Please let us know!\n\n" + f"```{e.__class__.__name__}: {e}```" + ) + + ctx.bot.stats.incr("errors.unexpected") + + with push_scope() as scope: + scope.user = { + "id": ctx.author.id, + "username": str(ctx.author) + } + + scope.set_tag("command", ctx.command.qualified_name) + scope.set_tag("message_id", ctx.message.id) + scope.set_tag("channel_id", ctx.channel.id) + + scope.set_extra("full_message", ctx.message.content) + + if ctx.guild is not None: + scope.set_extra( + "jump_to", + f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" + ) + + log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) + + +def setup(bot: Bot) -> None: + """Load the ErrorHandler cog.""" + bot.add_cog(ErrorHandler(bot)) diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py new file mode 100644 index 000000000..94fa2b139 --- /dev/null +++ b/bot/exts/backend/logging.py @@ -0,0 +1,42 @@ +import logging + +from discord import Embed +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, DEBUG_MODE + + +log = logging.getLogger(__name__) + + +class Logging(Cog): + """Debug logging module.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.bot.loop.create_task(self.startup_greeting()) + + async def startup_greeting(self) -> None: + """Announce our presence to the configured devlog channel.""" + await self.bot.wait_until_guild_available() + log.info("Bot connected!") + + embed = Embed(description="Connected!") + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=( + "https://raw.githubusercontent.com/" + "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" + ) + ) + + if not DEBUG_MODE: + await self.bot.get_channel(Channels.dev_log).send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the Logging cog.""" + bot.add_cog(Logging(bot)) diff --git a/bot/exts/backend/sync/__init__.py b/bot/exts/backend/sync/__init__.py new file mode 100644 index 000000000..2541beaa8 --- /dev/null +++ b/bot/exts/backend/sync/__init__.py @@ -0,0 +1,7 @@ +from bot.bot import Bot + + +def setup(bot: Bot) -> None: + """Load the Sync cog.""" + from ._cog import Sync + bot.add_cog(Sync(bot)) diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py new file mode 100644 index 000000000..b6068f328 --- /dev/null +++ b/bot/exts/backend/sync/_cog.py @@ -0,0 +1,180 @@ +import logging +from typing import Any, Dict + +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from . import _syncers + +log = logging.getLogger(__name__) + + +class Sync(Cog): + """Captures relevant events and sends them to the site.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.role_syncer = _syncers.RoleSyncer(self.bot) + self.user_syncer = _syncers.UserSyncer(self.bot) + + self.bot.loop.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Syncs the roles/users of the guild with the database.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: + """Send a PATCH request to partially update a user in the database.""" + try: + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) + except ResponseCodeError as e: + if e.response.status != 404: + raise + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") + + @Cog.listener() + async def on_guild_role_create(self, role: Role) -> None: + """Adds newly create role to the database table over the API.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.post( + 'bot/roles', + json={ + 'colour': role.colour.value, + 'id': role.id, + 'name': role.name, + 'permissions': role.permissions.value, + 'position': role.position, + } + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: Role) -> None: + """Deletes role from the database when it's deleted from the guild.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + @Cog.listener() + async def on_guild_role_update(self, before: Role, after: Role) -> None: + """Syncs role with the database if any of the stored attributes were updated.""" + if after.guild.id != constants.Guild.id: + return + + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: + await self.bot.api_client.put( + f'bot/roles/{after.id}', + json={ + 'colour': after.colour.value, + 'id': after.id, + 'name': after.name, + 'permissions': after.permissions.value, + 'position': after.position, + } + ) + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """ + Adds a new user or updates existing user to the database when a member joins the guild. + + If the joining member is a user that is already known to the database (i.e., a user that + previously left), it will update the user's information. If the user is not yet known by + the database, the user is added. + """ + if member.guild.id != constants.Guild.id: + return + + packed = { + 'discriminator': int(member.discriminator), + 'id': member.id, + 'in_guild': True, + 'name': member.name, + 'roles': sorted(role.id for role in member.roles) + } + + got_error = False + + try: + # First try an update of the user to set the `in_guild` field and other + # fields that may have changed since the last time we've seen them. + await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) + + except ResponseCodeError as e: + # If we didn't get 404, something else broke - propagate it up. + if e.response.status != 404: + raise + + got_error = True # yikes + + if got_error: + # If we got `404`, the user is new. Create them. + await self.bot.api_client.post('bot/users', json=packed) + + @Cog.listener() + async def on_member_remove(self, member: Member) -> None: + """Set the in_guild field to False when a member leaves the guild.""" + if member.guild.id != constants.Guild.id: + return + + await self.patch_user(member.id, json={"in_guild": False}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Update the roles of the member in the database if a change is detected.""" + if after.guild.id != constants.Guild.id: + return + + if before.roles != after.roles: + updated_information = {"roles": sorted(role.id for role in after.roles)} + await self.patch_user(after.id, json=updated_information) + + @Cog.listener() + async def on_user_update(self, before: User, after: User) -> None: + """Update the user information in the database if a relevant change is detected.""" + attrs = ("name", "discriminator") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): + updated_information = { + "name": after.name, + "discriminator": int(after.discriminator), + } + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) + + @commands.group(name='sync') + @commands.has_permissions(administrator=True) + async def sync_group(self, ctx: Context) -> None: + """Run synchronizations between the bot and site manually.""" + + @sync_group.command(name='roles') + @commands.has_permissions(administrator=True) + async def sync_roles_command(self, ctx: Context) -> None: + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) + + @sync_group.command(name='users') + @commands.has_permissions(administrator=True) + async def sync_users_command(self, ctx: Context) -> None: + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py new file mode 100644 index 000000000..f7ba811bc --- /dev/null +++ b/bot/exts/backend/sync/_syncers.py @@ -0,0 +1,347 @@ +import abc +import asyncio +import logging +import typing as t +from collections import namedtuple +from functools import partial + +import discord +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot + +log = logging.getLogger(__name__) + +# These objects are declared as namedtuples because tuples are hashable, +# something that we make use of when diffing site roles against guild roles. +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) + + +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' + ) + + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + allowed_roles = [discord.Object(constants.Roles.core_developers)] + message = await channel.send( + f"{self._CORE_DEV_MENTION}{msg_content}", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS + ) + + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except asyncio.TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.info(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/exts/dm_relay.py b/bot/exts/dm_relay.py new file mode 100644 index 000000000..0d8f340b4 --- /dev/null +++ b/bot/exts/dm_relay.py @@ -0,0 +1,124 @@ +import logging +from typing import Optional + +import discord +from discord import Color +from discord.ext import commands +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.converters import UserMentionOrID +from bot.utils import RedisCache +from bot.utils.checks import in_whitelist_check, with_role_check +from bot.utils.messages import send_attachments +from bot.utils.webhooks import send_webhook + +log = logging.getLogger(__name__) + + +class DMRelay(Cog): + """Relay direct messages to and from the bot.""" + + # RedisCache[str, t.Union[discord.User.id, discord.Member.id]] + dm_cache = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.dm_log + self.webhook = None + self.bot.loop.create_task(self.fetch_webhook()) + + @commands.command(aliases=("reply",)) + async def send_dm(self, ctx: commands.Context, member: Optional[UserMentionOrID], *, message: str) -> None: + """ + Allows you to send a DM to a user from the bot. + + If `member` is not provided, it will send to the last user who DM'd the bot. + + This feature should be used extremely sparingly. Use ModMail if you need to have a serious + conversation with a user. This is just for responding to extraordinary DMs, having a little + fun with users, and telling people they are DMing the wrong bot. + + NOTE: This feature will be removed if it is overused. + """ + if not member: + user_id = await self.dm_cache.get("last_user") + member = ctx.guild.get_member(user_id) if user_id else None + + # If we still don't have a Member at this point, give up + if not member: + log.debug("This bot has never gotten a DM, or the RedisCache has been cleared.") + await ctx.message.add_reaction("❌") + return + + try: + await member.send(message) + except discord.errors.Forbidden: + log.debug("User has disabled DMs.") + await ctx.message.add_reaction("❌") + else: + await ctx.message.add_reaction("✅") + self.bot.stats.incr("dm_relay.dm_sent") + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_guild_available() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Relays the message's content and attachments to the dm_log channel.""" + # Only relay DMs from humans + if message.author.bot or message.guild or self.webhook is None: + return + + if message.clean_content: + await send_webhook( + webhook=self.webhook, + content=message.clean_content, + username=f"{message.author.display_name} ({message.author.id})", + avatar_url=message.author.avatar_url + ) + await self.dm_cache.set("last_user", message.author.id) + self.bot.stats.incr("dm_relay.dm_received") + + # Handle any attachments + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (discord.errors.Forbidden, discord.errors.NotFound): + e = discord.Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await send_webhook( + webhook=self.webhook, + embed=e, + username=f"{message.author.display_name} ({message.author.id})", + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception("Failed to send an attachment to the webhook") + + def cog_check(self, ctx: commands.Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_whitelist_check( + ctx, + channels=[constants.Channels.dm_log], + redirect=None, + fail_silently=True, + ) + ] + return all(checks) + + +def setup(bot: Bot) -> None: + """Load the DMRelay cog.""" + bot.add_cog(DMRelay(bot)) diff --git a/bot/exts/duck_pond.py b/bot/exts/duck_pond.py new file mode 100644 index 000000000..7021069fa --- /dev/null +++ b/bot/exts/duck_pond.py @@ -0,0 +1,166 @@ +import logging +from typing import Union + +import discord +from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.utils.messages import send_attachments +from bot.utils.webhooks import send_webhook + +log = logging.getLogger(__name__) + + +class DuckPond(Cog): + """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.duck_pond + self.webhook = None + self.bot.loop.create_task(self.fetch_webhook()) + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_guild_available() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @staticmethod + def is_staff(member: Union[User, Member]) -> bool: + """Check if a specific member or user is staff.""" + if hasattr(member, "roles"): + for role in member.roles: + if role.id in constants.STAFF_ROLES: + return True + return False + + async def has_green_checkmark(self, message: Message) -> bool: + """Check if the message has a green checkmark reaction.""" + for reaction in message.reactions: + if reaction.emoji == "✅": + async for user in reaction.users(): + if user == self.bot.user: + return True + return False + + async def count_ducks(self, message: Message) -> int: + """ + Count the number of ducks in the reactions of a specific message. + + Only counts ducks added by staff members. + """ + duck_count = 0 + duck_reactors = [] + + for reaction in message.reactions: + async for user in reaction.users(): + + # Is the user a staff member and not already counted as reactor? + if not self.is_staff(user) or user.id in duck_reactors: + continue + + # Is the emoji a duck? + if hasattr(reaction.emoji, "id"): + if reaction.emoji.id in constants.DuckPond.custom_emojis: + duck_count += 1 + duck_reactors.append(user.id) + elif isinstance(reaction.emoji, str): + if reaction.emoji == "🦆": + duck_count += 1 + duck_reactors.append(user.id) + return duck_count + + async def relay_message(self, message: Message) -> None: + """Relays the message's content and attachments to the duck pond channel.""" + if message.clean_content: + await send_webhook( + webhook=self.webhook, + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await send_webhook( + webhook=self.webhook, + embed=e, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception("Failed to send an attachment to the webhook") + + await message.add_reaction("✅") + + @staticmethod + def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: + """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" + if payload.emoji.is_custom_emoji(): + if payload.emoji.id in constants.DuckPond.custom_emojis: + return True + elif payload.emoji.name == "🦆": + return True + + return False + + @Cog.listener() + async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: + """ + Determine if a message should be sent to the duck pond. + + This will count the number of duck reactions on the message, and if this amount meets the + amount of ducks specified in the config under duck_pond/threshold, it will + send the message off to the duck pond. + """ + # Is the emoji in the reaction a duck? + if not self._payload_has_duckpond_emoji(payload): + return + + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + message = await channel.fetch_message(payload.message_id) + member = discord.utils.get(message.guild.members, id=payload.user_id) + + # Is the member a human and a staff member? + if not self.is_staff(member) or member.bot: + return + + # Does the message already have a green checkmark? + if await self.has_green_checkmark(message): + return + + # Time to count our ducks! + duck_count = await self.count_ducks(message) + + # If we've got more than the required amount of ducks, send the message to the duck_pond. + if duck_count >= constants.DuckPond.threshold: + await self.relay_message(message) + + @Cog.listener() + async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: + """Ensure that people don't remove the green checkmark from duck ponded messages.""" + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + + # Prevent the green checkmark from being removed + if payload.emoji.name == "✅": + message = await channel.fetch_message(payload.message_id) + duck_count = await self.count_ducks(message) + if duck_count >= constants.DuckPond.threshold: + await message.add_reaction("✅") + + +def setup(bot: Bot) -> None: + """Load the DuckPond cog.""" + bot.add_cog(DuckPond(bot)) diff --git a/bot/exts/filters/__init__.py b/bot/exts/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py new file mode 100644 index 000000000..c76bd2c60 --- /dev/null +++ b/bot/exts/filters/antimalware.py @@ -0,0 +1,98 @@ +import logging +import typing as t +from os.path import splitext + +from discord import Embed, Message, NotFound +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, STAFF_ROLES, URLs + +log = logging.getLogger(__name__) + +PY_EMBED_DESCRIPTION = ( + "It looks like you tried to attach a Python file - " + f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" +) + +TXT_EMBED_DESCRIPTION = ( + "**Uh-oh!** It looks like your message got zapped by our spam filter. " + "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" + "• If you attempted to send a message longer than 2000 characters, try shortening your message " + "to fit within the character limit or use a pasting service (see below) \n\n" + "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " + "{cmd_channel_mention} for more information) or use a pasting service like: " + f"\n\n{URLs.site_schema}{URLs.site_paste}" +) + +DISALLOWED_EMBED_DESCRIPTION = ( + "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " + "We currently allow the following file types: **{joined_whitelist}**.\n\n" + "Feel free to ask in {meta_channel_mention} if you think this is a mistake." +) + + +class AntiMalware(Cog): + """Delete messages which contain attachments with non-whitelisted file extensions.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_whitelisted_file_formats(self) -> list: + """Get the file formats currently on the whitelist.""" + return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() + + def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: + """Get an iterable containing all the disallowed extensions of attachments.""" + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) + return extensions_blocked + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Identify messages with prohibited attachments.""" + # Return when message don't have attachment and don't moderate DMs + if not message.attachments or not message.guild: + return + + # Check if user is staff, if is, return + # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance + if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): + return + + embed = Embed() + extensions_blocked = self._get_disallowed_extensions(message) + blocked_extensions_str = ', '.join(extensions_blocked) + if ".py" in extensions_blocked: + # Short-circuit on *.py files to provide a pastebin link + embed.description = PY_EMBED_DESCRIPTION + elif ".txt" in extensions_blocked: + # Work around Discord AutoConversion of messages longer than 2000 chars to .txt + cmd_channel = self.bot.get_channel(Channels.bot_commands) + embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) + elif extensions_blocked: + meta_channel = self.bot.get_channel(Channels.meta) + embed.description = DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=', '.join(self._get_whitelisted_file_formats()), + blocked_extensions_str=blocked_extensions_str, + meta_channel_mention=meta_channel.mention, + ) + + if embed.description: + log.info( + f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", + extra={"attachment_list": [attachment.filename for attachment in message.attachments]} + ) + + await message.channel.send(f"Hey {message.author.mention}!", embed=embed) + + # Delete the offending message: + try: + await message.delete() + except NotFound: + log.info(f"Tried to delete message `{message.id}`, but message could not be found.") + + +def setup(bot: Bot) -> None: + """Load the AntiMalware cog.""" + bot.add_cog(AntiMalware(bot)) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py new file mode 100644 index 000000000..3c5f13ebf --- /dev/null +++ b/bot/exts/filters/antispam.py @@ -0,0 +1,288 @@ +import asyncio +import logging +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from operator import itemgetter +from typing import Dict, Iterable, List, Set + +from discord import Colour, Member, Message, NotFound, Object, TextChannel +from discord.ext.commands import Cog + +from bot import rules +from bot.bot import Bot +from bot.constants import ( + AntiSpam as AntiSpamConfig, Channels, + Colours, DEBUG_MODE, Event, Filter, + Guild as GuildConfig, Icons, + STAFF_ROLES, +) +from bot.converters import Duration +from bot.exts.moderation.modlog import ModLog +from bot.utils.messages import send_attachments + + +log = logging.getLogger(__name__) + +RULE_FUNCTION_MAPPING = { + 'attachments': rules.apply_attachments, + 'burst': rules.apply_burst, + 'burst_shared': rules.apply_burst_shared, + 'chars': rules.apply_chars, + 'discord_emojis': rules.apply_discord_emojis, + 'duplicates': rules.apply_duplicates, + 'links': rules.apply_links, + 'mentions': rules.apply_mentions, + 'newlines': rules.apply_newlines, + 'role_mentions': rules.apply_role_mentions +} + + +@dataclass +class DeletionContext: + """Represents a Deletion Context for a single spam event.""" + + channel: TextChannel + members: Dict[int, Member] = field(default_factory=dict) + rules: Set[str] = field(default_factory=set) + messages: Dict[int, Message] = field(default_factory=dict) + attachments: List[List[str]] = field(default_factory=list) + + async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: + """Adds new rule violation events to the deletion context.""" + self.rules.add(rule_name) + + for member in members: + if member.id not in self.members: + self.members[member.id] = member + + for message in messages: + if message.id not in self.messages: + self.messages[message.id] = message + + # Re-upload attachments + destination = message.guild.get_channel(Channels.attachment_log) + urls = await send_attachments(message, destination, link_large=False) + self.attachments.append(urls) + + async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: + """Method that takes care of uploading the queue and posting modlog alert.""" + triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values()) + + mod_alert_message = ( + f"**Triggered by:** {triggered_by_users}\n" + f"**Channel:** {self.channel.mention}\n" + f"**Rules:** {', '.join(rule for rule in self.rules)}\n" + ) + + # For multiple messages or those with excessive newlines, use the logs API + if len(self.messages) > 1 or 'newlines' in self.rules: + url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) + mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" + else: + mod_alert_message += "Message:\n" + [message] = self.messages.values() + content = message.clean_content + remaining_chars = 2040 - len(mod_alert_message) + + if len(content) > remaining_chars: + content = content[:remaining_chars] + "..." + + mod_alert_message += f"{content}" + + *_, last_message = self.messages.values() + await modlog.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title="Spam detected!", + text=mod_alert_message, + thumbnail=last_message.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=AntiSpamConfig.ping_everyone + ) + + +class AntiSpam(Cog): + """Cog that controls our anti-spam measures.""" + + def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: + self.bot = bot + self.validation_errors = validation_errors + role_id = AntiSpamConfig.punishment['role_id'] + self.muted_role = Object(role_id) + self.expiration_date_converter = Duration() + + self.message_deletion_queue = dict() + + self.bot.loop.create_task(self.alert_on_validation_error()) + + @property + def mod_log(self) -> ModLog: + """Allows for easy access of the ModLog cog.""" + return self.bot.get_cog("ModLog") + + async def alert_on_validation_error(self) -> None: + """Unloads the cog and alerts admins if configuration validation failed.""" + await self.bot.wait_until_guild_available() + if self.validation_errors: + body = "**The following errors were encountered:**\n" + body += "\n".join(f"- {error}" for error in self.validation_errors.values()) + body += "\n\n**The cog has been unloaded.**" + + await self.mod_log.send_log_message( + title="Error: AntiSpam configuration validation failed!", + text=body, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Colour.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Applies the antispam rules to each received message.""" + if ( + not message.guild + or message.guild.id != GuildConfig.id + or message.author.bot + or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) + or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) + ): + return + + # Fetch the rule configuration with the highest rule interval. + max_interval_config = max( + AntiSpamConfig.rules.values(), + key=itemgetter('interval') + ) + max_interval = max_interval_config['interval'] + + # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. + earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) + relevant_messages = [ + msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) + if not msg.author.bot + ] + + for rule_name in AntiSpamConfig.rules: + rule_config = AntiSpamConfig.rules[rule_name] + rule_function = RULE_FUNCTION_MAPPING[rule_name] + + # Create a list of messages that were sent in the interval that the rule cares about. + latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) + messages_for_rule = [ + msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp + ] + result = await rule_function(message, messages_for_rule, rule_config) + + # If the rule returns `None`, that means the message didn't violate it. + # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` + # which contains the reason for why the message violated the rule and + # an iterable of all members that violated the rule. + if result is not None: + self.bot.stats.incr(f"mod_alerts.{rule_name}") + reason, members, relevant_messages = result + full_reason = f"`{rule_name}` rule: {reason}" + + # If there's no spam event going on for this channel, start a new Message Deletion Context + channel = message.channel + if channel.id not in self.message_deletion_queue: + log.trace(f"Creating queue for channel `{channel.id}`") + self.message_deletion_queue[message.channel.id] = DeletionContext(channel) + self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) + + # Add the relevant of this trigger to the Deletion Context + await self.message_deletion_queue[message.channel.id].add( + rule_name=rule_name, + members=members, + messages=relevant_messages + ) + + for member in members: + + # Fire it off as a background task to ensure + # that the sleep doesn't block further tasks + self.bot.loop.create_task( + self.punish(message, member, full_reason) + ) + + await self.maybe_delete_messages(channel, relevant_messages) + break + + async def punish(self, msg: Message, member: Member, reason: str) -> None: + """Punishes the given member for triggering an antispam rule.""" + if not any(role.id == self.muted_role.id for role in member.roles): + remove_role_after = AntiSpamConfig.punishment['remove_after'] + + # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes + context = await self.bot.get_context(msg) + context.author = self.bot.user + context.message.author = self.bot.user + + # Since we're going to invoke the tempmute command directly, we need to manually call the converter. + dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S") + await context.invoke( + self.bot.get_command('tempmute'), + member, + dt_remove_role_after, + reason=reason + ) + + async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: + """Cleans the messages if cleaning is configured.""" + if AntiSpamConfig.clean_offending: + # If we have more than one message, we can use bulk delete. + if len(messages) > 1: + message_ids = [message.id for message in messages] + self.mod_log.ignore(Event.message_delete, *message_ids) + await channel.delete_messages(messages) + + # Otherwise, the bulk delete endpoint will throw up. + # Delete the message directly instead. + else: + self.mod_log.ignore(Event.message_delete, messages[0].id) + try: + await messages[0].delete() + except NotFound: + log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") + + async def _process_deletion_context(self, context_id: int) -> None: + """Processes the Deletion Context queue.""" + log.trace("Sleeping before processing message deletion queue.") + await asyncio.sleep(10) + + if context_id not in self.message_deletion_queue: + log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") + return + + deletion_context = self.message_deletion_queue.pop(context_id) + await deletion_context.upload_messages(self.bot.user.id, self.mod_log) + + +def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: + """Validates the antispam configs.""" + validation_errors = {} + for name, config in rules_.items(): + if name not in RULE_FUNCTION_MAPPING: + log.error( + f"Unrecognized antispam rule `{name}`. " + f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" + ) + validation_errors[name] = f"`{name}` is not recognized as an antispam rule." + continue + for required_key in ('interval', 'max'): + if required_key not in config: + log.error( + f"`{required_key}` is required but was not " + f"set in rule `{name}`'s configuration." + ) + validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" + return validation_errors + + +def setup(bot: Bot) -> None: + """Validate the AntiSpam configs and load the AntiSpam cog.""" + validation_errors = validate_config() + bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py new file mode 100644 index 000000000..c15adc461 --- /dev/null +++ b/bot/exts/filters/filter_lists.py @@ -0,0 +1,273 @@ +import logging +from typing import Optional + +from discord import Colour, Embed +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.converters import ValidDiscordServerInvite, ValidFilterListType +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + + +class FilterLists(Cog): + """Commands for blacklisting and whitelisting things.""" + + methods_with_filterlist_types = [ + "allow_add", + "allow_delete", + "allow_get", + "deny_add", + "deny_delete", + "deny_get", + ] + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.bot.loop.create_task(self._amend_docstrings()) + + async def _amend_docstrings(self) -> None: + """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" + await self.bot.wait_until_guild_available() + + # Add valid filterlist types to the docstrings + valid_types = await ValidFilterListType.get_valid_types(self.bot) + valid_types = [f"`{type_.lower()}`" for type_ in valid_types] + + for method_name in self.methods_with_filterlist_types: + command = getattr(self, method_name) + command.help = ( + f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." + ) + + async def _add_data( + self, + ctx: Context, + allowed: bool, + list_type: ValidFilterListType, + content: str, + comment: Optional[str] = None, + ) -> None: + """Add an item to a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we gotta validate it. + if list_type == "GUILD_INVITE": + guild_data = await self._validate_guild_invite(ctx, content) + content = guild_data.get("id") + + # Unless the user has specified another comment, let's + # use the server name as the comment so that the list + # of guild IDs will be more easily readable when we + # display it. + if not comment: + comment = guild_data.get("name") + + # If it's a file format, let's make sure it has a leading dot. + elif list_type == "FILE_FORMAT" and not content.startswith("."): + content = f".{content}" + + # Try to add the item to the database + log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") + payload = { + "allowed": allowed, + "type": list_type, + "content": content, + "comment": comment, + } + + try: + item = await self.bot.api_client.post( + "bot/filter-lists", + json=payload + ) + except ResponseCodeError as e: + if e.status == 400: + await ctx.message.add_reaction("❌") + log.debug( + f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " + "probably because the request violated the UniqueConstraint." + ) + raise BadArgument( + f"Unable to add the item to the {allow_type}. " + "The item probably already exists. Keep in mind that a " + "blacklist and a whitelist for the same item cannot co-exist, " + "and we do not permit any duplicates." + ) + raise + + # Insert the item into the cache + self.bot.insert_item_into_filter_list_cache(item) + await ctx.message.add_reaction("✅") + + async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we need to convert it. + if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): + guild_data = await self._validate_guild_invite(ctx, content) + content = guild_data.get("id") + + # If it's a file format, let's make sure it has a leading dot. + elif list_type == "FILE_FORMAT" and not content.startswith("."): + content = f".{content}" + + # Find the content and delete it. + log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") + item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) + + if item is not None: + try: + await self.bot.api_client.delete( + f"bot/filter-lists/{item['id']}" + ) + del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] + await ctx.message.add_reaction("✅") + except ResponseCodeError as e: + log.debug( + f"{ctx.author} tried to delete an item with the id {item['id']}, but " + f"the API raised an unexpected error: {e}" + ) + await ctx.message.add_reaction("❌") + else: + await ctx.message.add_reaction("❌") + + async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: + """Paginate and display all items in a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] + + # Build a list of lines we want to show in the paginator + lines = [] + for content, metadata in result.items(): + line = f"• `{content}`" + + if comment := metadata.get("comment"): + line += f" - {comment}" + + lines.append(line) + lines = sorted(lines) + + # Build the embed + list_type_plural = list_type.lower().replace("_", " ").title() + "s" + embed = Embed( + title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", + colour=Colour.blue() + ) + log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") + + if result: + await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + await ctx.message.add_reaction("❌") + + async def _sync_data(self, ctx: Context) -> None: + """Syncs the filterlists with the API.""" + try: + log.trace("Attempting to sync FilterList cache with data from the API.") + await self.bot.cache_filter_list_data() + await ctx.message.add_reaction("✅") + except ResponseCodeError as e: + log.debug( + f"{ctx.author} tried to sync FilterList cache data but " + f"the API raised an unexpected error: {e}" + ) + await ctx.message.add_reaction("❌") + + @staticmethod + async def _validate_guild_invite(ctx: Context, invite: str) -> dict: + """ + Validates a guild invite, and returns the guild info as a dict. + + Will raise a BadArgument if the guild invite is invalid. + """ + log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") + validator = ValidDiscordServerInvite() + guild_data = await validator.convert(ctx, invite) + + # If we make it this far without raising a BadArgument, the invite is + # valid. Let's return a dict of guild information. + log.trace(f"{invite} validated as server invite. Converting to ID.") + return guild_data + + @group(aliases=("allowlist", "allow", "al", "wl")) + async def whitelist(self, ctx: Context) -> None: + """Group for whitelisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @group(aliases=("denylist", "deny", "bl", "dl")) + async def blacklist(self, ctx: Context) -> None: + """Group for blacklisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @whitelist.command(name="add", aliases=("a", "set")) + async def allow_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified allowlist.""" + await self._add_data(ctx, True, list_type, content, comment) + + @blacklist.command(name="add", aliases=("a", "set")) + async def deny_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified denylist.""" + await self._add_data(ctx, False, list_type, content, comment) + + @whitelist.command(name="remove", aliases=("delete", "rm",)) + async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified allowlist.""" + await self._delete_data(ctx, True, list_type, content) + + @blacklist.command(name="remove", aliases=("delete", "rm",)) + async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified denylist.""" + await self._delete_data(ctx, False, list_type, content) + + @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified allowlist.""" + await self._list_all_data(ctx, True, list_type) + + @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified denylist.""" + await self._list_all_data(ctx, False, list_type) + + @whitelist.command(name="sync", aliases=("s",)) + async def allow_sync(self, ctx: Context) -> None: + """Syncs both allowlists and denylists with the API.""" + await self._sync_data(ctx) + + @blacklist.command(name="sync", aliases=("s",)) + async def deny_sync(self, ctx: Context) -> None: + """Syncs both allowlists and denylists with the API.""" + await self._sync_data(ctx) + + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the FilterLists cog.""" + bot.add_cog(FilterLists(bot)) diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py new file mode 100644 index 000000000..2ae476d8a --- /dev/null +++ b/bot/exts/filters/filtering.py @@ -0,0 +1,575 @@ +import asyncio +import logging +import re +from datetime import datetime, timedelta +from typing import List, Mapping, Optional, Tuple, Union + +import dateutil +import discord.errors +from dateutil.relativedelta import relativedelta +from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel +from discord.ext.commands import Cog +from discord.utils import escape_markdown + +from bot.bot import Bot +from bot.constants import ( + Channels, Colours, + Filter, Icons, URLs +) +from bot.exts.moderation.modlog import ModLog +from bot.utils.redis_cache import RedisCache +from bot.utils.regex import INVITE_RE +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + +# Regular expressions +SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) +URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) +ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") + +# Other constants. +DAYS_BETWEEN_ALERTS = 3 +OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) + + +class Filtering(Cog): + """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" + + # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent + name_alerts = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + self.name_lock = asyncio.Lock() + + staff_mistake_str = "If you believe this was a mistake, please let staff know!" + self.filters = { + "filter_zalgo": { + "enabled": Filter.filter_zalgo, + "function": self._has_zalgo, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_zalgo, + "notification_msg": ( + "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " + f"{staff_mistake_str}" + ), + "schedule_deletion": False + }, + "filter_invites": { + "enabled": Filter.filter_invites, + "function": self._has_invites, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_invites, + "notification_msg": ( + f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" + r"Our server rules can be found here: " + ), + "schedule_deletion": False + }, + "filter_domains": { + "enabled": Filter.filter_domains, + "function": self._has_urls, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_domains, + "notification_msg": ( + f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" + ), + "schedule_deletion": False + }, + "watch_regex": { + "enabled": Filter.watch_regex, + "function": self._has_watch_regex_match, + "type": "watchlist", + "content_only": True, + "schedule_deletion": True + }, + "watch_rich_embeds": { + "enabled": Filter.watch_rich_embeds, + "function": self._has_rich_embed, + "type": "watchlist", + "content_only": False, + "schedule_deletion": False + } + } + + self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: + """Fetch items from the filter_list_cache.""" + return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() + + @staticmethod + def _expand_spoilers(text: str) -> str: + """Return a string containing all interpretations of a spoilered message.""" + split_text = SPOILER_RE.split(text) + return ''.join( + split_text[0::2] + split_text[1::2] + split_text + ) + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Invoke message filter for new messages.""" + await self._filter_message(msg) + + # Ignore webhook messages. + if msg.webhook_id is None: + await self.check_bad_words_in_name(msg.author) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Invoke message filter for message edits. + + If there have been multiple edits, calculate the time delta from the previous edit. + """ + if not before.edited_at: + delta = relativedelta(after.edited_at, before.created_at).microseconds + else: + delta = relativedelta(after.edited_at, before.edited_at).microseconds + await self._filter_message(after, delta) + + def get_name_matches(self, name: str) -> List[re.Match]: + """Check bad words from passed string (name). Return list of matches.""" + matches = [] + watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) + for pattern in watchlist_patterns: + if match := re.search(pattern, name, flags=re.IGNORECASE): + matches.append(match) + return matches + + async def check_send_alert(self, member: Member) -> bool: + """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" + if last_alert := await self.name_alerts.get(member.id): + last_alert = datetime.utcfromtimestamp(last_alert) + if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: + log.trace(f"Last alert was too recent for {member}'s nickname.") + return False + + return True + + async def check_bad_words_in_name(self, member: Member) -> None: + """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" + # Use lock to avoid race conditions + async with self.name_lock: + # Check whether the users display name contains any words in our blacklist + matches = self.get_name_matches(member.display_name) + + if not matches or not await self.check_send_alert(member): + return + + log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") + + log_string = ( + f"**User:** {member.mention} (`{member.id}`)\n" + f"**Display Name:** {member.display_name}\n" + f"**Bad Matches:** {', '.join(match.group() for match in matches)}" + ) + + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colours.soft_red, + title="Username filtering alert", + text=log_string, + channel_id=Channels.mod_alerts, + thumbnail=member.avatar_url + ) + + # Update time when alert sent + await self.name_alerts.set(member.id, datetime.utcnow().timestamp()) + + async def filter_eval(self, result: str, msg: Message) -> bool: + """ + Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. + + Also requires the original message, to check whether to filter and for mod logs. + Returns whether a filter was triggered or not. + """ + filter_triggered = False + # Should we filter this message? + if self._check_filter(msg): + for filter_name, _filter in self.filters.items(): + # Is this specific filter enabled in the config? + # We also do not need to worry about filters that take the full message, + # since all we have is an arbitrary string. + if _filter["enabled"] and _filter["content_only"]: + match = await _filter["function"](result) + + if match: + # If this is a filter (not a watchlist), we set the variable so we know + # that it has been triggered + if _filter["type"] == "filter": + filter_triggered = True + + # We do not have to check against DM channels since !eval cannot be used there. + channel_str = f"in {msg.channel.mention}" + + message_content, additional_embeds, additional_embeds_msg = self._add_stats( + filter_name, match, result + ) + + message = ( + f"The {filter_name} {_filter['type']} was triggered " + f"by **{msg.author}** " + f"(`{msg.author.id}`) {channel_str} using !eval with " + f"[the following message]({msg.jump_url}):\n\n" + f"{message_content}" + ) + + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title=f"{_filter['type'].title()} triggered!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=Filter.ping_everyone, + additional_embeds=additional_embeds, + additional_embeds_msg=additional_embeds_msg + ) + + break # We don't want multiple filters to trigger + + return filter_triggered + + async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: + """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" + # Should we filter this message? + if self._check_filter(msg): + for filter_name, _filter in self.filters.items(): + # Is this specific filter enabled in the config? + if _filter["enabled"]: + # Double trigger check for the embeds filter + if filter_name == "watch_rich_embeds": + # If the edit delta is less than 0.001 seconds, then we're probably dealing + # with a double filter trigger. + if delta is not None and delta < 100: + continue + + # Does the filter only need the message content or the full message? + if _filter["content_only"]: + match = await _filter["function"](msg.content) + else: + match = await _filter["function"](msg) + + if match: + is_private = msg.channel.type is discord.ChannelType.private + + # If this is a filter (not a watchlist) and not in a DM, delete the message. + if _filter["type"] == "filter" and not is_private: + try: + # Embeds (can?) trigger both the `on_message` and `on_message_edit` + # event handlers, triggering filtering twice for the same message. + # + # If `on_message`-triggered filtering already deleted the message + # then `on_message_edit`-triggered filtering will raise exception + # since the message no longer exists. + # + # In addition, to avoid sending two notifications to the user, the + # logs, and mod_alert, we return if the message no longer exists. + await msg.delete() + except discord.errors.NotFound: + return + + # Notify the user if the filter specifies + if _filter["user_notification"]: + await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) + + # If the message is classed as offensive, we store it in the site db and + # it will be deleted it after one week. + if _filter["schedule_deletion"] and not is_private: + delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() + data = { + 'id': msg.id, + 'channel_id': msg.channel.id, + 'delete_date': delete_date + } + + await self.bot.api_client.post('bot/offensive-messages', json=data) + self.schedule_msg_delete(data) + log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") + + if is_private: + channel_str = "via DM" + else: + channel_str = f"in {msg.channel.mention}" + + message_content, additional_embeds, additional_embeds_msg = self._add_stats( + filter_name, match, msg.content + ) + + message = ( + f"The {filter_name} {_filter['type']} was triggered " + f"by **{msg.author}** " + f"(`{msg.author.id}`) {channel_str} with [the " + f"following message]({msg.jump_url}):\n\n" + f"{message_content}" + ) + + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.filtering, + colour=Colour(Colours.soft_red), + title=f"{_filter['type'].title()} triggered!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ping_everyone=Filter.ping_everyone if not is_private else False, + additional_embeds=additional_embeds, + additional_embeds_msg=additional_embeds_msg + ) + + break # We don't want multiple filters to trigger + + def _add_stats(self, name: str, match: Union[re.Match, dict, bool, List[discord.Embed]], content: str) -> Tuple[ + str, Optional[List[discord.Embed]], Optional[str] + ]: + """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" + # Word and match stats for watch_regex + if name == "watch_regex": + surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] + message_content = ( + f"**Match:** '{match[0]}'\n" + f"**Location:** '...{escape_markdown(surroundings)}...'\n" + f"\n**Original Message:**\n{escape_markdown(content)}" + ) + else: # Use original content + message_content = content + + additional_embeds = None + additional_embeds_msg = None + + self.bot.stats.incr(f"filters.{name}") + + # The function returns True for invalid invites. + # They have no data so additional embeds can't be created for them. + if name == "filter_invites" and match is not True: + additional_embeds = [] + for _, data in match.items(): + embed = discord.Embed(description=( + f"**Members:**\n{data['members']}\n" + f"**Active:**\n{data['active']}" + )) + embed.set_author(name=data["name"]) + embed.set_thumbnail(url=data["icon"]) + embed.set_footer(text=f"Guild ID: {data['id']}") + additional_embeds.append(embed) + additional_embeds_msg = "For the following guild(s):" + + elif name == "watch_rich_embeds": + additional_embeds = match + additional_embeds_msg = "With the following embed(s):" + + return message_content, additional_embeds, additional_embeds_msg + + @staticmethod + def _check_filter(msg: Message) -> bool: + """Check whitelists to see if we should filter this message.""" + role_whitelisted = False + + if type(msg.author) is Member: # Only Member has roles, not User. + for role in msg.author.roles: + if role.id in Filter.role_whitelist: + role_whitelisted = True + + return ( + msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist + and not role_whitelisted # Role not in whitelist + and not msg.author.bot # Author not a bot + ) + + async def _has_watch_regex_match(self, text: str) -> Union[bool, re.Match]: + """ + Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. + + `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is + matched as-is. Spoilers are expanded, if any, and URLs are ignored. + """ + if SPOILER_RE.search(text): + text = self._expand_spoilers(text) + + # Make sure it's not a URL + if URL_RE.search(text): + return False + + watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) + for pattern in watchlist_patterns: + match = re.search(pattern, text, flags=re.IGNORECASE) + if match: + return match + + async def _has_urls(self, text: str) -> bool: + """Returns True if the text contains one of the blacklisted URLs from the config file.""" + if not URL_RE.search(text): + return False + + text = text.lower() + domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) + + for url in domain_blacklist: + if url.lower() in text: + return True + + return False + + @staticmethod + async def _has_zalgo(text: str) -> bool: + """ + Returns True if the text contains zalgo characters. + + Zalgo range is \u0300 – \u036F and \u0489. + """ + return bool(ZALGO_RE.search(text)) + + async def _has_invites(self, text: str) -> Union[dict, bool]: + """ + Checks if there's any invites in the text content that aren't in the guild whitelist. + + If any are detected, a dictionary of invite data is returned, with a key per invite. + If none are detected, False is returned. + + Attempts to catch some of common ways to try to cheat the system. + """ + # Remove backslashes to prevent escape character aroundfuckery like + # discord\.gg/gdudes-pony-farm + text = text.replace("\\", "") + + invites = INVITE_RE.findall(text) + invite_data = dict() + for invite in invites: + if invite in invite_data: + continue + + response = await self.bot.http_session.get( + f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} + ) + response = await response.json() + guild = response.get("guild") + if guild is None: + # Lack of a "guild" key in the JSON response indicates either an group DM invite, an + # expired invite, or an invalid invite. The API does not currently differentiate + # between invalid and expired invites + return True + + guild_id = guild.get("id") + guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) + guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) + + # Is this invite allowed? + guild_partnered_or_verified = ( + 'PARTNERED' in guild.get("features", []) + or 'VERIFIED' in guild.get("features", []) + ) + invite_not_allowed = ( + guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted. + or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted. + and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered. + ) + + if invite_not_allowed: + guild_icon_hash = guild["icon"] + guild_icon = ( + "https://cdn.discordapp.com/icons/" + f"{guild_id}/{guild_icon_hash}.png?size=512" + ) + + invite_data[invite] = { + "name": guild["name"], + "id": guild['id'], + "icon": guild_icon, + "members": response["approximate_member_count"], + "active": response["approximate_presence_count"] + } + + return invite_data if invite_data else False + + @staticmethod + async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: + """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" + if msg.embeds: + for embed in msg.embeds: + if embed.type == "rich": + urls = URL_RE.findall(msg.content) + if not embed.url or embed.url not in urls: + # If `embed.url` does not exist or if `embed.url` is not part of the content + # of the message, it's unlikely to be an auto-generated embed by Discord. + return msg.embeds + else: + log.trace( + "Found a rich embed sent by a regular user account, " + "but it was likely just an automatic URL embed." + ) + return False + return False + + async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: + """ + Notify filtered_member about a moderation action with the reason str. + + First attempts to DM the user, fall back to in-channel notification if user has DMs disabled + """ + try: + await filtered_member.send(reason) + except discord.errors.Forbidden: + await channel.send(f"{filtered_member.mention} {reason}") + + def schedule_msg_delete(self, msg: dict) -> None: + """Delete an offensive message once its deletion date is reached.""" + delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) + self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) + + async def reschedule_offensive_msg_deletion(self) -> None: + """Get all the pending message deletion from the API and reschedule them.""" + await self.bot.wait_until_ready() + response = await self.bot.api_client.get('bot/offensive-messages',) + + now = datetime.utcnow() + + for msg in response: + delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) + + if delete_at < now: + await self.delete_offensive_msg(msg) + else: + self.schedule_msg_delete(msg) + + async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: + """Delete an offensive message, and then delete it from the db.""" + try: + channel = self.bot.get_channel(msg['channel_id']) + if channel: + msg_obj = await channel.fetch_message(msg['id']) + await msg_obj.delete() + except NotFound: + log.info( + f"Tried to delete message {msg['id']}, but the message can't be found " + f"(it has been probably already deleted)." + ) + except HTTPException as e: + log.warning(f"Failed to delete message {msg['id']}: status {e.status}") + + await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') + log.info(f"Deleted the offensive message with id {msg['id']}.") + + +def setup(bot: Bot) -> None: + """Load the Filtering cog.""" + bot.add_cog(Filtering(bot)) diff --git a/bot/exts/filters/security.py b/bot/exts/filters/security.py new file mode 100644 index 000000000..c680c5e27 --- /dev/null +++ b/bot/exts/filters/security.py @@ -0,0 +1,31 @@ +import logging + +from discord.ext.commands import Cog, Context, NoPrivateMessage + +from bot.bot import Bot + +log = logging.getLogger(__name__) + + +class Security(Cog): + """Security-related helpers.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.bot.check(self.check_not_bot) # Global commands check - no bots can run any commands at all + self.bot.check(self.check_on_guild) # Global commands check - commands can't be run in a DM + + def check_not_bot(self, ctx: Context) -> bool: + """Check if the context is a bot user.""" + return not ctx.author.bot + + def check_on_guild(self, ctx: Context) -> bool: + """Check if the context is in a guild.""" + if ctx.guild is None: + raise NoPrivateMessage("This command cannot be used in private messages.") + return True + + +def setup(bot: Bot) -> None: + """Load the Security cog.""" + bot.add_cog(Security(bot)) diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filters/token_remover.py new file mode 100644 index 000000000..0eda3dc6a --- /dev/null +++ b/bot/exts/filters/token_remover.py @@ -0,0 +1,182 @@ +import base64 +import binascii +import logging +import re +import typing as t + +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog + +from bot import utils +from bot.bot import Bot +from bot.constants import Channels, Colours, Event, Icons +from bot.exts.moderation.modlog import ModLog + +log = logging.getLogger(__name__) + +LOG_MESSAGE = ( + "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, " + "token was `{user_id}.{timestamp}.{hmac}`" +) +DELETION_MESSAGE_TEMPLATE = ( + "Hey {mention}! I noticed you posted a seemingly valid Discord API " + "token in your message and have removed your message. " + "This means that your token has been **compromised**. " + "Please change your token **immediately** at: " + "\n\n" + "Feel free to re-post it with the token removed. " + "If you believe this was a mistake, please let us know!" +) +DISCORD_EPOCH = 1_420_070_400 +TOKEN_EPOCH = 1_293_840_000 + +# Three parts delimited by dots: user ID, creation timestamp, HMAC. +# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string. +# Each part only matches base64 URL-safe characters. +# Padding has never been observed, but the padding character '=' is matched just in case. +TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII) + + +class Token(t.NamedTuple): + """A Discord Bot token.""" + + user_id: str + timestamp: str + hmac: str + + +class TokenRemover(Cog): + """Scans messages for potential discord.py bot tokens and removes them.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """ + Check each message for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return + + found_token = self.find_token_in_message(msg) + if found_token: + await self.take_action(msg, found_token) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + await self.on_message(after) + + async def take_action(self, msg: Message, found_token: Token) -> None: + """Remove the `msg` containing the `found_token` and send a mod log message.""" + self.mod_log.ignore(Event.message_delete, msg.id) + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") + return + + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + log_message = self.format_log_message(msg, found_token) + log.debug(log_message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=log_message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) + + self.bot.stats.incr("tokens.removed_tokens") + + @staticmethod + def format_log_message(msg: Message, token: Token) -> str: + """Return the log message to send for `token` being censored in `msg`.""" + return LOG_MESSAGE.format( + author=msg.author, + author_id=msg.author.id, + channel=msg.channel.mention, + user_id=token.user_id, + timestamp=token.timestamp, + hmac='x' * len(token.hmac), + ) + + @classmethod + def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: + """Return a seemingly valid token found in `msg` or `None` if no token is found.""" + # Use finditer rather than search to guard against method calls prematurely returning the + # token check (e.g. `message.channel.send` also matches our token pattern) + for match in TOKEN_RE.finditer(msg.content): + token = Token(*match.groups()) + if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp): + # Short-circuit on first match + return token + + # No matching substring + return + + @staticmethod + def is_valid_user_id(b64_content: str) -> bool: + """ + Check potential token to see if it contains a valid Discord user ID. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + b64_content = utils.pad_base64(b64_content) + + try: + decoded_bytes = base64.urlsafe_b64decode(b64_content) + string = decoded_bytes.decode('utf-8') + + # isdigit on its own would match a lot of other Unicode characters, hence the isascii. + return string.isascii() and string.isdigit() + except (binascii.Error, ValueError): + return False + + @staticmethod + def is_valid_timestamp(b64_content: str) -> bool: + """ + Return True if `b64_content` decodes to a valid timestamp. + + If the timestamp is greater than the Discord epoch, it's probably valid. + See: https://i.imgur.com/7WdehGn.png + """ + b64_content = utils.pad_base64(b64_content) + + try: + decoded_bytes = base64.urlsafe_b64decode(b64_content) + timestamp = int.from_bytes(decoded_bytes, byteorder="big") + except (binascii.Error, ValueError) as e: + log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") + return False + + # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound + # is not checked. + if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH: + return True + else: + log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch") + return False + + +def setup(bot: Bot) -> None: + """Load the TokenRemover cog.""" + bot.add_cog(TokenRemover(bot)) diff --git a/bot/exts/filters/webhook_remover.py b/bot/exts/filters/webhook_remover.py new file mode 100644 index 000000000..ca126ebf5 --- /dev/null +++ b/bot/exts/filters/webhook_remover.py @@ -0,0 +1,84 @@ +import logging +import re + +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, Colours, Event, Icons +from bot.exts.moderation.modlog import ModLog + +WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) + +ALERT_MESSAGE_TEMPLATE = ( + "{user}, looks like you posted a Discord webhook URL. Therefore, your " + "message has been removed. Your webhook may have been **compromised** so " + "please re-create the webhook **immediately**. If you believe this was " + "mistake, please let us know." +) + +log = logging.getLogger(__name__) + + +class WebhookRemover(Cog): + """Scan messages to detect Discord webhooks links.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get current instance of `ModLog`.""" + return self.bot.get_cog("ModLog") + + async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: + """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" + # Don't log this, due internal delete, not by user. Will make different entry. + self.mod_log.ignore(Event.message_delete, msg.id) + + try: + await msg.delete() + except NotFound: + log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") + return + + await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) + + message = ( + f"{msg.author} (`{msg.author.id}`) posted a Discord webhook URL " + f"to #{msg.channel}. Webhook URL was `{redacted_url}`" + ) + log.debug(message) + + # Send entry to moderation alerts. + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Discord webhook URL removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts + ) + + self.bot.stats.incr("tokens.removed_webhooks") + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Check if a Discord webhook URL is in `message`.""" + # Ignore DMs; can't delete messages in there anyway. + if not msg.guild or msg.author.bot: + return + + matches = WEBHOOK_URL_RE.search(msg.content) + if matches: + await self.delete_and_respond(msg, matches[1] + "xxx") + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """Check if a Discord webhook URL is in the edited message `after`.""" + await self.on_message(after) + + +def setup(bot: Bot) -> None: + """Load `WebhookRemover` cog.""" + bot.add_cog(WebhookRemover(bot)) diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py new file mode 100644 index 000000000..57094751e --- /dev/null +++ b/bot/exts/help_channels.py @@ -0,0 +1,944 @@ +import asyncio +import json +import logging +import random +import typing as t +from collections import deque +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import discord +import discord.abc +from discord.ext import commands + +from bot import constants +from bot.bot import Bot +from bot.utils import RedisCache +from bot.utils.checks import with_role_check +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + +ASKING_GUIDE_URL = "https://pythondiscord.com/pages/asking-good-questions/" +MAX_CHANNELS_PER_CATEGORY = 50 +EXCLUDED_CHANNELS = (constants.Channels.how_to_get_help, constants.Channels.cooldown) + +HELP_CHANNEL_TOPIC = """ +This is a Python help channel. You can claim your own help channel in the Python Help: Available category. +""" + +AVAILABLE_MSG = f""" +This help channel is now **available**, which means that you can claim it by simply typing your \ +question into it. Once claimed, the channel will move into the **Python Help: Occupied** category, \ +and will be yours until it has been inactive for {constants.HelpChannels.idle_minutes} minutes or \ +is closed manually with `!close`. When that happens, it will be set to **dormant** and moved into \ +the **Help: Dormant** category. + +Try to write the best question you can by providing a detailed description and telling us what \ +you've tried already. For more information on asking a good question, \ +check out our guide on [asking good questions]({ASKING_GUIDE_URL}). +""" + +DORMANT_MSG = f""" +This help channel has been marked as **dormant**, and has been moved into the **Help: Dormant** \ +category at the bottom of the channel list. It is no longer possible to send messages in this \ +channel until it becomes available again. + +If your question wasn't answered yet, you can claim a new help channel from the \ +**Help: Available** category by simply asking your question again. Consider rephrasing the \ +question to maximize your chance of getting a good answer. If you're not sure how, have a look \ +through our guide for [asking a good question]({ASKING_GUIDE_URL}). +""" + +CoroutineFunc = t.Callable[..., t.Coroutine] + + +class HelpChannels(commands.Cog): + """ + Manage the help channel system of the guild. + + The system is based on a 3-category system: + + Available Category + + * Contains channels which are ready to be occupied by someone who needs help + * Will always contain `constants.HelpChannels.max_available` channels; refilled automatically + from the pool of dormant channels + * Prioritise using the channels which have been dormant for the longest amount of time + * If there are no more dormant channels, the bot will automatically create a new one + * If there are no dormant channels to move, helpers will be notified (see `notify()`) + * When a channel becomes available, the dormant embed will be edited to show `AVAILABLE_MSG` + * User can only claim a channel at an interval `constants.HelpChannels.claim_minutes` + * To keep track of cooldowns, user which claimed a channel will have a temporary role + + In Use Category + + * Contains all channels which are occupied by someone needing help + * Channel moves to dormant category after `constants.HelpChannels.idle_minutes` of being idle + * Command can prematurely mark a channel as dormant + * Channel claimant is allowed to use the command + * Allowed roles for the command are configurable with `constants.HelpChannels.cmd_whitelist` + * When a channel becomes dormant, an embed with `DORMANT_MSG` will be sent + + Dormant Category + + * Contains channels which aren't in use + * Channels are used to refill the Available category + + Help channels are named after the chemical elements in `bot/resources/elements.json`. + """ + + # This cache tracks which channels are claimed by which members. + # RedisCache[discord.TextChannel.id, t.Union[discord.User.id, discord.Member.id]] + help_channel_claimants = RedisCache() + + # This cache maps a help channel to whether it has had any + # activity other than the original claimant. True being no other + # activity and False being other activity. + # RedisCache[discord.TextChannel.id, bool] + unanswered = RedisCache() + + # This dictionary maps a help channel to the time it was claimed + # RedisCache[discord.TextChannel.id, UtcPosixTimestamp] + claim_times = RedisCache() + + # This cache maps a help channel to original question message in same channel. + # RedisCache[discord.TextChannel.id, discord.Message.id] + question_messages = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + # Categories + self.available_category: discord.CategoryChannel = None + self.in_use_category: discord.CategoryChannel = None + self.dormant_category: discord.CategoryChannel = None + + # Queues + self.channel_queue: asyncio.Queue[discord.TextChannel] = None + self.name_queue: t.Deque[str] = None + + self.name_positions = self.get_names() + self.last_notification: t.Optional[datetime] = None + + # Asyncio stuff + self.queue_tasks: t.List[asyncio.Task] = [] + self.ready = asyncio.Event() + self.on_message_lock = asyncio.Lock() + self.init_task = self.bot.loop.create_task(self.init_cog()) + + def cog_unload(self) -> None: + """Cancel the init task and scheduled tasks when the cog unloads.""" + log.trace("Cog unload: cancelling the init_cog task") + self.init_task.cancel() + + log.trace("Cog unload: cancelling the channel queue tasks") + for task in self.queue_tasks: + task.cancel() + + self.scheduler.cancel_all() + + def create_channel_queue(self) -> asyncio.Queue: + """ + Return a queue of dormant channels to use for getting the next available channel. + + The channels are added to the queue in a random order. + """ + log.trace("Creating the channel queue.") + + channels = list(self.get_category_channels(self.dormant_category)) + random.shuffle(channels) + + log.trace("Populating the channel queue with channels.") + queue = asyncio.Queue() + for channel in channels: + queue.put_nowait(channel) + + return queue + + async def create_dormant(self) -> t.Optional[discord.TextChannel]: + """ + Create and return a new channel in the Dormant category. + + The new channel will sync its permission overwrites with the category. + + Return None if no more channel names are available. + """ + log.trace("Getting a name for a new dormant channel.") + + try: + name = self.name_queue.popleft() + except IndexError: + log.debug("No more names available for new dormant channels.") + return None + + log.debug(f"Creating a new dormant channel named {name}.") + return await self.dormant_category.create_text_channel(name, topic=HELP_CHANNEL_TOPIC) + + def create_name_queue(self) -> deque: + """Return a queue of element names to use for creating new channels.""" + log.trace("Creating the chemical element name queue.") + + used_names = self.get_used_names() + + log.trace("Determining the available names.") + available_names = (name for name in self.name_positions if name not in used_names) + + log.trace("Populating the name queue with names.") + return deque(available_names) + + async def dormant_check(self, ctx: commands.Context) -> bool: + """Return True if the user is the help channel claimant or passes the role check.""" + if await self.help_channel_claimants.get(ctx.channel.id) == ctx.author.id: + log.trace(f"{ctx.author} is the help channel claimant, passing the check for dormant.") + self.bot.stats.incr("help.dormant_invoke.claimant") + return True + + log.trace(f"{ctx.author} is not the help channel claimant, checking roles.") + role_check = with_role_check(ctx, *constants.HelpChannels.cmd_whitelist) + + if role_check: + self.bot.stats.incr("help.dormant_invoke.staff") + + return role_check + + @commands.command(name="close", aliases=["dormant", "solved"], enabled=False) + async def close_command(self, ctx: commands.Context) -> None: + """ + Make the current in-use help channel dormant. + + Make the channel dormant if the user passes the `dormant_check`, + delete the message that invoked this, + and reset the send permissions cooldown for the user who started the session. + """ + log.trace("close command invoked; checking if the channel is in-use.") + if ctx.channel.category == self.in_use_category: + if await self.dormant_check(ctx): + await self.remove_cooldown_role(ctx.author) + + # Ignore missing task when cooldown has passed but the channel still isn't dormant. + if ctx.author.id in self.scheduler: + self.scheduler.cancel(ctx.author.id) + + await self.move_to_dormant(ctx.channel, "command") + self.scheduler.cancel(ctx.channel.id) + else: + log.debug(f"{ctx.author} invoked command 'dormant' outside an in-use help channel") + + async def get_available_candidate(self) -> discord.TextChannel: + """ + Return a dormant channel to turn into an available channel. + + If no channel is available, wait indefinitely until one becomes available. + """ + log.trace("Getting an available channel candidate.") + + try: + channel = self.channel_queue.get_nowait() + except asyncio.QueueEmpty: + log.info("No candidate channels in the queue; creating a new channel.") + channel = await self.create_dormant() + + if not channel: + log.info("Couldn't create a candidate channel; waiting to get one from the queue.") + await self.notify() + channel = await self.wait_for_dormant_channel() + + return channel + + @staticmethod + def get_clean_channel_name(channel: discord.TextChannel) -> str: + """Return a clean channel name without status emojis prefix.""" + prefix = constants.HelpChannels.name_prefix + try: + # Try to remove the status prefix using the index of the channel prefix + name = channel.name[channel.name.index(prefix):] + log.trace(f"The clean name for `{channel}` is `{name}`") + except ValueError: + # If, for some reason, the channel name does not contain "help-" fall back gracefully + log.info(f"Can't get clean name because `{channel}` isn't prefixed by `{prefix}`.") + name = channel.name + + return name + + @staticmethod + def is_excluded_channel(channel: discord.abc.GuildChannel) -> bool: + """Check if a channel should be excluded from the help channel system.""" + return not isinstance(channel, discord.TextChannel) or channel.id in EXCLUDED_CHANNELS + + def get_category_channels(self, category: discord.CategoryChannel) -> t.Iterable[discord.TextChannel]: + """Yield the text channels of the `category` in an unsorted manner.""" + log.trace(f"Getting text channels in the category '{category}' ({category.id}).") + + # This is faster than using category.channels because the latter sorts them. + for channel in self.bot.get_guild(constants.Guild.id).channels: + if channel.category_id == category.id and not self.is_excluded_channel(channel): + yield channel + + async def get_in_use_time(self, channel_id: int) -> t.Optional[timedelta]: + """Return the duration `channel_id` has been in use. Return None if it's not in use.""" + log.trace(f"Calculating in use time for channel {channel_id}.") + + claimed_timestamp = await self.claim_times.get(channel_id) + if claimed_timestamp: + claimed = datetime.utcfromtimestamp(claimed_timestamp) + return datetime.utcnow() - claimed + + @staticmethod + def get_names() -> t.List[str]: + """ + Return a truncated list of prefixed element names. + + The amount of names is configured with `HelpChannels.max_total_channels`. + The prefix is configured with `HelpChannels.name_prefix`. + """ + count = constants.HelpChannels.max_total_channels + prefix = constants.HelpChannels.name_prefix + + log.trace(f"Getting the first {count} element names from JSON.") + + with Path("bot/resources/elements.json").open(encoding="utf-8") as elements_file: + all_names = json.load(elements_file) + + if prefix: + return [prefix + name for name in all_names[:count]] + else: + return all_names[:count] + + def get_used_names(self) -> t.Set[str]: + """Return channel names which are already being used.""" + log.trace("Getting channel names which are already being used.") + + names = set() + for cat in (self.available_category, self.in_use_category, self.dormant_category): + for channel in self.get_category_channels(cat): + names.add(self.get_clean_channel_name(channel)) + + if len(names) > MAX_CHANNELS_PER_CATEGORY: + log.warning( + f"Too many help channels ({len(names)}) already exist! " + f"Discord only supports {MAX_CHANNELS_PER_CATEGORY} in a category." + ) + + log.trace(f"Got {len(names)} used names: {names}") + return names + + @classmethod + async def get_idle_time(cls, channel: discord.TextChannel) -> t.Optional[int]: + """ + Return the time elapsed, in seconds, since the last message sent in the `channel`. + + Return None if the channel has no messages. + """ + log.trace(f"Getting the idle time for #{channel} ({channel.id}).") + + msg = await cls.get_last_message(channel) + if not msg: + log.debug(f"No idle time available; #{channel} ({channel.id}) has no messages.") + return None + + idle_time = (datetime.utcnow() - msg.created_at).seconds + + log.trace(f"#{channel} ({channel.id}) has been idle for {idle_time} seconds.") + return idle_time + + @staticmethod + async def get_last_message(channel: discord.TextChannel) -> t.Optional[discord.Message]: + """Return the last message sent in the channel or None if no messages exist.""" + log.trace(f"Getting the last message in #{channel} ({channel.id}).") + + try: + return await channel.history(limit=1).next() # noqa: B305 + except discord.NoMoreItems: + log.debug(f"No last message available; #{channel} ({channel.id}) has no messages.") + return None + + async def init_available(self) -> None: + """Initialise the Available category with channels.""" + log.trace("Initialising the Available category with channels.") + + channels = list(self.get_category_channels(self.available_category)) + missing = constants.HelpChannels.max_available - len(channels) + + # If we've got less than `max_available` channel available, we should add some. + if missing > 0: + log.trace(f"Moving {missing} missing channels to the Available category.") + for _ in range(missing): + await self.move_to_available() + + # If for some reason we have more than `max_available` channels available, + # we should move the superfluous ones over to dormant. + elif missing < 0: + log.trace(f"Moving {abs(missing)} superfluous available channels over to the Dormant category.") + for channel in channels[:abs(missing)]: + await self.move_to_dormant(channel, "auto") + + async def init_categories(self) -> None: + """Get the help category objects. Remove the cog if retrieval fails.""" + log.trace("Getting the CategoryChannel objects for the help categories.") + + try: + self.available_category = await self.try_get_channel( + constants.Categories.help_available + ) + self.in_use_category = await self.try_get_channel(constants.Categories.help_in_use) + self.dormant_category = await self.try_get_channel(constants.Categories.help_dormant) + except discord.HTTPException: + log.exception("Failed to get a category; cog will be removed") + self.bot.remove_cog(self.qualified_name) + + async def init_cog(self) -> None: + """Initialise the help channel system.""" + log.trace("Waiting for the guild to be available before initialisation.") + await self.bot.wait_until_guild_available() + + log.trace("Initialising the cog.") + await self.init_categories() + await self.check_cooldowns() + + self.channel_queue = self.create_channel_queue() + self.name_queue = self.create_name_queue() + + log.trace("Moving or rescheduling in-use channels.") + for channel in self.get_category_channels(self.in_use_category): + await self.move_idle_channel(channel, has_task=False) + + # Prevent the command from being used until ready. + # The ready event wasn't used because channels could change categories between the time + # the command is invoked and the cog is ready (e.g. if move_idle_channel wasn't called yet). + # This may confuse users. So would potentially long delays for the cog to become ready. + self.close_command.enabled = True + + await self.init_available() + + log.info("Cog is ready!") + self.ready.set() + + self.report_stats() + + def report_stats(self) -> None: + """Report the channel count stats.""" + total_in_use = sum(1 for _ in self.get_category_channels(self.in_use_category)) + total_available = sum(1 for _ in self.get_category_channels(self.available_category)) + total_dormant = sum(1 for _ in self.get_category_channels(self.dormant_category)) + + self.bot.stats.gauge("help.total.in_use", total_in_use) + self.bot.stats.gauge("help.total.available", total_available) + self.bot.stats.gauge("help.total.dormant", total_dormant) + + @staticmethod + def is_claimant(member: discord.Member) -> bool: + """Return True if `member` has the 'Help Cooldown' role.""" + return any(constants.Roles.help_cooldown == role.id for role in member.roles) + + def match_bot_embed(self, message: t.Optional[discord.Message], description: str) -> bool: + """Return `True` if the bot's `message`'s embed description matches `description`.""" + if not message or not message.embeds: + return False + + bot_msg_desc = message.embeds[0].description + if bot_msg_desc is discord.Embed.Empty: + log.trace("Last message was a bot embed but it was empty.") + return False + return message.author == self.bot.user and bot_msg_desc.strip() == description.strip() + + @staticmethod + def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: + """Return True if `channel` is within a category with `category_id`.""" + actual_category = getattr(channel, "category", None) + return actual_category is not None and actual_category.id == category_id + + async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None: + """ + Make the `channel` dormant if idle or schedule the move if still active. + + If `has_task` is True and rescheduling is required, the extant task to make the channel + dormant will first be cancelled. + """ + log.trace(f"Handling in-use channel #{channel} ({channel.id}).") + + if not await self.is_empty(channel): + idle_seconds = constants.HelpChannels.idle_minutes * 60 + else: + idle_seconds = constants.HelpChannels.deleted_idle_minutes * 60 + + time_elapsed = await self.get_idle_time(channel) + + if time_elapsed is None or time_elapsed >= idle_seconds: + log.info( + f"#{channel} ({channel.id}) is idle longer than {idle_seconds} seconds " + f"and will be made dormant." + ) + + await self.move_to_dormant(channel, "auto") + else: + # Cancel the existing task, if any. + if has_task: + self.scheduler.cancel(channel.id) + + delay = idle_seconds - time_elapsed + log.info( + f"#{channel} ({channel.id}) is still active; " + f"scheduling it to be moved after {delay} seconds." + ) + + self.scheduler.schedule_later(delay, channel.id, self.move_idle_channel(channel)) + + async def move_to_bottom_position(self, channel: discord.TextChannel, category_id: int, **options) -> None: + """ + Move the `channel` to the bottom position of `category` and edit channel attributes. + + To ensure "stable sorting", we use the `bulk_channel_update` endpoint and provide the current + positions of the other channels in the category as-is. This should make sure that the channel + really ends up at the bottom of the category. + + If `options` are provided, the channel will be edited after the move is completed. This is the + same order of operations that `discord.TextChannel.edit` uses. For information on available + options, see the documention on `discord.TextChannel.edit`. While possible, position-related + options should be avoided, as it may interfere with the category move we perform. + """ + # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. + category = await self.try_get_channel(category_id) + + payload = [{"id": c.id, "position": c.position} for c in category.channels] + + # Calculate the bottom position based on the current highest position in the category. If the + # category is currently empty, we simply use the current position of the channel to avoid making + # unnecessary changes to positions in the guild. + bottom_position = payload[-1]["position"] + 1 if payload else channel.position + + payload.append( + { + "id": channel.id, + "position": bottom_position, + "parent_id": category.id, + "lock_permissions": True, + } + ) + + # We use d.py's method to ensure our request is processed by d.py's rate limit manager + await self.bot.http.bulk_channel_update(category.guild.id, payload) + + # Now that the channel is moved, we can edit the other attributes + if options: + await channel.edit(**options) + + async def move_to_available(self) -> None: + """Make a channel available.""" + log.trace("Making a channel available.") + + channel = await self.get_available_candidate() + log.info(f"Making #{channel} ({channel.id}) available.") + + await self.send_available_message(channel) + + log.trace(f"Moving #{channel} ({channel.id}) to the Available category.") + + await self.move_to_bottom_position( + channel=channel, + category_id=constants.Categories.help_available, + ) + + self.report_stats() + + async def move_to_dormant(self, channel: discord.TextChannel, caller: str) -> None: + """ + Make the `channel` dormant. + + A caller argument is provided for metrics. + """ + log.info(f"Moving #{channel} ({channel.id}) to the Dormant category.") + + await self.help_channel_claimants.delete(channel.id) + await self.move_to_bottom_position( + channel=channel, + category_id=constants.Categories.help_dormant, + ) + + self.bot.stats.incr(f"help.dormant_calls.{caller}") + + in_use_time = await self.get_in_use_time(channel.id) + if in_use_time: + self.bot.stats.timing("help.in_use_time", in_use_time) + + unanswered = await self.unanswered.get(channel.id) + if unanswered: + self.bot.stats.incr("help.sessions.unanswered") + elif unanswered is not None: + self.bot.stats.incr("help.sessions.answered") + + log.trace(f"Position of #{channel} ({channel.id}) is actually {channel.position}.") + log.trace(f"Sending dormant message for #{channel} ({channel.id}).") + embed = discord.Embed(description=DORMANT_MSG) + await channel.send(embed=embed) + + await self.unpin(channel) + + log.trace(f"Pushing #{channel} ({channel.id}) into the channel queue.") + self.channel_queue.put_nowait(channel) + self.report_stats() + + async def move_to_in_use(self, channel: discord.TextChannel) -> None: + """Make a channel in-use and schedule it to be made dormant.""" + log.info(f"Moving #{channel} ({channel.id}) to the In Use category.") + + await self.move_to_bottom_position( + channel=channel, + category_id=constants.Categories.help_in_use, + ) + + timeout = constants.HelpChannels.idle_minutes * 60 + + log.trace(f"Scheduling #{channel} ({channel.id}) to become dormant in {timeout} sec.") + self.scheduler.schedule_later(timeout, channel.id, self.move_idle_channel(channel)) + self.report_stats() + + async def notify(self) -> None: + """ + Send a message notifying about a lack of available help channels. + + Configuration: + + * `HelpChannels.notify` - toggle notifications + * `HelpChannels.notify_channel` - destination channel for notifications + * `HelpChannels.notify_minutes` - minimum interval between notifications + * `HelpChannels.notify_roles` - roles mentioned in notifications + """ + if not constants.HelpChannels.notify: + return + + log.trace("Notifying about lack of channels.") + + if self.last_notification: + elapsed = (datetime.utcnow() - self.last_notification).seconds + minimum_interval = constants.HelpChannels.notify_minutes * 60 + should_send = elapsed >= minimum_interval + else: + should_send = True + + if not should_send: + log.trace("Notification not sent because it's too recent since the previous one.") + return + + try: + log.trace("Sending notification message.") + + channel = self.bot.get_channel(constants.HelpChannels.notify_channel) + mentions = " ".join(f"<@&{role}>" for role in constants.HelpChannels.notify_roles) + allowed_roles = [discord.Object(id_) for id_ in constants.HelpChannels.notify_roles] + + message = await channel.send( + f"{mentions} A new available help channel is needed but there " + f"are no more dormant ones. Consider freeing up some in-use channels manually by " + f"using the `{constants.Bot.prefix}dormant` command within the channels.", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + + self.bot.stats.incr("help.out_of_channel_alerts") + + self.last_notification = message.created_at + except Exception: + # Handle it here cause this feature isn't critical for the functionality of the system. + log.exception("Failed to send notification about lack of dormant channels!") + + async def check_for_answer(self, message: discord.Message) -> None: + """Checks for whether new content in a help channel comes from non-claimants.""" + channel = message.channel + + # Confirm the channel is an in use help channel + if self.is_in_category(channel, constants.Categories.help_in_use): + log.trace(f"Checking if #{channel} ({channel.id}) has been answered.") + + # Check if there is an entry in unanswered + if await self.unanswered.contains(channel.id): + claimant_id = await self.help_channel_claimants.get(channel.id) + if not claimant_id: + # The mapping for this channel doesn't exist, we can't do anything. + return + + # Check the message did not come from the claimant + if claimant_id != message.author.id: + # Mark the channel as answered + await self.unanswered.set(channel.id, False) + + @commands.Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Move an available channel to the In Use category and replace it with a dormant one.""" + if message.author.bot: + return # Ignore messages sent by bots. + + channel = message.channel + + await self.check_for_answer(message) + + if not self.is_in_category(channel, constants.Categories.help_available) or self.is_excluded_channel(channel): + return # Ignore messages outside the Available category or in excluded channels. + + log.trace("Waiting for the cog to be ready before processing messages.") + await self.ready.wait() + + log.trace("Acquiring lock to prevent a channel from being processed twice...") + async with self.on_message_lock: + log.trace(f"on_message lock acquired for {message.id}.") + + if not self.is_in_category(channel, constants.Categories.help_available): + log.debug( + f"Message {message.id} will not make #{channel} ({channel.id}) in-use " + f"because another message in the channel already triggered that." + ) + return + + log.info(f"Channel #{channel} was claimed by `{message.author.id}`.") + await self.move_to_in_use(channel) + await self.revoke_send_permissions(message.author) + + await self.pin(message) + + # Add user with channel for dormant check. + await self.help_channel_claimants.set(channel.id, message.author.id) + + self.bot.stats.incr("help.claimed") + + # Must use a timezone-aware datetime to ensure a correct POSIX timestamp. + timestamp = datetime.now(timezone.utc).timestamp() + await self.claim_times.set(channel.id, timestamp) + + await self.unanswered.set(channel.id, True) + + log.trace(f"Releasing on_message lock for {message.id}.") + + # Move a dormant channel to the Available category to fill in the gap. + # This is done last and outside the lock because it may wait indefinitely for a channel to + # be put in the queue. + await self.move_to_available() + + @commands.Cog.listener() + async def on_message_delete(self, msg: discord.Message) -> None: + """ + Reschedule an in-use channel to become dormant sooner if the channel is empty. + + The new time for the dormant task is configured with `HelpChannels.deleted_idle_minutes`. + """ + if not self.is_in_category(msg.channel, constants.Categories.help_in_use): + return + + if not await self.is_empty(msg.channel): + return + + log.info(f"Claimant of #{msg.channel} ({msg.author}) deleted message, channel is empty now. Rescheduling task.") + + # Cancel existing dormant task before scheduling new. + self.scheduler.cancel(msg.channel.id) + + delay = constants.HelpChannels.deleted_idle_minutes * 60 + self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel)) + + async def is_empty(self, channel: discord.TextChannel) -> bool: + """Return True if there's an AVAILABLE_MSG and the messages leading up are bot messages.""" + log.trace(f"Checking if #{channel} ({channel.id}) is empty.") + + # A limit of 100 results in a single API call. + # If AVAILABLE_MSG isn't found within 100 messages, then assume the channel is not empty. + # Not gonna do an extensive search for it cause it's too expensive. + async for msg in channel.history(limit=100): + if not msg.author.bot: + log.trace(f"#{channel} ({channel.id}) has a non-bot message.") + return False + + if self.match_bot_embed(msg, AVAILABLE_MSG): + log.trace(f"#{channel} ({channel.id}) has the available message embed.") + return True + + return False + + async def check_cooldowns(self) -> None: + """Remove expired cooldowns and re-schedule active ones.""" + log.trace("Checking all cooldowns to remove or re-schedule them.") + guild = self.bot.get_guild(constants.Guild.id) + cooldown = constants.HelpChannels.claim_minutes * 60 + + for channel_id, member_id in await self.help_channel_claimants.items(): + member = guild.get_member(member_id) + if not member: + continue # Member probably left the guild. + + in_use_time = await self.get_in_use_time(channel_id) + + if not in_use_time or in_use_time.seconds > cooldown: + # Remove the role if no claim time could be retrieved or if the cooldown expired. + # Since the channel is in the claimants cache, it is definitely strange for a time + # to not exist. However, it isn't a reason to keep the user stuck with a cooldown. + await self.remove_cooldown_role(member) + else: + # The member is still on a cooldown; re-schedule it for the remaining time. + delay = cooldown - in_use_time.seconds + self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) + + async def add_cooldown_role(self, member: discord.Member) -> None: + """Add the help cooldown role to `member`.""" + log.trace(f"Adding cooldown role for {member} ({member.id}).") + await self._change_cooldown_role(member, member.add_roles) + + async def remove_cooldown_role(self, member: discord.Member) -> None: + """Remove the help cooldown role from `member`.""" + log.trace(f"Removing cooldown role for {member} ({member.id}).") + await self._change_cooldown_role(member, member.remove_roles) + + async def _change_cooldown_role(self, member: discord.Member, coro_func: CoroutineFunc) -> None: + """ + Change `member`'s cooldown role via awaiting `coro_func` and handle errors. + + `coro_func` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. + """ + guild = self.bot.get_guild(constants.Guild.id) + role = guild.get_role(constants.Roles.help_cooldown) + if role is None: + log.warning(f"Help cooldown role ({constants.Roles.help_cooldown}) could not be found!") + return + + try: + await coro_func(role) + except discord.NotFound: + log.debug(f"Failed to change role for {member} ({member.id}): member not found") + except discord.Forbidden: + log.debug( + f"Forbidden to change role for {member} ({member.id}); " + f"possibly due to role hierarchy" + ) + except discord.HTTPException as e: + log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") + + async def revoke_send_permissions(self, member: discord.Member) -> None: + """ + Disallow `member` to send messages in the Available category for a certain time. + + The time until permissions are reinstated can be configured with + `HelpChannels.claim_minutes`. + """ + log.trace( + f"Revoking {member}'s ({member.id}) send message permissions in the Available category." + ) + + await self.add_cooldown_role(member) + + # Cancel the existing task, if any. + # Would mean the user somehow bypassed the lack of permissions (e.g. user is guild owner). + if member.id in self.scheduler: + self.scheduler.cancel(member.id) + + delay = constants.HelpChannels.claim_minutes * 60 + self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) + + async def send_available_message(self, channel: discord.TextChannel) -> None: + """Send the available message by editing a dormant message or sending a new message.""" + channel_info = f"#{channel} ({channel.id})" + log.trace(f"Sending available message in {channel_info}.") + + embed = discord.Embed(description=AVAILABLE_MSG) + + msg = await self.get_last_message(channel) + if self.match_bot_embed(msg, DORMANT_MSG): + log.trace(f"Found dormant message {msg.id} in {channel_info}; editing it.") + await msg.edit(embed=embed) + else: + log.trace(f"Dormant message not found in {channel_info}; sending a new message.") + await channel.send(embed=embed) + + async def try_get_channel(self, channel_id: int) -> discord.abc.GuildChannel: + """Attempt to get or fetch a channel and return it.""" + log.trace(f"Getting the channel {channel_id}.") + + channel = self.bot.get_channel(channel_id) + if not channel: + log.debug(f"Channel {channel_id} is not in cache; fetching from API.") + channel = await self.bot.fetch_channel(channel_id) + + log.trace(f"Channel #{channel} ({channel_id}) retrieved.") + return channel + + async def pin_wrapper(self, msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool: + """ + Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. + + Return True if successful and False otherwise. + """ + channel_str = f"#{channel} ({channel.id})" + if pin: + func = self.bot.http.pin_message + verb = "pin" + else: + func = self.bot.http.unpin_message + verb = "unpin" + + try: + await func(channel.id, msg_id) + except discord.HTTPException as e: + if e.code == 10008: + log.debug(f"Message {msg_id} in {channel_str} doesn't exist; can't {verb}.") + else: + log.exception( + f"Error {verb}ning message {msg_id} in {channel_str}: {e.status} ({e.code})" + ) + return False + else: + log.trace(f"{verb.capitalize()}ned message {msg_id} in {channel_str}.") + return True + + async def pin(self, message: discord.Message) -> None: + """Pin an initial question `message` and store it in a cache.""" + if await self.pin_wrapper(message.id, message.channel, pin=True): + await self.question_messages.set(message.channel.id, message.id) + + async def unpin(self, channel: discord.TextChannel) -> None: + """Unpin the initial question message sent in `channel`.""" + msg_id = await self.question_messages.pop(channel.id) + if msg_id is None: + log.debug(f"#{channel} ({channel.id}) doesn't have a message pinned.") + else: + await self.pin_wrapper(msg_id, channel, pin=False) + + async def wait_for_dormant_channel(self) -> discord.TextChannel: + """Wait for a dormant channel to become available in the queue and return it.""" + log.trace("Waiting for a dormant channel.") + + task = asyncio.create_task(self.channel_queue.get()) + self.queue_tasks.append(task) + channel = await task + + log.trace(f"Channel #{channel} ({channel.id}) finally retrieved from the queue.") + self.queue_tasks.remove(task) + + return channel + + +def validate_config() -> None: + """Raise a ValueError if the cog's config is invalid.""" + log.trace("Validating config.") + total = constants.HelpChannels.max_total_channels + available = constants.HelpChannels.max_available + + if total == 0 or available == 0: + raise ValueError("max_total_channels and max_available and must be greater than 0.") + + if total < available: + raise ValueError( + f"max_total_channels ({total}) must be greater than or equal to max_available " + f"({available})." + ) + + if total > MAX_CHANNELS_PER_CATEGORY: + raise ValueError( + f"max_total_channels ({total}) must be less than or equal to " + f"{MAX_CHANNELS_PER_CATEGORY} due to Discord's limit on channels per category." + ) + + +def setup(bot: Bot) -> None: + """Load the HelpChannels cog.""" + try: + validate_config() + except ValueError as e: + log.error(f"HelpChannels cog will not be loaded due to misconfiguration: {e}") + else: + bot.add_cog(HelpChannels(bot)) diff --git a/bot/exts/info/__init__.py b/bot/exts/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py new file mode 100644 index 000000000..204cffb37 --- /dev/null +++ b/bot/exts/info/doc.py @@ -0,0 +1,511 @@ +import asyncio +import functools +import logging +import re +import textwrap +from collections import OrderedDict +from contextlib import suppress +from types import SimpleNamespace +from typing import Any, Callable, Optional, Tuple + +import discord +from bs4 import BeautifulSoup +from bs4.element import PageElement, Tag +from discord.errors import NotFound +from discord.ext import commands +from markdownify import MarkdownConverter +from requests import ConnectTimeout, ConnectionError, HTTPError +from sphinx.ext import intersphinx +from urllib3.exceptions import ProtocolError + +from bot.bot import Bot +from bot.constants import MODERATION_ROLES, RedirectOutput +from bot.converters import ValidPythonIdentifier, ValidURL +from bot.decorators import with_role +from bot.pagination import LinePaginator + + +log = logging.getLogger(__name__) +logging.getLogger('urllib3').setLevel(logging.WARNING) + +# Since Intersphinx is intended to be used with Sphinx, +# we need to mock its configuration. +SPHINX_MOCK_APP = SimpleNamespace( + config=SimpleNamespace( + intersphinx_timeout=3, + tls_verify=True, + user_agent="python3:python-discord/bot:1.0.0" + ) +) + +NO_OVERRIDE_GROUPS = ( + "2to3fixer", + "token", + "label", + "pdbcommand", + "term", +) +NO_OVERRIDE_PACKAGES = ( + "python", +) + +SEARCH_END_TAG_ATTRS = ( + "data", + "function", + "class", + "exception", + "seealso", + "section", + "rubric", + "sphinxsidebar", +) +UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") +WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") + +FAILED_REQUEST_RETRY_AMOUNT = 3 +NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay + + +def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: + """ + LRU cache implementation for coroutines. + + Once the cache exceeds the maximum size, keys are deleted in FIFO order. + + An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. + """ + # Assign the cache to the function itself so we can clear it from outside. + async_cache.cache = OrderedDict() + + def decorator(function: Callable) -> Callable: + """Define the async_cache decorator.""" + @functools.wraps(function) + async def wrapper(*args) -> Any: + """Decorator wrapper for the caching logic.""" + key = ':'.join(args[arg_offset:]) + + value = async_cache.cache.get(key) + if value is None: + if len(async_cache.cache) > max_size: + async_cache.cache.popitem(last=False) + + async_cache.cache[key] = await function(*args) + return async_cache.cache[key] + return wrapper + return decorator + + +class DocMarkdownConverter(MarkdownConverter): + """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" + + def convert_code(self, el: PageElement, text: str) -> str: + """Undo `markdownify`s underscore escaping.""" + return f"`{text}`".replace('\\', '') + + def convert_pre(self, el: PageElement, text: str) -> str: + """Wrap any codeblocks in `py` for syntax highlighting.""" + code = ''.join(el.strings) + return f"```py\n{code}```" + + +def markdownify(html: str) -> DocMarkdownConverter: + """Create a DocMarkdownConverter object from the input html.""" + return DocMarkdownConverter(bullets='•').convert(html) + + +class InventoryURL(commands.Converter): + """ + Represents an Intersphinx inventory URL. + + This converter checks whether intersphinx accepts the given inventory URL, and raises + `BadArgument` if that is not the case. + + Otherwise, it simply passes through the given URL. + """ + + @staticmethod + async def convert(ctx: commands.Context, url: str) -> str: + """Convert url to Intersphinx inventory URL.""" + try: + intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) + except AttributeError: + raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") + except ConnectionError: + if url.startswith('https'): + raise commands.BadArgument( + f"Cannot establish a connection to `{url}`. Does it support HTTPS?" + ) + raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") + except ValueError: + raise commands.BadArgument( + f"Failed to read Intersphinx inventory from URL `{url}`. " + "Are you sure that it's a valid inventory file?" + ) + return url + + +class Doc(commands.Cog): + """A set of commands for querying & displaying documentation.""" + + def __init__(self, bot: Bot): + self.base_urls = {} + self.bot = bot + self.inventories = {} + self.renamed_symbols = set() + + self.bot.loop.create_task(self.init_refresh_inventory()) + + async def init_refresh_inventory(self) -> None: + """Refresh documentation inventory on cog initialization.""" + await self.bot.wait_until_guild_available() + await self.refresh_inventory() + + async def update_single( + self, package_name: str, base_url: str, inventory_url: str + ) -> None: + """ + Rebuild the inventory for a single package. + + Where: + * `package_name` is the package name to use, appears in the log + * `base_url` is the root documentation URL for the specified package, used to build + absolute paths that link to specific symbols + * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running + `intersphinx.fetch_inventory` in an executor on the bot's event loop + """ + self.base_urls[package_name] = base_url + + package = await self._fetch_inventory(inventory_url) + if not package: + return None + + for group, value in package.items(): + for symbol, (package_name, _version, relative_doc_url, _) in value.items(): + absolute_doc_url = base_url + relative_doc_url + + if symbol in self.inventories: + group_name = group.split(":")[1] + symbol_base_url = self.inventories[symbol].split("/", 3)[2] + if ( + group_name in NO_OVERRIDE_GROUPS + or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) + ): + + symbol = f"{group_name}.{symbol}" + # If renamed `symbol` already exists, add library name in front to differentiate between them. + if symbol in self.renamed_symbols: + # Split `package_name` because of packages like Pillow that have spaces in them. + symbol = f"{package_name.split()[0]}.{symbol}" + + self.inventories[symbol] = absolute_doc_url + self.renamed_symbols.add(symbol) + continue + + self.inventories[symbol] = absolute_doc_url + + log.trace(f"Fetched inventory for {package_name}.") + + async def refresh_inventory(self) -> None: + """Refresh internal documentation inventory.""" + log.debug("Refreshing documentation inventory...") + + # Clear the old base URLS and inventories to ensure + # that we start from a fresh local dataset. + # Also, reset the cache used for fetching documentation. + self.base_urls.clear() + self.inventories.clear() + self.renamed_symbols.clear() + async_cache.cache = OrderedDict() + + # Run all coroutines concurrently - since each of them performs a HTTP + # request, this speeds up fetching the inventory data heavily. + coros = [ + self.update_single( + package["package"], package["base_url"], package["inventory_url"] + ) for package in await self.bot.api_client.get('bot/documentation-links') + ] + await asyncio.gather(*coros) + + async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: + """ + Given a Python symbol, return its signature and description. + + The first tuple element is the signature of the given symbol as a markup-free string, and + the second tuple element is the description of the given symbol with HTML markup included. + + If the given symbol is a module, returns a tuple `(None, str)` + else if the symbol could not be found, returns `None`. + """ + url = self.inventories.get(symbol) + if url is None: + return None + + async with self.bot.http_session.get(url) as response: + html = await response.text(encoding='utf-8') + + # Find the signature header and parse the relevant parts. + symbol_id = url.split('#')[-1] + soup = BeautifulSoup(html, 'lxml') + symbol_heading = soup.find(id=symbol_id) + search_html = str(soup) + + if symbol_heading is None: + return None + + if symbol_id == f"module-{symbol}": + # Get page content from the module headerlink to the + # first tag that has its class in `SEARCH_END_TAG_ATTRS` + start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) + if start_tag is None: + return [], "" + + end_tag = start_tag.find_next(self._match_end_tag) + if end_tag is None: + return [], "" + + description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) + description_end_index = search_html.find(str(end_tag)) + description = search_html[description_start_index:description_end_index] + signatures = None + + else: + signatures = [] + description = str(symbol_heading.find_next_sibling("dd")) + description_pos = search_html.find(description) + # Get text of up to 3 signatures, remove unwanted symbols + for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): + signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) + if signature and search_html.find(str(element)) < description_pos: + signatures.append(signature) + + return signatures, description.replace('¶', '') + + @async_cache(arg_offset=1) + async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: + """ + Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. + + If the symbol is known, an Embed with documentation about it is returned. + """ + scraped_html = await self.get_symbol_html(symbol) + if scraped_html is None: + return None + + signatures = scraped_html[0] + permalink = self.inventories[symbol] + description = markdownify(scraped_html[1]) + + # Truncate the description of the embed to the last occurrence + # of a double newline (interpreted as a paragraph) before index 1000. + if len(description) > 1000: + shortened = description[:1000] + description_cutoff = shortened.rfind('\n\n', 100) + if description_cutoff == -1: + # Search the shortened version for cutoff points in decreasing desirability, + # cutoff at 1000 if none are found. + for string in (". ", ", ", ",", " "): + description_cutoff = shortened.rfind(string) + if description_cutoff != -1: + break + else: + description_cutoff = 1000 + description = description[:description_cutoff] + + # If there is an incomplete code block, cut it out + if description.count("```") % 2: + codeblock_start = description.rfind('```py') + description = description[:codeblock_start].rstrip() + description += f"... [read more]({permalink})" + + description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) + if signatures is None: + # If symbol is a module, don't show signature. + embed_description = description + + elif not signatures: + # It's some "meta-page", for example: + # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views + embed_description = "This appears to be a generic page not tied to a specific symbol." + + else: + embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) + embed_description += f"\n{description}" + + embed = discord.Embed( + title=f'`{symbol}`', + url=permalink, + description=embed_description + ) + # Show all symbols with the same name that were renamed in the footer. + embed.set_footer( + text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) + ) + return embed + + @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) + async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: + """Lookup documentation for Python symbols.""" + await ctx.invoke(self.get_command, symbol) + + @docs_group.command(name='get', aliases=('g',)) + async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: + """ + Return a documentation embed for a given symbol. + + If no symbol is given, return a list of all available inventories. + + Examples: + !docs + !docs aiohttp + !docs aiohttp.ClientSession + !docs get aiohttp.ClientSession + """ + if symbol is None: + inventory_embed = discord.Embed( + title=f"All inventories (`{len(self.base_urls)}` total)", + colour=discord.Colour.blue() + ) + + lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) + if self.base_urls: + await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) + + else: + inventory_embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=inventory_embed) + + else: + # Fetching documentation for a symbol (at least for the first time, since + # caching is used) takes quite some time, so let's send typing to indicate + # that we got the command, but are still working on it. + async with ctx.typing(): + doc_embed = await self.get_symbol_embed(symbol) + + if doc_embed is None: + error_embed = discord.Embed( + description=f"Sorry, I could not find any documentation for `{symbol}`.", + colour=discord.Colour.red() + ) + error_message = await ctx.send(embed=error_embed) + with suppress(NotFound): + await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) + await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) + else: + await ctx.send(embed=doc_embed) + + @docs_group.command(name='set', aliases=('s',)) + @with_role(*MODERATION_ROLES) + async def set_command( + self, ctx: commands.Context, package_name: ValidPythonIdentifier, + base_url: ValidURL, inventory_url: InventoryURL + ) -> None: + """ + Adds a new documentation metadata object to the site's database. + + The database will update the object, should an existing item with the specified `package_name` already exist. + + Example: + !docs set \ + python \ + https://docs.python.org/3/ \ + https://docs.python.org/3/objects.inv + """ + body = { + 'package': package_name, + 'base_url': base_url, + 'inventory_url': inventory_url + } + await self.bot.api_client.post('bot/documentation-links', json=body) + + log.info( + f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" + f"Package name: {package_name}\n" + f"Base url: {base_url}\n" + f"Inventory URL: {inventory_url}" + ) + + # Rebuilding the inventory can take some time, so lets send out a + # typing event to show that the Bot is still working. + async with ctx.typing(): + await self.refresh_inventory() + await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") + + @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) + @with_role(*MODERATION_ROLES) + async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: + """ + Removes the specified package from the database. + + Examples: + !docs delete aiohttp + """ + await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') + + async with ctx.typing(): + # Rebuild the inventory to ensure that everything + # that was from this package is properly deleted. + await self.refresh_inventory() + await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") + + @docs_group.command(name="refresh", aliases=("rfsh", "r")) + @with_role(*MODERATION_ROLES) + async def refresh_command(self, ctx: commands.Context) -> None: + """Refresh inventories and send differences to channel.""" + old_inventories = set(self.base_urls) + with ctx.typing(): + await self.refresh_inventory() + # Get differences of added and removed inventories + added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) + if added: + added = f"+ {added}" + + removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) + if removed: + removed = f"- {removed}" + + embed = discord.Embed( + title="Inventories refreshed", + description=f"```diff\n{added}\n{removed}```" if added or removed else "" + ) + await ctx.send(embed=embed) + + async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: + """Get and return inventory from `inventory_url`. If fetching fails, return None.""" + fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) + for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): + try: + package = await self.bot.loop.run_in_executor(None, fetch_func) + except ConnectTimeout: + log.error( + f"Fetching of inventory {inventory_url} timed out," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except ProtocolError: + log.error( + f"Connection lost while fetching inventory {inventory_url}," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except HTTPError as e: + log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") + return None + except ConnectionError: + log.error(f"Couldn't establish connection to inventory {inventory_url}.") + return None + else: + return package + log.error(f"Fetching of inventory {inventory_url} failed.") + return None + + @staticmethod + def _match_end_tag(tag: Tag) -> bool: + """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" + for attr in SEARCH_END_TAG_ATTRS: + if attr in tag.get("class", ()): + return True + + return tag.name == "table" + + +def setup(bot: Bot) -> None: + """Load the Doc cog.""" + bot.add_cog(Doc(bot)) diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py new file mode 100644 index 000000000..3d1d6fd10 --- /dev/null +++ b/bot/exts/info/help.py @@ -0,0 +1,375 @@ +import itertools +import logging +from asyncio import TimeoutError +from collections import namedtuple +from contextlib import suppress +from typing import List, Union + +from discord import Colour, Embed, Member, Message, NotFound, Reaction, User +from discord.ext.commands import Bot, Cog, Command, Context, Group, HelpCommand +from fuzzywuzzy import fuzz, process +from fuzzywuzzy.utils import full_process + +from bot import constants +from bot.constants import Channels, Emojis, STAFF_ROLES +from bot.decorators import redirect_output +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + +COMMANDS_PER_PAGE = 8 +DELETE_EMOJI = Emojis.trashcan +PREFIX = constants.Bot.prefix + +Category = namedtuple("Category", ["name", "description", "cogs"]) + + +async def help_cleanup(bot: Bot, author: Member, message: Message) -> None: + """ + Runs the cleanup for the help command. + + Adds the :trashcan: reaction that, when clicked, will delete the help message. + After a 300 second timeout, the reaction will be removed. + """ + def check(reaction: Reaction, user: User) -> bool: + """Checks the reaction is :trashcan:, the author is original author and messages are the same.""" + return str(reaction) == DELETE_EMOJI and user.id == author.id and reaction.message.id == message.id + + await message.add_reaction(DELETE_EMOJI) + + with suppress(NotFound): + try: + await bot.wait_for("reaction_add", check=check, timeout=300) + await message.delete() + except TimeoutError: + await message.remove_reaction(DELETE_EMOJI, bot.user) + + +class HelpQueryNotFound(ValueError): + """ + Raised when a HelpSession Query doesn't match a command or cog. + + Contains the custom attribute of ``possible_matches``. + + Instances of this object contain a dictionary of any command(s) that were close to matching the + query, where keys are the possible matched command names and values are the likeness match scores. + """ + + def __init__(self, arg: str, possible_matches: dict = None): + super().__init__(arg) + self.possible_matches = possible_matches + + +class CustomHelpCommand(HelpCommand): + """ + An interactive instance for the bot help command. + + Cogs can be grouped into custom categories. All cogs with the same category will be displayed + under a single category name in the help output. Custom categories are defined inside the cogs + as a class attribute named `category`. A description can also be specified with the attribute + `category_description`. If a description is not found in at least one cog, the default will be + the regular description (class docstring) of the first cog found in the category. + """ + + def __init__(self): + super().__init__(command_attrs={"help": "Shows help for bot commands"}) + + @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) + async def command_callback(self, ctx: Context, *, command: str = None) -> None: + """Attempts to match the provided query with a valid command or cog.""" + # the only reason we need to tamper with this is because d.py does not support "categories", + # so we need to deal with them ourselves. + + bot = ctx.bot + + if command is None: + # quick and easy, send bot help if command is none + mapping = self.get_bot_mapping() + await self.send_bot_help(mapping) + return + + cog_matches = [] + description = None + for cog in bot.cogs.values(): + if hasattr(cog, "category") and cog.category == command: + cog_matches.append(cog) + if hasattr(cog, "category_description"): + description = cog.category_description + + if cog_matches: + category = Category(name=command, description=description, cogs=cog_matches) + await self.send_category_help(category) + return + + # it's either a cog, group, command or subcommand; let the parent class deal with it + await super().command_callback(ctx, command=command) + + async def get_all_help_choices(self) -> set: + """ + Get all the possible options for getting help in the bot. + + This will only display commands the author has permission to run. + + These include: + - Category names + - Cog names + - Group command names (and aliases) + - Command names (and aliases) + - Subcommand names (with parent group and aliases for subcommand, but not including aliases for group) + + Options and choices are case sensitive. + """ + # first get all commands including subcommands and full command name aliases + choices = set() + for command in await self.filter_commands(self.context.bot.walk_commands()): + # the the command or group name + choices.add(str(command)) + + if isinstance(command, Command): + # all aliases if it's just a command + choices.update(command.aliases) + else: + # otherwise we need to add the parent name in + choices.update(f"{command.full_parent_name} {alias}" for alias in command.aliases) + + # all cog names + choices.update(self.context.bot.cogs) + + # all category names + choices.update(cog.category for cog in self.context.bot.cogs.values() if hasattr(cog, "category")) + return choices + + async def command_not_found(self, string: str) -> "HelpQueryNotFound": + """ + Handles when a query does not match a valid command, group, cog or category. + + Will return an instance of the `HelpQueryNotFound` exception with the error message and possible matches. + """ + choices = await self.get_all_help_choices() + + # Run fuzzywuzzy's processor beforehand, and avoid matching if processed string is empty + # This avoids fuzzywuzzy from raising a warning on inputs with only non-alphanumeric characters + if (processed := full_process(string)): + result = process.extractBests(processed, choices, scorer=fuzz.ratio, score_cutoff=60, processor=None) + else: + result = [] + + return HelpQueryNotFound(f'Query "{string}" not found.', dict(result)) + + async def subcommand_not_found(self, command: Command, string: str) -> "HelpQueryNotFound": + """ + Redirects the error to `command_not_found`. + + `command_not_found` deals with searching and getting best choices for both commands and subcommands. + """ + return await self.command_not_found(f"{command.qualified_name} {string}") + + async def send_error_message(self, error: HelpQueryNotFound) -> None: + """Send the error message to the channel.""" + embed = Embed(colour=Colour.red(), title=str(error)) + + if getattr(error, "possible_matches", None): + matches = "\n".join(f"`{match}`" for match in error.possible_matches) + embed.description = f"**Did you mean:**\n{matches}" + + await self.context.send(embed=embed) + + async def command_formatting(self, command: Command) -> Embed: + """ + Takes a command and turns it into an embed. + + It will add an author, command signature + help, aliases and a note if the user can't run the command. + """ + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + parent = command.full_parent_name + + name = str(command) if not parent else f"{parent} {command.name}" + command_details = f"**```{PREFIX}{name} {command.signature}```**\n" + + # show command aliases + aliases = ", ".join(f"`{alias}`" if not parent else f"`{parent} {alias}`" for alias in command.aliases) + if aliases: + command_details += f"**Can also use:** {aliases}\n\n" + + # check if the user is allowed to run this command + if not await command.can_run(self.context): + command_details += "***You cannot run this command.***\n\n" + + command_details += f"*{command.help or 'No details provided.'}*\n" + embed.description = command_details + + return embed + + async def send_command_help(self, command: Command) -> None: + """Send help for a single command.""" + embed = await self.command_formatting(command) + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + @staticmethod + def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: + """ + Formats the prefix, command name and signature, and short doc for an iterable of commands. + + return_as_list is helpful for passing these command details into the paginator as a list of command details. + """ + details = [] + for command in commands_: + signature = f" {command.signature}" if command.signature else "" + details.append( + f"\n**`{PREFIX}{command.qualified_name}{signature}`**\n*{command.short_doc or 'No details provided'}*" + ) + if return_as_list: + return details + else: + return "".join(details) + + async def send_group_help(self, group: Group) -> None: + """Sends help for a group command.""" + subcommands = group.commands + + if len(subcommands) == 0: + # no subcommands, just treat it like a regular command + await self.send_command_help(group) + return + + # remove commands that the user can't run and are hidden, and sort by name + commands_ = await self.filter_commands(subcommands, sort=True) + + embed = await self.command_formatting(group) + + command_details = self.get_commands_brief_details(commands_) + if command_details: + embed.description += f"\n**Subcommands:**\n{command_details}" + + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + async def send_cog_help(self, cog: Cog) -> None: + """Send help for a cog.""" + # sort commands by name, and remove any the user cant run or are hidden. + commands_ = await self.filter_commands(cog.get_commands(), sort=True) + + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + embed.description = f"**{cog.qualified_name}**\n*{cog.description}*" + + command_details = self.get_commands_brief_details(commands_) + if command_details: + embed.description += f"\n\n**Commands:**\n{command_details}" + + message = await self.context.send(embed=embed) + await help_cleanup(self.context.bot, self.context.author, message) + + @staticmethod + def _category_key(command: Command) -> str: + """ + Returns a cog name of a given command for use as a key for `sorted` and `groupby`. + + A zero width space is used as a prefix for results with no cogs to force them last in ordering. + """ + if command.cog: + with suppress(AttributeError): + if command.cog.category: + return f"**{command.cog.category}**" + return f"**{command.cog_name}**" + else: + return "**\u200bNo Category:**" + + async def send_category_help(self, category: Category) -> None: + """ + Sends help for a bot category. + + This sends a brief help for all commands in all cogs registered to the category. + """ + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + all_commands = [] + for cog in category.cogs: + all_commands.extend(cog.get_commands()) + + filtered_commands = await self.filter_commands(all_commands, sort=True) + + command_detail_lines = self.get_commands_brief_details(filtered_commands, return_as_list=True) + description = f"**{category.name}**\n*{category.description}*" + + if command_detail_lines: + description += "\n\n**Commands:**" + + await LinePaginator.paginate( + command_detail_lines, + self.context, + embed, + prefix=description, + max_lines=COMMANDS_PER_PAGE, + max_size=2000, + ) + + async def send_bot_help(self, mapping: dict) -> None: + """Sends help for all bot commands and cogs.""" + bot = self.context.bot + + embed = Embed() + embed.set_author(name="Command Help", icon_url=constants.Icons.questionmark) + + filter_commands = await self.filter_commands(bot.commands, sort=True, key=self._category_key) + + cog_or_category_pages = [] + + for cog_or_category, _commands in itertools.groupby(filter_commands, key=self._category_key): + sorted_commands = sorted(_commands, key=lambda c: c.name) + + if len(sorted_commands) == 0: + continue + + command_detail_lines = self.get_commands_brief_details(sorted_commands, return_as_list=True) + + # Split cogs or categories which have too many commands to fit in one page. + # The length of commands is included for later use when aggregating into pages for the paginator. + for index in range(0, len(sorted_commands), COMMANDS_PER_PAGE): + truncated_lines = command_detail_lines[index:index + COMMANDS_PER_PAGE] + joined_lines = "".join(truncated_lines) + cog_or_category_pages.append((f"**{cog_or_category}**{joined_lines}", len(truncated_lines))) + + pages = [] + counter = 0 + page = "" + for page_details, length in cog_or_category_pages: + counter += length + if counter > COMMANDS_PER_PAGE: + # force a new page on paginator even if it falls short of the max pages + # since we still want to group categories/cogs. + counter = length + pages.append(page) + page = f"{page_details}\n\n" + else: + page += f"{page_details}\n\n" + + if page: + # add any remaining command help that didn't get added in the last iteration above. + pages.append(page) + + await LinePaginator.paginate(pages, self.context, embed=embed, max_lines=1, max_size=2000) + + +class Help(Cog): + """Custom Embed Pagination Help feature.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.old_help_command = bot.help_command + bot.help_command = CustomHelpCommand() + bot.help_command.cog = self + + def cog_unload(self) -> None: + """Reset the help command when the cog is unloaded.""" + self.bot.help_command = self.old_help_command + + +def setup(bot: Bot) -> None: + """Load the Help cog.""" + bot.add_cog(Help(bot)) + log.info("Cog loaded: Help") diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py new file mode 100644 index 000000000..8982196d1 --- /dev/null +++ b/bot/exts/info/information.py @@ -0,0 +1,422 @@ +import colorsys +import logging +import pprint +import textwrap +from collections import Counter, defaultdict +from string import Template +from typing import Any, Mapping, Optional, Union + +from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils +from discord.abc import GuildChannel +from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group +from discord.utils import escape_markdown + +from bot import constants +from bot.bot import Bot +from bot.decorators import in_whitelist, with_role +from bot.pagination import LinePaginator +from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + + +class Information(Cog): + """A cog with commands for generating embeds with server info, such as server stats and user info.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @staticmethod + def role_can_read(channel: GuildChannel, role: Role) -> bool: + """Return True if `role` can read messages in `channel`.""" + overwrites = channel.overwrites_for(role) + return overwrites.read_messages is True + + def get_staff_channel_count(self, guild: Guild) -> int: + """ + Get the number of channels that are staff-only. + + We need to know two things about a channel: + - Does the @everyone role have explicit read deny permissions? + - Do staff roles have explicit read allow permissions? + + If the answer to both of these questions is yes, it's a staff channel. + """ + channel_ids = set() + for channel in guild.channels: + if channel.type is ChannelType.category: + continue + + everyone_can_read = self.role_can_read(channel, guild.default_role) + + for role in constants.STAFF_ROLES: + role_can_read = self.role_can_read(channel, guild.get_role(role)) + if role_can_read and not everyone_can_read: + channel_ids.add(channel.id) + break + + return len(channel_ids) + + @staticmethod + def get_channel_type_counts(guild: Guild) -> str: + """Return the total amounts of the various types of channels in `guild`.""" + channel_counter = Counter(c.type for c in guild.channels) + channel_type_list = [] + for channel, count in channel_counter.items(): + channel_type = str(channel).title() + channel_type_list.append(f"{channel_type} channels: {count}") + + channel_type_list = sorted(channel_type_list) + return "\n".join(channel_type_list) + + @with_role(*constants.MODERATION_ROLES) + @command(name="roles") + async def roles_info(self, ctx: Context) -> None: + """Returns a list of all roles and their corresponding IDs.""" + # Sort the roles alphabetically and remove the @everyone role + roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) + + # Build a list + role_list = [] + for role in roles: + role_list.append(f"`{role.id}` - {role.mention}") + + # Build an embed + embed = Embed( + title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", + colour=Colour.blurple() + ) + + await LinePaginator.paginate(role_list, ctx, embed, empty=False) + + @with_role(*constants.MODERATION_ROLES) + @command(name="role") + async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: + """ + Return information on a role or list of roles. + + To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. + """ + parsed_roles = [] + failed_roles = [] + + for role_name in roles: + if isinstance(role_name, Role): + # Role conversion has already succeeded + parsed_roles.append(role_name) + continue + + role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) + + if not role: + failed_roles.append(role_name) + continue + + parsed_roles.append(role) + + if failed_roles: + await ctx.send(f":x: Could not retrieve the following roles: {', '.join(failed_roles)}") + + for role in parsed_roles: + h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) + + embed = Embed( + title=f"{role.name} info", + colour=role.colour, + ) + embed.add_field(name="ID", value=role.id, inline=True) + embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) + embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) + embed.add_field(name="Member count", value=len(role.members), inline=True) + embed.add_field(name="Position", value=role.position) + embed.add_field(name="Permission code", value=role.permissions.value, inline=True) + + await ctx.send(embed=embed) + + @command(name="server", aliases=["server_info", "guild", "guild_info"]) + async def server_info(self, ctx: Context) -> None: + """Returns an embed full of server information.""" + created = time_since(ctx.guild.created_at, precision="days") + features = ", ".join(ctx.guild.features) + region = ctx.guild.region + + roles = len(ctx.guild.roles) + member_count = ctx.guild.member_count + channel_counts = self.get_channel_type_counts(ctx.guild) + + # How many of each user status? + statuses = Counter(member.status for member in ctx.guild.members) + embed = Embed(colour=Colour.blurple()) + + # How many staff members and staff channels do we have? + staff_member_count = len(ctx.guild.get_role(constants.Roles.helpers).members) + staff_channel_count = self.get_staff_channel_count(ctx.guild) + + # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the + # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting + # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts + # after the dedent is made. + embed.description = Template( + textwrap.dedent(f""" + **Server information** + Created: {created} + Voice region: {region} + Features: {features} + + **Channel counts** + $channel_counts + Staff channels: {staff_channel_count} + + **Member counts** + Members: {member_count:,} + Staff members: {staff_member_count} + Roles: {roles} + + **Member statuses** + {constants.Emojis.status_online} {statuses[Status.online]:,} + {constants.Emojis.status_idle} {statuses[Status.idle]:,} + {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} + {constants.Emojis.status_offline} {statuses[Status.offline]:,} + """) + ).substitute({"channel_counts": channel_counts}) + embed.set_thumbnail(url=ctx.guild.icon_url) + + await ctx.send(embed=embed) + + @command(name="user", aliases=["user_info", "member", "member_info"]) + async def user_info(self, ctx: Context, user: Member = None) -> None: + """Returns info about a user.""" + if user is None: + user = ctx.author + + # Do a role check if this is being executed on someone other than the caller + elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): + await ctx.send("You may not use this command on users other than yourself.") + return + + # Non-staff may only do this in #bot-commands + if not with_role_check(ctx, *constants.STAFF_ROLES): + if not ctx.channel.id == constants.Channels.bot_commands: + raise InWhitelistCheckFailure(constants.Channels.bot_commands) + + embed = await self.create_user_embed(ctx, user) + + await ctx.send(embed=embed) + + async def create_user_embed(self, ctx: Context, user: Member) -> Embed: + """Creates an embed containing information on the `user`.""" + created = time_since(user.created_at, max_units=3) + + # Custom status + custom_status = '' + for activity in user.activities: + # Check activity.state for None value if user has a custom status set + # This guards against a custom status with an emoji but no text, which will cause + # escape_markdown to raise an exception + # This can be reworked after a move to d.py 1.3.0+, which adds a CustomActivity class + if activity.name == 'Custom Status' and activity.state: + state = escape_markdown(activity.state) + custom_status = f'Status: {state}\n' + + name = str(user) + if user.nick: + name = f"{user.nick} ({name})" + + joined = time_since(user.joined_at, max_units=3) + roles = ", ".join(role.mention for role in user.roles[1:]) + + description = [ + textwrap.dedent(f""" + **User Information** + Created: {created} + Profile: {user.mention} + ID: {user.id} + {custom_status} + **Member Information** + Joined: {joined} + Roles: {roles or None} + """).strip() + ] + + # Show more verbose output in moderation channels for infractions and nominations + if ctx.channel.id in constants.MODERATION_CHANNELS: + description.append(await self.expanded_user_infraction_counts(user)) + description.append(await self.user_nomination_counts(user)) + else: + description.append(await self.basic_user_infraction_counts(user)) + + # Let's build the embed now + embed = Embed( + title=name, + description="\n\n".join(description) + ) + + embed.set_thumbnail(url=user.avatar_url_as(static_format="png")) + embed.colour = user.top_role.colour if roles else Colour.blurple() + + return embed + + async def basic_user_infraction_counts(self, member: Member) -> str: + """Gets the total and active infraction counts for the given `member`.""" + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'hidden': 'False', + 'user__id': str(member.id) + } + ) + + total_infractions = len(infractions) + active_infractions = sum(infraction['active'] for infraction in infractions) + + infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}" + + return infraction_output + + async def expanded_user_infraction_counts(self, member: Member) -> str: + """ + Gets expanded infraction counts for the given `member`. + + The counts will be split by infraction type and the number of active infractions for each type will indicated + in the output as well. + """ + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'user__id': str(member.id) + } + ) + + infraction_output = ["**Infractions**"] + if not infractions: + infraction_output.append("This user has never received an infraction.") + else: + # Count infractions split by `type` and `active` status for this user + infraction_types = set() + infraction_counter = defaultdict(int) + for infraction in infractions: + infraction_type = infraction["type"] + infraction_active = 'active' if infraction["active"] else 'inactive' + + infraction_types.add(infraction_type) + infraction_counter[f"{infraction_active} {infraction_type}"] += 1 + + # Format the output of the infraction counts + for infraction_type in sorted(infraction_types): + active_count = infraction_counter[f"active {infraction_type}"] + total_count = active_count + infraction_counter[f"inactive {infraction_type}"] + + line = f"{infraction_type.capitalize()}s: {total_count}" + if active_count: + line += f" ({active_count} active)" + + infraction_output.append(line) + + return "\n".join(infraction_output) + + async def user_nomination_counts(self, member: Member) -> str: + """Gets the active and historical nomination counts for the given `member`.""" + nominations = await self.bot.api_client.get( + 'bot/nominations', + params={ + 'user__id': str(member.id) + } + ) + + output = ["**Nominations**"] + + if not nominations: + output.append("This user has never been nominated.") + else: + count = len(nominations) + is_currently_nominated = any(nomination["active"] for nomination in nominations) + nomination_noun = "nomination" if count == 1 else "nominations" + + if is_currently_nominated: + output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).") + else: + output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.") + + return "\n".join(output) + + def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str: + """Format a mapping to be readable to a human.""" + # sorting is technically superfluous but nice if you want to look for a specific field + fields = sorted(mapping.items(), key=lambda item: item[0]) + + if field_width is None: + field_width = len(max(mapping.keys(), key=len)) + + out = '' + + for key, val in fields: + if isinstance(val, dict): + # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries + inner_width = int(field_width * 1.6) + val = '\n' + self.format_fields(val, field_width=inner_width) + + elif isinstance(val, str): + # split up text since it might be long + text = textwrap.fill(val, width=100, replace_whitespace=False) + + # indent it, I guess you could do this with `wrap` and `join` but this is nicer + val = textwrap.indent(text, ' ' * (field_width + len(': '))) + + # the first line is already indented so we `str.lstrip` it + val = val.lstrip() + + if key == 'color': + # makes the base 10 representation of a hex number readable to humans + val = hex(val) + + out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width) + + # remove trailing whitespace + return out.rstrip() + + @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) + @group(invoke_without_command=True) + @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) + async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: + """Shows information about the raw API response.""" + # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling + # doing this extra request is also much easier than trying to convert everything back into a dictionary again + raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) + + paginator = Paginator() + + def add_content(title: str, content: str) -> None: + paginator.add_line(f'== {title} ==\n') + # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution. + # we hope it's not close to 2000 + paginator.add_line(content.replace('```', '`` `')) + paginator.close_page() + + if message.content: + add_content('Raw message', message.content) + + transformer = pprint.pformat if json else self.format_fields + for field_name in ('embeds', 'attachments'): + data = raw_data[field_name] + + if not data: + continue + + total = len(data) + for current, item in enumerate(data, start=1): + title = f'Raw {field_name} ({current}/{total})' + add_content(title, transformer(item)) + + for page in paginator.pages: + await ctx.send(page) + + @raw.command() + async def json(self, ctx: Context, message: Message) -> None: + """Shows information about the raw API response in a copy-pasteable Python format.""" + await ctx.invoke(self.raw, message=message, json=True) + + +def setup(bot: Bot) -> None: + """Load the Information cog.""" + bot.add_cog(Information(bot)) diff --git a/bot/exts/info/python_news.py b/bot/exts/info/python_news.py new file mode 100644 index 000000000..0ab5738a4 --- /dev/null +++ b/bot/exts/info/python_news.py @@ -0,0 +1,232 @@ +import logging +import typing as t +from datetime import date, datetime + +import discord +import feedparser +from bs4 import BeautifulSoup +from discord.ext.commands import Cog +from discord.ext.tasks import loop + +from bot import constants +from bot.bot import Bot +from bot.utils.webhooks import send_webhook + +PEPS_RSS_URL = "https://www.python.org/dev/peps/peps.rss/" + +RECENT_THREADS_TEMPLATE = "https://mail.python.org/archives/list/{name}@python.org/recent-threads" +THREAD_TEMPLATE_URL = "https://mail.python.org/archives/api/list/{name}@python.org/thread/{id}/" +MAILMAN_PROFILE_URL = "https://mail.python.org/archives/users/{id}/" +THREAD_URL = "https://mail.python.org/archives/list/{list}@python.org/thread/{id}/" + +AVATAR_URL = "https://www.python.org/static/opengraph-icon-200x200.png" + +log = logging.getLogger(__name__) + + +class PythonNews(Cog): + """Post new PEPs and Python News to `#python-news`.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_names = {} + self.webhook: t.Optional[discord.Webhook] = None + + self.bot.loop.create_task(self.get_webhook_names()) + self.bot.loop.create_task(self.get_webhook_and_channel()) + + async def start_tasks(self) -> None: + """Start the tasks for fetching new PEPs and mailing list messages.""" + self.fetch_new_media.start() + + @loop(minutes=20) + async def fetch_new_media(self) -> None: + """Fetch new mailing list messages and then new PEPs.""" + await self.post_maillist_news() + await self.post_pep_news() + + async def sync_maillists(self) -> None: + """Sync currently in-use maillists with API.""" + # Wait until guild is available to avoid running before everything is ready + await self.bot.wait_until_guild_available() + + response = await self.bot.api_client.get("bot/bot-settings/news") + for mail in constants.PythonNews.mail_lists: + if mail not in response["data"]: + response["data"][mail] = [] + + # Because we are handling PEPs differently, we don't include it to mail lists + if "pep" not in response["data"]: + response["data"]["pep"] = [] + + await self.bot.api_client.put("bot/bot-settings/news", json=response) + + async def get_webhook_names(self) -> None: + """Get webhook author names from maillist API.""" + await self.bot.wait_until_guild_available() + + async with self.bot.http_session.get("https://mail.python.org/archives/api/lists") as resp: + lists = await resp.json() + + for mail in lists: + if mail["name"].split("@")[0] in constants.PythonNews.mail_lists: + self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] + + async def post_pep_news(self) -> None: + """Fetch new PEPs and when they don't have announcement in #python-news, create it.""" + # Wait until everything is ready and http_session available + await self.bot.wait_until_guild_available() + await self.sync_maillists() + + async with self.bot.http_session.get(PEPS_RSS_URL) as resp: + data = feedparser.parse(await resp.text("utf-8")) + + news_listing = await self.bot.api_client.get("bot/bot-settings/news") + payload = news_listing.copy() + pep_numbers = news_listing["data"]["pep"] + + # Reverse entries to send oldest first + data["entries"].reverse() + for new in data["entries"]: + try: + new_datetime = datetime.strptime(new["published"], "%a, %d %b %Y %X %Z") + except ValueError: + log.warning(f"Wrong datetime format passed in PEP new: {new['published']}") + continue + pep_nr = new["title"].split(":")[0].split()[1] + if ( + pep_nr in pep_numbers + or new_datetime.date() < date.today() + ): + continue + + # Build an embed and send a webhook + embed = discord.Embed( + title=new["title"], + description=new["summary"], + timestamp=new_datetime, + url=new["link"], + colour=constants.Colours.soft_green + ) + embed.set_footer(text=data["feed"]["title"], icon_url=AVATAR_URL) + msg = await send_webhook( + webhook=self.webhook, + username=data["feed"]["title"], + embed=embed, + avatar_url=AVATAR_URL, + wait=True, + ) + payload["data"]["pep"].append(pep_nr) + + # Increase overall PEP new stat + self.bot.stats.incr("python_news.posted.pep") + + if msg.channel.is_news(): + log.trace("Publishing PEP annnouncement because it was in a news channel") + await msg.publish() + + # Apply new sent news to DB to avoid duplicate sending + await self.bot.api_client.put("bot/bot-settings/news", json=payload) + + async def post_maillist_news(self) -> None: + """Send new maillist threads to #python-news that is listed in configuration.""" + await self.bot.wait_until_guild_available() + await self.sync_maillists() + existing_news = await self.bot.api_client.get("bot/bot-settings/news") + payload = existing_news.copy() + + for maillist in constants.PythonNews.mail_lists: + async with self.bot.http_session.get(RECENT_THREADS_TEMPLATE.format(name=maillist)) as resp: + recents = BeautifulSoup(await resp.text(), features="lxml") + + # When a

element is present in the response then the mailing list + # has not had any activity during the current month, so therefore it + # can be ignored. + if recents.p: + continue + + for thread in recents.html.body.div.find_all("a", href=True): + # We want only these threads that have identifiers + if "latest" in thread["href"]: + continue + + thread_information, email_information = await self.get_thread_and_first_mail( + maillist, thread["href"].split("/")[-2] + ) + + try: + new_date = datetime.strptime(email_information["date"], "%Y-%m-%dT%X%z") + except ValueError: + log.warning(f"Invalid datetime from Thread email: {email_information['date']}") + continue + + if ( + thread_information["thread_id"] in existing_news["data"][maillist] + or 'Re: ' in thread_information["subject"] + or new_date.date() < date.today() + ): + continue + + content = email_information["content"] + link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) + + # Build an embed and send a message to the webhook + embed = discord.Embed( + title=thread_information["subject"], + description=content[:500] + f"... [continue reading]({link})" if len(content) > 500 else content, + timestamp=new_date, + url=link, + colour=constants.Colours.soft_green + ) + embed.set_author( + name=f"{email_information['sender_name']} ({email_information['sender']['address']})", + url=MAILMAN_PROFILE_URL.format(id=email_information["sender"]["mailman_id"]), + ) + embed.set_footer( + text=f"Posted to {self.webhook_names[maillist]}", + icon_url=AVATAR_URL, + ) + msg = await send_webhook( + webhook=self.webhook, + username=self.webhook_names[maillist], + embed=embed, + avatar_url=AVATAR_URL, + wait=True, + ) + payload["data"][maillist].append(thread_information["thread_id"]) + + # Increase this specific maillist counter in stats + self.bot.stats.incr(f"python_news.posted.{maillist.replace('-', '_')}") + + if msg.channel.is_news(): + log.trace("Publishing mailing list message because it was in a news channel") + await msg.publish() + + await self.bot.api_client.put("bot/bot-settings/news", json=payload) + + async def get_thread_and_first_mail(self, maillist: str, thread_identifier: str) -> t.Tuple[t.Any, t.Any]: + """Get mail thread and first mail from mail.python.org based on `maillist` and `thread_identifier`.""" + async with self.bot.http_session.get( + THREAD_TEMPLATE_URL.format(name=maillist, id=thread_identifier) + ) as resp: + thread_information = await resp.json() + + async with self.bot.http_session.get(thread_information["starting_email"]) as resp: + email_information = await resp.json() + return thread_information, email_information + + async def get_webhook_and_channel(self) -> None: + """Storage #python-news channel Webhook and `TextChannel` to `News.webhook` and `channel`.""" + await self.bot.wait_until_guild_available() + self.webhook = await self.bot.fetch_webhook(constants.PythonNews.webhook) + + await self.start_tasks() + + def cog_unload(self) -> None: + """Stop news posting tasks on cog unload.""" + self.fetch_new_media.cancel() + + +def setup(bot: Bot) -> None: + """Add `News` cog.""" + bot.add_cog(PythonNews(bot)) diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py new file mode 100644 index 000000000..d853ab2ea --- /dev/null +++ b/bot/exts/info/reddit.py @@ -0,0 +1,304 @@ +import asyncio +import logging +import random +import textwrap +from collections import namedtuple +from datetime import datetime, timedelta +from typing import List + +from aiohttp import BasicAuth, ClientError +from discord import Colour, Embed, TextChannel +from discord.ext.commands import Cog, Context, group +from discord.ext.tasks import loop + +from bot.bot import Bot +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks +from bot.converters import Subreddit +from bot.decorators import with_role +from bot.pagination import LinePaginator +from bot.utils.messages import sub_clyde + +log = logging.getLogger(__name__) + +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + + +class Reddit(Cog): + """Track subreddit posts and show detailed statistics about them.""" + + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} + URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 + + def __init__(self, bot: Bot): + self.bot = bot + + self.webhook = None + self.access_token = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + bot.loop.create_task(self.init_reddit_ready()) + self.auto_poster_loop.start() + + def cog_unload(self) -> None: + """Stop the loop task and revoke the access token when the cog is unloaded.""" + self.auto_poster_loop.cancel() + if self.access_token and self.access_token.expires_at > datetime.utcnow(): + asyncio.create_task(self.revoke_access_token()) + + async def init_reddit_ready(self) -> None: + """Sets the reddit webhook when the cog is loaded.""" + await self.bot.wait_until_guild_available() + if not self.webhook: + self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) + + @property + def channel(self) -> TextChannel: + """Get the #reddit channel object from the bot's cache.""" + return self.bot.get_channel(Channels.reddit) + + async def get_access_token(self) -> None: + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded and a ClientError raised if retrieval was still unsuccessful. + """ + for i in range(1, self.MAX_RETRIES + 1): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(seconds=expiration) + ) + + log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") + return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) + + await asyncio.sleep(3) + + self.bot.remove_cog(self.qualified_name) + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + + async def revoke_access_token(self) -> None: + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } + ) + + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + else: + log.warning(f"Unable to revoke access token: status {response.status}.") + + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: + """A helper method to fetch a certain amount of Reddit posts at a given route.""" + # Reddit's JSON responses only provide 25 posts at most. + if not 25 >= amount > 0: + raise ValueError("Invalid amount of subreddit posts requested.") + + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): + response = await self.bot.http_session.get( + url=url, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, + params=params + ) + if response.status == 200 and response.content_type == 'application/json': + # Got appropriate response - process and return. + content = await response.json() + posts = content["data"]["children"] + return posts[:amount] + + await asyncio.sleep(3) + + log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") + return list() # Failed to get appropriate response within allowed number of retries. + + async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: + """ + Get the top amount of posts for a given subreddit within a specified timeframe. + + A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top + weekly posts. + + The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. + """ + embed = Embed(description="") + + posts = await self.fetch_posts( + route=f"{subreddit}/top", + amount=amount, + params={"t": time} + ) + + if not posts: + embed.title = random.choice(ERROR_REPLIES) + embed.colour = Colour.red() + embed.description = ( + "Sorry! We couldn't find any posts from that subreddit. " + "If this problem persists, please let us know." + ) + + return embed + + for post in posts: + data = post["data"] + + text = data["selftext"] + if text: + text = textwrap.shorten(text, width=128, placeholder="...") + text += "\n" # Add newline to separate embed info + + ups = data["ups"] + comments = data["num_comments"] + author = data["author"] + + title = textwrap.shorten(data["title"], width=64, placeholder="...") + link = self.URL + data["permalink"] + + embed.description += ( + f"**[{title}]({link})**\n" + f"{text}" + f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" + ) + + embed.colour = Colour.blurple() + return embed + + @loop() + async def auto_poster_loop(self) -> None: + """Post the top 5 posts daily, and the top 5 posts weekly.""" + # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter + now = datetime.utcnow() + tomorrow = now + timedelta(days=1) + midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) + seconds_until = (midnight_tomorrow - now).total_seconds() + + await asyncio.sleep(seconds_until) + + await self.bot.wait_until_guild_available() + if not self.webhook: + await self.bot.fetch_webhook(Webhooks.reddit) + + if datetime.utcnow().weekday() == 0: + await self.top_weekly_posts() + # if it's a monday send the top weekly posts + + for subreddit in RedditConfig.subreddits: + top_posts = await self.get_top_posts(subreddit=subreddit, time="day") + username = sub_clyde(f"{subreddit} Top Daily Posts") + message = await self.webhook.send(username=username, embed=top_posts, wait=True) + + if message.channel.is_news(): + await message.publish() + + async def top_weekly_posts(self) -> None: + """Post a summary of the top posts.""" + for subreddit in RedditConfig.subreddits: + # Send and pin the new weekly posts. + top_posts = await self.get_top_posts(subreddit=subreddit, time="week") + username = sub_clyde(f"{subreddit} Top Weekly Posts") + message = await self.webhook.send(wait=True, username=username, embed=top_posts) + + if subreddit.lower() == "r/python": + if not self.channel: + log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") + return + + # Remove the oldest pins so that only 12 remain at most. + pins = await self.channel.pins() + + while len(pins) >= 12: + await pins[-1].unpin() + del pins[-1] + + await message.pin() + + if message.channel.is_news(): + await message.publish() + + @group(name="reddit", invoke_without_command=True) + async def reddit_group(self, ctx: Context) -> None: + """View the top posts from various subreddits.""" + await ctx.send_help(ctx.command) + + @reddit_group.command(name="top") + async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of all time from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="all") + + await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) + + @reddit_group.command(name="daily") + async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of today from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="day") + + await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) + + @reddit_group.command(name="weekly") + async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of this week from a given subreddit.""" + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="week") + + await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) + + @with_role(*STAFF_ROLES) + @reddit_group.command(name="subreddits", aliases=("subs",)) + async def subreddits_command(self, ctx: Context) -> None: + """Send a paginated embed of all the subreddits we're relaying.""" + embed = Embed() + embed.title = "Relayed subreddits." + embed.colour = Colour.blurple() + + await LinePaginator.paginate( + RedditConfig.subreddits, + ctx, embed, + footer_text="Use the reddit commands along with these to view their posts.", + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Reddit cog.""" + if not RedditConfig.secret or not RedditConfig.client_id: + log.error("Credentials not provided, cog not loaded.") + return + bot.add_cog(Reddit(bot)) diff --git a/bot/exts/info/site.py b/bot/exts/info/site.py new file mode 100644 index 000000000..ac29daa1d --- /dev/null +++ b/bot/exts/info/site.py @@ -0,0 +1,146 @@ +import logging + +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import URLs +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + +PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" + + +class Site(Cog): + """Commands for linking to different parts of the site.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="site", aliases=("s",), invoke_without_command=True) + async def site_group(self, ctx: Context) -> None: + """Commands for getting info about our website.""" + await ctx.send_help(ctx.command) + + @site_group.command(name="home", aliases=("about",)) + async def site_main(self, ctx: Context) -> None: + """Info about the website itself.""" + url = f"{URLs.site_schema}{URLs.site}/" + + embed = Embed(title="Python Discord website") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + f"[Our official website]({url}) is an open-source community project " + "created with Python and Django. It contains information about the server " + "itself, lets you sign up for upcoming events, has its own wiki, contains " + "a list of valuable learning resources, and much more." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="resources") + async def site_resources(self, ctx: Context) -> None: + """Info about the site's Resources page.""" + learning_url = f"{PAGES_URL}/resources" + + embed = Embed(title="Resources") + embed.set_footer(text=f"{learning_url}") + embed.colour = Colour.blurple() + embed.description = ( + f"The [Resources page]({learning_url}) on our website contains a " + "list of hand-selected learning resources that we regularly recommend " + f"to both beginners and experts." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="tools") + async def site_tools(self, ctx: Context) -> None: + """Info about the site's Tools page.""" + tools_url = f"{PAGES_URL}/resources/tools" + + embed = Embed(title="Tools") + embed.set_footer(text=f"{tools_url}") + embed.colour = Colour.blurple() + embed.description = ( + f"The [Tools page]({tools_url}) on our website contains a " + f"couple of the most popular tools for programming in Python." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="help") + async def site_help(self, ctx: Context) -> None: + """Info about the site's Getting Help page.""" + url = f"{PAGES_URL}/resources/guides/asking-good-questions" + + embed = Embed(title="Asking Good Questions") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + "Asking the right question about something that's new to you can sometimes be tricky. " + f"To help with this, we've created a [guide to asking good questions]({url}) on our website. " + "It contains everything you need to get the very best help from our community." + ) + + await ctx.send(embed=embed) + + @site_group.command(name="faq") + async def site_faq(self, ctx: Context) -> None: + """Info about the site's FAQ page.""" + url = f"{PAGES_URL}/frequently-asked-questions" + + embed = Embed(title="FAQ") + embed.set_footer(text=url) + embed.colour = Colour.blurple() + embed.description = ( + "As the largest Python community on Discord, we get hundreds of questions every day. " + "Many of these questions have been asked before. We've compiled a list of the most " + "frequently asked questions along with their answers, which can be found on " + f"our [FAQ page]({url})." + ) + + await ctx.send(embed=embed) + + @site_group.command(aliases=['r', 'rule'], name='rules') + async def site_rules(self, ctx: Context, *rules: int) -> None: + """Provides a link to all rules or, if specified, displays specific rule(s).""" + rules_embed = Embed(title='Rules', color=Colour.blurple()) + rules_embed.url = f"{PAGES_URL}/rules" + + if not rules: + # Rules were not submitted. Return the default description. + rules_embed.description = ( + "The rules and guidelines that apply to this community can be found on" + f" our [rules page]({PAGES_URL}/rules). We expect" + " all members of the community to have read and understood these." + ) + + await ctx.send(embed=rules_embed) + return + + full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) + invalid_indices = tuple( + pick + for pick in rules + if pick < 1 or pick > len(full_rules) + ) + + if invalid_indices: + indices = ', '.join(map(str, invalid_indices)) + await ctx.send(f":x: Invalid rule indices: {indices}") + return + + for rule in rules: + self.bot.stats.incr(f"rule_uses.{rule}") + + final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) + + await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) + + +def setup(bot: Bot) -> None: + """Load the Site cog.""" + bot.add_cog(Site(bot)) diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py new file mode 100644 index 000000000..205e0ba81 --- /dev/null +++ b/bot/exts/info/source.py @@ -0,0 +1,141 @@ +import inspect +from pathlib import Path +from typing import Optional, Tuple, Union + +from discord import Embed +from discord.ext import commands + +from bot.bot import Bot +from bot.constants import URLs + +SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] + + +class SourceConverter(commands.Converter): + """Convert an argument into a help command, tag, command, or cog.""" + + async def convert(self, ctx: commands.Context, argument: str) -> SourceType: + """Convert argument into source object.""" + if argument.lower().startswith("help"): + return ctx.bot.help_command + + cog = ctx.bot.get_cog(argument) + if cog: + return cog + + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd + + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + elif argument.lower() in tags_cog._cache: + return argument.lower() + + raise commands.BadArgument( + f"Unable to convert `{argument}` to valid command{', tag,' if show_tag else ''} or Cog." + ) + + +class BotSource(commands.Cog): + """Displays information about the bot's source code.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @commands.command(name="source", aliases=("src",)) + async def source_command(self, ctx: commands.Context, *, source_item: SourceConverter = None) -> None: + """Display information and a GitHub link to the source code of a command, tag, or cog.""" + if not source_item: + embed = Embed(title="Bot's GitHub Repository") + embed.add_field(name="Repository", value=f"[Go to GitHub]({URLs.github_bot_repo})") + embed.set_thumbnail(url="https://avatars1.githubusercontent.com/u/9919") + await ctx.send(embed=embed) + return + + embed = await self.build_embed(source_item) + await ctx.send(embed=embed) + + def get_source_link(self, source_item: SourceType) -> Tuple[str, str, Optional[int]]: + """ + Build GitHub link of source item, return this link, file location and first line number. + + Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). + """ + if isinstance(source_item, commands.Command): + if source_item.cog_name == "Alias": + cmd_name = source_item.callback.__name__.replace("_alias", "") + cmd = self.bot.get_command(cmd_name.replace("_", " ")) + src = cmd.callback.__code__ + filename = src.co_filename + else: + src = source_item.callback.__code__ + filename = src.co_filename + elif isinstance(source_item, str): + tags_cog = self.bot.get_cog("Tags") + filename = tags_cog._cache[source_item]["location"] + else: + src = type(source_item) + try: + filename = inspect.getsourcefile(src) + except TypeError: + raise commands.BadArgument("Cannot get source for a dynamically-created object.") + + if not isinstance(source_item, str): + try: + lines, first_line_no = inspect.getsourcelines(src) + except OSError: + raise commands.BadArgument("Cannot get source for a dynamically-created object.") + + lines_extension = f"#L{first_line_no}-L{first_line_no+len(lines)-1}" + else: + first_line_no = None + lines_extension = "" + + # Handle tag file location differently than others to avoid errors in some cases + if not first_line_no: + file_location = Path(filename).relative_to("/bot/") + else: + file_location = Path(filename).relative_to(Path.cwd()).as_posix() + + url = f"{URLs.github_bot_repo}/blob/master/{file_location}{lines_extension}" + + return url, file_location, first_line_no or None + + async def build_embed(self, source_object: SourceType) -> Optional[Embed]: + """Build embed based on source object.""" + url, location, first_line = self.get_source_link(source_object) + + if isinstance(source_object, commands.HelpCommand): + title = "Help Command" + description = source_object.__doc__.splitlines()[1] + elif isinstance(source_object, commands.Command): + if source_object.cog_name == "Alias": + cmd_name = source_object.callback.__name__.replace("_alias", "") + cmd = self.bot.get_command(cmd_name.replace("_", " ")) + description = cmd.short_doc + else: + description = source_object.short_doc + + title = f"Command: {source_object.qualified_name}" + elif isinstance(source_object, str): + title = f"Tag: {source_object}" + description = "" + else: + title = f"Cog: {source_object.qualified_name}" + description = source_object.description.splitlines()[0] + + embed = Embed(title=title, description=description) + embed.add_field(name="Source Code", value=f"[Go to GitHub]({url})") + line_text = f":{first_line}" if first_line else "" + embed.set_footer(text=f"{location}{line_text}") + + return embed + + +def setup(bot: Bot) -> None: + """Load the BotSource cog.""" + bot.add_cog(BotSource(bot)) diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py new file mode 100644 index 000000000..d42f55466 --- /dev/null +++ b/bot/exts/info/stats.py @@ -0,0 +1,129 @@ +import string +from datetime import datetime + +from discord import Member, Message, Status +from discord.ext.commands import Cog, Context +from discord.ext.tasks import loop + +from bot.bot import Bot +from bot.constants import Categories, Channels, Guild, Stats as StatConf + + +CHANNEL_NAME_OVERRIDES = { + Channels.off_topic_0: "off_topic_0", + Channels.off_topic_1: "off_topic_1", + Channels.off_topic_2: "off_topic_2", + Channels.staff_lounge: "staff_lounge" +} + +ALLOWED_CHARS = string.ascii_letters + string.digits + "_" + + +class Stats(Cog): + """A cog which provides a way to hook onto Discord events and forward to stats.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.last_presence_update = None + self.update_guild_boost.start() + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Report message events in the server to statsd.""" + if message.guild is None: + return + + if message.guild.id != Guild.id: + return + + cat = getattr(message.channel, "category", None) + if cat is not None and cat.id == Categories.modmail: + if message.channel.id != Channels.incidents: + # Do not report modmail channels to stats, there are too many + # of them for interesting statistics to be drawn out of this. + return + + reformatted_name = message.channel.name.replace('-', '_') + + if CHANNEL_NAME_OVERRIDES.get(message.channel.id): + reformatted_name = CHANNEL_NAME_OVERRIDES.get(message.channel.id) + + reformatted_name = "".join(char for char in reformatted_name if char in ALLOWED_CHARS) + + stat_name = f"channels.{reformatted_name}" + self.bot.stats.incr(stat_name) + + # Increment the total message count + self.bot.stats.incr("messages") + + @Cog.listener() + async def on_command_completion(self, ctx: Context) -> None: + """Report completed commands to statsd.""" + command_name = ctx.command.qualified_name.replace(" ", "_") + + self.bot.stats.incr(f"commands.{command_name}") + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Update member count stat on member join.""" + if member.guild.id != Guild.id: + return + + self.bot.stats.gauge("guild.total_members", len(member.guild.members)) + + @Cog.listener() + async def on_member_leave(self, member: Member) -> None: + """Update member count stat on member leave.""" + if member.guild.id != Guild.id: + return + + self.bot.stats.gauge("guild.total_members", len(member.guild.members)) + + @Cog.listener() + async def on_member_update(self, _before: Member, after: Member) -> None: + """Update presence estimates on member update.""" + if after.guild.id != Guild.id: + return + + if self.last_presence_update: + if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: + return + + self.last_presence_update = datetime.now() + + online = 0 + idle = 0 + dnd = 0 + offline = 0 + + for member in after.guild.members: + if member.status is Status.online: + online += 1 + elif member.status is Status.dnd: + dnd += 1 + elif member.status is Status.idle: + idle += 1 + elif member.status is Status.offline: + offline += 1 + + self.bot.stats.gauge("guild.status.online", online) + self.bot.stats.gauge("guild.status.idle", idle) + self.bot.stats.gauge("guild.status.do_not_disturb", dnd) + self.bot.stats.gauge("guild.status.offline", offline) + + @loop(hours=1) + async def update_guild_boost(self) -> None: + """Post the server boost level and tier every hour.""" + await self.bot.wait_until_guild_available() + g = self.bot.get_guild(Guild.id) + self.bot.stats.gauge("boost.amount", g.premium_subscription_count) + self.bot.stats.gauge("boost.tier", g.premium_tier) + + def cog_unload(self) -> None: + """Stop the boost statistic task on unload of the Cog.""" + self.update_guild_boost.stop() + + +def setup(bot: Bot) -> None: + """Load the stats cog.""" + bot.add_cog(Stats(bot)) diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py new file mode 100644 index 000000000..3d76c5c08 --- /dev/null +++ b/bot/exts/info/tags.py @@ -0,0 +1,277 @@ +import logging +import re +import time +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot import constants +from bot.bot import Bot +from bot.converters import TagNameConverter +from bot.pagination import LinePaginator +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +TEST_CHANNELS = ( + constants.Channels.bot_commands, + constants.Channels.helpers +) + +REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) +FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." + + +class Tags(Cog): + """Save new tags and fetch existing tags.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.tag_cooldowns = {} + self._cache = self.get_tags() + + @staticmethod + def get_tags() -> dict: + """Get all tags.""" + cache = {} + + base_path = Path("bot", "resources", "tags") + for file in base_path.glob("**/*"): + if file.is_file(): + tag_title = file.stem + tag = { + "title": tag_title, + "embed": { + "description": file.read_text(encoding="utf8"), + }, + "restricted_to": "developers", + "location": f"/bot/{file}" + } + + # Convert to a list to allow negative indexing. + parents = list(file.relative_to(base_path).parents) + if len(parents) > 1: + # -1 would be '.' hence -2 is used as the index. + tag["restricted_to"] = parents[-2].name + + cache[tag_title] = tag + + return cache + + @staticmethod + def check_accessibility(user: Member, tag: dict) -> bool: + """Check if user can access a tag.""" + return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] + + @staticmethod + def _fuzzy_search(search: str, target: str) -> float: + """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" + current, index = 0, 0 + _search = REGEX_NON_ALPHABET.sub('', search.lower()) + _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) + _target = next(_targets) + try: + while True: + while index < len(_target) and _search[current] == _target[index]: + current += 1 + index += 1 + index, _target = 0, next(_targets) + except (StopIteration, IndexError): + pass + return current / len(_search) * 100 + + def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: + """Return a list of suggested tags.""" + scores: Dict[str, int] = { + tag_title: Tags._fuzzy_search(tag_name, tag['title']) + for tag_title, tag in self._cache.items() + } + + thresholds = thresholds or [100, 90, 80, 70, 60] + + for threshold in thresholds: + suggestions = [ + self._cache[tag_title] + for tag_title, matching_score in scores.items() + if matching_score >= threshold + ] + if suggestions: + return suggestions + + return [] + + def _get_tag(self, tag_name: str) -> list: + """Get a specific tag.""" + found = [self._cache.get(tag_name.lower(), None)] + if not found[0]: + return self._get_suggestions(tag_name) + return found + + def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: + """ + Search for tags via contents. + + `predicate` will be the built-in any, all, or a custom callable. Must return a bool. + """ + keywords_processed: List[str] = [] + for keyword in keywords.split(','): + keyword_sanitized = keyword.strip().casefold() + if not keyword_sanitized: + # this happens when there are leading / trailing / consecutive comma. + continue + keywords_processed.append(keyword_sanitized) + + if not keywords_processed: + # after sanitizing, we can end up with an empty list, for example when keywords is ',' + # in that case, we simply want to search for such keywords directly instead. + keywords_processed = [keywords] + + matching_tags = [] + for tag in self._cache.values(): + matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) + if self.check_accessibility(user, tag) and check(matches): + matching_tags.append(tag) + + return matching_tags + + async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: + """Send the result of matching tags to user.""" + if not matching_tags: + pass + elif len(matching_tags) == 1: + await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) + else: + is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 + embed = Embed( + title=f"Here are the tags containing the given keyword{'s' * is_plural}:", + description='\n'.join(tag['title'] for tag in matching_tags[:10]) + ) + await LinePaginator.paginate( + sorted(f"**»** {tag['title']}" for tag in matching_tags), + ctx, + embed, + footer_text=FOOTER_TEXT, + empty=False, + max_lines=15 + ) + + @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) + async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + """Show all known tags, a single tag, or run a subcommand.""" + await ctx.invoke(self.get_command, tag_name=tag_name) + + @tags_group.group(name='search', invoke_without_command=True) + async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Only search for tags that has ALL the keywords. + """ + matching_tags = self._get_tags_via_content(all, keywords, ctx.author) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @search_tag_content.command(name='any') + async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Search for tags that has ANY of the keywords. + """ + matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @tags_group.command(name='get', aliases=('show', 'g')) + async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + """Get a specified tag, or a list of all tags if no tag is specified.""" + + def _command_on_cooldown(tag_name: str) -> bool: + """ + Check if the command is currently on cooldown, on a per-tag, per-channel basis. + + The cooldown duration is set in constants.py. + """ + now = time.time() + + cooldown_conditions = ( + tag_name + and tag_name in self.tag_cooldowns + and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags + and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id + ) + + if cooldown_conditions: + return True + return False + + if _command_on_cooldown(tag_name): + time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] + time_left = constants.Cooldowns.tags - time_elapsed + log.info( + f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " + f"Cooldown ends in {time_left:.1f} seconds." + ) + return + + if tag_name is not None: + temp_founds = self._get_tag(tag_name) + + founds = [] + + for found_tag in temp_founds: + if self.check_accessibility(ctx.author, found_tag): + founds.append(found_tag) + + if len(founds) == 1: + tag = founds[0] + if ctx.channel.id not in TEST_CHANNELS: + self.tag_cooldowns[tag_name] = { + "time": time.time(), + "channel": ctx.channel.id + } + + self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") + + await wait_for_deletion( + await ctx.send(embed=Embed.from_dict(tag['embed'])), + [ctx.author.id], + client=self.bot + ) + elif founds and len(tag_name) >= 3: + await wait_for_deletion( + await ctx.send( + embed=Embed( + title='Did you mean ...', + description='\n'.join(tag['title'] for tag in founds[:10]) + ) + ), + [ctx.author.id], + client=self.bot + ) + + else: + tags = self._cache.values() + if not tags: + await ctx.send(embed=Embed( + description="**There are no tags in the database!**", + colour=Colour.red() + )) + else: + embed: Embed = Embed(title="**Current tags**") + await LinePaginator.paginate( + sorted( + f"**»** {tag['title']}" for tag in tags + if self.check_accessibility(ctx.author, tag) + ), + ctx, + embed, + footer_text=FOOTER_TEXT, + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Tags cog.""" + bot.add_cog(Tags(bot)) diff --git a/bot/exts/info/wolfram.py b/bot/exts/info/wolfram.py new file mode 100644 index 000000000..e6cae3bb8 --- /dev/null +++ b/bot/exts/info/wolfram.py @@ -0,0 +1,280 @@ +import logging +from io import BytesIO +from typing import Callable, List, Optional, Tuple +from urllib import parse + +import discord +from dateutil.relativedelta import relativedelta +from discord import Embed +from discord.ext import commands +from discord.ext.commands import BucketType, Cog, Context, check, group + +from bot.bot import Bot +from bot.constants import Colours, STAFF_ROLES, Wolfram +from bot.pagination import ImagePaginator +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +APPID = Wolfram.key +DEFAULT_OUTPUT_FORMAT = "JSON" +QUERY = "http://api.wolframalpha.com/v2/{request}?{data}" +WOLF_IMAGE = "https://www.symbols.com/gi.php?type=1&id=2886&i=1" + +MAX_PODS = 20 + +# Allows for 10 wolfram calls pr user pr day +usercd = commands.CooldownMapping.from_cooldown(Wolfram.user_limit_day, 60*60*24, BucketType.user) + +# Allows for max api requests / days in month per day for the entire guild (Temporary) +guildcd = commands.CooldownMapping.from_cooldown(Wolfram.guild_limit_day, 60*60*24, BucketType.guild) + + +async def send_embed( + ctx: Context, + message_txt: str, + colour: int = Colours.soft_red, + footer: str = None, + img_url: str = None, + f: discord.File = None +) -> None: + """Generate & send a response embed with Wolfram as the author.""" + embed = Embed(colour=colour) + embed.description = message_txt + embed.set_author(name="Wolfram Alpha", + icon_url=WOLF_IMAGE, + url="https://www.wolframalpha.com/") + if footer: + embed.set_footer(text=footer) + + if img_url: + embed.set_image(url=img_url) + + await ctx.send(embed=embed, file=f) + + +def custom_cooldown(*ignore: List[int]) -> Callable: + """ + Implement per-user and per-guild cooldowns for requests to the Wolfram API. + + A list of roles may be provided to ignore the per-user cooldown + """ + async def predicate(ctx: Context) -> bool: + if ctx.invoked_with == 'help': + # if the invoked command is help we don't want to increase the ratelimits since it's not actually + # invoking the command/making a request, so instead just check if the user/guild are on cooldown. + guild_cooldown = not guildcd.get_bucket(ctx.message).get_tokens() == 0 # if guild is on cooldown + if not any(r.id in ignore for r in ctx.author.roles): # check user bucket if user is not ignored + return guild_cooldown and not usercd.get_bucket(ctx.message).get_tokens() == 0 + return guild_cooldown + + user_bucket = usercd.get_bucket(ctx.message) + + if all(role.id not in ignore for role in ctx.author.roles): + user_rate = user_bucket.update_rate_limit() + + if user_rate: + # Can't use api; cause: member limit + delta = relativedelta(seconds=int(user_rate)) + cooldown = humanize_delta(delta) + message = ( + "You've used up your limit for Wolfram|Alpha requests.\n" + f"Cooldown: {cooldown}" + ) + await send_embed(ctx, message) + return False + + guild_bucket = guildcd.get_bucket(ctx.message) + guild_rate = guild_bucket.update_rate_limit() + + # Repr has a token attribute to read requests left + log.debug(guild_bucket) + + if guild_rate: + # Can't use api; cause: guild limit + message = ( + "The max limit of requests for the server has been reached for today.\n" + f"Cooldown: {int(guild_rate)}" + ) + await send_embed(ctx, message) + return False + + return True + return check(predicate) + + +async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tuple]]: + """Get the Wolfram API pod pages for the provided query.""" + async with ctx.channel.typing(): + url_str = parse.urlencode({ + "input": query, + "appid": APPID, + "output": DEFAULT_OUTPUT_FORMAT, + "format": "image,plaintext" + }) + request_url = QUERY.format(request="query", data=url_str) + + async with bot.http_session.get(request_url) as response: + json = await response.json(content_type='text/plain') + + result = json["queryresult"] + + if result["error"]: + # API key not set up correctly + if result["error"]["msg"] == "Invalid appid": + message = "Wolfram API key is invalid or missing." + log.warning( + "API key seems to be missing, or invalid when " + f"processing a wolfram request: {url_str}, Response: {json}" + ) + await send_embed(ctx, message) + return + + message = "Something went wrong internally with your request, please notify staff!" + log.warning(f"Something went wrong getting a response from wolfram: {url_str}, Response: {json}") + await send_embed(ctx, message) + return + + if not result["success"]: + message = f"I couldn't find anything for {query}." + await send_embed(ctx, message) + return + + if not result["numpods"]: + message = "Could not find any results." + await send_embed(ctx, message) + return + + pods = result["pods"] + pages = [] + for pod in pods[:MAX_PODS]: + subs = pod.get("subpods") + + for sub in subs: + title = sub.get("title") or sub.get("plaintext") or sub.get("id", "") + img = sub["img"]["src"] + pages.append((title, img)) + return pages + + +class Wolfram(Cog): + """Commands for interacting with the Wolfram|Alpha API.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_command(self, ctx: Context, *, query: str) -> None: + """Requests all answers on a single image, sends an image of all related pods.""" + url_str = parse.urlencode({ + "i": query, + "appid": APPID, + }) + query = QUERY.format(request="simple", data=url_str) + + # Give feedback that the bot is working. + async with ctx.channel.typing(): + async with self.bot.http_session.get(query) as response: + status = response.status + image_bytes = await response.read() + + f = discord.File(BytesIO(image_bytes), filename="image.png") + image_url = "attachment://image.png" + + if status == 501: + message = "Failed to get response" + footer = "" + color = Colours.soft_red + elif status == 400: + message = "No input found" + footer = "" + color = Colours.soft_red + elif status == 403: + message = "Wolfram API key is invalid or missing." + footer = "" + color = Colours.soft_red + else: + message = "" + footer = "View original for a bigger picture." + color = Colours.soft_orange + + # Sends a "blank" embed if no request is received, unsure how to fix + await send_embed(ctx, message, color, footer=footer, img_url=image_url, f=f) + + @wolfram_command.command(name="page", aliases=("pa", "p")) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_page_command(self, ctx: Context, *, query: str) -> None: + """ + Requests a drawn image of given query. + + Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. + """ + pages = await get_pod_pages(ctx, self.bot, query) + + if not pages: + return + + embed = Embed() + embed.set_author(name="Wolfram Alpha", + icon_url=WOLF_IMAGE, + url="https://www.wolframalpha.com/") + embed.colour = Colours.soft_orange + + await ImagePaginator.paginate(pages, ctx, embed) + + @wolfram_command.command(name="cut", aliases=("c",)) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_cut_command(self, ctx: Context, *, query: str) -> None: + """ + Requests a drawn image of given query. + + Keywords worth noting are, "like curve", "curve", "graph", "pokemon", etc. + """ + pages = await get_pod_pages(ctx, self.bot, query) + + if not pages: + return + + if len(pages) >= 2: + page = pages[1] + else: + page = pages[0] + + await send_embed(ctx, page[0], colour=Colours.soft_orange, img_url=page[1]) + + @wolfram_command.command(name="short", aliases=("sh", "s")) + @custom_cooldown(*STAFF_ROLES) + async def wolfram_short_command(self, ctx: Context, *, query: str) -> None: + """Requests an answer to a simple question.""" + url_str = parse.urlencode({ + "i": query, + "appid": APPID, + }) + query = QUERY.format(request="result", data=url_str) + + # Give feedback that the bot is working. + async with ctx.channel.typing(): + async with self.bot.http_session.get(query) as response: + status = response.status + response_text = await response.text() + + if status == 501: + message = "Failed to get response" + color = Colours.soft_red + elif status == 400: + message = "No input found" + color = Colours.soft_red + elif response_text == "Error 1: Invalid appid": + message = "Wolfram API key is invalid or missing." + color = Colours.soft_red + else: + message = response_text + color = Colours.soft_orange + + await send_embed(ctx, message, color) + + +def setup(bot: Bot) -> None: + """Load the Wolfram cog.""" + bot.add_cog(Wolfram(bot)) diff --git a/bot/exts/moderation/__init__.py b/bot/exts/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py new file mode 100644 index 000000000..b75a4dcfe --- /dev/null +++ b/bot/exts/moderation/defcon.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import logging +from collections import namedtuple +from datetime import datetime, timedelta +from enum import Enum + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles +from bot.decorators import with_role +from bot.exts.moderation.modlog import ModLog + +log = logging.getLogger(__name__) + +REJECTION_MESSAGE = """ +Hi, {user} - Thanks for your interest in our server! + +Due to a current (or detected) cyberattack on our community, we've limited access to the server for new accounts. Since +your account is relatively new, we're unable to provide access to the server at this time. + +Even so, thanks for joining! We're very excited at the possibility of having you here, and we hope that this situation +will be resolved soon. In the meantime, please feel free to peruse the resources on our site at +, and have a nice day! +""" + +BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism" + + +class Action(Enum): + """Defcon Action.""" + + ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template']) + + ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n") + DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "") + UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n") + + +class Defcon(Cog): + """Time-sensitive server defense mechanisms.""" + + days = None # type: timedelta + enabled = False # type: bool + + def __init__(self, bot: Bot): + self.bot = bot + self.channel = None + self.days = timedelta(days=0) + + self.bot.loop.create_task(self.sync_settings()) + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def sync_settings(self) -> None: + """On cog load, try to synchronize DEFCON settings to the API.""" + await self.bot.wait_until_guild_available() + self.channel = await self.bot.fetch_channel(Channels.defcon) + + try: + response = await self.bot.api_client.get('bot/bot-settings/defcon') + data = response['data'] + + except Exception: # Yikes! + log.exception("Unable to get DEFCON settings!") + await self.bot.get_channel(Channels.dev_log).send( + f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" + ) + + else: + if data["enabled"]: + self.enabled = True + self.days = timedelta(days=data["days"]) + log.info(f"DEFCON enabled: {self.days.days} days") + + else: + self.enabled = False + self.days = timedelta(days=0) + log.info("DEFCON disabled") + + await self.update_channel_topic() + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """If DEFCON is enabled, check newly joining users to see if they meet the account age threshold.""" + if self.enabled and self.days.days > 0: + now = datetime.utcnow() + + if now - member.created_at < self.days: + log.info(f"Rejecting user {member}: Account is too new and DEFCON is enabled") + + message_sent = False + + try: + await member.send(REJECTION_MESSAGE.format(user=member.mention)) + + message_sent = True + except Exception: + log.exception(f"Unable to send rejection message to user: {member}") + + await member.kick(reason="DEFCON active, user is too new") + self.bot.stats.incr("defcon.leaves") + + message = ( + f"{member} (`{member.id}`) was denied entry because their account is too new." + ) + + if not message_sent: + message = f"{message}\n\nUnable to send rejection message via DM; they probably have DMs disabled." + + await self.mod_log.send_log_message( + Icons.defcon_denied, Colours.soft_red, "Entry denied", + message, member.avatar_url_as(static_format="png") + ) + + @group(name='defcon', aliases=('dc',), invoke_without_command=True) + @with_role(Roles.admins, Roles.owners) + async def defcon_group(self, ctx: Context) -> None: + """Check the DEFCON status or run a subcommand.""" + await ctx.send_help(ctx.command) + + async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: + """Providing a structured way to do an defcon action.""" + try: + response = await self.bot.api_client.get('bot/bot-settings/defcon') + data = response['data'] + + if "enable_date" in data and action is Action.DISABLED: + enabled = datetime.fromisoformat(data["enable_date"]) + + delta = datetime.now() - enabled + + self.bot.stats.timing("defcon.enabled", delta) + except Exception: + pass + + error = None + try: + await self.bot.api_client.put( + 'bot/bot-settings/defcon', + json={ + 'name': 'defcon', + 'data': { + # TODO: retrieve old days count + 'days': days, + 'enabled': action is not Action.DISABLED, + 'enable_date': datetime.now().isoformat() + } + } + ) + except Exception as err: + log.exception("Unable to update DEFCON settings.") + error = err + finally: + await ctx.send(self.build_defcon_msg(action, error)) + await self.send_defcon_log(action, ctx.author, error) + + self.bot.stats.gauge("defcon.threshold", days) + + @defcon_group.command(name='enable', aliases=('on', 'e')) + @with_role(Roles.admins, Roles.owners) + async def enable_command(self, ctx: Context) -> None: + """ + Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! + + Currently, this just adds an account age requirement. Use !defcon days to set how old an account must be, + in days. + """ + self.enabled = True + await self._defcon_action(ctx, days=0, action=Action.ENABLED) + await self.update_channel_topic() + + @defcon_group.command(name='disable', aliases=('off', 'd')) + @with_role(Roles.admins, Roles.owners) + async def disable_command(self, ctx: Context) -> None: + """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" + self.enabled = False + await self._defcon_action(ctx, days=0, action=Action.DISABLED) + await self.update_channel_topic() + + @defcon_group.command(name='status', aliases=('s',)) + @with_role(Roles.admins, Roles.owners) + async def status_command(self, ctx: Context) -> None: + """Check the current status of DEFCON mode.""" + embed = Embed( + colour=Colour.blurple(), title="DEFCON Status", + description=f"**Enabled:** {self.enabled}\n" + f"**Days:** {self.days.days}" + ) + + await ctx.send(embed=embed) + + @defcon_group.command(name='days') + @with_role(Roles.admins, Roles.owners) + async def days_command(self, ctx: Context, days: int) -> None: + """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" + self.days = timedelta(days=days) + self.enabled = True + await self._defcon_action(ctx, days=days, action=Action.UPDATED) + await self.update_channel_topic() + + async def update_channel_topic(self) -> None: + """Update the #defcon channel topic with the current DEFCON status.""" + if self.enabled: + day_str = "days" if self.days.days > 1 else "day" + new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Enabled, Threshold: {self.days.days} {day_str})" + else: + new_topic = f"{BASE_CHANNEL_TOPIC}\n(Status: Disabled)" + + self.mod_log.ignore(Event.guild_channel_update, Channels.defcon) + await self.channel.edit(topic=new_topic) + + def build_defcon_msg(self, action: Action, e: Exception = None) -> str: + """Build in-channel response string for DEFCON action.""" + if action is Action.ENABLED: + msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n" + elif action is Action.DISABLED: + msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n" + elif action is Action.UPDATED: + msg = ( + f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} " + f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n" + ) + + if e: + msg += ( + "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" + f"```py\n{e}\n```" + ) + + return msg + + async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None: + """Send log message for DEFCON action.""" + info = action.value + log_msg: str = ( + f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n" + f"{info.template.format(days=self.days.days)}" + ) + status_msg = f"DEFCON {action.name.lower()}" + + if e: + log_msg += ( + "**There was a problem updating the site** - This setting may be reverted when the bot restarts.\n\n" + f"```py\n{e}\n```" + ) + + await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg) + + +def setup(bot: Bot) -> None: + """Load the Defcon cog.""" + bot.add_cog(Defcon(bot)) diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py new file mode 100644 index 000000000..e49913552 --- /dev/null +++ b/bot/exts/moderation/incidents.py @@ -0,0 +1,412 @@ +import asyncio +import logging +import typing as t +from datetime import datetime +from enum import Enum + +import discord +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, Colours, Emojis, Guild, Webhooks +from bot.utils.messages import sub_clyde + +log = logging.getLogger(__name__) + +# Amount of messages for `crawl_task` to process at most on start-up - limited to 50 +# as in practice, there should never be this many messages, and if there are, +# something has likely gone very wrong +CRAWL_LIMIT = 50 + +# Seconds for `crawl_task` to sleep after adding reactions to a message +CRAWL_SLEEP = 2 + + +class Signal(Enum): + """ + Recognized incident status signals. + + This binds emoji to actions. The bot will only react to emoji linked here. + All other signals are seen as invalid. + """ + + ACTIONED = Emojis.incident_actioned + NOT_ACTIONED = Emojis.incident_unactioned + INVESTIGATING = Emojis.incident_investigating + + +# Reactions from non-mod roles will be removed +ALLOWED_ROLES: t.Set[int] = set(Guild.moderation_roles) + +# Message must have all of these emoji to pass the `has_signals` check +ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} + +# An embed coupled with an optional file to be dispatched +# If the file is not None, the embed attempts to show it in its body +FileEmbed = t.Tuple[discord.Embed, t.Optional[discord.File]] + + +async def download_file(attachment: discord.Attachment) -> t.Optional[discord.File]: + """ + Download & return `attachment` file. + + If the download fails, the reason is logged and None will be returned. + 404 and 403 errors are only logged at debug level. + """ + log.debug(f"Attempting to download attachment: {attachment.filename}") + try: + return await attachment.to_file() + except (discord.NotFound, discord.Forbidden) as exc: + log.debug(f"Failed to download attachment: {exc}") + except Exception: + log.exception("Failed to download attachment") + + +async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> FileEmbed: + """ + Create an embed representation of `incident` for the #incidents-archive channel. + + The name & discriminator of `actioned_by` and `outcome` will be presented in the + embed footer. Additionally, the embed is coloured based on `outcome`. + + The author of `incident` is not shown in the embed. It is assumed that this piece + of information will be relayed in other ways, e.g. webhook username. + + As mentions in embeds do not ping, we do not need to use `incident.clean_content`. + + If `incident` contains attachments, the first attachment will be downloaded and + returned alongside the embed. The embed attempts to display the attachment. + Should the download fail, we fallback on linking the `proxy_url`, which should + remain functional for some time after the original message is deleted. + """ + log.trace(f"Creating embed for {incident.id=}") + + if outcome is Signal.ACTIONED: + colour = Colours.soft_green + footer = f"Actioned by {actioned_by}" + else: + colour = Colours.soft_red + footer = f"Rejected by {actioned_by}" + + embed = discord.Embed( + description=incident.content, + timestamp=datetime.utcnow(), + colour=colour, + ) + embed.set_footer(text=footer, icon_url=actioned_by.avatar_url) + + if incident.attachments: + attachment = incident.attachments[0] # User-sent messages can only contain one attachment + file = await download_file(attachment) + + if file is not None: + embed.set_image(url=f"attachment://{attachment.filename}") # Embed displays the attached file + else: + embed.set_author(name="[Failed to relay attachment]", url=attachment.proxy_url) # Embed links the file + else: + file = None + + return embed, file + + +def is_incident(message: discord.Message) -> bool: + """True if `message` qualifies as an incident, False otherwise.""" + conditions = ( + message.channel.id == Channels.incidents, # Message sent in #incidents + not message.author.bot, # Not by a bot + not message.content.startswith("#"), # Doesn't start with a hash + not message.pinned, # And isn't header + ) + return all(conditions) + + +def own_reactions(message: discord.Message) -> t.Set[str]: + """Get the set of reactions placed on `message` by the bot itself.""" + return {str(reaction.emoji) for reaction in message.reactions if reaction.me} + + +def has_signals(message: discord.Message) -> bool: + """True if `message` already has all `Signal` reactions, False otherwise.""" + return ALL_SIGNALS.issubset(own_reactions(message)) + + +async def add_signals(incident: discord.Message) -> None: + """ + Add `Signal` member emoji to `incident` as reactions. + + If the emoji has already been placed on `incident` by the bot, it will be skipped. + """ + existing_reacts = own_reactions(incident) + + for signal_emoji in Signal: + if signal_emoji.value in existing_reacts: # This would not raise, but it is a superfluous API call + log.trace(f"Skipping emoji as it's already been placed: {signal_emoji}") + else: + log.trace(f"Adding reaction: {signal_emoji}") + await incident.add_reaction(signal_emoji.value) + + +class Incidents(Cog): + """ + Automation for the #incidents channel. + + This cog does not provide a command API, it only reacts to the following events. + + On start-up: + * Crawl #incidents and add missing `Signal` emoji where appropriate + * This is to retro-actively add the available options for messages which + were sent while the bot wasn't listening + * Pinned messages and message starting with # do not qualify as incidents + * See: `crawl_incidents` + + On message: + * Add `Signal` member emoji if message qualifies as an incident + * Ignore messages starting with # + * Use this if verbal communication is necessary + * Each such message must be deleted manually once appropriate + * See: `on_message` + + On reaction: + * Remove reaction if not permitted + * User does not have any of the roles in `ALLOWED_ROLES` + * Used emoji is not a `Signal` member + * If `Signal.ACTIONED` or `Signal.NOT_ACTIONED` were chosen, attempt to + relay the incident message to #incidents-archive + * If relay successful, delete original message + * See: `on_raw_reaction_add` + + Please refer to function docstrings for implementation details. + """ + + def __init__(self, bot: Bot) -> None: + """Prepare `event_lock` and schedule `crawl_task` on start-up.""" + self.bot = bot + + self.event_lock = asyncio.Lock() + self.crawl_task = self.bot.loop.create_task(self.crawl_incidents()) + + async def crawl_incidents(self) -> None: + """ + Crawl #incidents and add missing emoji where necessary. + + This is to catch-up should an incident be reported while the bot wasn't listening. + After adding each reaction, we take a short break to avoid drowning in ratelimits. + + Once this task is scheduled, listeners that change messages should await it. + The crawl assumes that the channel history doesn't change as we go over it. + + Behaviour is configured by: `CRAWL_LIMIT`, `CRAWL_SLEEP`. + """ + await self.bot.wait_until_guild_available() + incidents: discord.TextChannel = self.bot.get_channel(Channels.incidents) + + log.debug(f"Crawling messages in #incidents: {CRAWL_LIMIT=}, {CRAWL_SLEEP=}") + async for message in incidents.history(limit=CRAWL_LIMIT): + + if not is_incident(message): + log.trace(f"Skipping message {message.id}: not an incident") + continue + + if has_signals(message): + log.trace(f"Skipping message {message.id}: already has all signals") + continue + + await add_signals(message) + await asyncio.sleep(CRAWL_SLEEP) + + log.debug("Crawl task finished!") + + async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: + """ + Relay an embed representation of `incident` to the #incidents-archive channel. + + The following pieces of information are relayed: + * Incident message content (as embed description) + * Incident attachment (if image, shown in archive embed) + * Incident author name (as webhook author) + * Incident author avatar (as webhook avatar) + * Resolution signal `outcome` (as embed colour & footer) + * Moderator `actioned_by` (name & discriminator shown in footer) + + If `incident` contains an attachment, we try to add it to the archive embed. There is + no handing of extensions / file types - we simply dispatch the attachment file with the + webhook, and try to display it in the embed. Testing indicates that if the attachment + cannot be displayed (e.g. a text file), it's invisible in the embed, with no error. + + Return True if the relay finishes successfully. If anything goes wrong, meaning + not all information was relayed, return False. This signals that the original + message is not safe to be deleted, as we will lose some information. + """ + log.debug(f"Archiving incident: {incident.id} (outcome: {outcome}, actioned by: {actioned_by})") + embed, attachment_file = await make_embed(incident, outcome, actioned_by) + + try: + webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive) + await webhook.send( + embed=embed, + username=sub_clyde(incident.author.name), + avatar_url=incident.author.avatar_url, + file=attachment_file, + ) + except Exception: + log.exception(f"Failed to archive incident {incident.id} to #incidents-archive") + return False + else: + log.trace("Message archived successfully!") + return True + + def make_confirmation_task(self, incident: discord.Message, timeout: int = 5) -> asyncio.Task: + """ + Create a task to wait `timeout` seconds for `incident` to be deleted. + + If `timeout` passes, this will raise `asyncio.TimeoutError`, signaling that we haven't + been able to confirm that the message was deleted. + """ + log.trace(f"Confirmation task will wait {timeout=} seconds for {incident.id=} to be deleted") + + def check(payload: discord.RawReactionActionEvent) -> bool: + return payload.message_id == incident.id + + coroutine = self.bot.wait_for(event="raw_message_delete", check=check, timeout=timeout) + return self.bot.loop.create_task(coroutine) + + async def process_event(self, reaction: str, incident: discord.Message, member: discord.Member) -> None: + """ + Process a `reaction_add` event in #incidents. + + First, we check that the reaction is a recognized `Signal` member, and that it was sent by + a permitted user (at least one role in `ALLOWED_ROLES`). If not, the reaction is removed. + + If the reaction was either `Signal.ACTIONED` or `Signal.NOT_ACTIONED`, we attempt to relay + the report to #incidents-archive. If successful, the original message is deleted. + + We do not release `event_lock` until we receive the corresponding `message_delete` event. + This ensures that if there is a racing event awaiting the lock, it will fail to find the + message, and will abort. There is a `timeout` to ensure that this doesn't hold the lock + forever should something go wrong. + """ + members_roles: t.Set[int] = {role.id for role in member.roles} + if not members_roles & ALLOWED_ROLES: # Intersection is truthy on at least 1 common element + log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals") + await incident.remove_reaction(reaction, member) + return + + try: + signal = Signal(reaction) + except ValueError: + log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") + await incident.remove_reaction(reaction, member) + return + + log.trace(f"Received signal: {signal}") + + if signal not in (Signal.ACTIONED, Signal.NOT_ACTIONED): + log.debug("Reaction was valid, but no action is currently defined for it") + return + + relay_successful = await self.archive(incident, signal, actioned_by=member) + if not relay_successful: + log.trace("Original message will not be deleted as we failed to relay it to the archive") + return + + timeout = 5 # Seconds + confirmation_task = self.make_confirmation_task(incident, timeout) + + log.trace("Deleting original message") + await incident.delete() + + log.trace(f"Awaiting deletion confirmation: {timeout=} seconds") + try: + await confirmation_task + except asyncio.TimeoutError: + log.warning(f"Did not receive incident deletion confirmation within {timeout} seconds!") + else: + log.trace("Deletion was confirmed") + + async def resolve_message(self, message_id: int) -> t.Optional[discord.Message]: + """ + Get `discord.Message` for `message_id` from cache, or API. + + We first look into the local cache to see if the message is present. + + If not, we try to fetch the message from the API. This is necessary for messages + which were sent before the bot's current session. + + In an edge-case, it is also possible that the message was already deleted, and + the API will respond with a 404. In such a case, None will be returned. + This signals that the event for `message_id` should be ignored. + """ + await self.bot.wait_until_guild_available() # First make sure that the cache is ready + log.trace(f"Resolving message for: {message_id=}") + message: t.Optional[discord.Message] = self.bot._connection._get_message(message_id) + + if message is not None: + log.trace("Message was found in cache") + return message + + log.trace("Message not found, attempting to fetch") + try: + message = await self.bot.get_channel(Channels.incidents).fetch_message(message_id) + except discord.NotFound: + log.trace("Message doesn't exist, it was likely already relayed") + except Exception: + log.exception(f"Failed to fetch message {message_id}!") + else: + log.trace("Message fetched successfully!") + return message + + @Cog.listener() + async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None: + """ + Pre-process `payload` and pass it to `process_event` if appropriate. + + We abort instantly if `payload` doesn't relate to a message sent in #incidents, + or if it was sent by a bot. + + If `payload` relates to a message in #incidents, we first ensure that `crawl_task` has + finished, to make sure we don't mutate channel state as we're crawling it. + + Next, we acquire `event_lock` - to prevent racing, events are processed one at a time. + + Once we have the lock, the `discord.Message` object for this event must be resolved. + If the lock was previously held by an event which successfully relayed the incident, + this will fail and we abort the current event. + + Finally, with both the lock and the `discord.Message` instance in our hands, we delegate + to `process_event` to handle the event. + + The justification for using a raw listener is the need to receive events for messages + which were not cached in the current session. As a result, a certain amount of + complexity is introduced, but at the moment this doesn't appear to be avoidable. + """ + if payload.channel_id != Channels.incidents or payload.member.bot: + return + + log.trace(f"Received reaction add event in #incidents, waiting for crawler: {self.crawl_task.done()=}") + await self.crawl_task + + log.trace(f"Acquiring event lock: {self.event_lock.locked()=}") + async with self.event_lock: + message = await self.resolve_message(payload.message_id) + + if message is None: + log.debug("Listener will abort as related message does not exist!") + return + + if not is_incident(message): + log.debug("Ignoring event for a non-incident message") + return + + await self.process_event(str(payload.emoji), message, payload.member) + log.trace("Releasing event lock") + + @Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Pass `message` to `add_signals` if and only if it satisfies `is_incident`.""" + if is_incident(message): + await add_signals(message) + + +def setup(bot: Bot) -> None: + """Load the Incidents cog.""" + bot.add_cog(Incidents(bot)) diff --git a/bot/exts/moderation/infraction/__init__.py b/bot/exts/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py new file mode 100644 index 000000000..1310fd3d9 --- /dev/null +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -0,0 +1,463 @@ +import logging +import textwrap +import typing as t +from abc import abstractmethod +from datetime import datetime +from gettext import ngettext + +import dateutil.parser +import discord +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Colours, STAFF_CHANNELS +from bot.exts.moderation.modlog import ModLog +from bot.utils import time +from bot.utils.scheduling import Scheduler +from . import _utils +from ._utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class InfractionScheduler: + """Handles the application, pardoning, and expiration of infractions.""" + + def __init__(self, bot: Bot, supported_infractions: t.Container[str]): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + @property + def mod_log(self) -> ModLog: + """Get the currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: + """Schedule expiration for previous infractions.""" + await self.bot.wait_until_guild_available() + + log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") + + infractions = await self.bot.api_client.get( + 'bot/infractions', + params={'active': 'true'} + ) + for infraction in infractions: + if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: + self.schedule_expiration(infraction) + + async def reapply_infraction( + self, + infraction: _utils.Infraction, + apply_coro: t.Optional[t.Awaitable] + ) -> None: + """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" + # Calculate the time remaining, in seconds, for the mute. + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + delta = (expiry - datetime.utcnow()).total_seconds() + + # Mark as inactive if less than a minute remains. + if delta < 60: + log.info( + "Infraction will be deactivated instead of re-applied " + "because less than 1 minute remains." + ) + await self.deactivate_infraction(infraction) + return + + # Allowing mod log since this is a passive action that should be logged. + await apply_coro + log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") + + async def apply_infraction( + self, + ctx: Context, + infraction: _utils.Infraction, + user: UserSnowflake, + action_coro: t.Optional[t.Awaitable] = None + ) -> None: + """Apply an infraction to the user, log the infraction, and optionally notify the user.""" + infr_type = infraction["type"] + icon = _utils.INFRACTION_ICONS[infr_type][0] + reason = infraction["reason"] + expiry = time.format_infraction_with_duration(infraction["expires_at"]) + id_ = infraction['id'] + + log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") + + # Default values for the confirmation message and mod log. + confirm_msg = ":ok_hand: applied" + + # Specifying an expiry for a note or warning makes no sense. + if infr_type in ("note", "warning"): + expiry_msg = "" + else: + expiry_msg = f" until {expiry}" if expiry else " permanently" + + dm_result = "" + dm_log_text = "" + expiry_log_text = f"\nExpires: {expiry}" if expiry else "" + log_title = "applied" + log_content = None + failed = False + + # DM the user about the infraction if it's not a shadow/hidden infraction. + # This needs to happen before we apply the infraction, as the bot cannot + # send DMs to user that it doesn't share a guild with. If we were to + # apply kick/ban infractions first, this would mean that we'd make it + # impossible for us to deliver a DM. See python-discord/bot#982. + if not infraction["hidden"]: + dm_result = f"{constants.Emojis.failmail} " + dm_log_text = "\nDM: **Failed**" + + # Sometimes user is a discord.Object; make it a proper user. + try: + if not isinstance(user, (discord.Member, discord.User)): + user = await self.bot.fetch_user(user.id) + except discord.HTTPException as e: + log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + else: + # Accordingly display whether the user was successfully notified via DM. + if await _utils.notify_infraction(user, infr_type, expiry, reason, icon): + dm_result = ":incoming_envelope: " + dm_log_text = "\nDM: Sent" + + end_msg = "" + if infraction["actor"] == self.bot.user.id: + log.trace( + f"Infraction #{id_} actor is bot; including the reason in the confirmation message." + ) + if reason: + end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" + elif ctx.channel.id not in STAFF_CHANNELS: + log.trace( + f"Infraction #{id_} context is not in a staff channel; omitting infraction count." + ) + else: + log.trace(f"Fetching total infraction count for {user}.") + + infractions = await self.bot.api_client.get( + "bot/infractions", + params={"user__id": str(user.id)} + ) + total = len(infractions) + end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" + + # Execute the necessary actions to apply the infraction on Discord. + if action_coro: + log.trace(f"Awaiting the infraction #{id_} application action coroutine.") + try: + await action_coro + if expiry: + # Schedule the expiration of the infraction. + self.schedule_expiration(infraction) + except discord.HTTPException as e: + # Accordingly display that applying the infraction failed. + confirm_msg = ":x: failed to apply" + expiry_msg = "" + log_content = ctx.author.mention + log_title = "failed to apply" + + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) + failed = True + + if failed: + log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") + try: + await self.bot.api_client.delete(f"bot/infractions/{id_}") + except ResponseCodeError as e: + confirm_msg += " and failed to delete" + log_title += " and failed to delete" + log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") + infr_message = "" + else: + infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" + + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{id_} confirmation message.") + await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") + + # Send a log message to the mod log. + log.trace(f"Sending apply mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=icon, + colour=Colours.soft_red, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {user.mention} (`{user.id}`) + Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} + Reason: {reason} + """), + content=log_content, + footer=f"ID {infraction['id']}" + ) + + log.info(f"Applied {infr_type} infraction #{id_} to {user}.") + + async def pardon_infraction( + self, + ctx: Context, + infr_type: str, + user: UserSnowflake, + send_msg: bool = True + ) -> None: + """ + Prematurely end an infraction for a user and log the action in the mod log. + + If `send_msg` is True, then a pardoning confirmation message will be sent to + the context channel. Otherwise, no such message will be sent. + """ + log.trace(f"Pardoning {infr_type} infraction for {user}.") + + # Check the current active infraction + log.trace(f"Fetching active {infr_type} infractions for {user}.") + response = await self.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': user.id + } + ) + + if not response: + log.debug(f"No active {infr_type} infraction found for {user}.") + await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") + return + + # Deactivate the infraction and cancel its scheduled expiration task. + log_text = await self.deactivate_infraction(response[0], send_log=False) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["Actor"] = str(ctx.message.author) + log_content = None + id_ = response[0]['id'] + footer = f"ID: {id_}" + + # If multiple active infractions were found, mark them as inactive in the database + # and cancel their expiration tasks. + if len(response) > 1: + log.info( + f"Found more than one active {infr_type} infraction for user {user.id}; " + "deactivating the extra active infractions too." + ) + + footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" + + log_note = f"Found multiple **active** {infr_type} infractions in the database." + if "Note" in log_text: + log_text["Note"] = f" {log_note}" + else: + log_text["Note"] = log_note + + # deactivate_infraction() is not called again because: + # 1. Discord cannot store multiple active bans or assign multiples of the same role + # 2. It would send a pardon DM for each active infraction, which is redundant + for infraction in response[1:]: + id_ = infraction['id'] + try: + # Mark infraction as inactive in the database. + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError: + log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") + # This is simpler and cleaner than trying to concatenate all the errors. + log_text["Failure"] = "See bot's logs for details." + + # Cancel pending expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Accordingly display whether the user was successfully notified via DM. + dm_emoji = "" + if log_text.get("DM") == "Sent": + dm_emoji = ":incoming_envelope: " + elif "DM" in log_text: + dm_emoji = f"{constants.Emojis.failmail} " + + # Accordingly display whether the pardon failed. + if "Failure" in log_text: + confirm_msg = ":x: failed to pardon" + log_title = "pardon failed" + log_content = ctx.author.mention + + log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") + else: + confirm_msg = ":ok_hand: pardoned" + log_title = "pardoned" + + log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") + + # Send a confirmation message to the invoking context. + if send_msg: + log.trace(f"Sending infraction #{id_} pardon confirmation message.") + await ctx.send( + f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " + f"{log_text.get('Failure', '')}" + ) + + # Move reason to end of entry to avoid cutting out some keys + log_text["Reason"] = log_text.pop("Reason") + + # Send a log message to the mod log. + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS[infr_type][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {infr_type}", + thumbnail=user.avatar_url_as(static_format="png"), + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=footer, + content=log_content, + ) + + async def deactivate_infraction( + self, + infraction: _utils.Infraction, + send_log: bool = True + ) -> t.Dict[str, str]: + """ + Deactivate an active infraction and return a dictionary of lines to send in a mod log. + + The infraction is removed from Discord, marked as inactive in the database, and has its + expiration task cancelled. If `send_log` is True, a mod log is sent for the + deactivation of the infraction. + + Infractions of unsupported types will raise a ValueError. + """ + guild = self.bot.get_guild(constants.Guild.id) + mod_role = guild.get_role(constants.Roles.moderators) + user_id = infraction["user"] + actor = infraction["actor"] + type_ = infraction["type"] + id_ = infraction["id"] + inserted_at = infraction["inserted_at"] + expiry = infraction["expires_at"] + + log.info(f"Marking infraction #{id_} as inactive (expired).") + + expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None + created = time.format_infraction_with_duration(inserted_at, expiry) + + log_content = None + log_text = { + "Member": f"<@{user_id}>", + "Actor": str(self.bot.get_user(actor) or actor), + "Reason": infraction["reason"], + "Created": created, + } + + try: + log.trace("Awaiting the pardon action coroutine.") + returned_log = await self._pardon_action(infraction) + + if returned_log is not None: + log_text = {**log_text, **returned_log} # Merge the logs together + else: + raise ValueError( + f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" + ) + except discord.Forbidden: + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") + log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" + log_content = mod_role.mention + except discord.HTTPException as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." + log_content = mod_role.mention + + # Check if the user is currently being watched by Big Brother. + try: + log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") + + active_watch = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "watch", + "user__id": user_id + } + ) + + log_text["Watching"] = "Yes" if active_watch else "No" + except ResponseCodeError: + log.exception(f"Failed to fetch watch status for user {user_id}") + log_text["Watching"] = "Unknown - failed to fetch watch status." + + try: + # Mark infraction as inactive in the database. + log.trace(f"Marking infraction #{id_} as inactive in the database.") + await self.bot.api_client.patch( + f"bot/infractions/{id_}", + json={"active": False} + ) + except ResponseCodeError as e: + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_line = f"API request failed with code {e.status}." + log_content = mod_role.mention + + # Append to an existing failure message if possible + if "Failure" in log_text: + log_text["Failure"] += f" {log_line}" + else: + log_text["Failure"] = log_line + + # Cancel the expiration task. + if infraction["expires_at"] is not None: + self.scheduler.cancel(infraction["id"]) + + # Send a log message to the mod log. + if send_log: + log_title = "expiration failed" if "Failure" in log_text else "expired" + + user = self.bot.get_user(user_id) + avatar = user.avatar_url_as(static_format="png") if user else None + + # Move reason to end so when reason is too long, this is not gonna cut out required items. + log_text["Reason"] = log_text.pop("Reason") + + log.trace(f"Sending deactivation mod log for infraction #{id_}.") + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS[type_][1], + colour=Colours.soft_green, + title=f"Infraction {log_title}: {type_}", + thumbnail=avatar, + text="\n".join(f"{k}: {v}" for k, v in log_text.items()), + footer=f"ID: {id_}", + content=log_content, + ) + + return log_text + + @abstractmethod + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + raise NotImplementedError + + def schedule_expiration(self, infraction: _utils.Infraction) -> None: + """ + Marks an infraction expired after the delay from time of scheduling to time of expiration. + + At the time of expiration, the infraction is marked as inactive on the website and the + expiration task is cancelled. + """ + expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) + self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py new file mode 100644 index 000000000..fb55287b6 --- /dev/null +++ b/bot/exts/moderation/infraction/_utils.py @@ -0,0 +1,201 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext.commands import Context + +from bot.api import ResponseCodeError +from bot.constants import Colours, Icons + +log = logging.getLogger(__name__) + +# apply icon, pardon icon +INFRACTION_ICONS = { + "ban": (Icons.user_ban, Icons.user_unban), + "kick": (Icons.sign_out, None), + "mute": (Icons.user_mute, Icons.user_unmute), + "note": (Icons.user_warn, None), + "superstar": (Icons.superstarify, Icons.unsuperstarify), + "warning": (Icons.user_warn, None), +} +RULES_URL = "https://pythondiscord.com/pages/rules" +APPEALABLE_INFRACTIONS = ("ban", "mute") + +# Type aliases +UserObject = t.Union[discord.Member, discord.User] +UserSnowflake = t.Union[UserObject, discord.Object] +Infraction = t.Dict[str, t.Union[str, int, bool]] + + +async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: + """ + Create a new user in the database. + + Used when an infraction needs to be applied on a user absent in the guild. + """ + log.trace(f"Attempting to add user {user.id} to the database.") + + if not isinstance(user, (discord.Member, discord.User)): + log.debug("The user being added to the DB is not a Member or User object.") + + payload = { + 'discriminator': int(getattr(user, 'discriminator', 0)), + 'id': user.id, + 'in_guild': False, + 'name': getattr(user, 'name', 'Name unknown'), + 'roles': [] + } + + try: + response = await ctx.bot.api_client.post('bot/users', json=payload) + log.info(f"User {user.id} added to the DB.") + return response + except ResponseCodeError as e: + log.error(f"Failed to add user {user.id} to the DB. {e}") + await ctx.send(f":x: The attempt to add the user to the DB failed: status {e.status}") + + +async def post_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + reason: str, + expires_at: datetime = None, + hidden: bool = False, + active: bool = True +) -> t.Optional[dict]: + """Posts an infraction to the API.""" + log.trace(f"Posting {infr_type} infraction for {user} to the API.") + + payload = { + "actor": ctx.message.author.id, + "hidden": hidden, + "reason": reason, + "type": infr_type, + "user": user.id, + "active": active + } + if expires_at: + payload['expires_at'] = expires_at.isoformat() + + # Try to apply the infraction. If it fails because the user doesn't exist, try to add it. + for should_post_user in (True, False): + try: + response = await ctx.bot.api_client.post('bot/infractions', json=payload) + return response + except ResponseCodeError as e: + if e.status == 400 and 'user' in e.response_json: + # Only one attempt to add the user to the database, not two: + if not should_post_user or await post_user(ctx, user) is None: + return + else: + log.exception(f"Unexpected error while adding an infraction for {user}:") + await ctx.send(f":x: There was an error adding the infraction: status {e.status}.") + return + + +async def get_active_infraction( + ctx: Context, + user: UserSnowflake, + infr_type: str, + send_msg: bool = True +) -> t.Optional[dict]: + """ + Retrieves an active infraction of the given type for the user. + + If `send_msg` is True and the user has an active infraction matching the `infr_type` parameter, + then a message for the moderator will be sent to the context channel letting them know. + Otherwise, no message will be sent. + """ + log.trace(f"Checking if {user} has active infractions of type {infr_type}.") + + active_infractions = await ctx.bot.api_client.get( + 'bot/infractions', + params={ + 'active': 'true', + 'type': infr_type, + 'user__id': str(user.id) + } + ) + if active_infractions: + # Checks to see if the moderator should be told there is an active infraction + if send_msg: + log.trace(f"{user} has active infractions of type {infr_type}.") + await ctx.send( + f":x: According to my records, this user already has a {infr_type} infraction. " + f"See infraction **#{active_infractions[0]['id']}**." + ) + return active_infractions[0] + else: + log.trace(f"{user} does not have active infractions of type {infr_type}.") + + +async def notify_infraction( + user: UserObject, + infr_type: str, + expires_at: t.Optional[str] = None, + reason: t.Optional[str] = None, + icon_url: str = Icons.token_removed +) -> bool: + """DM a user about their new infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + + text = textwrap.dedent(f""" + **Type:** {infr_type.capitalize()} + **Expires:** {expires_at or "N/A"} + **Reason:** {reason or "No reason provided."} + """) + + embed = discord.Embed( + description=textwrap.shorten(text, width=2048, placeholder="..."), + colour=Colours.soft_red + ) + + embed.set_author(name="Infraction information", icon_url=icon_url, url=RULES_URL) + embed.title = f"Please review our rules over at {RULES_URL}" + embed.url = RULES_URL + + if infr_type in APPEALABLE_INFRACTIONS: + embed.set_footer( + text="To appeal this infraction, send an e-mail to appeals@pythondiscord.com" + ) + + return await send_private_embed(user, embed) + + +async def notify_pardon( + user: UserObject, + title: str, + content: str, + icon_url: str = Icons.user_verified +) -> bool: + """DM a user about their pardoned infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their pardoned infraction.") + + embed = discord.Embed( + description=content, + colour=Colours.soft_green + ) + + embed.set_author(name=title, icon_url=icon_url) + + return await send_private_embed(user, embed) + + +async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: + """ + A helper method for sending an embed to a user's DMs. + + Returns a boolean indicator of DM success. + """ + try: + await user.send(embed=embed) + return True + except (discord.HTTPException, discord.Forbidden, discord.NotFound): + log.debug( + f"Infraction-related information could not be sent to user {user} ({user.id}). " + "The user either could not be retrieved or probably disabled their DMs." + ) + return False diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py new file mode 100644 index 000000000..cb459b447 --- /dev/null +++ b/bot/exts/moderation/infraction/infractions.py @@ -0,0 +1,375 @@ +import logging +import textwrap +import typing as t + +import discord +from discord import Member +from discord.ext import commands +from discord.ext.commands import Context, command + +from bot import constants +from bot.bot import Bot +from bot.constants import Event +from bot.converters import Expiry, FetchedMember +from bot.decorators import respect_role_hierarchy +from bot.utils.checks import with_role_check +from . import _utils +from ._scheduler import InfractionScheduler +from ._utils import UserSnowflake + +log = logging.getLogger(__name__) + + +class Infractions(InfractionScheduler, commands.Cog): + """Apply and pardon infractions on users for moderation purposes.""" + + category = "Moderation" + category_description = "Server moderation tools." + + def __init__(self, bot: Bot): + super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) + + self.category = "Moderation" + self._muted_role = discord.Object(constants.Roles.muted) + + @commands.Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Reapply active mute infractions for returning members.""" + active_mutes = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "mute", + "user__id": member.id + } + ) + + if active_mutes: + reason = f"Re-applying active mute: {active_mutes[0]['id']}" + action = member.add_roles(self._muted_role, reason=reason) + + await self.reapply_infraction(active_mutes[0], action) + + # region: Permanent infractions + + @command() + async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Warn a user for the given reason.""" + infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False) + if infraction is None: + return + + await self.apply_infraction(ctx, infraction, user) + + @command() + async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Kick a user for the given reason.""" + await self.apply_kick(ctx, user, reason) + + @command() + async def ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Permanently ban a user for the given reason and stop watching them with Big Brother.""" + await self.apply_ban(ctx, user, reason) + + # endregion + # region: Temporary infractions + + @command(aliases=["mute"]) + async def tempmute(self, ctx: Context, user: Member, duration: Expiry, *, reason: t.Optional[str] = None) -> None: + """ + Temporarily mute a user for the given reason and duration. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_mute(ctx, user, reason, expires_at=duration) + + @command() + async def tempban( + self, + ctx: Context, + user: FetchedMember, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily ban a user for the given reason and duration. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_ban(ctx, user, reason, expires_at=duration) + + # endregion + # region: Permanent shadow infractions + + @command(hidden=True) + async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Create a private note for a user with the given reason without notifying the user.""" + infraction = await _utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False) + if infraction is None: + return + + await self.apply_infraction(ctx, infraction, user) + + @command(hidden=True, aliases=['shadowkick', 'skick']) + async def shadow_kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + """Kick a user for the given reason without notifying the user.""" + await self.apply_kick(ctx, user, reason, hidden=True) + + @command(hidden=True, aliases=['shadowban', 'sban']) + async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: + """Permanently ban a user for the given reason without notifying the user.""" + await self.apply_ban(ctx, user, reason, hidden=True) + + # endregion + # region: Temporary shadow infractions + + @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) + async def shadow_tempmute( + self, ctx: Context, + user: Member, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily mute a user for the given reason and duration without notifying the user. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True) + + @command(hidden=True, aliases=["shadowtempban, stempban"]) + async def shadow_tempban( + self, + ctx: Context, + user: FetchedMember, + duration: Expiry, + *, + reason: t.Optional[str] = None + ) -> None: + """ + Temporarily ban a user for the given reason and duration without notifying the user. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + """ + await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) + + # endregion + # region: Remove infractions (un- commands) + + @command() + async def unmute(self, ctx: Context, user: FetchedMember) -> None: + """Prematurely end the active mute infraction for the user.""" + await self.pardon_infraction(ctx, "mute", user) + + @command() + async def unban(self, ctx: Context, user: FetchedMember) -> None: + """Prematurely end the active ban infraction for the user.""" + await self.pardon_infraction(ctx, "ban", user) + + # endregion + # region: Base apply functions + + async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: + """Apply a mute infraction with kwargs passed to `post_infraction`.""" + if await _utils.get_active_infraction(ctx, user, "mute"): + return + + infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_update, user.id) + + async def action() -> None: + await user.add_roles(self._muted_role, reason=reason) + + log.trace(f"Attempting to kick {user} from voice because they've been muted.") + await user.move_to(None, reason=reason) + + await self.apply_infraction(ctx, infraction, user, action()) + + @respect_role_hierarchy() + async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: + """Apply a kick infraction with kwargs passed to `post_infraction`.""" + infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_remove, user.id) + + if reason: + reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = user.kick(reason=reason) + await self.apply_infraction(ctx, infraction, user, action) + + @respect_role_hierarchy() + async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: + """ + Apply a ban infraction with kwargs passed to `post_infraction`. + + Will also remove the banned user from the Big Brother watch list if applicable. + """ + # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active + is_temporary = kwargs.get("expires_at") is not None + active_infraction = await _utils.get_active_infraction(ctx, user, "ban", is_temporary) + + if active_infraction: + if is_temporary: + log.trace("Tempban ignored as it cannot overwrite an active ban.") + return + + if active_infraction.get('expires_at') is None: + log.trace("Permaban already exists, notify.") + await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).") + return + + log.trace("Old tempban is being replaced by new permaban.") + await self.pardon_infraction(ctx, "ban", user, is_temporary) + + infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) + if infraction is None: + return + + self.mod_log.ignore(Event.member_remove, user.id) + + if reason: + reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = ctx.guild.ban(user, reason=reason, delete_message_days=0) + await self.apply_infraction(ctx, infraction, user, action) + + if infraction.get('expires_at') is not None: + log.trace(f"Ban isn't permanent; user {user} won't be unwatched by Big Brother.") + return + + bb_cog = self.bot.get_cog("Big Brother") + if not bb_cog: + log.error(f"Big Brother cog not loaded; perma-banned user {user} won't be unwatched.") + return + + log.trace(f"Big Brother cog loaded; attempting to unwatch perma-banned user {user}.") + + bb_reason = "User has been permanently banned from the server. Automatically removed." + await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) + + # endregion + # region: Base pardon functions + + async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + """Remove a user's muted role, DM them a notification, and return a log dict.""" + user = guild.get_member(user_id) + log_text = {} + + if user: + # Remove the muted role. + self.mod_log.ignore(Event.member_update, user.id) + await user.remove_roles(self._muted_role, reason=reason) + + # DM the user about the expiration. + notified = await _utils.notify_pardon( + user=user, + title="You have been unmuted", + content="You may now send messages in the server.", + icon_url=_utils.INFRACTION_ICONS["mute"][1] + ) + + log_text["Member"] = f"{user.mention}(`{user.id}`)" + log_text["DM"] = "Sent" if notified else "**Failed**" + else: + log.info(f"Failed to unmute user {user_id}: user not found") + log_text["Failure"] = "User was not found in the guild." + + return log_text + + async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + """Remove a user's ban on the Discord guild and return a log dict.""" + user = discord.Object(user_id) + log_text = {} + + self.mod_log.ignore(Event.member_unban, user_id) + + try: + await guild.unban(user, reason=reason) + except discord.NotFound: + log.info(f"Failed to unban user {user_id}: no active ban found on Discord") + log_text["Note"] = "No active ban found on Discord." + + return log_text + + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """ + Execute deactivation steps specific to the infraction's type and return a log dict. + + If an infraction type is unsupported, return None instead. + """ + guild = self.bot.get_guild(constants.Guild.id) + user_id = infraction["user"] + reason = f"Infraction #{infraction['id']} expired or was pardoned." + + if infraction["type"] == "mute": + return await self.pardon_mute(user_id, guild, reason) + elif infraction["type"] == "ban": + return await self.pardon_ban(user_id, guild, reason) + + # endregion + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Send a notification to the invoking context on a Union failure.""" + if isinstance(error, commands.BadUnionArgument): + if discord.User in error.converters or discord.Member in error.converters: + await ctx.send(str(error.errors[0])) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Infractions cog.""" + bot.add_cog(Infractions(bot)) diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py new file mode 100644 index 000000000..eea6ac9ea --- /dev/null +++ b/bot/exts/moderation/infraction/management.py @@ -0,0 +1,310 @@ +import logging +import textwrap +import typing as t +from datetime import datetime + +import discord +from discord.ext import commands +from discord.ext.commands import Context + +from bot import constants +from bot.bot import Bot +from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user +from bot.exts.moderation.modlog import ModLog +from bot.pagination import LinePaginator +from bot.utils import time +from bot.utils.checks import in_whitelist_check, with_role_check +from . import _utils +from .infractions import Infractions + +log = logging.getLogger(__name__) + + +class ModManagement(commands.Cog): + """Management of infractions.""" + + category = "Moderation" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @property + def infractions_cog(self) -> Infractions: + """Get currently loaded Infractions cog instance.""" + return self.bot.get_cog("Infractions") + + # region: Edit infraction commands + + @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) + async def infraction_group(self, ctx: Context) -> None: + """Infraction manipulation commands.""" + await ctx.send_help(ctx.command) + + @infraction_group.command(name='edit') + async def infraction_edit( + self, + ctx: Context, + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821 + duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821 + *, + reason: str = None + ) -> None: + """ + Edit the duration and/or the reason of an infraction. + + Durations are relative to the time of updating and should be appended with a unit of time. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. + + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. + """ + if duration is None and reason is None: + # Unlike UserInputError, the error handler will show a specified message for BadArgument + raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") + + # Retrieve the previous infraction for its information. + if isinstance(infraction_id, str): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get("bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + ":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") + + request_data = {} + confirm_messages = [] + log_text = "" + + if duration is not None and not old_infraction['active']: + if reason is None: + await ctx.send(":x: Cannot edit the expiration of an expired infraction.") + return + confirm_messages.append("expiry unchanged (infraction already expired)") + elif isinstance(duration, str): + request_data['expires_at'] = None + confirm_messages.append("marked as permanent") + elif duration is not None: + request_data['expires_at'] = duration.isoformat() + expiry = time.format_infraction_with_duration(request_data['expires_at']) + confirm_messages.append(f"set to expire on {expiry}") + else: + confirm_messages.append("expiry unchanged") + + if reason: + request_data['reason'] = reason + confirm_messages.append("set a new reason") + log_text += f""" + Previous reason: {old_infraction['reason']} + New reason: {reason} + """.rstrip() + else: + confirm_messages.append("reason unchanged") + + # Update the infraction + new_infraction = await self.bot.api_client.patch( + f'bot/infractions/{infraction_id}', + json=request_data, + ) + + # Re-schedule infraction if the expiration has been updated + if 'expires_at' in request_data: + # A scheduled task should only exist if the old infraction wasn't permanent + if old_infraction['expires_at']: + self.infractions_cog.scheduler.cancel(new_infraction['id']) + + # If the infraction was not marked as permanent, schedule a new expiration task + if request_data['expires_at']: + self.infractions_cog.schedule_expiration(new_infraction) + + log_text += f""" + Previous expiry: {old_infraction['expires_at'] or "Permanent"} + New expiry: {new_infraction['expires_at'] or "Permanent"} + """.rstrip() + + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") + + # Get information about the infraction's user + user_id = new_infraction['user'] + user = ctx.guild.get_member(user_id) + + if user: + user_text = f"{user.mention} (`{user.id}`)" + thumbnail = user.avatar_url_as(static_format="png") + else: + user_text = f"`{user_id}`" + thumbnail = None + + # The infraction's actor + actor_id = new_infraction['actor'] + actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`" + + await self.mod_log.send_log_message( + icon_url=constants.Icons.pencil, + colour=discord.Colour.blurple(), + title="Infraction edited", + thumbnail=thumbnail, + text=textwrap.dedent(f""" + Member: {user_text} + Actor: {actor} + Edited by: {ctx.message.author}{log_text} + """) + ) + + # endregion + # region: Search infractions + + @infraction_group.group(name="search", invoke_without_command=True) + async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None: + """Searches for infractions in the database.""" + if isinstance(query, discord.User): + await ctx.invoke(self.search_user, query) + else: + await ctx.invoke(self.search_reason, query) + + @infraction_search_group.command(name="user", aliases=("member", "id")) + async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: + """Search for infractions by member.""" + infraction_list = await self.bot.api_client.get( + 'bot/infractions', + params={'user__id': str(user.id)} + ) + embed = discord.Embed( + title=f"Infractions for {user} ({len(infraction_list)} total)", + colour=discord.Colour.orange() + ) + await self.send_infraction_list(ctx, embed, infraction_list) + + @infraction_search_group.command(name="reason", aliases=("match", "regex", "re")) + async def search_reason(self, ctx: Context, reason: str) -> None: + """Search for infractions by their reason. Use Re2 for matching.""" + infraction_list = await self.bot.api_client.get( + 'bot/infractions', + params={'search': reason} + ) + embed = discord.Embed( + title=f"Infractions matching `{reason}` ({len(infraction_list)} total)", + colour=discord.Colour.orange() + ) + await self.send_infraction_list(ctx, embed, infraction_list) + + # endregion + # region: Utility functions + + async def send_infraction_list( + self, + ctx: Context, + embed: discord.Embed, + infractions: t.Iterable[_utils.Infraction] + ) -> None: + """Send a paginated embed of infractions for the specified user.""" + if not infractions: + await ctx.send(":warning: No infractions could be found for that query.") + return + + lines = tuple( + self.infraction_to_string(infraction) + for infraction in infractions + ) + + await LinePaginator.paginate( + lines, + ctx=ctx, + embed=embed, + empty=True, + max_lines=3, + max_size=1000 + ) + + def infraction_to_string(self, infraction: _utils.Infraction) -> str: + """Convert the infraction object to a string representation.""" + actor_id = infraction["actor"] + guild = self.bot.get_guild(constants.Guild.id) + actor = guild.get_member(actor_id) + active = infraction["active"] + user_id = infraction["user"] + hidden = infraction["hidden"] + created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" + else: + remaining = "Inactive" + + if infraction["expires_at"] is None: + expires = "*Permanent*" + else: + date_from = datetime.strptime(created, time.INFRACTION_FORMAT) + expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) + + lines = textwrap.dedent(f""" + {"**===============**" if active else "==============="} + Status: {"__**Active**__" if active else "Inactive"} + User: {self.bot.get_user(user_id)} (`{user_id}`) + Type: **{infraction["type"]}** + Shadow: {hidden} + Created: {created} + Expires: {expires} + Remaining: {remaining} + Actor: {actor.mention if actor else actor_id} + ID: `{infraction["id"]}` + Reason: {infraction["reason"] or "*None*"} + {"**===============**" if active else "==============="} + """) + + return lines.strip() + + # endregion + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators inside moderator channels to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_whitelist_check( + ctx, + channels=constants.MODERATION_CHANNELS, + categories=[constants.Categories.modmail], + redirect=None, + fail_silently=True, + ) + ] + return all(checks) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Send a notification to the invoking context on a Union failure.""" + if isinstance(error, commands.BadUnionArgument): + if discord.User in error.converters: + await ctx.send(str(error.errors[0])) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the ModManagement cog.""" + bot.add_cog(ModManagement(bot)) diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py new file mode 100644 index 000000000..7dc5b4691 --- /dev/null +++ b/bot/exts/moderation/infraction/superstarify.py @@ -0,0 +1,244 @@ +import json +import logging +import random +import textwrap +import typing as t +from pathlib import Path + +from discord import Colour, Embed, Member +from discord.ext.commands import Cog, Context, command + +from bot import constants +from bot.bot import Bot +from bot.converters import Expiry +from bot.utils.checks import with_role_check +from bot.utils.time import format_infraction +from . import _utils +from ._scheduler import InfractionScheduler + +log = logging.getLogger(__name__) +NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" + +with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file: + STAR_NAMES = json.load(stars_file) + + +class Superstarify(InfractionScheduler, Cog): + """A set of commands to moderate terrible nicknames.""" + + def __init__(self, bot: Bot): + super().__init__(bot, supported_infractions={"superstar"}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Revert nickname edits if the user has an active superstarify infraction.""" + if before.display_name == after.display_name: + return # User didn't change their nickname. Abort! + + log.trace( + f"{before} ({before.display_name}) is trying to change their nickname to " + f"{after.display_name}. Checking if the user is in superstar-prison..." + ) + + active_superstarifies = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "superstar", + "user__id": str(before.id) + } + ) + + if not active_superstarifies: + log.trace(f"{before} has no active superstar infractions.") + return + + infraction = active_superstarifies[0] + forced_nick = self.get_nick(infraction["id"], before.id) + if after.display_name == forced_nick: + return # Nick change was triggered by this event. Ignore. + + log.info( + f"{after.display_name} ({after.id}) tried to escape superstar prison. " + f"Changing the nick back to {before.display_name}." + ) + await after.edit( + nick=forced_nick, + reason=f"Superstarified member tried to escape the prison: {infraction['id']}" + ) + + notified = await _utils.notify_infraction( + user=after, + infr_type="Superstarify", + expires_at=format_infraction(infraction["expires_at"]), + reason=( + "You have tried to change your nickname on the **Python Discord** server " + f"from **{before.display_name}** to **{after.display_name}**, but as you " + "are currently in superstar-prison, you do not have permission to do so." + ), + icon_url=_utils.INFRACTION_ICONS["superstar"][0] + ) + + if not notified: + log.info("Failed to DM user about why they cannot change their nickname.") + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """Reapply active superstar infractions for returning members.""" + active_superstarifies = await self.bot.api_client.get( + "bot/infractions", + params={ + "active": "true", + "type": "superstar", + "user__id": member.id + } + ) + + if active_superstarifies: + infraction = active_superstarifies[0] + action = member.edit( + nick=self.get_nick(infraction["id"], member.id), + reason=f"Superstarified member tried to escape the prison: {infraction['id']}" + ) + + await self.reapply_infraction(infraction, action) + + @command(name="superstarify", aliases=("force_nick", "star")) + async def superstarify( + self, + ctx: Context, + member: Member, + duration: Expiry, + *, + reason: str = None, + ) -> None: + """ + Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. + + A unit of time should be appended to the duration. + Units (∗case-sensitive): + \u2003`y` - years + \u2003`m` - months∗ + \u2003`w` - weeks + \u2003`d` - days + \u2003`h` - hours + \u2003`M` - minutes∗ + \u2003`s` - seconds + + Alternatively, an ISO 8601 timestamp can be provided for the duration. + + An optional reason can be provided. If no reason is given, the original name will be shown + in a generated reason. + """ + if await _utils.get_active_infraction(ctx, member, "superstar"): + return + + # Post the infraction to the API + reason = reason or f"old nick: {member.display_name}" + infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True) + id_ = infraction["id"] + + old_nick = member.display_name + forced_nick = self.get_nick(id_, member.id) + expiry_str = format_infraction(infraction["expires_at"]) + + # Apply the infraction and schedule the expiration task. + log.debug(f"Changing nickname of {member} to {forced_nick}.") + self.mod_log.ignore(constants.Event.member_update, member.id) + await member.edit(nick=forced_nick, reason=reason) + self.schedule_expiration(infraction) + + # Send a DM to the user to notify them of their new infraction. + await _utils.notify_infraction( + user=member, + infr_type="Superstarify", + expires_at=expiry_str, + icon_url=_utils.INFRACTION_ICONS["superstar"][0], + reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})." + ) + + # Send an embed with the infraction information to the invoking context. + log.trace(f"Sending superstar #{id_} embed.") + embed = Embed( + title="Congratulations!", + colour=constants.Colours.soft_orange, + description=( + f"Your previous nickname, **{old_nick}**, " + f"was so bad that we have decided to change it. " + f"Your new nickname will be **{forced_nick}**.\n\n" + f"You will be unable to change your nickname until **{expiry_str}**.\n\n" + "If you're confused by this, please read our " + f"[official nickname policy]({NICKNAME_POLICY_URL})." + ) + ) + await ctx.send(embed=embed) + + # Log to the mod log channel. + log.trace(f"Sending apply mod log for superstar #{id_}.") + await self.mod_log.send_log_message( + icon_url=_utils.INFRACTION_ICONS["superstar"][0], + colour=Colour.gold(), + title="Member achieved superstardom", + thumbnail=member.avatar_url_as(static_format="png"), + text=textwrap.dedent(f""" + Member: {member.mention} (`{member.id}`) + Actor: {ctx.message.author} + Expires: {expiry_str} + Old nickname: `{old_nick}` + New nickname: `{forced_nick}` + Reason: {reason} + """), + footer=f"ID {id_}" + ) + + @command(name="unsuperstarify", aliases=("release_nick", "unstar")) + async def unsuperstarify(self, ctx: Context, member: Member) -> None: + """Remove the superstarify infraction and allow the user to change their nickname.""" + await self.pardon_infraction(ctx, "superstar", member) + + async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + """Pardon a superstar infraction and return a log dict.""" + if infraction["type"] != "superstar": + return + + guild = self.bot.get_guild(constants.Guild.id) + user = guild.get_member(infraction["user"]) + + # Don't bother sending a notification if the user left the guild. + if not user: + log.debug( + "User left the guild and therefore won't be notified about superstar " + f"{infraction['id']} pardon." + ) + return {} + + # DM the user about the expiration. + notified = await _utils.notify_pardon( + user=user, + title="You are no longer superstarified", + content="You may now change your nickname on the server.", + icon_url=_utils.INFRACTION_ICONS["superstar"][1] + ) + + return { + "Member": f"{user.mention}(`{user.id}`)", + "DM": "Sent" if notified else "**Failed**" + } + + @staticmethod + def get_nick(infraction_id: int, member_id: int) -> str: + """Randomly select a nickname from the Superstarify nickname list.""" + log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") + + rng = random.Random(str(infraction_id) + str(member_id)) + return rng.choice(STAR_NAMES) + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Superstarify cog.""" + bot.add_cog(Superstarify(bot)) diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py new file mode 100644 index 000000000..c86f04b9d --- /dev/null +++ b/bot/exts/moderation/modlog.py @@ -0,0 +1,837 @@ +import asyncio +import difflib +import itertools +import logging +import typing as t +from datetime import datetime +from itertools import zip_longest + +import discord +from dateutil.relativedelta import relativedelta +from deepdiff import DeepDiff +from discord import Colour +from discord.abc import GuildChannel +from discord.ext.commands import Cog, Context +from discord.utils import escape_markdown + +from bot.bot import Bot +from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +GUILD_CHANNEL = t.Union[discord.CategoryChannel, discord.TextChannel, discord.VoiceChannel] + +CHANNEL_CHANGES_UNSUPPORTED = ("permissions",) +CHANNEL_CHANGES_SUPPRESSED = ("_overwrites", "position") +ROLE_CHANGES_UNSUPPORTED = ("colour", "permissions") + +VOICE_STATE_ATTRIBUTES = { + "channel.name": "Channel", + "self_stream": "Streaming", + "self_video": "Broadcasting", +} + + +class ModLog(Cog, name="ModLog"): + """Logging for server events and staff actions.""" + + def __init__(self, bot: Bot): + self.bot = bot + self._ignored = {event: [] for event in Event} + + self._cached_deletes = [] + self._cached_edits = [] + + async def upload_log( + self, + messages: t.Iterable[discord.Message], + actor_id: int, + attachments: t.Iterable[t.List[str]] = None + ) -> str: + """Upload message logs to the database and return a URL to a page for viewing the logs.""" + if attachments is None: + attachments = [] + + response = await self.bot.api_client.post( + 'bot/deleted-messages', + json={ + 'actor': actor_id, + 'creation': datetime.utcnow().isoformat(), + 'deletedmessage_set': [ + { + 'id': message.id, + 'author': message.author.id, + 'channel_id': message.channel.id, + 'content': message.content, + 'embeds': [embed.to_dict() for embed in message.embeds], + 'attachments': attachment, + } + for message, attachment in zip_longest(messages, attachments, fillvalue=[]) + ] + } + ) + + return f"{URLs.site_logs_view}/{response['id']}" + + def ignore(self, event: Event, *items: int) -> None: + """Add event to ignored events to suppress log emission.""" + for item in items: + if item not in self._ignored[event]: + self._ignored[event].append(item) + + async def send_log_message( + self, + icon_url: t.Optional[str], + colour: t.Union[discord.Colour, int], + title: t.Optional[str], + text: str, + thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, + channel_id: int = Channels.mod_log, + ping_everyone: bool = False, + files: t.Optional[t.List[discord.File]] = None, + content: t.Optional[str] = None, + additional_embeds: t.Optional[t.List[discord.Embed]] = None, + additional_embeds_msg: t.Optional[str] = None, + timestamp_override: t.Optional[datetime] = None, + footer: t.Optional[str] = None, + ) -> Context: + """Generate log embed and send to logging channel.""" + # Truncate string directly here to avoid removing newlines + embed = discord.Embed( + description=text[:2045] + "..." if len(text) > 2048 else text + ) + + if title and icon_url: + embed.set_author(name=title, icon_url=icon_url) + + embed.colour = colour + embed.timestamp = timestamp_override or datetime.utcnow() + + if footer: + embed.set_footer(text=footer) + + if thumbnail: + embed.set_thumbnail(url=thumbnail) + + if ping_everyone: + if content: + content = f"@everyone\n{content}" + else: + content = "@everyone" + + channel = self.bot.get_channel(channel_id) + log_message = await channel.send( + content=content, + embed=embed, + files=files, + allowed_mentions=discord.AllowedMentions(everyone=True) + ) + + if additional_embeds: + if additional_embeds_msg: + await channel.send(additional_embeds_msg) + for additional_embed in additional_embeds: + await channel.send(embed=additional_embed) + + return await self.bot.get_context(log_message) # Optionally return for use with antispam + + @Cog.listener() + async def on_guild_channel_create(self, channel: GUILD_CHANNEL) -> None: + """Log channel create event to mod log.""" + if channel.guild.id != GuildConstant.id: + return + + if isinstance(channel, discord.CategoryChannel): + title = "Category created" + message = f"{channel.name} (`{channel.id}`)" + elif isinstance(channel, discord.VoiceChannel): + title = "Voice channel created" + + if channel.category: + message = f"{channel.category}/{channel.name} (`{channel.id}`)" + else: + message = f"{channel.name} (`{channel.id}`)" + else: + title = "Text channel created" + + if channel.category: + message = f"{channel.category}/{channel.name} (`{channel.id}`)" + else: + message = f"{channel.name} (`{channel.id}`)" + + await self.send_log_message(Icons.hash_green, Colours.soft_green, title, message) + + @Cog.listener() + async def on_guild_channel_delete(self, channel: GUILD_CHANNEL) -> None: + """Log channel delete event to mod log.""" + if channel.guild.id != GuildConstant.id: + return + + if isinstance(channel, discord.CategoryChannel): + title = "Category deleted" + elif isinstance(channel, discord.VoiceChannel): + title = "Voice channel deleted" + else: + title = "Text channel deleted" + + if channel.category and not isinstance(channel, discord.CategoryChannel): + message = f"{channel.category}/{channel.name} (`{channel.id}`)" + else: + message = f"{channel.name} (`{channel.id}`)" + + await self.send_log_message( + Icons.hash_red, Colours.soft_red, + title, message + ) + + @Cog.listener() + async def on_guild_channel_update(self, before: GUILD_CHANNEL, after: GuildChannel) -> None: + """Log channel update event to mod log.""" + if before.guild.id != GuildConstant.id: + return + + if before.id in self._ignored[Event.guild_channel_update]: + self._ignored[Event.guild_channel_update].remove(before.id) + return + + # Two channel updates are sent for a single edit: 1 for topic and 1 for category change. + # TODO: remove once support is added for ignoring multiple occurrences for the same channel. + help_categories = (Categories.help_available, Categories.help_dormant, Categories.help_in_use) + if after.category and after.category.id in help_categories: + return + + diff = DeepDiff(before, after) + changes = [] + done = [] + + diff_values = diff.get("values_changed", {}) + diff_values.update(diff.get("type_changes", {})) + + for key, value in diff_values.items(): + if not key: # Not sure why, but it happens + continue + + key = key[5:] # Remove "root." prefix + + if "[" in key: + key = key.split("[", 1)[0] + + if "." in key: + key = key.split(".", 1)[0] + + if key in done or key in CHANNEL_CHANGES_SUPPRESSED: + continue + + if key in CHANNEL_CHANGES_UNSUPPORTED: + changes.append(f"**{key.title()}** updated") + else: + new = value["new_value"] + old = value["old_value"] + + # Discord does not treat consecutive backticks ("``") as an empty inline code block, so the markdown + # formatting is broken when `new` and/or `old` are empty values. "None" is used for these cases so + # formatting is preserved. + changes.append(f"**{key.title()}:** `{old or 'None'}` **→** `{new or 'None'}`") + + done.append(key) + + if not changes: + return + + message = "" + + for item in sorted(changes): + message += f"{Emojis.bullet} {item}\n" + + if after.category: + message = f"**{after.category}/#{after.name} (`{after.id}`)**\n{message}" + else: + message = f"**#{after.name}** (`{after.id}`)\n{message}" + + await self.send_log_message( + Icons.hash_blurple, Colour.blurple(), + "Channel updated", message + ) + + @Cog.listener() + async def on_guild_role_create(self, role: discord.Role) -> None: + """Log role create event to mod log.""" + if role.guild.id != GuildConstant.id: + return + + await self.send_log_message( + Icons.crown_green, Colours.soft_green, + "Role created", f"`{role.id}`" + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: discord.Role) -> None: + """Log role delete event to mod log.""" + if role.guild.id != GuildConstant.id: + return + + await self.send_log_message( + Icons.crown_red, Colours.soft_red, + "Role removed", f"{role.name} (`{role.id}`)" + ) + + @Cog.listener() + async def on_guild_role_update(self, before: discord.Role, after: discord.Role) -> None: + """Log role update event to mod log.""" + if before.guild.id != GuildConstant.id: + return + + diff = DeepDiff(before, after) + changes = [] + done = [] + + diff_values = diff.get("values_changed", {}) + diff_values.update(diff.get("type_changes", {})) + + for key, value in diff_values.items(): + if not key: # Not sure why, but it happens + continue + + key = key[5:] # Remove "root." prefix + + if "[" in key: + key = key.split("[", 1)[0] + + if "." in key: + key = key.split(".", 1)[0] + + if key in done or key == "color": + continue + + if key in ROLE_CHANGES_UNSUPPORTED: + changes.append(f"**{key.title()}** updated") + else: + new = value["new_value"] + old = value["old_value"] + + changes.append(f"**{key.title()}:** `{old}` **→** `{new}`") + + done.append(key) + + if not changes: + return + + message = "" + + for item in sorted(changes): + message += f"{Emojis.bullet} {item}\n" + + message = f"**{after.name}** (`{after.id}`)\n{message}" + + await self.send_log_message( + Icons.crown_blurple, Colour.blurple(), + "Role updated", message + ) + + @Cog.listener() + async def on_guild_update(self, before: discord.Guild, after: discord.Guild) -> None: + """Log guild update event to mod log.""" + if before.id != GuildConstant.id: + return + + diff = DeepDiff(before, after) + changes = [] + done = [] + + diff_values = diff.get("values_changed", {}) + diff_values.update(diff.get("type_changes", {})) + + for key, value in diff_values.items(): + if not key: # Not sure why, but it happens + continue + + key = key[5:] # Remove "root." prefix + + if "[" in key: + key = key.split("[", 1)[0] + + if "." in key: + key = key.split(".", 1)[0] + + if key in done: + continue + + new = value["new_value"] + old = value["old_value"] + + changes.append(f"**{key.title()}:** `{old}` **→** `{new}`") + + done.append(key) + + if not changes: + return + + message = "" + + for item in sorted(changes): + message += f"{Emojis.bullet} {item}\n" + + message = f"**{after.name}** (`{after.id}`)\n{message}" + + await self.send_log_message( + Icons.guild_update, Colour.blurple(), + "Guild updated", message, + thumbnail=after.icon_url_as(format="png") + ) + + @Cog.listener() + async def on_member_ban(self, guild: discord.Guild, member: discord.Member) -> None: + """Log ban event to user log.""" + if guild.id != GuildConstant.id: + return + + if member.id in self._ignored[Event.member_ban]: + self._ignored[Event.member_ban].remove(member.id) + return + + await self.send_log_message( + Icons.user_ban, Colours.soft_red, + "User banned", f"{member} (`{member.id}`)", + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.user_log + ) + + @Cog.listener() + async def on_member_join(self, member: discord.Member) -> None: + """Log member join event to user log.""" + if member.guild.id != GuildConstant.id: + return + + member_str = escape_markdown(str(member)) + message = f"{member_str} (`{member.id}`)" + now = datetime.utcnow() + difference = abs(relativedelta(now, member.created_at)) + + message += "\n\n**Account age:** " + humanize_delta(difference) + + if difference.days < 1 and difference.months < 1 and difference.years < 1: # New user account! + message = f"{Emojis.new} {message}" + + await self.send_log_message( + Icons.sign_in, Colours.soft_green, + "User joined", message, + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.user_log + ) + + @Cog.listener() + async def on_member_remove(self, member: discord.Member) -> None: + """Log member leave event to user log.""" + if member.guild.id != GuildConstant.id: + return + + if member.id in self._ignored[Event.member_remove]: + self._ignored[Event.member_remove].remove(member.id) + return + + member_str = escape_markdown(str(member)) + await self.send_log_message( + Icons.sign_out, Colours.soft_red, + "User left", f"{member_str} (`{member.id}`)", + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.user_log + ) + + @Cog.listener() + async def on_member_unban(self, guild: discord.Guild, member: discord.User) -> None: + """Log member unban event to mod log.""" + if guild.id != GuildConstant.id: + return + + if member.id in self._ignored[Event.member_unban]: + self._ignored[Event.member_unban].remove(member.id) + return + + member_str = escape_markdown(str(member)) + await self.send_log_message( + Icons.user_unban, Colour.blurple(), + "User unbanned", f"{member_str} (`{member.id}`)", + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.mod_log + ) + + @staticmethod + def get_role_diff(before: t.List[discord.Role], after: t.List[discord.Role]) -> t.List[str]: + """Return a list of strings describing the roles added and removed.""" + changes = [] + before_roles = set(before) + after_roles = set(after) + + for role in (before_roles - after_roles): + changes.append(f"**Role removed:** {role.name} (`{role.id}`)") + + for role in (after_roles - before_roles): + changes.append(f"**Role added:** {role.name} (`{role.id}`)") + + return changes + + @Cog.listener() + async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: + """Log member update event to user log.""" + if before.guild.id != GuildConstant.id: + return + + if before.id in self._ignored[Event.member_update]: + self._ignored[Event.member_update].remove(before.id) + return + + changes = self.get_role_diff(before.roles, after.roles) + + # The regex is a simple way to exclude all sequence and mapping types. + diff = DeepDiff(before, after, exclude_regex_paths=r".*\[.*") + + # A type change seems to always take precedent over a value change. Furthermore, it will + # include the value change along with the type change anyway. Therefore, it's OK to + # "overwrite" values_changed; in practice there will never even be anything to overwrite. + diff_values = {**diff.get("values_changed", {}), **diff.get("type_changes", {})} + + for attr, value in diff_values.items(): + if not attr: # Not sure why, but it happens. + continue + + attr = attr[5:] # Remove "root." prefix. + attr = attr.replace("_", " ").replace(".", " ").capitalize() + + new = value.get("new_value") + old = value.get("old_value") + + changes.append(f"**{attr}:** `{old}` **→** `{new}`") + + if not changes: + return + + message = "" + + for item in sorted(changes): + message += f"{Emojis.bullet} {item}\n" + + member_str = escape_markdown(str(after)) + message = f"**{member_str}** (`{after.id}`)\n{message}" + + await self.send_log_message( + icon_url=Icons.user_update, + colour=Colour.blurple(), + title="Member updated", + text=message, + thumbnail=after.avatar_url_as(static_format="png"), + channel_id=Channels.user_log + ) + + @Cog.listener() + async def on_message_delete(self, message: discord.Message) -> None: + """Log message delete event to message change log.""" + channel = message.channel + author = message.author + + # Ignore DMs. + if not message.guild: + return + + if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist: + return + + self._cached_deletes.append(message.id) + + if message.id in self._ignored[Event.message_delete]: + self._ignored[Event.message_delete].remove(message.id) + return + + if author.bot: + return + + author_str = escape_markdown(str(author)) + if channel.category: + response = ( + f"**Author:** {author_str} (`{author.id}`)\n" + f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + ) + else: + response = ( + f"**Author:** {author_str} (`{author.id}`)\n" + f"**Channel:** #{channel.name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + ) + + if message.attachments: + # Prepend the message metadata with the number of attachments + response = f"**Attachments:** {len(message.attachments)}\n" + response + + # Shorten the message content if necessary + content = message.clean_content + remaining_chars = 2040 - len(response) + + if len(content) > remaining_chars: + botlog_url = await self.upload_log(messages=[message], actor_id=message.author.id) + ending = f"\n\nMessage truncated, [full message here]({botlog_url})." + truncation_point = remaining_chars - len(ending) + content = f"{content[:truncation_point]}...{ending}" + + response += f"{content}" + + await self.send_log_message( + Icons.message_delete, Colours.soft_red, + "Message deleted", + response, + channel_id=Channels.message_log + ) + + @Cog.listener() + async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: + """Log raw message delete event to message change log.""" + if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.modlog_blacklist: + return + + await asyncio.sleep(1) # Wait here in case the normal event was fired + + if event.message_id in self._cached_deletes: + # It was in the cache and the normal event was fired, so we can just ignore it + self._cached_deletes.remove(event.message_id) + return + + if event.message_id in self._ignored[Event.message_delete]: + self._ignored[Event.message_delete].remove(event.message_id) + return + + channel = self.bot.get_channel(event.channel_id) + + if channel.category: + response = ( + f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" + f"**Message ID:** `{event.message_id}`\n" + "\n" + "This message was not cached, so the message content cannot be displayed." + ) + else: + response = ( + f"**Channel:** #{channel.name} (`{channel.id}`)\n" + f"**Message ID:** `{event.message_id}`\n" + "\n" + "This message was not cached, so the message content cannot be displayed." + ) + + await self.send_log_message( + Icons.message_delete, Colours.soft_red, + "Message deleted", + response, + channel_id=Channels.message_log + ) + + @Cog.listener() + async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: + """Log message edit event to message change log.""" + if ( + not msg_before.guild + or msg_before.guild.id != GuildConstant.id + or msg_before.channel.id in GuildConstant.modlog_blacklist + or msg_before.author.bot + ): + return + + self._cached_edits.append(msg_before.id) + + if msg_before.content == msg_after.content: + return + + author = msg_before.author + author_str = escape_markdown(str(author)) + + channel = msg_before.channel + channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" + + # Getting the difference per words and group them by type - add, remove, same + # Note that this is intended grouping without sorting + diff = difflib.ndiff(msg_before.clean_content.split(), msg_after.clean_content.split()) + diff_groups = tuple( + (diff_type, tuple(s[2:] for s in diff_words)) + for diff_type, diff_words in itertools.groupby(diff, key=lambda s: s[0]) + ) + + content_before: t.List[str] = [] + content_after: t.List[str] = [] + + for index, (diff_type, words) in enumerate(diff_groups): + sub = ' '.join(words) + if diff_type == '-': + content_before.append(f"[{sub}](http://o.hi)") + elif diff_type == '+': + content_after.append(f"[{sub}](http://o.hi)") + elif diff_type == ' ': + if len(words) > 2: + sub = ( + f"{words[0] if index > 0 else ''}" + " ... " + f"{words[-1] if index < len(diff_groups) - 1 else ''}" + ) + content_before.append(sub) + content_after.append(sub) + + response = ( + f"**Author:** {author_str} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{msg_before.id}`\n" + "\n" + f"**Before**:\n{' '.join(content_before)}\n" + f"**After**:\n{' '.join(content_after)}\n" + "\n" + f"[Jump to message]({msg_after.jump_url})" + ) + + if msg_before.edited_at: + # Message was previously edited, to assist with self-bot detection, use the edited_at + # datetime as the baseline and create a human-readable delta between this edit event + # and the last time the message was edited + timestamp = msg_before.edited_at + delta = humanize_delta(relativedelta(msg_after.edited_at, msg_before.edited_at)) + footer = f"Last edited {delta} ago" + else: + # Message was not previously edited, use the created_at datetime as the baseline, no + # delta calculation needed + timestamp = msg_before.created_at + footer = None + + await self.send_log_message( + Icons.message_edit, Colour.blurple(), "Message edited", response, + channel_id=Channels.message_log, timestamp_override=timestamp, footer=footer + ) + + @Cog.listener() + async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: + """Log raw message edit event to message change log.""" + try: + channel = self.bot.get_channel(int(event.data["channel_id"])) + message = await channel.fetch_message(event.message_id) + except discord.NotFound: # Was deleted before we got the event + return + + if ( + not message.guild + or message.guild.id != GuildConstant.id + or message.channel.id in GuildConstant.modlog_blacklist + or message.author.bot + ): + return + + await asyncio.sleep(1) # Wait here in case the normal event was fired + + if event.message_id in self._cached_edits: + # It was in the cache and the normal event was fired, so we can just ignore it + self._cached_edits.remove(event.message_id) + return + + author = message.author + channel = message.channel + channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" + + before_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + "This message was not cached, so the message content cannot be displayed." + ) + + after_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + f"{message.clean_content}" + ) + + await self.send_log_message( + Icons.message_edit, Colour.blurple(), "Message edited (Before)", + before_response, channel_id=Channels.message_log + ) + + await self.send_log_message( + Icons.message_edit, Colour.blurple(), "Message edited (After)", + after_response, channel_id=Channels.message_log + ) + + @Cog.listener() + async def on_voice_state_update( + self, + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState + ) -> None: + """Log member voice state changes to the voice log channel.""" + if ( + member.guild.id != GuildConstant.id + or (before.channel and before.channel.id in GuildConstant.modlog_blacklist) + ): + return + + if member.id in self._ignored[Event.voice_state_update]: + self._ignored[Event.voice_state_update].remove(member.id) + return + + # Exclude all channel attributes except the name. + diff = DeepDiff( + before, + after, + exclude_paths=("root.session_id", "root.afk"), + exclude_regex_paths=r"root\.channel\.(?!name)", + ) + + # A type change seems to always take precedent over a value change. Furthermore, it will + # include the value change along with the type change anyway. Therefore, it's OK to + # "overwrite" values_changed; in practice there will never even be anything to overwrite. + diff_values = {**diff.get("values_changed", {}), **diff.get("type_changes", {})} + + icon = Icons.voice_state_blue + colour = Colour.blurple() + changes = [] + + for attr, values in diff_values.items(): + if not attr: # Not sure why, but it happens. + continue + + old = values["old_value"] + new = values["new_value"] + + attr = attr[5:] # Remove "root." prefix. + attr = VOICE_STATE_ATTRIBUTES.get(attr, attr.replace("_", " ").capitalize()) + + changes.append(f"**{attr}:** `{old}` **→** `{new}`") + + # Set the embed icon and colour depending on which attribute changed. + if any(name in attr for name in ("Channel", "deaf", "mute")): + if new is None or new is True: + # Left a channel or was muted/deafened. + icon = Icons.voice_state_red + colour = Colours.soft_red + elif old is None or old is True: + # Joined a channel or was unmuted/undeafened. + icon = Icons.voice_state_green + colour = Colours.soft_green + + if not changes: + return + + member_str = escape_markdown(str(member)) + message = "\n".join(f"{Emojis.bullet} {item}" for item in sorted(changes)) + message = f"**{member_str}** (`{member.id}`)\n{message}" + + await self.send_log_message( + icon_url=icon, + colour=colour, + title="Voice state updated", + text=message, + thumbnail=member.avatar_url_as(static_format="png"), + channel_id=Channels.voice_log + ) + + +def setup(bot: Bot) -> None: + """Load the ModLog cog.""" + bot.add_cog(ModLog(bot)) diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py new file mode 100644 index 000000000..4af87c724 --- /dev/null +++ b/bot/exts/moderation/silence.py @@ -0,0 +1,170 @@ +import asyncio +import logging +from contextlib import suppress +from typing import Optional + +from discord import TextChannel +from discord.ext import commands, tasks +from discord.ext.commands import Context + +from bot.bot import Bot +from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles +from bot.converters import HushDurationConverter +from bot.utils.checks import with_role_check +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + + +class SilenceNotifier(tasks.Loop): + """Loop notifier for posting notices to `alert_channel` containing added channels.""" + + def __init__(self, alert_channel: TextChannel): + super().__init__(self._notifier, seconds=1, minutes=0, hours=0, count=None, reconnect=True, loop=None) + self._silenced_channels = {} + self._alert_channel = alert_channel + + def add_channel(self, channel: TextChannel) -> None: + """Add channel to `_silenced_channels` and start loop if not launched.""" + if not self._silenced_channels: + self.start() + log.info("Starting notifier loop.") + self._silenced_channels[channel] = self._current_loop + + def remove_channel(self, channel: TextChannel) -> None: + """Remove channel from `_silenced_channels` and stop loop if no channels remain.""" + with suppress(KeyError): + del self._silenced_channels[channel] + if not self._silenced_channels: + self.stop() + log.info("Stopping notifier loop.") + + async def _notifier(self) -> None: + """Post notice of `_silenced_channels` with their silenced duration to `_alert_channel` periodically.""" + # Wait for 15 minutes between notices with pause at start of loop. + if self._current_loop and not self._current_loop/60 % 15: + log.debug( + f"Sending notice with channels: " + f"{', '.join(f'#{channel} ({channel.id})' for channel in self._silenced_channels)}." + ) + channels_text = ', '.join( + f"{channel.mention} for {(self._current_loop-start)//60} min" + for channel, start in self._silenced_channels.items() + ) + await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") + + +class Silence(commands.Cog): + """Commands for stopping channel messages for `verified` role in a channel.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + self.muted_channels = set() + + self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars()) + self._get_instance_vars_event = asyncio.Event() + + async def _get_instance_vars(self) -> None: + """Get instance variables after they're available to get from the guild.""" + await self.bot.wait_until_guild_available() + guild = self.bot.get_guild(Guild.id) + self._verified_role = guild.get_role(Roles.verified) + self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) + self._mod_log_channel = self.bot.get_channel(Channels.mod_log) + self.notifier = SilenceNotifier(self._mod_log_channel) + self._get_instance_vars_event.set() + + @commands.command(aliases=("hush",)) + async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None: + """ + Silence the current channel for `duration` minutes or `forever`. + + Duration is capped at 15 minutes, passing forever makes the silence indefinite. + Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start. + """ + await self._get_instance_vars_event.wait() + log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.") + if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration): + await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.") + return + if duration is None: + await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.") + return + + await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") + + self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) + + @commands.command(aliases=("unhush",)) + async def unsilence(self, ctx: Context) -> None: + """ + Unsilence the current channel. + + If the channel was silenced indefinitely, notifications for the channel will stop. + """ + await self._get_instance_vars_event.wait() + log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") + if not await self._unsilence(ctx.channel): + await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.") + else: + await ctx.send(f"{Emojis.check_mark} unsilenced current channel.") + + async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: + """ + Silence `channel` for `self._verified_role`. + + If `persistent` is `True` add `channel` to notifier. + `duration` is only used for logging; if None is passed `persistent` should be True to not log None. + Return `True` if channel permissions were changed, `False` otherwise. + """ + current_overwrite = channel.overwrites_for(self._verified_role) + if current_overwrite.send_messages is False: + log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.") + return False + await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False)) + self.muted_channels.add(channel) + if persistent: + log.info(f"Silenced #{channel} ({channel.id}) indefinitely.") + self.notifier.add_channel(channel) + return True + + log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).") + return True + + async def _unsilence(self, channel: TextChannel) -> bool: + """ + Unsilence `channel`. + + Check if `channel` is silenced through a `PermissionOverwrite`, + if it is unsilence it and remove it from the notifier. + Return `True` if channel permissions were changed, `False` otherwise. + """ + current_overwrite = channel.overwrites_for(self._verified_role) + if current_overwrite.send_messages is False: + await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) + log.info(f"Unsilenced channel #{channel} ({channel.id}).") + self.scheduler.cancel(channel.id) + self.notifier.remove_channel(channel) + self.muted_channels.discard(channel) + return True + log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") + return False + + def cog_unload(self) -> None: + """Send alert with silenced channels and cancel scheduled tasks on unload.""" + self.scheduler.cancel_all() + if self.muted_channels: + channels_string = ''.join(channel.mention for channel in self.muted_channels) + message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}" + asyncio.create_task(self._mod_alerts_channel.send(message)) + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Silence cog.""" + bot.add_cog(Silence(bot)) diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py new file mode 100644 index 000000000..1d055afac --- /dev/null +++ b/bot/exts/moderation/slowmode.py @@ -0,0 +1,97 @@ +import logging +from datetime import datetime +from typing import Optional + +from dateutil.relativedelta import relativedelta +from discord import TextChannel +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Emojis, MODERATION_ROLES +from bot.converters import DurationDelta +from bot.decorators import with_role_check +from bot.utils import time + +log = logging.getLogger(__name__) + +SLOWMODE_MAX_DELAY = 21600 # seconds + + +class Slowmode(Cog): + """Commands for getting and setting slowmode delays of text channels.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @group(name='slowmode', aliases=['sm'], invoke_without_command=True) + async def slowmode_group(self, ctx: Context) -> None: + """Get or set the slowmode delay for the text channel this was invoked in or a given text channel.""" + await ctx.send_help(ctx.command) + + @slowmode_group.command(name='get', aliases=['g']) + async def get_slowmode(self, ctx: Context, channel: Optional[TextChannel]) -> None: + """Get the slowmode delay for a text channel.""" + # Use the channel this command was invoked in if one was not given + if channel is None: + channel = ctx.channel + + delay = relativedelta(seconds=channel.slowmode_delay) + humanized_delay = time.humanize_delta(delay) + + await ctx.send(f'The slowmode delay for {channel.mention} is {humanized_delay}.') + + @slowmode_group.command(name='set', aliases=['s']) + async def set_slowmode(self, ctx: Context, channel: Optional[TextChannel], delay: DurationDelta) -> None: + """Set the slowmode delay for a text channel.""" + # Use the channel this command was invoked in if one was not given + if channel is None: + channel = ctx.channel + + # Convert `dateutil.relativedelta.relativedelta` to `datetime.timedelta` + # Must do this to get the delta in a particular unit of time + utcnow = datetime.utcnow() + slowmode_delay = (utcnow + delay - utcnow).total_seconds() + + humanized_delay = time.humanize_delta(delay) + + # Ensure the delay is within discord's limits + if slowmode_delay <= SLOWMODE_MAX_DELAY: + log.info(f'{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.') + + await channel.edit(slowmode_delay=slowmode_delay) + await ctx.send( + f'{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}.' + ) + + else: + log.info( + f'{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, ' + 'which is not between 0 and 6 hours.' + ) + + await ctx.send( + f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.' + ) + + @slowmode_group.command(name='reset', aliases=['r']) + async def reset_slowmode(self, ctx: Context, channel: Optional[TextChannel]) -> None: + """Reset the slowmode delay for a text channel to 0 seconds.""" + # Use the channel this command was invoked in if one was not given + if channel is None: + channel = ctx.channel + + log.info(f'{ctx.author} reset the slowmode delay for #{channel} to 0 seconds.') + + await channel.edit(slowmode_delay=0) + await ctx.send( + f'{Emojis.check_mark} The slowmode delay for {channel.mention} has been reset to 0 seconds.' + ) + + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the Slowmode cog.""" + bot.add_cog(Slowmode(bot)) diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py new file mode 100644 index 000000000..0db3e800d --- /dev/null +++ b/bot/exts/moderation/verification.py @@ -0,0 +1,191 @@ +import logging +from contextlib import suppress + +from discord import Colour, Forbidden, Message, NotFound, Object +from discord.ext.commands import Cog, Context, command + +from bot import constants +from bot.bot import Bot +from bot.decorators import in_whitelist, without_role +from bot.exts.moderation.modlog import ModLog +from bot.utils.checks import InWhitelistCheckFailure, without_role_check + +log = logging.getLogger(__name__) + +WELCOME_MESSAGE = f""" +Hello! Welcome to the server, and thanks for verifying yourself! + +For your records, these are the documents you accepted: + +`1)` Our rules, here: +`2)` Our privacy policy, here: - you can find information on how to have \ +your information removed here as well. + +Feel free to review them at any point! + +Additionally, if you'd like to receive notifications for the announcements \ +we post in <#{constants.Channels.announcements}> +from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ +to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. + +If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ +<#{constants.Channels.bot_commands}>. +""" + +BOT_MESSAGE_DELETE_DELAY = 10 + + +class Verification(Cog): + """User verification and role self-management.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Check new message event for messages to the checkpoint channel & process.""" + if message.channel.id != constants.Channels.verification: + return # Only listen for #checkpoint messages + + if message.author.bot: + # They're a bot, delete their message after the delay. + await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) + return + + # if a user mentions a role or guild member + # alert the mods in mod-alerts channel + if message.mentions or message.role_mentions: + log.debug( + f"{message.author} mentioned one or more users " + f"and/or roles in {message.channel.name}" + ) + + embed_text = ( + f"{message.author.mention} sent a message in " + f"{message.channel.mention} that contained user and/or role mentions." + f"\n\n**Original message:**\n>>> {message.content}" + ) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=constants.Icons.filtering, + colour=Colour(constants.Colours.soft_red), + title=f"User/Role mentioned in {message.channel.name}", + text=embed_text, + thumbnail=message.author.avatar_url_as(static_format="png"), + channel_id=constants.Channels.mod_alerts, + ) + + ctx: Context = await self.bot.get_context(message) + if ctx.command is not None and ctx.command.name == "accept": + return + + if any(r.id == constants.Roles.verified for r in ctx.author.roles): + log.info( + f"{ctx.author} posted '{ctx.message.content}' " + "in the verification channel, but is already verified." + ) + return + + log.debug( + f"{ctx.author} posted '{ctx.message.content}' in the verification " + "channel. We are providing instructions how to verify." + ) + await ctx.send( + f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " + f"and gain access to the rest of the server.", + delete_after=20 + ) + + log.trace(f"Deleting the message posted by {ctx.author}") + with suppress(NotFound): + await ctx.message.delete() + + @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) + @without_role(constants.Roles.verified) + @in_whitelist(channels=(constants.Channels.verification,)) + async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Accept our rules and gain access to the rest of the server.""" + log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") + await ctx.author.add_roles(Object(constants.Roles.verified), reason="Accepted the rules") + try: + await ctx.author.send(WELCOME_MESSAGE) + except Forbidden: + log.info(f"Sending welcome message failed for {ctx.author}.") + finally: + log.trace(f"Deleting accept message by {ctx.author}.") + with suppress(NotFound): + self.mod_log.ignore(constants.Event.message_delete, ctx.message.id) + await ctx.message.delete() + + @command(name='subscribe') + @in_whitelist(channels=(constants.Channels.bot_commands,)) + async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Subscribe to announcement notifications by assigning yourself the role.""" + has_role = False + + for role in ctx.author.roles: + if role.id == constants.Roles.announcements: + has_role = True + break + + if has_role: + await ctx.send(f"{ctx.author.mention} You're already subscribed!") + return + + log.debug(f"{ctx.author} called !subscribe. Assigning the 'Announcements' role.") + await ctx.author.add_roles(Object(constants.Roles.announcements), reason="Subscribed to announcements") + + log.trace(f"Deleting the message posted by {ctx.author}.") + + await ctx.send( + f"{ctx.author.mention} Subscribed to <#{constants.Channels.announcements}> notifications.", + ) + + @command(name='unsubscribe') + @in_whitelist(channels=(constants.Channels.bot_commands,)) + async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args + """Unsubscribe from announcement notifications by removing the role from yourself.""" + has_role = False + + for role in ctx.author.roles: + if role.id == constants.Roles.announcements: + has_role = True + break + + if not has_role: + await ctx.send(f"{ctx.author.mention} You're already unsubscribed!") + return + + log.debug(f"{ctx.author} called !unsubscribe. Removing the 'Announcements' role.") + await ctx.author.remove_roles(Object(constants.Roles.announcements), reason="Unsubscribed from announcements") + + log.trace(f"Deleting the message posted by {ctx.author}.") + + await ctx.send( + f"{ctx.author.mention} Unsubscribed from <#{constants.Channels.announcements}> notifications." + ) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Check for & ignore any InWhitelistCheckFailure.""" + if isinstance(error, InWhitelistCheckFailure): + error.handled = True + + @staticmethod + def bot_check(ctx: Context) -> bool: + """Block any command within the verification channel that is not !accept.""" + if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): + return ctx.command.name == "accept" + else: + return True + + +def setup(bot: Bot) -> None: + """Load the Verification cog.""" + bot.add_cog(Verification(bot)) diff --git a/bot/exts/moderation/watchchannels/__init__.py b/bot/exts/moderation/watchchannels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py new file mode 100644 index 000000000..013d3ee03 --- /dev/null +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -0,0 +1,348 @@ +import asyncio +import logging +import re +import textwrap +from abc import abstractmethod +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Optional + +import dateutil.parser +import discord +from discord import Color, DMChannel, Embed, HTTPException, Message, errors +from discord.ext.commands import Cog, Context + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons +from bot.exts.moderation.modlog import ModLog +from bot.pagination import LinePaginator +from bot.utils import CogABCMeta, messages +from bot.utils.time import time_since + +log = logging.getLogger(__name__) + +URL_RE = re.compile(r"(https?://[^\s]+)") + + +@dataclass +class MessageHistory: + """Represents a watch channel's message history.""" + + last_author: Optional[int] = None + last_channel: Optional[int] = None + message_count: int = 0 + + +class WatchChannel(metaclass=CogABCMeta): + """ABC with functionality for relaying users' messages to a certain channel.""" + + @abstractmethod + def __init__( + self, + bot: Bot, + destination: int, + webhook_id: int, + api_endpoint: str, + api_default_params: dict, + logger: logging.Logger + ) -> None: + self.bot = bot + + self.destination = destination # E.g., Channels.big_brother_logs + self.webhook_id = webhook_id # E.g., Webhooks.big_brother + self.api_endpoint = api_endpoint # E.g., 'bot/infractions' + self.api_default_params = api_default_params # E.g., {'active': 'true', 'type': 'watch'} + self.log = logger # Logger of the child cog for a correct name in the logs + + self._consume_task = None + self.watched_users = defaultdict(dict) + self.message_queue = defaultdict(lambda: defaultdict(deque)) + self.consumption_queue = {} + self.retries = 5 + self.retry_delay = 10 + self.channel = None + self.webhook = None + self.message_history = MessageHistory() + + self._start = self.bot.loop.create_task(self.start_watchchannel()) + + @property + def modlog(self) -> ModLog: + """Provides access to the ModLog cog for alert purposes.""" + return self.bot.get_cog("ModLog") + + @property + def consuming_messages(self) -> bool: + """Checks if a consumption task is currently running.""" + if self._consume_task is None: + return False + + if self._consume_task.done(): + exc = self._consume_task.exception() + if exc: + self.log.exception( + "The message queue consume task has failed with:", + exc_info=exc + ) + return False + + return True + + async def start_watchchannel(self) -> None: + """Starts the watch channel by getting the channel, webhook, and user cache ready.""" + await self.bot.wait_until_guild_available() + + try: + self.channel = await self.bot.fetch_channel(self.destination) + except HTTPException: + self.log.exception(f"Failed to retrieve the text channel with id `{self.destination}`") + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + if self.channel is None or self.webhook is None: + self.log.error("Failed to start the watch channel; unloading the cog.") + + message = textwrap.dedent( + f""" + An error occurred while loading the text channel or webhook. + + TextChannel: {"**Failed to load**" if self.channel is None else "Loaded successfully"} + Webhook: {"**Failed to load**" if self.webhook is None else "Loaded successfully"} + + The Cog has been unloaded. + """ + ) + + await self.modlog.send_log_message( + title=f"Error: Failed to initialize the {self.__class__.__name__} watch channel", + text=message, + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + self.bot.remove_cog(self.__class__.__name__) + return + + if not await self.fetch_user_cache(): + await self.modlog.send_log_message( + title=f"Warning: Failed to retrieve user cache for the {self.__class__.__name__} watch channel", + text="Could not retrieve the list of watched users from the API and messages will not be relayed.", + ping_everyone=True, + icon_url=Icons.token_removed, + colour=Color.red() + ) + + async def fetch_user_cache(self) -> bool: + """ + Fetches watched users from the API and updates the watched user cache accordingly. + + This function returns `True` if the update succeeded. + """ + try: + data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) + except ResponseCodeError as err: + self.log.exception("Failed to fetch the watched users from the API", exc_info=err) + return False + + self.watched_users = defaultdict(dict) + + for entry in data: + user_id = entry.pop('user') + self.watched_users[user_id] = entry + + return True + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Queues up messages sent by watched users.""" + if msg.author.id in self.watched_users: + if not self.consuming_messages: + self._consume_task = self.bot.loop.create_task(self.consume_messages()) + + self.log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)") + self.message_queue[msg.author.id][msg.channel.id].append(msg) + + async def consume_messages(self, delay_consumption: bool = True) -> None: + """Consumes the message queues to log watched users' messages.""" + if delay_consumption: + self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") + await asyncio.sleep(BigBrotherConfig.log_delay) + + self.log.trace("Started consuming the message queue") + + # If the previous consumption Task failed, first consume the existing comsumption_queue + if not self.consumption_queue: + self.consumption_queue = self.message_queue.copy() + self.message_queue.clear() + + for user_channel_queues in self.consumption_queue.values(): + for channel_queue in user_channel_queues.values(): + while channel_queue: + msg = channel_queue.popleft() + + self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") + await self.relay_message(msg) + + self.consumption_queue.clear() + + if self.message_queue: + self.log.trace("Channel queue not empty: Continuing consuming queues") + self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) + else: + self.log.trace("Done consuming messages.") + + async def webhook_send( + self, + content: Optional[str] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + embed: Optional[Embed] = None, + ) -> None: + """Sends a message to the webhook with the specified kwargs.""" + username = messages.sub_clyde(username) + try: + await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send a message to the webhook", + exc_info=exc + ) + + async def relay_message(self, msg: Message) -> None: + """Relays the message to the relevant watch channel.""" + limit = BigBrotherConfig.header_message_limit + + if ( + msg.author.id != self.message_history.last_author + or msg.channel.id != self.message_history.last_channel + or self.message_history.message_count >= limit + ): + self.message_history = MessageHistory(last_author=msg.author.id, last_channel=msg.channel.id) + + await self.send_header(msg) + + cleaned_content = msg.clean_content + + if cleaned_content: + # Put all non-media URLs in a code block to prevent embeds + media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")} + for url in URL_RE.findall(cleaned_content): + if url not in media_urls: + cleaned_content = cleaned_content.replace(url, f"`{url}`") + await self.webhook_send( + cleaned_content, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + + if msg.attachments: + try: + await messages.send_attachments(msg, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.webhook_send( + embed=e, + username=msg.author.display_name, + avatar_url=msg.author.avatar_url + ) + except discord.HTTPException as exc: + self.log.exception( + "Failed to send an attachment to the webhook", + exc_info=exc + ) + + self.message_history.message_count += 1 + + async def send_header(self, msg: Message) -> None: + """Sends a header embed with information about the relayed messages to the watch channel.""" + user_id = msg.author.id + + guild = self.bot.get_guild(GuildConfig.id) + actor = guild.get_member(self.watched_users[user_id]['actor']) + actor = actor.display_name if actor else self.watched_users[user_id]['actor'] + + inserted_at = self.watched_users[user_id]['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + + reason = self.watched_users[user_id]['reason'] + + if isinstance(msg.channel, DMChannel): + # If a watched user DMs the bot there won't be a channel name or jump URL + # This could technically include a GroupChannel but bot's can't be in those + message_jump = "via DM" + else: + message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + + footer = f"Added {time_delta} by {actor} | Reason: {reason}" + embed = Embed(description=f"{msg.author.mention} {message_jump}") + embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) + + await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) + + async def list_watched_users( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Gives an overview of the watched user list for this channel. + + The optional kwarg `oldest_first` orders the list by oldest entry. + + The optional kwarg `update_cache` specifies whether the cache should + be refreshed by polling the API. + """ + if update_cache: + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update {self.__class__.__name__} user cache, serving from cache") + update_cache = False + + lines = [] + for user_id, user_data in self.watched_users.items(): + inserted_at = user_data['inserted_at'] + time_delta = self._get_time_delta(inserted_at) + lines.append(f"• <@{user_id}> (added {time_delta})") + + if oldest_first: + lines.reverse() + + lines = lines or ("There's nothing here yet.",) + + embed = Embed( + title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})", + color=Color.blue() + ) + await LinePaginator.paginate(lines, ctx, embed, empty=False) + + @staticmethod + def _get_time_delta(time_string: str) -> str: + """Returns the time in human-readable time delta format.""" + date_time = dateutil.parser.isoparse(time_string).replace(tzinfo=None) + time_delta = time_since(date_time, precision="minutes", max_units=1) + + return time_delta + + def _remove_user(self, user_id: int) -> None: + """Removes a user from a watch channel.""" + self.watched_users.pop(user_id, None) + self.message_queue.pop(user_id, None) + self.consumption_queue.pop(user_id, None) + + def cog_unload(self) -> None: + """Takes care of unloading the cog and canceling the consumption task.""" + self.log.trace("Unloading the cog") + if self._consume_task and not self._consume_task.done(): + self._consume_task.cancel() + try: + self._consume_task.result() + except asyncio.CancelledError as e: + self.log.exception( + "The consume task was canceled. Messages may be lost.", + exc_info=e + ) diff --git a/bot/exts/moderation/watchchannels/bigbrother.py b/bot/exts/moderation/watchchannels/bigbrother.py new file mode 100644 index 000000000..4ac916c9e --- /dev/null +++ b/bot/exts/moderation/watchchannels/bigbrother.py @@ -0,0 +1,170 @@ +import logging +import textwrap +from collections import ChainMap + +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES, Webhooks +from bot.converters import FetchedMember +from bot.decorators import with_role +from bot.exts.moderation.infraction._utils import post_infraction +from ._watchchannel import WatchChannel + +log = logging.getLogger(__name__) + + +class BigBrother(WatchChannel, Cog, name="Big Brother"): + """Monitors users by relaying their messages to a watch channel to assist with moderation.""" + + def __init__(self, bot: Bot) -> None: + super().__init__( + bot, + destination=Channels.big_brother_logs, + webhook_id=Webhooks.big_brother, + api_endpoint='bot/infractions', + api_default_params={'active': 'true', 'type': 'watch', 'ordering': '-inserted_at'}, + logger=log + ) + + @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def bigbrother_group(self, ctx: Context) -> None: + """Monitors users by relaying their messages to the Big Brother watch channel.""" + await ctx.send_help(ctx.command) + + @bigbrother_group.command(name='watched', aliases=('all', 'list')) + @with_role(*MODERATION_ROLES) + async def watched_command( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Shows the users that are currently being monitored by Big Brother. + + The optional kwarg `oldest_first` can be used to order the list by oldest watched. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) + + @bigbrother_group.command(name='oldest') + @with_role(*MODERATION_ROLES) + async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: + """ + Shows Big Brother monitored users ordered by oldest watched. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) + + @bigbrother_group.command(name='watch', aliases=('w',)) + @with_role(*MODERATION_ROLES) + async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Relay messages sent by the given `user` to the `#big-brother` channel. + + A `reason` for adding the user to Big Brother is required and will be displayed + in the header when relaying messages of this user to the watchchannel. + """ + await self.apply_watch(ctx, user, reason) + + @bigbrother_group.command(name='unwatch', aliases=('uw',)) + @with_role(*MODERATION_ROLES) + async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Stop relaying messages by the given `user`.""" + await self.apply_unwatch(ctx, user, reason) + + async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: + """ + Add `user` to watched users and apply a watch infraction with `reason`. + + A message indicating the result of the operation is sent to `ctx`. + The message will include `user`'s previous watch infraction history, if it exists. + """ + if user.bot: + await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") + return + + if not await self.fetch_user_cache(): + await ctx.send(f":x: Updating the user cache failed, can't watch user {user}") + return + + if user.id in self.watched_users: + await ctx.send(f":x: {user} is already being watched.") + return + + response = await post_infraction(ctx, user, 'watch', reason, hidden=True, active=True) + + if response is not None: + self.watched_users[user.id] = response + msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother." + + history = await self.bot.api_client.get( + self.api_endpoint, + params={ + "user__id": str(user.id), + "active": "false", + 'type': 'watch', + 'ordering': '-inserted_at' + } + ) + + if len(history) > 1: + total = f"({len(history) // 2} previous infractions in total)" + end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") + start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" + msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" + else: + msg = ":x: Failed to post the infraction: response was empty." + + await ctx.send(msg) + + async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: + """ + Remove `user` from watched users and mark their infraction as inactive with `reason`. + + If `send_message` is True, a message indicating the result of the operation is sent to + `ctx`. + """ + active_watches = await self.bot.api_client.get( + self.api_endpoint, + params=ChainMap( + self.api_default_params, + {"user__id": str(user.id)} + ) + ) + if active_watches: + log.trace("Active watches for user found. Attempting to remove.") + [infraction] = active_watches + + await self.bot.api_client.patch( + f"{self.api_endpoint}/{infraction['id']}", + json={'active': False} + ) + + await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False) + + self._remove_user(user.id) + + if not send_message: # Prevents a message being sent to the channel if part of a permanent ban + log.debug(f"Perma-banned user {user} was unwatched.") + return + log.trace("User is not banned. Sending message to channel") + message = f":white_check_mark: Messages sent by {user} will no longer be relayed." + + else: + log.trace("No active watches found for user.") + if not send_message: # Prevents a message being sent to the channel if part of a permanent ban + log.debug(f"{user} was not on the watch list; no removal necessary.") + return + log.trace("User is not perma banned. Send the error message.") + message = ":x: The specified user is currently not being watched." + + await ctx.send(message) + + +def setup(bot: Bot) -> None: + """Load the BigBrother cog.""" + bot.add_cog(BigBrother(bot)) diff --git a/bot/exts/moderation/watchchannels/talentpool.py b/bot/exts/moderation/watchchannels/talentpool.py new file mode 100644 index 000000000..2972f56e1 --- /dev/null +++ b/bot/exts/moderation/watchchannels/talentpool.py @@ -0,0 +1,269 @@ +import logging +import textwrap +from collections import ChainMap + +from discord import Color, Embed, Member +from discord.ext.commands import Cog, Context, group + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks +from bot.converters import FetchedMember +from bot.decorators import with_role +from bot.pagination import LinePaginator +from bot.utils import time +from ._watchchannel import WatchChannel + +log = logging.getLogger(__name__) + + +class TalentPool(WatchChannel, Cog, name="Talentpool"): + """Relays messages of helper candidates to a watch channel to observe them.""" + + def __init__(self, bot: Bot) -> None: + super().__init__( + bot, + destination=Channels.talent_pool, + webhook_id=Webhooks.talent_pool, + api_endpoint='bot/nominations', + api_default_params={'active': 'true', 'ordering': '-inserted_at'}, + logger=log, + ) + + @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def nomination_group(self, ctx: Context) -> None: + """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" + await ctx.send_help(ctx.command) + + @nomination_group.command(name='watched', aliases=('all', 'list')) + @with_role(*MODERATION_ROLES) + async def watched_command( + self, ctx: Context, oldest_first: bool = False, update_cache: bool = True + ) -> None: + """ + Shows the users that are currently being monitored in the talent pool. + + The optional kwarg `oldest_first` can be used to order the list by oldest nomination. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache) + + @nomination_group.command(name='oldest') + @with_role(*MODERATION_ROLES) + async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None: + """ + Shows talent pool monitored users ordered by oldest nomination. + + The optional kwarg `update_cache` can be used to update the user + cache using the API before listing the users. + """ + await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache) + + @nomination_group.command(name='watch', aliases=('w', 'add', 'a')) + @with_role(*STAFF_ROLES) + async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Relay messages sent by the given `user` to the `#talent-pool` channel. + + A `reason` for adding the user to the talent pool is required and will be displayed + in the header when relaying messages of this user to the channel. + """ + if user.bot: + await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") + return + + if isinstance(user, Member) and any(role.id in STAFF_ROLES for role in user.roles): + await ctx.send(":x: Nominating staff members, eh? Here's a cookie :cookie:") + return + + if not await self.fetch_user_cache(): + await ctx.send(f":x: Failed to update the user cache; can't add {user}") + return + + if user.id in self.watched_users: + await ctx.send(f":x: {user} is already being watched in the talent pool") + return + + # Manual request with `raise_for_status` as False because we want the actual response + session = self.bot.api_client.session + url = self.bot.api_client._url_for(self.api_endpoint) + kwargs = { + 'json': { + 'actor': ctx.author.id, + 'reason': reason, + 'user': user.id + }, + 'raise_for_status': False, + } + async with session.post(url, **kwargs) as resp: + response_data = await resp.json() + + if resp.status == 400 and response_data.get('user', False): + await ctx.send(":x: The specified user can't be found in the database tables") + return + else: + resp.raise_for_status() + + self.watched_users[user.id] = response_data + msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel" + + history = await self.bot.api_client.get( + self.api_endpoint, + params={ + "user__id": str(user.id), + "active": "false", + "ordering": "-inserted_at" + } + ) + + if history: + total = f"({len(history)} previous nominations in total)" + start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" + end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" + msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" + + await ctx.send(msg) + + @nomination_group.command(name='history', aliases=('info', 'search')) + @with_role(*MODERATION_ROLES) + async def history_command(self, ctx: Context, user: FetchedMember) -> None: + """Shows the specified user's nomination history.""" + result = await self.bot.api_client.get( + self.api_endpoint, + params={ + 'user__id': str(user.id), + 'ordering': "-active,-inserted_at" + } + ) + if not result: + await ctx.send(":warning: This user has never been nominated") + return + + embed = Embed( + title=f"Nominations for {user.display_name} `({user.id})`", + color=Color.blue() + ) + lines = [self._nomination_to_string(nomination) for nomination in result] + await LinePaginator.paginate( + lines, + ctx=ctx, + embed=embed, + empty=True, + max_lines=3, + max_size=1000 + ) + + @nomination_group.command(name='unwatch', aliases=('end', )) + @with_role(*MODERATION_ROLES) + async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """ + Ends the active nomination of the specified user with the given reason. + + Providing a `reason` is required. + """ + active_nomination = await self.bot.api_client.get( + self.api_endpoint, + params=ChainMap( + self.api_default_params, + {"user__id": str(user.id)} + ) + ) + + if not active_nomination: + await ctx.send(":x: The specified user does not have an active nomination") + return + + [nomination] = active_nomination + await self.bot.api_client.patch( + f"{self.api_endpoint}/{nomination['id']}", + json={'end_reason': reason, 'active': False} + ) + await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") + self._remove_user(user.id) + + @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def nomination_edit_group(self, ctx: Context) -> None: + """Commands to edit nominations.""" + await ctx.send_help(ctx.command) + + @nomination_edit_group.command(name='reason') + @with_role(*MODERATION_ROLES) + async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None: + """ + Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. + + If the nomination is active, the reason for nominating the user will be edited; + If the nomination is no longer active, the reason for ending the nomination will be edited instead. + """ + try: + nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") + except ResponseCodeError as e: + if e.response.status == 404: + self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") + await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") + return + else: + raise + + field = "reason" if nomination["active"] else "end_reason" + + self.log.trace(f"Changing {field} for nomination with id {nomination_id} to {reason}") + + await self.bot.api_client.patch( + f"{self.api_endpoint}/{nomination_id}", + json={field: reason} + ) + + await ctx.send(f":white_check_mark: Updated the {field} of the nomination!") + + def _nomination_to_string(self, nomination_object: dict) -> str: + """Creates a string representation of a nomination.""" + guild = self.bot.get_guild(Guild.id) + + actor_id = nomination_object["actor"] + actor = guild.get_member(actor_id) + + active = nomination_object["active"] + log.debug(active) + log.debug(type(nomination_object["inserted_at"])) + + start_date = time.format_infraction(nomination_object["inserted_at"]) + if active: + lines = textwrap.dedent( + f""" + =============== + Status: **Active** + Date: {start_date} + Actor: {actor.mention if actor else actor_id} + Reason: {nomination_object["reason"]} + Nomination ID: `{nomination_object["id"]}` + =============== + """ + ) + else: + end_date = time.format_infraction(nomination_object["ended_at"]) + lines = textwrap.dedent( + f""" + =============== + Status: Inactive + Date: {start_date} + Actor: {actor.mention if actor else actor_id} + Reason: {nomination_object["reason"]} + + End date: {end_date} + Unwatch reason: {nomination_object["end_reason"]} + Nomination ID: `{nomination_object["id"]}` + =============== + """ + ) + + return lines.strip() + + +def setup(bot: Bot) -> None: + """Load the TalentPool cog.""" + bot.add_cog(TalentPool(bot)) diff --git a/bot/exts/off_topic_names.py b/bot/exts/off_topic_names.py new file mode 100644 index 000000000..ce95450e0 --- /dev/null +++ b/bot/exts/off_topic_names.py @@ -0,0 +1,162 @@ +import asyncio +import difflib +import logging +from datetime import datetime, timedelta + +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES +from bot.converters import OffTopicName +from bot.decorators import with_role +from bot.pagination import LinePaginator + +CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) +log = logging.getLogger(__name__) + + +async def update_names(bot: Bot) -> None: + """Background updater task that performs the daily channel name update.""" + while True: + # Since we truncate the compute timedelta to seconds, we add one second to ensure + # we go past midnight in the `seconds_to_sleep` set below. + today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0) + next_midnight = today_at_midnight + timedelta(days=1) + seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1 + await asyncio.sleep(seconds_to_sleep) + + try: + channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get( + 'bot/off-topic-channel-names', params={'random_items': 3} + ) + except ResponseCodeError as e: + log.error(f"Failed to get new off topic channel names: code {e.response.status}") + continue + channel_0, channel_1, channel_2 = (bot.get_channel(channel_id) for channel_id in CHANNELS) + + await channel_0.edit(name=f'ot0-{channel_0_name}') + await channel_1.edit(name=f'ot1-{channel_1_name}') + await channel_2.edit(name=f'ot2-{channel_2_name}') + log.debug( + "Updated off-topic channel names to" + f" {channel_0_name}, {channel_1_name} and {channel_2_name}" + ) + + +class OffTopicNames(Cog): + """Commands related to managing the off-topic category channel names.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.updater_task = None + + self.bot.loop.create_task(self.init_offtopic_updater()) + + def cog_unload(self) -> None: + """Cancel any running updater tasks on cog unload.""" + if self.updater_task is not None: + self.updater_task.cancel() + + async def init_offtopic_updater(self) -> None: + """Start off-topic channel updating event loop if it hasn't already started.""" + await self.bot.wait_until_guild_available() + if self.updater_task is None: + coro = update_names(self.bot) + self.updater_task = self.bot.loop.create_task(coro) + + @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def otname_group(self, ctx: Context) -> None: + """Add or list items from the off-topic channel name rotation.""" + await ctx.send_help(ctx.command) + + @otname_group.command(name='add', aliases=('a',)) + @with_role(*MODERATION_ROLES) + async def add_command(self, ctx: Context, *, name: OffTopicName) -> None: + """ + Adds a new off-topic name to the rotation. + + The name is not added if it is too similar to an existing name. + """ + existing_names = await self.bot.api_client.get('bot/off-topic-channel-names') + close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) + + if close_match: + match = close_match[0] + log.info( + f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'" + ) + await ctx.send( + f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. " + "Use `!otn forceadd` to override this check." + ) + else: + await self._add_name(ctx, name) + + @otname_group.command(name='forceadd', aliases=('fa',)) + @with_role(*MODERATION_ROLES) + async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None: + """Forcefully adds a new off-topic name to the rotation.""" + await self._add_name(ctx, name) + + async def _add_name(self, ctx: Context, name: str) -> None: + """Adds an off-topic channel name to the site storage.""" + await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name}) + + log.info(f"{ctx.author} added the off-topic channel name '{name}'") + await ctx.send(f":ok_hand: Added `{name}` to the names list.") + + @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) + @with_role(*MODERATION_ROLES) + async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None: + """Removes a off-topic name from the rotation.""" + await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') + + log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") + await ctx.send(f":ok_hand: Removed `{name}` from the names list.") + + @otname_group.command(name='list', aliases=('l',)) + @with_role(*MODERATION_ROLES) + async def list_command(self, ctx: Context) -> None: + """ + Lists all currently known off-topic channel names in a paginator. + + Restricted to Moderator and above to not spoil the surprise. + """ + result = await self.bot.api_client.get('bot/off-topic-channel-names') + lines = sorted(f"• {name}" for name in result) + embed = Embed( + title=f"Known off-topic names (`{len(result)}` total)", + colour=Colour.blue() + ) + if result: + await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + + @otname_group.command(name='search', aliases=('s',)) + @with_role(*MODERATION_ROLES) + async def search_command(self, ctx: Context, *, query: OffTopicName) -> None: + """Search for an off-topic name.""" + result = await self.bot.api_client.get('bot/off-topic-channel-names') + in_matches = {name for name in result if query in name} + close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70) + lines = sorted(f"• {name}" for name in in_matches.union(close_matches)) + embed = Embed( + title="Query results", + colour=Colour.blue() + ) + + if lines: + await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) + else: + embed.description = "Nothing found." + await ctx.send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the OffTopicNames cog.""" + bot.add_cog(OffTopicNames(bot)) diff --git a/bot/exts/utils/__init__.py b/bot/exts/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py new file mode 100644 index 000000000..866fd2b68 --- /dev/null +++ b/bot/exts/utils/bot.py @@ -0,0 +1,385 @@ +import ast +import logging +import re +import time +from typing import Optional, Tuple + +from discord import Embed, Message, RawMessageUpdateEvent, TextChannel +from discord.ext.commands import Cog, Context, command, group + +from bot.bot import Bot +from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs +from bot.decorators import with_role +from bot.exts.filters.token_remover import TokenRemover +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +RE_MARKDOWN = re.compile(r'([*_~`|>])') + + +class BotCog(Cog, name="Bot"): + """Bot information commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + # Stores allowed channels plus epoch time since last call. + self.channel_cooldowns = { + Channels.python_discussion: 0, + } + + # These channels will also work, but will not be subject to cooldown + self.channel_whitelist = ( + Channels.bot_commands, + ) + + # Stores improperly formatted Python codeblock message ids and the corresponding bot message + self.codeblock_message_ids = {} + + @group(invoke_without_command=True, name="bot", hidden=True) + @with_role(Roles.verified) + async def botinfo_group(self, ctx: Context) -> None: + """Bot informational commands.""" + await ctx.send_help(ctx.command) + + @botinfo_group.command(name='about', aliases=('info',), hidden=True) + @with_role(Roles.verified) + async def about_command(self, ctx: Context) -> None: + """Get information about the bot.""" + embed = Embed( + description="A utility bot designed just for the Python server! Try `!help` for more info.", + url="https://github.com/python-discord/bot" + ) + + embed.add_field(name="Total Users", value=str(len(self.bot.get_guild(Guild.id).members))) + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=URLs.bot_avatar + ) + + await ctx.send(embed=embed) + + @command(name='echo', aliases=('print',)) + @with_role(*MODERATION_ROLES) + async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Repeat the given message in either a specified channel or the current channel.""" + if channel is None: + await ctx.send(text) + else: + await channel.send(text) + + @command(name='embed') + @with_role(*MODERATION_ROLES) + async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Send the input within an embed to either a specified channel or the current channel.""" + embed = Embed(description=text) + + if channel is None: + await ctx.send(embed=embed) + else: + await channel.send(embed=embed) + + def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: + """ + Strip msg in order to find Python code. + + Tries to strip out Python code out of msg and returns the stripped block or + None if the block is a valid Python codeblock. + """ + if msg.count("\n") >= 3: + # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. + if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: + log.trace( + "Someone wrote a message that was already a " + "valid Python syntax highlighted code block. No action taken." + ) + return None + + else: + # Stripping backticks from every line of the message. + log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") + content = "" + for line in msg.splitlines(keepends=True): + content += line.strip("`") + + content = content.strip() + + # Remove "Python" or "Py" from start of the message if it exists. + log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") + pycode = False + if content.lower().startswith("python"): + content = content[6:] + pycode = True + elif content.lower().startswith("py"): + content = content[2:] + pycode = True + + if pycode: + content = content.splitlines(keepends=True) + + # Check if there might be code in the first line, and preserve it. + first_line = content[0] + if " " in content[0]: + first_space = first_line.index(" ") + content[0] = first_line[first_space:] + content = "".join(content) + + # If there's no code we can just get rid of the first line. + else: + content = "".join(content[1:]) + + # Strip it again to remove any leading whitespace. This is neccessary + # if the first line of the message looked like ```python + old = content.strip() + + # Strips REPL code out of the message if there is any. + content, repl_code = self.repl_stripping(old) + if old != content: + return (content, old), repl_code + + # Try to apply indentation fixes to the code. + content = self.fix_indentation(content) + + # Check if the code contains backticks, if it does ignore the message. + if "`" in content: + log.trace("Detected ` inside the code, won't reply") + return None + else: + log.trace(f"Returning message.\n\n{content}\n\n") + return (content,), repl_code + + def fix_indentation(self, msg: str) -> str: + """Attempts to fix badly indented code.""" + def unindent(code: str, skip_spaces: int = 0) -> str: + """Unindents all code down to the number of spaces given in skip_spaces.""" + final = "" + current = code[0] + leading_spaces = 0 + + # Get numbers of spaces before code in the first line. + while current == " ": + current = code[leading_spaces + 1] + leading_spaces += 1 + leading_spaces -= skip_spaces + + # If there are any, remove that number of spaces from every line. + if leading_spaces > 0: + for line in code.splitlines(keepends=True): + line = line[leading_spaces:] + final += line + return final + else: + return code + + # Apply fix for "all lines are overindented" case. + msg = unindent(msg) + + # If the first line does not end with a colon, we can be + # certain the next line will be on the same indentation level. + # + # If it does end with a colon, we will need to indent all successive + # lines one additional level. + first_line = msg.splitlines()[0] + code = "".join(msg.splitlines(keepends=True)[1:]) + if not first_line.endswith(":"): + msg = f"{first_line}\n{unindent(code)}" + else: + msg = f"{first_line}\n{unindent(code, 4)}" + return msg + + def repl_stripping(self, msg: str) -> Tuple[str, bool]: + """ + Strip msg in order to extract Python code out of REPL output. + + Tries to strip out REPL Python code out of msg and returns the stripped msg. + + Returns True for the boolean if REPL code was found in the input msg. + """ + final = "" + for line in msg.splitlines(keepends=True): + if line.startswith(">>>") or line.startswith("..."): + final += line[4:] + log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") + if not final: + log.trace(f"Found no REPL code in \n\n{msg}\n\n") + return msg, False + else: + log.trace(f"Found REPL code in \n\n{msg}\n\n") + return final.rstrip(), True + + def has_bad_ticks(self, msg: Message) -> bool: + """Check to see if msg contains ticks that aren't '`'.""" + not_backticks = [ + "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", + "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", + "\u3003\u3003\u3003" + ] + + return msg.content[:3] in not_backticks + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """ + Detect poorly formatted Python code in new messages. + + If poorly formatted code is detected, send the user a helpful message explaining how to do + properly formatted Python syntax highlighting codeblocks. + """ + is_help_channel = ( + getattr(msg.channel, "category", None) + and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) + ) + parse_codeblock = ( + ( + is_help_channel + or msg.channel.id in self.channel_cooldowns + or msg.channel.id in self.channel_whitelist + ) + and not msg.author.bot + and len(msg.content.splitlines()) > 3 + and not TokenRemover.find_token_in_message(msg) + ) + + if parse_codeblock: # no token in the msg + on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 + if not on_cooldown or DEBUG_MODE: + try: + if self.has_bad_ticks(msg): + ticks = msg.content[:3] + content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) + if content is None: + return + + content, repl_code = content + + if len(content) == 2: + content = content[1] + else: + content = content[0] + + space_left = 204 + if len(content) >= space_left: + current_length = 0 + lines_walked = 0 + for line in content.splitlines(keepends=True): + if current_length + len(line) > space_left or lines_walked == 10: + break + current_length += len(line) + lines_walked += 1 + content = content[:current_length] + "#..." + content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) + howto = ( + "It looks like you are trying to paste code into this channel.\n\n" + "You seem to be using the wrong symbols to indicate where the codeblock should start. " + f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" + "**Here is an example of how it should look:**\n" + f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" + "**This will result in the following:**\n" + f"```python\n{content}\n```" + ) + + else: + howto = "" + content = self.codeblock_stripping(msg.content, False) + if content is None: + return + + content, repl_code = content + # Attempts to parse the message into an AST node. + # Invalid Python code will raise a SyntaxError. + tree = ast.parse(content[0]) + + # Multiple lines of single words could be interpreted as expressions. + # This check is to avoid all nodes being parsed as expressions. + # (e.g. words over multiple lines) + if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: + # Shorten the code to 10 lines and/or 204 characters. + space_left = 204 + if content and repl_code: + content = content[1] + else: + content = content[0] + + if len(content) >= space_left: + current_length = 0 + lines_walked = 0 + for line in content.splitlines(keepends=True): + if current_length + len(line) > space_left or lines_walked == 10: + break + current_length += len(line) + lines_walked += 1 + content = content[:current_length] + "#..." + + content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) + howto += ( + "It looks like you're trying to paste code into this channel.\n\n" + "Discord has support for Markdown, which allows you to post code with full " + "syntax highlighting. Please use these whenever you paste code, as this " + "helps improve the legibility and makes it easier for us to help you.\n\n" + f"**To do this, use the following method:**\n" + f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" + "**This will result in the following:**\n" + f"```python\n{content}\n```" + ) + + log.debug(f"{msg.author} posted something that needed to be put inside python code " + "blocks. Sending the user some instructions.") + else: + log.trace("The code consists only of expressions, not sending instructions") + + if howto != "": + # Increase amount of codeblock correction in stats + self.bot.stats.incr("codeblock_corrections") + howto_embed = Embed(description=howto) + bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) + self.codeblock_message_ids[msg.id] = bot_message.id + + self.bot.loop.create_task( + wait_for_deletion(bot_message, user_ids=(msg.author.id,), client=self.bot) + ) + else: + return + + if msg.channel.id not in self.channel_whitelist: + self.channel_cooldowns[msg.channel.id] = time.time() + + except SyntaxError: + log.trace( + f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " + "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " + f"The message that was posted was:\n\n{msg.content}\n\n" + ) + + @Cog.listener() + async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: + """Check to see if an edited message (previously called out) still contains poorly formatted code.""" + if ( + # Checks to see if the message was called out by the bot + payload.message_id not in self.codeblock_message_ids + # Makes sure that there is content in the message + or payload.data.get("content") is None + # Makes sure there's a channel id in the message payload + or payload.data.get("channel_id") is None + ): + return + + # Retrieve channel and message objects for use later + channel = self.bot.get_channel(int(payload.data.get("channel_id"))) + user_message = await channel.fetch_message(payload.message_id) + + # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None + has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) + + # If the message is fixed, delete the bot message and the entry from the id dictionary + if has_fixed_codeblock is None: + bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) + await bot_message.delete() + del self.codeblock_message_ids[payload.message_id] + log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") + + +def setup(bot: Bot) -> None: + """Load the Bot cog.""" + bot.add_cog(BotCog(bot)) diff --git a/bot/exts/utils/clean.py b/bot/exts/utils/clean.py new file mode 100644 index 000000000..d9a7aafe1 --- /dev/null +++ b/bot/exts/utils/clean.py @@ -0,0 +1,272 @@ +import logging +import random +import re +from typing import Iterable, Optional + +from discord import Colour, Embed, Message, TextChannel, User +from discord.ext import commands +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import ( + Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES +) +from bot.decorators import with_role +from bot.exts.moderation.modlog import ModLog + +log = logging.getLogger(__name__) + + +class Clean(Cog): + """ + A cog that allows messages to be deleted in bulk, while applying various filters. + + You can delete messages sent by a specific user, messages sent by bots, all messages, or messages that match a + specific regular expression. + + The deleted messages are saved and uploaded to the database via an API endpoint, and a URL is returned which can be + used to view the messages in the Discord dark theme style. + """ + + def __init__(self, bot: Bot): + self.bot = bot + self.cleaning = False + + @property + def mod_log(self) -> ModLog: + """Get currently loaded ModLog cog instance.""" + return self.bot.get_cog("ModLog") + + async def _clean_messages( + self, + amount: int, + ctx: Context, + channels: Iterable[TextChannel], + bots_only: bool = False, + user: User = None, + regex: Optional[str] = None, + until_message: Optional[Message] = None, + ) -> None: + """A helper function that does the actual message cleaning.""" + def predicate_bots_only(message: Message) -> bool: + """Return True if the message was sent by a bot.""" + return message.author.bot + + def predicate_specific_user(message: Message) -> bool: + """Return True if the message was sent by the user provided in the _clean_messages call.""" + return message.author == user + + def predicate_regex(message: Message) -> bool: + """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" + content = [message.content] + + # Add the content for all embed attributes + for embed in message.embeds: + content.append(embed.title) + content.append(embed.description) + content.append(embed.footer.text) + content.append(embed.author.name) + for field in embed.fields: + content.append(field.name) + content.append(field.value) + + # Get rid of empty attributes and turn it into a string + content = [attr for attr in content if attr] + content = "\n".join(content) + + # Now let's see if there's a regex match + if not content: + return False + else: + return bool(re.search(regex.lower(), content.lower())) + + # Is this an acceptable amount of messages to clean? + if amount > CleanMessages.message_limit: + embed = Embed( + color=Colour(Colours.soft_red), + title=random.choice(NEGATIVE_REPLIES), + description=f"You cannot clean more than {CleanMessages.message_limit} messages." + ) + await ctx.send(embed=embed) + return + + # Are we already performing a clean? + if self.cleaning: + embed = Embed( + color=Colour(Colours.soft_red), + title=random.choice(NEGATIVE_REPLIES), + description="Please wait for the currently ongoing clean operation to complete." + ) + await ctx.send(embed=embed) + return + + # Set up the correct predicate + if bots_only: + predicate = predicate_bots_only # Delete messages from bots + elif user: + predicate = predicate_specific_user # Delete messages from specific user + elif regex: + predicate = predicate_regex # Delete messages that match regex + else: + predicate = None # Delete all messages + + # Default to using the invoking context's channel + if not channels: + channels = [ctx.channel] + + # Delete the invocation first + self.mod_log.ignore(Event.message_delete, ctx.message.id) + await ctx.message.delete() + + messages = [] + message_ids = [] + self.cleaning = True + + # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. + for channel in channels: + async for message in channel.history(limit=amount): + + # If at any point the cancel command is invoked, we should stop. + if not self.cleaning: + return + + # If we are looking for specific message. + if until_message: + + # we could use ID's here however in case if the message we are looking for gets deleted, + # we won't have a way to figure that out thus checking for datetime should be more reliable + if message.created_at < until_message.created_at: + # means we have found the message until which we were supposed to be deleting. + break + + # Since we will be using `delete_messages` method of a TextChannel and we need message objects to + # use it as well as to send logs we will start appending messages here instead adding them from + # purge. + messages.append(message) + + # If the message passes predicate, let's save it. + if predicate is None or predicate(message): + message_ids.append(message.id) + + self.cleaning = False + + # Now let's delete the actual messages with purge. + self.mod_log.ignore(Event.message_delete, *message_ids) + for channel in channels: + if until_message: + for i in range(0, len(messages), 100): + # while purge automatically handles the amount of messages + # delete_messages only allows for up to 100 messages at once + # thus we need to paginate the amount to always be <= 100 + await channel.delete_messages(messages[i:i + 100]) + else: + messages += await channel.purge(limit=amount, check=predicate) + + # Reverse the list to restore chronological order + if messages: + messages = reversed(messages) + log_url = await self.mod_log.upload_log(messages, ctx.author.id) + else: + # Can't build an embed, nothing to clean! + embed = Embed( + color=Colour(Colours.soft_red), + description="No matching messages could be found." + ) + await ctx.send(embed=embed, delete_after=10) + return + + # Build the embed and send it + target_channels = ", ".join(channel.mention for channel in channels) + + message = ( + f"**{len(message_ids)}** messages deleted in {target_channels} by **{ctx.author.name}**\n\n" + f"A log of the deleted messages can be found [here]({log_url})." + ) + + await self.mod_log.send_log_message( + icon_url=Icons.message_bulk_delete, + colour=Colour(Colours.soft_red), + title="Bulk message delete", + text=message, + channel_id=Channels.mod_log, + ) + + @group(invoke_without_command=True, name="clean", aliases=["purge"]) + @with_role(*MODERATION_ROLES) + async def clean_group(self, ctx: Context) -> None: + """Commands for cleaning messages in channels.""" + await ctx.send_help(ctx.command) + + @clean_group.command(name="user", aliases=["users"]) + @with_role(*MODERATION_ROLES) + async def clean_user( + self, + ctx: Context, + user: User, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, user=user, channels=channels) + + @clean_group.command(name="all", aliases=["everything"]) + @with_role(*MODERATION_ROLES) + async def clean_all( + self, + ctx: Context, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, channels=channels) + + @clean_group.command(name="bots", aliases=["bot"]) + @with_role(*MODERATION_ROLES) + async def clean_bots( + self, + ctx: Context, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, bots_only=True, channels=channels) + + @clean_group.command(name="regex", aliases=["word", "expression"]) + @with_role(*MODERATION_ROLES) + async def clean_regex( + self, + ctx: Context, + regex: str, + amount: Optional[int] = 10, + channels: commands.Greedy[TextChannel] = None + ) -> None: + """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" + await self._clean_messages(amount, ctx, regex=regex, channels=channels) + + @clean_group.command(name="message", aliases=["messages"]) + @with_role(*MODERATION_ROLES) + async def clean_message(self, ctx: Context, message: Message) -> None: + """Delete all messages until certain message, stop cleaning after hitting the `message`.""" + await self._clean_messages( + CleanMessages.message_limit, + ctx, + channels=[message.channel], + until_message=message + ) + + @clean_group.command(name="stop", aliases=["cancel", "abort"]) + @with_role(*MODERATION_ROLES) + async def clean_cancel(self, ctx: Context) -> None: + """If there is an ongoing cleaning process, attempt to immediately cancel it.""" + self.cleaning = False + + embed = Embed( + color=Colour.blurple(), + description="Clean interrupted." + ) + await ctx.send(embed=embed, delete_after=10) + + +def setup(bot: Bot) -> None: + """Load the Clean cog.""" + bot.add_cog(Clean(bot)) diff --git a/bot/exts/utils/eval.py b/bot/exts/utils/eval.py new file mode 100644 index 000000000..eb8bfb1cf --- /dev/null +++ b/bot/exts/utils/eval.py @@ -0,0 +1,202 @@ +import contextlib +import inspect +import logging +import pprint +import re +import textwrap +import traceback +from io import StringIO +from typing import Any, Optional, Tuple + +import discord +from discord.ext.commands import Cog, Context, group + +from bot.bot import Bot +from bot.constants import Roles +from bot.decorators import with_role +from bot.interpreter import Interpreter + +log = logging.getLogger(__name__) + + +class CodeEval(Cog): + """Owner and admin feature that evaluates code and returns the result to the channel.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.env = {} + self.ln = 0 + self.stdout = StringIO() + + self.interpreter = Interpreter(bot) + + def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: + """Format the eval output into a string & attempt to format it into an Embed.""" + self._ = out + + res = "" + + # Erase temp input we made + if inp.startswith("_ = "): + inp = inp[4:] + + # Get all non-empty lines + lines = [line for line in inp.split("\n") if line.strip()] + if len(lines) != 1: + lines += [""] + + # Create the input dialog + for i, line in enumerate(lines): + if i == 0: + # Start dialog + start = f"In [{self.ln}]: " + + else: + # Indent the 3 dots correctly; + # Normally, it's something like + # In [X]: + # ...: + # + # But if it's + # In [XX]: + # ...: + # + # You can see it doesn't look right. + # This code simply indents the dots + # far enough to align them. + # we first `str()` the line number + # then we get the length + # and use `str.rjust()` + # to indent it. + start = "...: ".rjust(len(str(self.ln)) + 7) + + if i == len(lines) - 2: + if line.startswith("return"): + line = line[6:].strip() + + # Combine everything + res += (start + line + "\n") + + self.stdout.seek(0) + text = self.stdout.read() + self.stdout.close() + self.stdout = StringIO() + + if text: + res += (text + "\n") + + if out is None: + # No output, return the input statement + return (res, None) + + res += f"Out[{self.ln}]: " + + if isinstance(out, discord.Embed): + # We made an embed? Send that as embed + res += "" + res = (res, out) + + else: + if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): + # Leave out the traceback message + out = "\n" + "\n".join(out.split("\n")[1:]) + + if isinstance(out, str): + pretty = out + else: + pretty = pprint.pformat(out, compact=True, width=60) + + if pretty != str(out): + # We're using the pretty version, start on the next line + res += "\n" + + if pretty.count("\n") > 20: + # Text too long, shorten + li = pretty.split("\n") + + pretty = ("\n".join(li[:3]) # First 3 lines + + "\n ...\n" # Ellipsis to indicate removed lines + + "\n".join(li[-3:])) # last 3 lines + + # Add the output + res += pretty + res = (res, None) + + return res # Return (text, embed) + + async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: + """Eval the input code string & send an embed to the invoking context.""" + self.ln += 1 + + if code.startswith("exit"): + self.ln = 0 + self.env = {} + return await ctx.send("```Reset history!```") + + env = { + "message": ctx.message, + "author": ctx.message.author, + "channel": ctx.channel, + "guild": ctx.guild, + "ctx": ctx, + "self": self, + "bot": self.bot, + "inspect": inspect, + "discord": discord, + "contextlib": contextlib + } + + self.env.update(env) + + # Ignore this code, it works + code_ = """ +async def func(): # (None,) -> Any + try: + with contextlib.redirect_stdout(self.stdout): +{0} + if '_' in locals(): + if inspect.isawaitable(_): + _ = await _ + return _ + finally: + self.env.update(locals()) +""".format(textwrap.indent(code, ' ')) + + try: + exec(code_, self.env) # noqa: B102,S102 + func = self.env['func'] + res = await func() + + except Exception: + res = traceback.format_exc() + + out, embed = self._format(code, res) + await ctx.send(f"```py\n{out}```", embed=embed) + + @group(name='internal', aliases=('int',)) + @with_role(Roles.owners, Roles.admins) + async def internal_group(self, ctx: Context) -> None: + """Internal commands. Top secret!""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @internal_group.command(name='eval', aliases=('e',)) + @with_role(Roles.admins, Roles.owners) + async def eval(self, ctx: Context, *, code: str) -> None: + """Run eval in a REPL-like format.""" + code = code.strip("`") + if re.match('py(thon)?\n', code): + code = "\n".join(code.split("\n")[1:]) + + if not re.search( # Check if it's an expression + r"^(return|import|for|while|def|class|" + r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( + code.split("\n")) == 1: + code = "_ = " + code + + await self._eval(ctx, code) + + +def setup(bot: Bot) -> None: + """Load the CodeEval cog.""" + bot.add_cog(CodeEval(bot)) diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py new file mode 100644 index 000000000..671397650 --- /dev/null +++ b/bot/exts/utils/extensions.py @@ -0,0 +1,289 @@ +import functools +import importlib +import inspect +import logging +import pkgutil +import typing as t +from enum import Enum + +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Context, group + +from bot import exts +from bot.bot import Bot +from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + + +def walk_extensions() -> t.Iterator[str]: + """Yield extension names from the bot.exts subpackage.""" + + def on_error(name: str) -> t.NoReturn: + raise ImportError(name=name) # pragma: no cover + + for module in pkgutil.walk_packages(exts.__path__, f"{exts.__name__}.", onerror=on_error): + if module.name.rsplit(".", maxsplit=1)[-1].startswith("_"): + # Ignore module/package names starting with an underscore. + continue + + if module.ispkg: + imported = importlib.import_module(module.name) + if not inspect.isfunction(getattr(imported, "setup", None)): + # If it lacks a setup function, it's not an extension. + continue + + yield module.name + + +UNLOAD_BLACKLIST = {f"{exts.__name__}.utils.extensions", f"{exts.__name__}.moderation.modlog"} +EXTENSIONS = frozenset(walk_extensions()) +BASE_PATH_LEN = len(exts.__name__.split(".")) + + +class Action(Enum): + """Represents an action to perform on an extension.""" + + # Need to be partial otherwise they are considered to be function definitions. + LOAD = functools.partial(Bot.load_extension) + UNLOAD = functools.partial(Bot.unload_extension) + RELOAD = functools.partial(Bot.reload_extension) + + +class Extension(commands.Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if argument in EXTENSIONS: + return argument + elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + return qualified_arg + + matches = [] + for ext in EXTENSIONS: + name = ext.rsplit(".", maxsplit=1)[-1] + if argument == name: + matches.append(ext) + + if len(matches) > 1: + matches.sort() + names = "\n".join(matches) + raise commands.BadArgument( + f":x: `{argument}` is an ambiguous extension name. " + f"Please use one of the following fully-qualified names.```\n{names}```" + ) + elif matches: + return matches[0] + else: + raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") + + +class Extensions(commands.Cog): + """Extension management commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) + async def extensions_group(self, ctx: Context) -> None: + """Load, unload, reload, and list loaded extensions.""" + await ctx.send_help(ctx.command) + + @extensions_group.command(name="load", aliases=("l",)) + async def load_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Load extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "*" in extensions or "**" in extensions: + extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + + msg = self.batch_manage(Action.LOAD, *extensions) + await ctx.send(msg) + + @extensions_group.command(name="unload", aliases=("ul",)) + async def unload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Unload currently loaded extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + + if blacklisted: + msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" + else: + if "*" in extensions or "**" in extensions: + extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + + msg = self.batch_manage(Action.UNLOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="reload", aliases=("r",)) + async def reload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Reload extensions given their fully qualified or unqualified names. + + If an extension fails to be reloaded, it will be rolled-back to the prior working state. + + If '\*' is given as the name, all currently loaded extensions will be reloaded. + If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "**" in extensions: + extensions = EXTENSIONS + elif "*" in extensions: + extensions = set(self.bot.extensions.keys()) | set(extensions) + extensions.remove("*") + + msg = self.batch_manage(Action.RELOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="list", aliases=("all",)) + async def list_command(self, ctx: Context) -> None: + """ + Get a list of all extensions, including their loaded status. + + Grey indicates that the extension is unloaded. + Green indicates that the extension is currently loaded. + """ + embed = Embed(colour=Colour.blurple()) + embed.set_author( + name="Extensions List", + url=URLs.github_bot_repo, + icon_url=URLs.bot_avatar + ) + + lines = [] + categories = self.group_extension_statuses() + for category, extensions in sorted(categories.items()): + # Treat each category as a single line by concatenating everything. + # This ensures the paginator will not cut off a page in the middle of a category. + category = category.replace("_", " ").title() + extensions = "\n".join(sorted(extensions)) + lines.append(f"**{category}**\n{extensions}\n") + + log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") + await LinePaginator.paginate(lines, ctx, embed, scale_to_size=700, empty=False) + + def group_extension_statuses(self) -> t.Mapping[str, str]: + """Return a mapping of extension names and statuses to their categories.""" + categories = {} + + for ext in EXTENSIONS: + if ext in self.bot.extensions: + status = Emojis.status_online + else: + status = Emojis.status_offline + + path = ext.split(".") + if len(path) > BASE_PATH_LEN + 1: + category = " - ".join(path[BASE_PATH_LEN:-1]) + else: + category = "uncategorised" + + categories.setdefault(category, []).append(f"{status} {path[-1]}") + + return categories + + def batch_manage(self, action: Action, *extensions: str) -> str: + """ + Apply an action to multiple extensions and return a message with the results. + + If only one extension is given, it is deferred to `manage()`. + """ + if len(extensions) == 1: + msg, _ = self.manage(action, extensions[0]) + return msg + + verb = action.name.lower() + failures = {} + + for extension in extensions: + _, error = self.manage(action, extension) + if error: + failures[extension] = error + + emoji = ":x:" if failures else ":ok_hand:" + msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." + + if failures: + failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) + msg += f"\nFailures:```{failures}```" + + log.debug(f"Batch {verb}ed extensions.") + + return msg + + def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: + """Apply an action to an extension and return the status message and any error message.""" + verb = action.name.lower() + error_msg = None + + try: + action.value(self.bot, ext) + except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): + if action is Action.RELOAD: + # When reloading, just load the extension if it was not loaded. + return self.manage(Action.LOAD, ext) + + msg = f":x: Extension `{ext}` is already {verb}ed." + log.debug(msg[4:]) + except Exception as e: + if hasattr(e, "original"): + e = e.original + + log.exception(f"Extension '{ext}' failed to {verb}.") + + error_msg = f"{e.__class__.__name__}: {e}" + msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" + else: + msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." + log.debug(msg[10:]) + + return msg, error_msg + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators and core developers to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Handle BadArgument errors locally to prevent the help command from showing.""" + if isinstance(error, commands.BadArgument): + await ctx.send(str(error)) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Extensions cog.""" + bot.add_cog(Extensions(bot)) diff --git a/bot/exts/utils/jams.py b/bot/exts/utils/jams.py new file mode 100644 index 000000000..b3102db2f --- /dev/null +++ b/bot/exts/utils/jams.py @@ -0,0 +1,150 @@ +import logging +import typing as t + +from discord import CategoryChannel, Guild, Member, PermissionOverwrite, Role +from discord.ext import commands +from more_itertools import unique_everseen + +from bot.bot import Bot +from bot.constants import Roles +from bot.decorators import with_role + +log = logging.getLogger(__name__) + +MAX_CHANNELS = 50 +CATEGORY_NAME = "Code Jam" + + +class CodeJams(commands.Cog): + """Manages the code-jam related parts of our server.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @commands.command() + @with_role(Roles.admins) + async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: + """ + Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. + + The first user passed will always be the team leader. + """ + # Ignore duplicate members + members = list(unique_everseen(members)) + + # We had a little issue during Code Jam 4 here, the greedy converter did it's job + # and ignored anything which wasn't a valid argument which left us with teams of + # two members or at some times even 1 member. This fixes that by checking that there + # are always 3 members in the members list. + if len(members) < 3: + await ctx.send( + ":no_entry_sign: One of your arguments was invalid\n" + f"There must be a minimum of 3 valid members in your team. Found: {len(members)}" + " members" + ) + return + + team_channel = await self.create_channels(ctx.guild, team_name, members) + await self.add_roles(ctx.guild, members) + + await ctx.send( + f":ok_hand: Team created: {team_channel}\n" + f"**Team Leader:** {members[0].mention}\n" + f"**Team Members:** {' '.join(member.mention for member in members[1:])}" + ) + + async def get_category(self, guild: Guild) -> CategoryChannel: + """ + Return a code jam category. + + If all categories are full or none exist, create a new category. + """ + for category in guild.categories: + # Need 2 available spaces: one for the text channel and one for voice. + if category.name == CATEGORY_NAME and MAX_CHANNELS - len(category.channels) >= 2: + return category + + return await self.create_category(guild) + + @staticmethod + async def create_category(guild: Guild) -> CategoryChannel: + """Create a new code jam category and return it.""" + log.info("Creating a new code jam category.") + + category_overwrites = { + guild.default_role: PermissionOverwrite(read_messages=False), + guild.me: PermissionOverwrite(read_messages=True) + } + + return await guild.create_category_channel( + CATEGORY_NAME, + overwrites=category_overwrites, + reason="It's code jam time!" + ) + + @staticmethod + def get_overwrites(members: t.List[Member], guild: Guild) -> t.Dict[t.Union[Member, Role], PermissionOverwrite]: + """Get code jam team channels permission overwrites.""" + # First member is always the team leader + team_channel_overwrites = { + members[0]: PermissionOverwrite( + manage_messages=True, + read_messages=True, + manage_webhooks=True, + connect=True + ), + guild.default_role: PermissionOverwrite(read_messages=False, connect=False), + guild.get_role(Roles.verified): PermissionOverwrite( + read_messages=False, + connect=False + ) + } + + # Rest of members should just have read_messages + for member in members[1:]: + team_channel_overwrites[member] = PermissionOverwrite( + read_messages=True, + connect=True + ) + + return team_channel_overwrites + + async def create_channels(self, guild: Guild, team_name: str, members: t.List[Member]) -> str: + """Create team text and voice channels. Return the mention for the text channel.""" + # Get permission overwrites and category + team_channel_overwrites = self.get_overwrites(members, guild) + code_jam_category = await self.get_category(guild) + + # Create a text channel for the team + team_channel = await guild.create_text_channel( + team_name, + overwrites=team_channel_overwrites, + category=code_jam_category + ) + + # Create a voice channel for the team + team_voice_name = " ".join(team_name.split("-")).title() + + await guild.create_voice_channel( + team_voice_name, + overwrites=team_channel_overwrites, + category=code_jam_category + ) + + return team_channel.mention + + @staticmethod + async def add_roles(guild: Guild, members: t.List[Member]) -> None: + """Assign team leader and jammer roles.""" + # Assign team leader role + await members[0].add_roles(guild.get_role(Roles.team_leaders)) + + # Assign rest of roles + jammer_role = guild.get_role(Roles.jammers) + for member in members: + await member.add_roles(jammer_role) + + +def setup(bot: Bot) -> None: + """Load the CodeJams cog.""" + bot.add_cog(CodeJams(bot)) diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py new file mode 100644 index 000000000..670493bcf --- /dev/null +++ b/bot/exts/utils/reminders.py @@ -0,0 +1,427 @@ +import asyncio +import logging +import random +import textwrap +import typing as t +from datetime import datetime, timedelta +from operator import itemgetter + +import discord +from dateutil.parser import isoparse +from dateutil.relativedelta import relativedelta +from discord.ext.commands import Cog, Context, Greedy, group + +from bot.bot import Bot +from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES +from bot.converters import Duration +from bot.pagination import LinePaginator +from bot.utils.checks import without_role_check +from bot.utils.messages import send_denial +from bot.utils.scheduling import Scheduler +from bot.utils.time import humanize_delta + +log = logging.getLogger(__name__) + +WHITELISTED_CHANNELS = Guild.reminder_whitelist +MAXIMUM_REMINDERS = 5 + +Mentionable = t.Union[discord.Member, discord.Role] + + +class Reminders(Cog): + """Provide in-channel reminder functionality.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + self.bot.loop.create_task(self.reschedule_reminders()) + + def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + async def reschedule_reminders(self) -> None: + """Get all current reminders from the API and reschedule them.""" + await self.bot.wait_until_guild_available() + response = await self.bot.api_client.get( + 'bot/reminders', + params={'active': 'true'} + ) + + now = datetime.utcnow() + + for reminder in response: + is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) + if not is_valid: + continue + + remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) + + # If the reminder is already overdue ... + if remind_at < now: + late = relativedelta(now, remind_at) + await self.send_reminder(reminder, late) + else: + self.schedule_reminder(reminder) + + def ensure_valid_reminder( + self, + reminder: dict, + cancel_task: bool = True + ) -> t.Tuple[bool, discord.User, discord.TextChannel]: + """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" + user = self.bot.get_user(reminder['author']) + channel = self.bot.get_channel(reminder['channel_id']) + is_valid = True + if not user or not channel: + is_valid = False + log.info( + f"Reminder {reminder['id']} invalid: " + f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." + ) + asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) + + return is_valid, user, channel + + @staticmethod + async def _send_confirmation( + ctx: Context, + on_success: str, + reminder_id: str, + delivery_dt: t.Optional[datetime], + ) -> None: + """Send an embed confirming the reminder change was made successfully.""" + embed = discord.Embed() + embed.colour = discord.Colour.green() + embed.title = random.choice(POSITIVE_REPLIES) + embed.description = on_success + + footer_str = f"ID: {reminder_id}" + if delivery_dt: + # Reminder deletion will have a `None` `delivery_dt` + footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" + + embed.set_footer(text=footer_str) + + await ctx.send(embed=embed) + + @staticmethod + async def _check_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> t.Tuple[bool, str]: + """ + Returns whether or not the list of mentions is allowed. + + Conditions: + - Role reminders are Mods+ + - Reminders for other users are Helpers+ + + If mentions aren't allowed, also return the type of mention(s) disallowed. + """ + if without_role_check(ctx, *STAFF_ROLES): + return False, "members/roles" + elif without_role_check(ctx, *MODERATION_ROLES): + return all(isinstance(mention, discord.Member) for mention in mentions), "roles" + else: + return True, "" + + @staticmethod + async def validate_mentions(ctx: Context, mentions: t.Iterable[Mentionable]) -> bool: + """ + Filter mentions to see if the user can mention, and sends a denial if not allowed. + + Returns whether or not the validation is successful. + """ + mentions_allowed, disallowed_mentions = await Reminders._check_mentions(ctx, mentions) + + if not mentions or mentions_allowed: + return True + else: + await send_denial(ctx, f"You can't mention other {disallowed_mentions} in your reminder!") + return False + + def get_mentionables(self, mention_ids: t.List[int]) -> t.Iterator[Mentionable]: + """Converts Role and Member ids to their corresponding objects if possible.""" + guild = self.bot.get_guild(Guild.id) + for mention_id in mention_ids: + if (mentionable := (guild.get_member(mention_id) or guild.get_role(mention_id))): + yield mentionable + + def schedule_reminder(self, reminder: dict) -> None: + """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" + reminder_id = reminder["id"] + reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) + + async def _remind() -> None: + await self.send_reminder(reminder) + + log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") + await self._delete_reminder(reminder_id) + + self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) + + async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: + """Delete a reminder from the database, given its ID, and cancel the running task.""" + await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) + + if cancel_task: + # Now we can remove it from the schedule list + self.scheduler.cancel(reminder_id) + + async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: + """ + Edits a reminder in the database given the ID and payload. + + Returns the edited reminder. + """ + # Send the request to update the reminder in the database + reminder = await self.bot.api_client.patch( + 'bot/reminders/' + str(reminder_id), + json=payload + ) + return reminder + + async def _reschedule_reminder(self, reminder: dict) -> None: + """Reschedule a reminder object.""" + log.trace(f"Cancelling old task #{reminder['id']}") + self.scheduler.cancel(reminder["id"]) + + log.trace(f"Scheduling new task #{reminder['id']}") + self.schedule_reminder(reminder) + + async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: + """Send the reminder.""" + is_valid, user, channel = self.ensure_valid_reminder(reminder) + if not is_valid: + return + + embed = discord.Embed() + embed.colour = discord.Colour.blurple() + embed.set_author( + icon_url=Icons.remind_blurple, + name="It has arrived!" + ) + + embed.description = f"Here's your reminder: `{reminder['content']}`." + + if reminder.get("jump_url"): # keep backward compatibility + embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" + + if late: + embed.colour = discord.Colour.red() + embed.set_author( + icon_url=Icons.remind_red, + name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" + ) + + additional_mentions = ' '.join( + mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) + ) + + await channel.send( + content=f"{user.mention} {additional_mentions}", + embed=embed + ) + await self._delete_reminder(reminder["id"]) + + @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) + async def remind_group( + self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + ) -> None: + """Commands for managing your reminders.""" + await ctx.invoke(self.new_reminder, mentions=mentions, expiration=expiration, content=content) + + @remind_group.command(name="new", aliases=("add", "create")) + async def new_reminder( + self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + ) -> None: + """ + Set yourself a simple reminder. + + Expiration is parsed per: http://strftime.org/ + """ + # If the user is not staff, we need to verify whether or not to make a reminder at all. + if without_role_check(ctx, *STAFF_ROLES): + + # If they don't have permission to set a reminder in this channel + if ctx.channel.id not in WHITELISTED_CHANNELS: + await send_denial(ctx, "Sorry, you can't do that here!") + return + + # Get their current active reminders + active_reminders = await self.bot.api_client.get( + 'bot/reminders', + params={ + 'author__id': str(ctx.author.id) + } + ) + + # Let's limit this, so we don't get 10 000 + # reminders from kip or something like that :P + if len(active_reminders) > MAXIMUM_REMINDERS: + await send_denial(ctx, "You have too many active reminders!") + return + + # Remove duplicate mentions + mentions = set(mentions) + mentions.discard(ctx.author) + + # Filter mentions to see if the user can mention members/roles + if not await self.validate_mentions(ctx, mentions): + return + + mention_ids = [mention.id for mention in mentions] + + # Now we can attempt to actually set the reminder. + reminder = await self.bot.api_client.post( + 'bot/reminders', + json={ + 'author': ctx.author.id, + 'channel_id': ctx.message.channel.id, + 'jump_url': ctx.message.jump_url, + 'content': content, + 'expiration': expiration.isoformat(), + 'mentions': mention_ids, + } + ) + + now = datetime.utcnow() - timedelta(seconds=1) + humanized_delta = humanize_delta(relativedelta(expiration, now)) + mention_string = ( + f"Your reminder will arrive in {humanized_delta} " + f"and will mention {len(mentions)} other(s)!" + ) + + # Confirm to the user that it worked. + await self._send_confirmation( + ctx, + on_success=mention_string, + reminder_id=reminder["id"], + delivery_dt=expiration, + ) + + self.schedule_reminder(reminder) + + @remind_group.command(name="list") + async def list_reminders(self, ctx: Context) -> None: + """View a paginated embed of all reminders for your user.""" + # Get all the user's reminders from the database. + data = await self.bot.api_client.get( + 'bot/reminders', + params={'author__id': str(ctx.author.id)} + ) + + now = datetime.utcnow() + + # Make a list of tuples so it can be sorted by time. + reminders = sorted( + ( + (rem['content'], rem['expiration'], rem['id'], rem['mentions']) + for rem in data + ), + key=itemgetter(1) + ) + + lines = [] + + for content, remind_at, id_, mentions in reminders: + # Parse and humanize the time, make it pretty :D + remind_datetime = isoparse(remind_at).replace(tzinfo=None) + time = humanize_delta(relativedelta(remind_datetime, now)) + + mentions = ", ".join( + # Both Role and User objects have the `name` attribute + mention.name for mention in self.get_mentionables(mentions) + ) + mention_string = f"\n**Mentions:** {mentions}" if mentions else "" + + text = textwrap.dedent(f""" + **Reminder #{id_}:** *expires in {time}* (ID: {id_}){mention_string} + {content} + """).strip() + + lines.append(text) + + embed = discord.Embed() + embed.colour = discord.Colour.blurple() + embed.title = f"Reminders for {ctx.author}" + + # Remind the user that they have no reminders :^) + if not lines: + embed.description = "No active reminders could be found." + await ctx.send(embed=embed) + return + + # Construct the embed and paginate it. + embed.colour = discord.Colour.blurple() + + await LinePaginator.paginate( + lines, + ctx, embed, + max_lines=3, + empty=True + ) + + @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) + async def edit_reminder_group(self, ctx: Context) -> None: + """Commands for modifying your current reminders.""" + await ctx.send_help(ctx.command) + + @edit_reminder_group.command(name="duration", aliases=("time",)) + async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: + """ + Edit one of your reminder's expiration. + + Expiration is parsed per: http://strftime.org/ + """ + await self.edit_reminder(ctx, id_, {'expiration': expiration.isoformat()}) + + @edit_reminder_group.command(name="content", aliases=("reason",)) + async def edit_reminder_content(self, ctx: Context, id_: int, *, content: str) -> None: + """Edit one of your reminder's content.""" + await self.edit_reminder(ctx, id_, {"content": content}) + + @edit_reminder_group.command(name="mentions", aliases=("pings",)) + async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: + """Edit one of your reminder's mentions.""" + # Remove duplicate mentions + mentions = set(mentions) + mentions.discard(ctx.author) + + # Filter mentions to see if the user can mention members/roles + if not await self.validate_mentions(ctx, mentions): + return + + mention_ids = [mention.id for mention in mentions] + await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) + + async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: + """Edits a reminder with the given payload, then sends a confirmation message.""" + reminder = await self._edit_reminder(id_, payload) + + # Parse the reminder expiration back into a datetime + expiration = isoparse(reminder["expiration"]).replace(tzinfo=None) + + # Send a confirmation message to the channel + await self._send_confirmation( + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=id_, + delivery_dt=expiration, + ) + await self._reschedule_reminder(reminder) + + @remind_group.command("delete", aliases=("remove", "cancel")) + async def delete_reminder(self, ctx: Context, id_: int) -> None: + """Delete one of your active reminders.""" + await self._delete_reminder(id_) + await self._send_confirmation( + ctx, + on_success="That reminder has been deleted successfully!", + reminder_id=id_, + delivery_dt=None, + ) + + +def setup(bot: Bot) -> None: + """Load the Reminders cog.""" + bot.add_cog(Reminders(bot)) diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py new file mode 100644 index 000000000..52c8b6f88 --- /dev/null +++ b/bot/exts/utils/snekbox.py @@ -0,0 +1,349 @@ +import asyncio +import contextlib +import datetime +import logging +import re +import textwrap +from functools import partial +from signal import Signals +from typing import Optional, Tuple + +from discord import HTTPException, Message, NotFound, Reaction, User +from discord.ext.commands import Cog, Context, command, guild_only + +from bot.bot import Bot +from bot.constants import Categories, Channels, Roles, URLs +from bot.decorators import in_whitelist +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") +FORMATTED_CODE_REGEX = re.compile( + r"^\s*" # any leading whitespace from the beginning of the string + r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block + r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) + r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all code inside the markup + r"\s*" # any more whitespace before the end of the code markup + r"(?P=delim)" # match the exact same delimiter from the start again + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive +) +RAW_CODE_REGEX = re.compile( + r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all the rest as code + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL # "." also matches newlines +) + +MAX_PASTE_LEN = 1000 + +# `!eval` command whitelists +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) +EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) +EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) + +SIGKILL = 9 + +REEVAL_EMOJI = '\U0001f501' # :repeat: +REEVAL_TIMEOUT = 30 + + +class Snekbox(Cog): + """Safe evaluation of Python code using Snekbox.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.jobs = {} + + async def post_eval(self, code: str) -> dict: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + url = URLs.snekbox_eval_api + data = {"input": code} + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return await resp.json() + + async def upload_output(self, output: str) -> Optional[str]: + """Upload the eval output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") + + if len(output) > MAX_PASTE_LEN: + log.info("Full output is too long to upload") + return "too long to upload" + + url = URLs.paste_service.format(key="documents") + try: + async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: + data = await resp.json() + + if "key" in data: + return URLs.paste_service.format(key=data["key"]) + except Exception: + # 400 (Bad Request) means there are too many characters + log.exception("Failed to upload full output to paste service!") + + @staticmethod + def prepare_input(code: str) -> str: + """Extract code from the Markdown, format it, and insert it into the code template.""" + match = FORMATTED_CODE_REGEX.fullmatch(code) + if match: + code, block, lang, delim = match.group("code", "block", "lang", "delim") + code = textwrap.dedent(code) + if block: + info = (f"'{lang}' highlighted" if lang else "plain") + " code block" + else: + info = f"{delim}-enclosed inline code" + log.trace(f"Extracted {info} for evaluation:\n{code}") + else: + code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) + log.trace( + f"Eval message contains unformatted or badly formatted code, " + f"stripping whitespace only:\n{code}" + ) + + return code + + @staticmethod + def get_results_message(results: dict) -> Tuple[str, str]: + """Return a user-friendly message and error corresponding to the process's return code.""" + stdout, returncode = results["stdout"], results["returncode"] + msg = f"Your eval job has completed with return code {returncode}" + error = "" + + if returncode is None: + msg = "Your eval job has failed" + error = stdout.strip() + elif returncode == 128 + SIGKILL: + msg = "Your eval job timed out or ran out of memory" + elif returncode == 255: + msg = "Your eval job has failed" + error = "A fatal NsJail error occurred" + else: + # Try to append signal's name if one exists + try: + name = Signals(returncode - 128).name + msg = f"{msg} ({name})" + except ValueError: + pass + + return msg, error + + @staticmethod + def get_status_emoji(results: dict) -> str: + """Return an emoji corresponding to the status code or lack of output in result.""" + if not results["stdout"].strip(): # No output + return ":warning:" + elif results["returncode"] == 0: # No error + return ":white_check_mark:" + else: # Exception + return ":x:" + + async def format_output(self, output: str) -> Tuple[str, Optional[str]]: + """ + Format the output and return a tuple of the formatted output and a URL to the full output. + + Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters + and upload the full output to a paste service. + """ + log.trace("Formatting output...") + + output = output.rstrip("\n") + original_output = output # To be uploaded to a pasting service if needed + paste_link = None + + if "<@" in output: + output = output.replace("<@", "<@\u200B") # Zero-width space + + if " 0: + output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] + output = output[:11] # Limiting to only 11 lines + output = "\n".join(output) + + if lines > 10: + truncated = True + if len(output) >= 1000: + output = f"{output[:1000]}\n... (truncated - too long, too many lines)" + else: + output = f"{output}\n... (truncated - too many lines)" + elif len(output) >= 1000: + truncated = True + output = f"{output[:1000]}\n... (truncated - too long)" + + if truncated: + paste_link = await self.upload_output(original_output) + + output = output or "[No output]" + + return output, paste_link + + async def send_eval(self, ctx: Context, code: str) -> Message: + """ + Evaluate code, format it, and send the output to the corresponding channel. + + Return the bot response. + """ + async with ctx.typing(): + results = await self.post_eval(code) + msg, error = self.get_results_message(results) + + if error: + output, paste_link = error, None + else: + output, paste_link = await self.format_output(results["stdout"]) + + icon = self.get_status_emoji(results) + msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" + if paste_link: + msg = f"{msg}\nFull output: {paste_link}" + + # Collect stats of eval fails + successes + if icon == ":x:": + self.bot.stats.incr("snekbox.python.fail") + else: + self.bot.stats.incr("snekbox.python.success") + + filter_cog = self.bot.get_cog("Filtering") + filter_triggered = False + if filter_cog: + filter_triggered = await filter_cog.filter_eval(msg, ctx.message) + if filter_triggered: + response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + else: + response = await ctx.send(msg) + self.bot.loop.create_task( + wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) + ) + + log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") + return response + + async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + """ + Check if the eval session should continue. + + Return the new code to evaluate or None if the eval session should be terminated. + """ + _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + + with contextlib.suppress(NotFound): + try: + _, new_message = await self.bot.wait_for( + 'message_edit', + check=_predicate_eval_message_edit, + timeout=REEVAL_TIMEOUT + ) + await ctx.message.add_reaction(REEVAL_EMOJI) + await self.bot.wait_for( + 'reaction_add', + check=_predicate_emoji_reaction, + timeout=10 + ) + + code = await self.get_code(new_message) + await ctx.message.clear_reactions() + with contextlib.suppress(HTTPException): + await response.delete() + + except asyncio.TimeoutError: + await ctx.message.clear_reactions() + return None + + return code + + async def get_code(self, message: Message) -> Optional[str]: + """ + Return the code from `message` to be evaluated. + + If the message is an invocation of the eval command, return the first argument or None if it + doesn't exist. Otherwise, return the full content of the message. + """ + log.trace(f"Getting context for message {message.id}.") + new_ctx = await self.bot.get_context(message) + + if new_ctx.command is self.eval_command: + log.trace(f"Message {message.id} invokes eval command.") + split = message.content.split(maxsplit=1) + code = split[1] if len(split) > 1 else None + else: + log.trace(f"Message {message.id} does not invoke eval command.") + code = message.content + + return code + + @command(name="eval", aliases=("e",)) + @guild_only() + @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) + async def eval_command(self, ctx: Context, *, code: str = None) -> None: + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + if ctx.author.id in self.jobs: + await ctx.send( + f"{ctx.author.mention} You've already got a job running - " + "please wait for it to finish!" + ) + return + + if not code: # None or empty string + await ctx.send_help(ctx.command) + return + + if Roles.helpers in (role.id for role in ctx.author.roles): + self.bot.stats.incr("snekbox_usages.roles.helpers") + else: + self.bot.stats.incr("snekbox_usages.roles.developers") + + if ctx.channel.category_id == Categories.help_in_use: + self.bot.stats.incr("snekbox_usages.channels.help") + elif ctx.channel.id == Channels.bot_commands: + self.bot.stats.incr("snekbox_usages.channels.bot_commands") + else: + self.bot.stats.incr("snekbox_usages.channels.topical") + + log.info(f"Received code from {ctx.author} for evaluation:\n{code}") + + while True: + self.jobs[ctx.author.id] = datetime.datetime.now() + code = self.prepare_input(code) + try: + response = await self.send_eval(ctx, code) + finally: + del self.jobs[ctx.author.id] + + code = await self.continue_eval(ctx, response) + if not code: + break + log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") + + +def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: + """Return True if the edited message is the context message and the content was indeed modified.""" + return new_msg.id == ctx.message.id and old_msg.content != new_msg.content + + +def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI + + +def setup(bot: Bot) -> None: + """Load the Snekbox cog.""" + bot.add_cog(Snekbox(bot)) diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py new file mode 100644 index 000000000..d96abbd5a --- /dev/null +++ b/bot/exts/utils/utils.py @@ -0,0 +1,265 @@ +import difflib +import logging +import re +import unicodedata +from email.parser import HeaderParser +from io import StringIO +from typing import Tuple, Union + +from discord import Colour, Embed, utils +from discord.ext.commands import BadArgument, Cog, Context, clean_content, command + +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES +from bot.decorators import in_whitelist, with_role +from bot.pagination import LinePaginator +from bot.utils import messages + +log = logging.getLogger(__name__) + +ZEN_OF_PYTHON = """\ +Beautiful is better than ugly. +Explicit is better than implicit. +Simple is better than complex. +Complex is better than complicated. +Flat is better than nested. +Sparse is better than dense. +Readability counts. +Special cases aren't special enough to break the rules. +Although practicality beats purity. +Errors should never pass silently. +Unless explicitly silenced. +In the face of ambiguity, refuse the temptation to guess. +There should be one-- and preferably only one --obvious way to do it. +Although that way may not be obvious at first unless you're Dutch. +Now is better than never. +Although never is often better than *right* now. +If the implementation is hard to explain, it's a bad idea. +If the implementation is easy to explain, it may be a good idea. +Namespaces are one honking great idea -- let's do more of those! +""" + +ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" + + +class Utils(Cog): + """A selection of utilities which don't have a clear category.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.base_pep_url = "http://www.python.org/dev/peps/pep-" + self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" + + @command(name='pep', aliases=('get_pep', 'p')) + async def pep_command(self, ctx: Context, pep_number: str) -> None: + """Fetches information about a PEP and sends it to the channel.""" + if pep_number.isdigit(): + pep_number = int(pep_number) + else: + await ctx.send_help(ctx.command) + return + + # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. + if pep_number == 0: + return await self.send_pep_zero(ctx) + + possible_extensions = ['.txt', '.rst'] + found_pep = False + for extension in possible_extensions: + # Attempt to fetch the PEP + pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" + log.trace(f"Requesting PEP {pep_number} with {pep_url}") + response = await self.bot.http_session.get(pep_url) + + if response.status == 200: + log.trace("PEP found") + found_pep = True + + pep_content = await response.text() + + # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 + pep_header = HeaderParser().parse(StringIO(pep_content)) + + # Assemble the embed + pep_embed = Embed( + title=f"**PEP {pep_number} - {pep_header['Title']}**", + description=f"[Link]({self.base_pep_url}{pep_number:04})", + ) + + pep_embed.set_thumbnail(url=ICON_URL) + + # Add the interesting information + fields_to_check = ("Status", "Python-Version", "Created", "Type") + for field in fields_to_check: + # Check for a PEP metadata field that is present but has an empty value + # embed field values can't contain an empty string + if pep_header.get(field, ""): + pep_embed.add_field(name=field, value=pep_header[field]) + + elif response.status != 404: + # any response except 200 and 404 is expected + found_pep = True # actually not, but it's easier to display this way + log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " + f"{response.status}.\n{response.text}") + + error_message = "Unexpected HTTP error during PEP search. Please let us know." + pep_embed = Embed(title="Unexpected error", description=error_message) + pep_embed.colour = Colour.red() + break + + if not found_pep: + log.trace("PEP was not found") + not_found = f"PEP {pep_number} does not exist." + pep_embed = Embed(title="PEP not found", description=not_found) + pep_embed.colour = Colour.red() + + await ctx.message.channel.send(embed=pep_embed) + + @command() + @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) + async def charinfo(self, ctx: Context, *, characters: str) -> None: + """Shows you information on up to 50 unicode characters.""" + match = re.match(r"<(a?):(\w+):(\d+)>", characters) + if match: + return await messages.send_denial( + ctx, + "**Non-Character Detected**\n" + "Only unicode characters can be processed, but a custom Discord emoji " + "was found. Please remove it and try again." + ) + + if len(characters) > 50: + return await messages.send_denial(ctx, f"Too many characters ({len(characters)}/50)") + + def get_info(char: str) -> Tuple[str, str]: + digit = f"{ord(char):x}" + if len(digit) <= 4: + u_code = f"\\u{digit:>04}" + else: + u_code = f"\\U{digit:>08}" + url = f"https://www.compart.com/en/unicode/U+{digit:>04}" + name = f"[{unicodedata.name(char, '')}]({url})" + info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}" + return info, u_code + + char_list, raw_list = zip(*(get_info(c) for c in characters)) + embed = Embed().set_author(name="Character Info") + + if len(characters) > 1: + # Maximum length possible is 502 out of 1024, so there's no need to truncate. + embed.add_field(name='Full Raw Text', value=f"`{''.join(raw_list)}`", inline=False) + + await LinePaginator.paginate(char_list, ctx, embed, max_lines=10, max_size=2000, empty=False) + + @command() + async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: + """ + Show the Zen of Python. + + Without any arguments, the full Zen will be produced. + If an integer is provided, the line with that index will be produced. + If a string is provided, the line which matches best will be produced. + """ + embed = Embed( + colour=Colour.blurple(), + title="The Zen of Python", + description=ZEN_OF_PYTHON + ) + + if search_value is None: + embed.title += ", by Tim Peters" + await ctx.send(embed=embed) + return + + zen_lines = ZEN_OF_PYTHON.splitlines() + + # handle if it's an index int + if isinstance(search_value, int): + upper_bound = len(zen_lines) - 1 + lower_bound = -1 * upper_bound + if not (lower_bound <= search_value <= upper_bound): + raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") + + embed.title += f" (line {search_value % len(zen_lines)}):" + embed.description = zen_lines[search_value] + await ctx.send(embed=embed) + return + + # Try to handle first exact word due difflib.SequenceMatched may use some other similar word instead + # exact word. + for i, line in enumerate(zen_lines): + for word in line.split(): + if word.lower() == search_value.lower(): + embed.title += f" (line {i}):" + embed.description = line + await ctx.send(embed=embed) + return + + # handle if it's a search string and not exact word + matcher = difflib.SequenceMatcher(None, search_value.lower()) + + best_match = "" + match_index = 0 + best_ratio = 0 + + for index, line in enumerate(zen_lines): + matcher.set_seq2(line.lower()) + + # the match ratio needs to be adjusted because, naturally, + # longer lines will have worse ratios than shorter lines when + # fuzzy searching for keywords. this seems to work okay. + adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() + + if adjusted_ratio > best_ratio: + best_ratio = adjusted_ratio + best_match = line + match_index = index + + if not best_match: + raise BadArgument("I didn't get a match! Please try again with a different search term.") + + embed.title += f" (line {match_index}):" + embed.description = best_match + await ctx.send(embed=embed) + + @command(aliases=("poll",)) + @with_role(*MODERATION_ROLES) + async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None: + """ + Build a quick voting poll with matching reactions with the provided options. + + A maximum of 20 options can be provided, as Discord supports a max of 20 + reactions on a single message. + """ + if len(title) > 256: + raise BadArgument("The title cannot be longer than 256 characters.") + if len(options) < 2: + raise BadArgument("Please provide at least 2 options.") + if len(options) > 20: + raise BadArgument("I can only handle 20 options!") + + codepoint_start = 127462 # represents "regional_indicator_a" unicode value + options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} + embed = Embed(title=title, description="\n".join(options.values())) + message = await ctx.send(embed=embed) + for reaction in options: + await message.add_reaction(reaction) + + async def send_pep_zero(self, ctx: Context) -> None: + """Send information about PEP 0.""" + pep_embed = Embed( + title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**", + description="[Link](https://www.python.org/dev/peps/)" + ) + pep_embed.set_thumbnail(url=ICON_URL) + pep_embed.add_field(name="Status", value="Active") + pep_embed.add_field(name="Created", value="13-Jul-2000") + pep_embed.add_field(name="Type", value="Informational") + + await ctx.send(embed=pep_embed) + + +def setup(bot: Bot) -> None: + """Load the Utils cog.""" + bot.add_cog(Utils(bot)) diff --git a/tests/bot/cogs/__init__.py b/tests/bot/cogs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/backend/__init__.py b/tests/bot/cogs/backend/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/backend/sync/__init__.py b/tests/bot/cogs/backend/sync/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/backend/sync/test_base.py b/tests/bot/cogs/backend/sync/test_base.py deleted file mode 100644 index 3009aacb6..000000000 --- a/tests/bot/cogs/backend/sync/test_base.py +++ /dev/null @@ -1,404 +0,0 @@ -import asyncio -import unittest -from unittest import mock - -import discord - -from bot import constants -from bot.api import ResponseCodeError -from bot.cogs.backend.sync._syncers import Syncer, _Diff -from tests import helpers - - -class TestSyncer(Syncer): - """Syncer subclass with mocks for abstract methods for testing purposes.""" - - name = "test" - _get_diff = mock.AsyncMock() - _sync = mock.AsyncMock() - - -class SyncerBaseTests(unittest.TestCase): - """Tests for the syncer base class.""" - - def setUp(self): - self.bot = helpers.MockBot() - - def test_instantiation_fails_without_abstract_methods(self): - """The class must have abstract methods implemented.""" - with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): - Syncer(self.bot) - - -class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): - """Tests for sending the sync confirmation prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - - def mock_get_channel(self): - """Fixture to return a mock channel and message for when `get_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - mock_channel.send.return_value = mock_message - self.bot.get_channel.return_value = mock_channel - - return mock_channel, mock_message - - def mock_fetch_channel(self): - """Fixture to return a mock channel and message for when `fetch_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - self.bot.get_channel.return_value = None - mock_channel.send.return_value = mock_message - self.bot.fetch_channel.return_value = mock_channel - - return mock_channel, mock_message - - async def test_send_prompt_edits_and_returns_message(self): - """The given message should be edited to display the prompt and then should be returned.""" - msg = helpers.MockMessage() - ret_val = await self.syncer._send_prompt(msg) - - msg.edit.assert_called_once() - self.assertIn("content", msg.edit.call_args[1]) - self.assertEqual(ret_val, msg) - - async def test_send_prompt_gets_dev_core_channel(self): - """The dev-core channel should be retrieved if an extant message isn't given.""" - subtests = ( - (self.bot.get_channel, self.mock_get_channel), - (self.bot.fetch_channel, self.mock_fetch_channel), - ) - - for method, mock_ in subtests: - with self.subTest(method=method, msg=mock_.__name__): - mock_() - await self.syncer._send_prompt() - - method.assert_called_once_with(constants.Channels.dev_core) - - async def test_send_prompt_returns_none_if_channel_fetch_fails(self): - """None should be returned if there's an HTTPException when fetching the channel.""" - self.bot.get_channel.return_value = None - self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") - - ret_val = await self.syncer._send_prompt() - - self.assertIsNone(ret_val) - - async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): - """A new message mentioning core devs should be sent and returned if message isn't given.""" - for mock_ in (self.mock_get_channel, self.mock_fetch_channel): - with self.subTest(msg=mock_.__name__): - mock_channel, mock_message = mock_() - ret_val = await self.syncer._send_prompt() - - mock_channel.send.assert_called_once() - self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) - self.assertEqual(ret_val, mock_message) - - async def test_send_prompt_adds_reactions(self): - """The message should have reactions for confirmation added.""" - extant_message = helpers.MockMessage() - subtests = ( - (extant_message, lambda: (None, extant_message)), - (None, self.mock_get_channel), - (None, self.mock_fetch_channel), - ) - - for message_arg, mock_ in subtests: - subtest_msg = "Extant message" if mock_.__name__ == "" else mock_.__name__ - - with self.subTest(msg=subtest_msg): - _, mock_message = mock_() - await self.syncer._send_prompt(message_arg) - - calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] - mock_message.add_reaction.assert_has_calls(calls) - - -class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): - """Tests for waiting for a sync confirmation reaction on the prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) - - @staticmethod - def get_message_reaction(emoji): - """Fixture to return a mock message an reaction from the given `emoji`.""" - message = helpers.MockMessage() - reaction = helpers.MockReaction(emoji=emoji, message=message) - - return message, reaction - - def test_reaction_check_for_valid_emoji_and_authors(self): - """Should return True if authors are identical or are a bot and a core dev, respectively.""" - user_subtests = ( - ( - helpers.MockMember(id=77), - helpers.MockMember(id=77), - "identical users", - ), - ( - helpers.MockMember(id=77, bot=True), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "bot author and core-dev reactor", - ), - ) - - for emoji in self.syncer._REACTION_EMOJIS: - for author, user, msg in user_subtests: - with self.subTest(author=author, user=user, emoji=emoji, msg=msg): - message, reaction = self.get_message_reaction(emoji) - ret_val = self.syncer._reaction_check(author, message, reaction, user) - - self.assertTrue(ret_val) - - def test_reaction_check_for_invalid_reactions(self): - """Should return False for invalid reaction events.""" - valid_emoji = self.syncer._REACTION_EMOJIS[0] - subtests = ( - ( - helpers.MockMember(id=77), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "users are not identical", - ), - ( - helpers.MockMember(id=77, bot=True), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43), - "reactor lacks the core-dev role", - ), - ( - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - "reactor is a bot", - ), - ( - helpers.MockMember(id=77), - helpers.MockMessage(id=95), - helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), - helpers.MockMember(id=77), - "messages are not identical", - ), - ( - helpers.MockMember(id=77), - *self.get_message_reaction("InVaLiD"), - helpers.MockMember(id=77), - "emoji is invalid", - ), - ) - - for *args, msg in subtests: - kwargs = dict(zip(("author", "message", "reaction", "user"), args)) - with self.subTest(**kwargs, msg=msg): - ret_val = self.syncer._reaction_check(*args) - self.assertFalse(ret_val) - - async def test_wait_for_confirmation(self): - """The message should always be edited and only return True if the emoji is a check mark.""" - subtests = ( - (constants.Emojis.check_mark, True, None), - ("InVaLiD", False, None), - (None, False, asyncio.TimeoutError), - ) - - for emoji, ret_val, side_effect in subtests: - for bot in (True, False): - with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): - # Set up mocks - message = helpers.MockMessage() - member = helpers.MockMember(bot=bot) - - self.bot.wait_for.reset_mock() - self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) - self.bot.wait_for.side_effect = side_effect - - # Call the function - actual_return = await self.syncer._wait_for_confirmation(member, message) - - # Perform assertions - self.bot.wait_for.assert_called_once() - self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) - - message.edit.assert_called_once() - kwargs = message.edit.call_args[1] - self.assertIn("content", kwargs) - - # Core devs should only be mentioned if the author is a bot. - if bot: - self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - else: - self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - - self.assertIs(actual_return, ret_val) - - -class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for main function orchestrating the sync.""" - - def setUp(self): - self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) - self.syncer = TestSyncer(self.bot) - - async def test_sync_respects_confirmation_result(self): - """The sync should abort if confirmation fails and continue if confirmed.""" - mock_message = helpers.MockMessage() - subtests = ( - (True, mock_message), - (False, None), - ) - - for confirmed, message in subtests: - with self.subTest(confirmed=confirmed): - self.syncer._sync.reset_mock() - self.syncer._get_diff.reset_mock() - - diff = _Diff({1, 2, 3}, {4, 5}, None) - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(confirmed, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - - if confirmed: - self.syncer._sync.assert_called_once_with(diff) - else: - self.syncer._sync.assert_not_called() - - async def test_sync_diff_size(self): - """The diff size should be correctly calculated.""" - subtests = ( - (6, _Diff({1, 2}, {3, 4}, {5, 6})), - (5, _Diff({1, 2, 3}, None, {4, 5})), - (0, _Diff(None, None, None)), - (0, _Diff(set(), set(), set())), - ) - - for size, diff in subtests: - with self.subTest(size=size, diff=diff): - self.syncer._get_diff.reset_mock() - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) - - async def test_sync_message_edited(self): - """The message should be edited if one was sent, even if the sync has an API error.""" - subtests = ( - (None, None, False), - (helpers.MockMessage(), None, True), - (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), - ) - - for message, side_effect, should_edit in subtests: - with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): - self.syncer._sync.side_effect = side_effect - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(True, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - if should_edit: - message.edit.assert_called_once() - self.assertIn("content", message.edit.call_args[1]) - - async def test_sync_confirmation_context_redirect(self): - """If ctx is given, a new message should be sent and author should be ctx's author.""" - mock_member = helpers.MockMember() - subtests = ( - (None, self.bot.user, None), - (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), - ) - - for ctx, author, message in subtests: - with self.subTest(ctx=ctx, author=author, message=message): - if ctx is not None: - ctx.send.return_value = message - - # Make sure `_get_diff` returns a MagicMock, not an AsyncMock - self.syncer._get_diff.return_value = mock.MagicMock() - - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild, ctx) - - if ctx is not None: - ctx.send.assert_called_once() - - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_small_diff(self): - """Should always return True and the given message if the diff size is too small.""" - author = helpers.MockMember() - expected_message = helpers.MockMessage() - - for size in (3, 2): # pragma: no cover - with self.subTest(size=size): - self.syncer._send_prompt = mock.AsyncMock() - self.syncer._wait_for_confirmation = mock.AsyncMock() - - coro = self.syncer._get_confirmation_result(size, author, expected_message) - result, actual_message = await coro - - self.assertTrue(result) - self.assertEqual(actual_message, expected_message) - self.syncer._send_prompt.assert_not_called() - self.syncer._wait_for_confirmation.assert_not_called() - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_large_diff(self): - """Should return True if confirmed and False if _send_prompt fails or aborted.""" - author = helpers.MockMember() - mock_message = helpers.MockMessage() - - subtests = ( - (True, mock_message, True, "confirmed"), - (False, None, False, "_send_prompt failed"), - (False, mock_message, False, "aborted"), - ) - - for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover - with self.subTest(msg=msg): - self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) - self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) - - coro = self.syncer._get_confirmation_result(4, author) - actual_result, actual_message = await coro - - self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None - self.assertIs(actual_result, expected_result) - self.assertEqual(actual_message, expected_message) - - if expected_message: - self.syncer._wait_for_confirmation.assert_called_once_with( - author, expected_message - ) diff --git a/tests/bot/cogs/backend/sync/test_cog.py b/tests/bot/cogs/backend/sync/test_cog.py deleted file mode 100644 index e40552817..000000000 --- a/tests/bot/cogs/backend/sync/test_cog.py +++ /dev/null @@ -1,416 +0,0 @@ -import unittest -from unittest import mock - -import discord - -from bot import constants -from bot.api import ResponseCodeError -from bot.cogs.backend import sync -from bot.cogs.backend.sync._cog import Sync -from bot.cogs.backend.sync._syncers import Syncer -from tests import helpers -from tests.base import CommandTestCase - - -class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): - """Tests for the sync extension.""" - - @staticmethod - def test_extension_setup(): - """The Sync cog should be added.""" - bot = helpers.MockBot() - sync.setup(bot) - bot.add_cog.assert_called_once() - - -class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): - """Base class for Sync cog tests. Sets up patches for syncers.""" - - def setUp(self): - self.bot = helpers.MockBot() - - self.role_syncer_patcher = mock.patch( - "bot.cogs.backend.sync._syncers.RoleSyncer", - autospec=Syncer, - spec_set=True - ) - self.user_syncer_patcher = mock.patch( - "bot.cogs.backend.sync._syncers.UserSyncer", - autospec=Syncer, - spec_set=True - ) - self.RoleSyncer = self.role_syncer_patcher.start() - self.UserSyncer = self.user_syncer_patcher.start() - - self.cog = Sync(self.bot) - - def tearDown(self): - self.role_syncer_patcher.stop() - self.user_syncer_patcher.stop() - - @staticmethod - def response_error(status: int) -> ResponseCodeError: - """Fixture to return a ResponseCodeError with the given status code.""" - response = mock.MagicMock() - response.status = status - - return ResponseCodeError(response) - - -class SyncCogTests(SyncCogTestCase): - """Tests for the Sync cog.""" - - @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) - def test_sync_cog_init(self, sync_guild): - """Should instantiate syncers and run a sync for the guild.""" - # Reset because a Sync cog was already instantiated in setUp. - self.RoleSyncer.reset_mock() - self.UserSyncer.reset_mock() - self.bot.loop.create_task = mock.MagicMock() - - mock_sync_guild_coro = mock.MagicMock() - sync_guild.return_value = mock_sync_guild_coro - - Sync(self.bot) - - self.RoleSyncer.assert_called_once_with(self.bot) - self.UserSyncer.assert_called_once_with(self.bot) - sync_guild.assert_called_once_with() - self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) - - async def test_sync_cog_sync_guild(self): - """Roles and users should be synced only if a guild is successfully retrieved.""" - for guild in (helpers.MockGuild(), None): - with self.subTest(guild=guild): - self.bot.reset_mock() - self.cog.role_syncer.reset_mock() - self.cog.user_syncer.reset_mock() - - self.bot.get_guild = mock.MagicMock(return_value=guild) - - await self.cog.sync_guild() - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.get_guild.assert_called_once_with(constants.Guild.id) - - if guild is None: - self.cog.role_syncer.sync.assert_not_called() - self.cog.user_syncer.sync.assert_not_called() - else: - self.cog.role_syncer.sync.assert_called_once_with(guild) - self.cog.user_syncer.sync.assert_called_once_with(guild) - - async def patch_user_helper(self, side_effect: BaseException) -> None: - """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" - self.bot.api_client.patch.reset_mock(side_effect=True) - self.bot.api_client.patch.side_effect = side_effect - - user_id, updated_information = 5, {"key": 123} - await self.cog.patch_user(user_id, updated_information) - - self.bot.api_client.patch.assert_called_once_with( - f"bot/users/{user_id}", - json=updated_information, - ) - - async def test_sync_cog_patch_user(self): - """A PATCH request should be sent and 404 errors ignored.""" - for side_effect in (None, self.response_error(404)): - with self.subTest(side_effect=side_effect): - await self.patch_user_helper(side_effect) - - async def test_sync_cog_patch_user_non_404(self): - """A PATCH request should be sent and the error raised if it's not a 404.""" - with self.assertRaises(ResponseCodeError): - await self.patch_user_helper(self.response_error(500)) - - -class SyncCogListenerTests(SyncCogTestCase): - """Tests for the listeners of the Sync cog.""" - - def setUp(self): - super().setUp() - self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - - self.guild_id_patcher = mock.patch("bot.cogs.backend.sync._cog.constants.Guild.id", 5) - self.guild_id = self.guild_id_patcher.start() - - self.guild = helpers.MockGuild(id=self.guild_id) - self.other_guild = helpers.MockGuild(id=0) - - def tearDown(self): - self.guild_id_patcher.stop() - - async def test_sync_cog_on_guild_role_create(self): - """A POST request should be sent with the new role's data.""" - self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) - - role_data = { - "colour": 49, - "id": 777, - "name": "rolename", - "permissions": 8, - "position": 23, - } - role = helpers.MockRole(**role_data, guild=self.guild) - await self.cog.on_guild_role_create(role) - - self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) - - async def test_sync_cog_on_guild_role_create_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_create(role) - self.bot.api_client.post.assert_not_awaited() - - async def test_sync_cog_on_guild_role_delete(self): - """A DELETE request should be sent.""" - self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) - - role = helpers.MockRole(id=99, guild=self.guild) - await self.cog.on_guild_role_delete(role) - - self.bot.api_client.delete.assert_called_once_with("bot/roles/99") - - async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_delete(role) - self.bot.api_client.delete.assert_not_awaited() - - async def test_sync_cog_on_guild_role_update(self): - """A PUT request should be sent if the colour, name, permissions, or position changes.""" - self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) - - role_data = { - "colour": 49, - "id": 777, - "name": "rolename", - "permissions": 8, - "position": 23, - } - subtests = ( - (True, ("colour", "name", "permissions", "position")), - (False, ("hoist", "mentionable")), - ) - - for should_put, attributes in subtests: - for attribute in attributes: - with self.subTest(should_put=should_put, changed_attribute=attribute): - self.bot.api_client.put.reset_mock() - - after_role_data = role_data.copy() - after_role_data[attribute] = 876 - - before_role = helpers.MockRole(**role_data, guild=self.guild) - after_role = helpers.MockRole(**after_role_data, guild=self.guild) - - await self.cog.on_guild_role_update(before_role, after_role) - - if should_put: - self.bot.api_client.put.assert_called_once_with( - f"bot/roles/{after_role.id}", - json=after_role_data - ) - else: - self.bot.api_client.put.assert_not_called() - - async def test_sync_cog_on_guild_role_update_ignores_guilds(self): - """Events from other guilds should be ignored.""" - role = helpers.MockRole(guild=self.other_guild) - await self.cog.on_guild_role_update(role, role) - self.bot.api_client.put.assert_not_awaited() - - async def test_sync_cog_on_member_remove(self): - """Member should be patched to set in_guild as False.""" - self.assertTrue(self.cog.on_member_remove.__cog_listener__) - - member = helpers.MockMember(guild=self.guild) - await self.cog.on_member_remove(member) - - self.cog.patch_user.assert_called_once_with( - member.id, - json={"in_guild": False} - ) - - async def test_sync_cog_on_member_remove_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_remove(member) - self.cog.patch_user.assert_not_awaited() - - async def test_sync_cog_on_member_update_roles(self): - """Members should be patched if their roles have changed.""" - self.assertTrue(self.cog.on_member_update.__cog_listener__) - - # Roles are intentionally unsorted. - before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] - before_member = helpers.MockMember(roles=before_roles, guild=self.guild) - after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) - - await self.cog.on_member_update(before_member, after_member) - - data = {"roles": sorted(role.id for role in after_member.roles)} - self.cog.patch_user.assert_called_once_with(after_member.id, json=data) - - async def test_sync_cog_on_member_update_other(self): - """Members should not be patched if other attributes have changed.""" - self.assertTrue(self.cog.on_member_update.__cog_listener__) - - subtests = ( - ("activities", discord.Game("Pong"), discord.Game("Frogger")), - ("nick", "old nick", "new nick"), - ("status", discord.Status.online, discord.Status.offline), - ) - - for attribute, old_value, new_value in subtests: - with self.subTest(attribute=attribute): - self.cog.patch_user.reset_mock() - - before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) - after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) - - await self.cog.on_member_update(before_member, after_member) - - self.cog.patch_user.assert_not_called() - - async def test_sync_cog_on_member_update_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_update(member, member) - self.cog.patch_user.assert_not_awaited() - - async def test_sync_cog_on_user_update(self): - """A user should be patched only if the name, discriminator, or avatar changes.""" - self.assertTrue(self.cog.on_user_update.__cog_listener__) - - before_data = { - "name": "old name", - "discriminator": "1234", - "bot": False, - } - - subtests = ( - (True, "name", "name", "new name", "new name"), - (True, "discriminator", "discriminator", "8765", 8765), - (False, "bot", "bot", True, True), - ) - - for should_patch, attribute, api_field, value, api_value in subtests: - with self.subTest(attribute=attribute): - self.cog.patch_user.reset_mock() - - after_data = before_data.copy() - after_data[attribute] = value - before_user = helpers.MockUser(**before_data) - after_user = helpers.MockUser(**after_data) - - await self.cog.on_user_update(before_user, after_user) - - if should_patch: - self.cog.patch_user.assert_called_once() - - # Don't care if *all* keys are present; only the changed one is required - call_args = self.cog.patch_user.call_args - self.assertEqual(call_args.args[0], after_user.id) - self.assertIn("json", call_args.kwargs) - - self.assertIn("ignore_404", call_args.kwargs) - self.assertTrue(call_args.kwargs["ignore_404"]) - - json = call_args.kwargs["json"] - self.assertIn(api_field, json) - self.assertEqual(json[api_field], api_value) - else: - self.cog.patch_user.assert_not_called() - - async def on_member_join_helper(self, side_effect: Exception) -> dict: - """ - Helper to set `side_effect` for on_member_join and assert a PUT request was sent. - - The request data for the mock member is returned. All exceptions will be re-raised. - """ - member = helpers.MockMember( - discriminator="1234", - roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], - guild=self.guild, - ) - - data = { - "discriminator": int(member.discriminator), - "id": member.id, - "in_guild": True, - "name": member.name, - "roles": sorted(role.id for role in member.roles) - } - - self.bot.api_client.put.reset_mock(side_effect=True) - self.bot.api_client.put.side_effect = side_effect - - try: - await self.cog.on_member_join(member) - except Exception: - raise - finally: - self.bot.api_client.put.assert_called_once_with( - f"bot/users/{member.id}", - json=data - ) - - return data - - async def test_sync_cog_on_member_join(self): - """Should PUT user's data or POST it if the user doesn't exist.""" - for side_effect in (None, self.response_error(404)): - with self.subTest(side_effect=side_effect): - self.bot.api_client.post.reset_mock() - data = await self.on_member_join_helper(side_effect) - - if side_effect: - self.bot.api_client.post.assert_called_once_with("bot/users", json=data) - else: - self.bot.api_client.post.assert_not_called() - - async def test_sync_cog_on_member_join_non_404(self): - """ResponseCodeError should be re-raised if status code isn't a 404.""" - with self.assertRaises(ResponseCodeError): - await self.on_member_join_helper(self.response_error(500)) - - self.bot.api_client.post.assert_not_called() - - async def test_sync_cog_on_member_join_ignores_guilds(self): - """Events from other guilds should be ignored.""" - member = helpers.MockMember(guild=self.other_guild) - await self.cog.on_member_join(member) - self.bot.api_client.post.assert_not_awaited() - self.bot.api_client.put.assert_not_awaited() - - -class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): - """Tests for the commands in the Sync cog.""" - - async def test_sync_roles_command(self): - """sync() should be called on the RoleSyncer.""" - ctx = helpers.MockContext() - await self.cog.sync_roles_command.callback(self.cog, ctx) - - self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) - - async def test_sync_users_command(self): - """sync() should be called on the UserSyncer.""" - ctx = helpers.MockContext() - await self.cog.sync_users_command.callback(self.cog, ctx) - - self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) - - async def test_commands_require_admin(self): - """The sync commands should only run if the author has the administrator permission.""" - cmds = ( - self.cog.sync_group, - self.cog.sync_roles_command, - self.cog.sync_users_command, - ) - - for cmd in cmds: - with self.subTest(cmd=cmd): - await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/backend/sync/test_roles.py b/tests/bot/cogs/backend/sync/test_roles.py deleted file mode 100644 index 99d682ede..000000000 --- a/tests/bot/cogs/backend/sync/test_roles.py +++ /dev/null @@ -1,157 +0,0 @@ -import unittest -from unittest import mock - -import discord - -from bot.cogs.backend.sync._syncers import RoleSyncer, _Diff, _Role -from tests import helpers - - -def fake_role(**kwargs): - """Fixture to return a dictionary representing a role with default values set.""" - kwargs.setdefault("id", 9) - kwargs.setdefault("name", "fake role") - kwargs.setdefault("colour", 7) - kwargs.setdefault("permissions", 0) - kwargs.setdefault("position", 55) - - return kwargs - - -class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): - """Tests for determining differences between roles in the DB and roles in the Guild cache.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) - - @staticmethod - def get_guild(*roles): - """Fixture to return a guild object with the given roles.""" - guild = helpers.MockGuild() - guild.roles = [] - - for role in roles: - mock_role = helpers.MockRole(**role) - mock_role.colour = discord.Colour(role["colour"]) - mock_role.permissions = discord.Permissions(role["permissions"]) - guild.roles.append(mock_role) - - return guild - - async def test_empty_diff_for_identical_roles(self): - """No differences should be found if the roles in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_role()] - guild = self.get_guild(fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_updated_roles(self): - """Only updated roles should be added to the 'updated' set of the diff.""" - updated_role = fake_role(id=41, name="new") - - self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] - guild = self.get_guild(updated_role, fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_Role(**updated_role)}, set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_roles(self): - """Only new roles should be added to the 'created' set of the diff.""" - new_role = fake_role(id=41, name="new") - - self.bot.api_client.get.return_value = [fake_role()] - guild = self.get_guild(fake_role(), new_role) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_Role(**new_role)}, set(), set()) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_deleted_roles(self): - """Only deleted roles should be added to the 'deleted' set of the diff.""" - deleted_role = fake_role(id=61, name="deleted") - - self.bot.api_client.get.return_value = [fake_role(), deleted_role] - guild = self.get_guild(fake_role()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), {_Role(**deleted_role)}) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_updated_and_deleted_roles(self): - """When roles are added, updated, and removed, all of them are returned properly.""" - new = fake_role(id=41, name="new") - updated = fake_role(id=71, name="updated") - deleted = fake_role(id=61, name="deleted") - - self.bot.api_client.get.return_value = [ - fake_role(), - fake_role(id=71, name="updated name"), - deleted, - ] - guild = self.get_guild(fake_role(), new, updated) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) - - self.assertEqual(actual_diff, expected_diff) - - -class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for the API requests that sync roles.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) - - async def test_sync_created_roles(self): - """Only POST requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(role_tuples, set(), set()) - await self.syncer._sync(diff) - - calls = [mock.call("bot/roles", json=role) for role in roles] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(roles)) - - self.bot.api_client.put.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_updated_roles(self): - """Only PUT requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(set(), role_tuples, set()) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(roles)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_deleted_roles(self): - """Only DELETE requests should be made with the correct payload.""" - roles = [fake_role(id=111), fake_role(id=222)] - - role_tuples = {_Role(**role) for role in roles} - diff = _Diff(set(), set(), role_tuples) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] - self.bot.api_client.delete.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/backend/sync/test_users.py b/tests/bot/cogs/backend/sync/test_users.py deleted file mode 100644 index 51dcbe48a..000000000 --- a/tests/bot/cogs/backend/sync/test_users.py +++ /dev/null @@ -1,158 +0,0 @@ -import unittest -from unittest import mock - -from bot.cogs.backend.sync._syncers import UserSyncer, _Diff, _User -from tests import helpers - - -def fake_user(**kwargs): - """Fixture to return a dictionary representing a user with default values set.""" - kwargs.setdefault("id", 43) - kwargs.setdefault("name", "bob the test man") - kwargs.setdefault("discriminator", 1337) - kwargs.setdefault("roles", (666,)) - kwargs.setdefault("in_guild", True) - - return kwargs - - -class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): - """Tests for determining differences between users in the DB and users in the Guild cache.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) - - @staticmethod - def get_guild(*members): - """Fixture to return a guild object with the given members.""" - guild = helpers.MockGuild() - guild.members = [] - - for member in members: - member = member.copy() - del member["in_guild"] - - mock_member = helpers.MockMember(**member) - mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] - - guild.members.append(mock_member) - - return guild - - async def test_empty_diff_for_no_users(self): - """When no users are given, an empty diff should be returned.""" - guild = self.get_guild() - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_empty_diff_for_identical_users(self): - """No differences should be found if the users in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_user()] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_updated_users(self): - """Only updated users should be added to the 'updated' set of the diff.""" - updated_user = fake_user(id=99, name="new") - - self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] - guild = self.get_guild(updated_user, fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**updated_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_users(self): - """Only new users should be added to the 'created' set of the diff.""" - new_user = fake_user(id=99, name="new") - - self.bot.api_client.get.return_value = [fake_user()] - guild = self.get_guild(fake_user(), new_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, set(), None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_sets_in_guild_false_for_leaving_users(self): - """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**leaving_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_diff_for_new_updated_and_leaving_users(self): - """When users are added, updated, and removed, all of them are returned properly.""" - new_user = fake_user(id=99, name="new") - updated_user = fake_user(id=55, name="updated") - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] - guild = self.get_guild(fake_user(), new_user, updated_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) - - self.assertEqual(actual_diff, expected_diff) - - async def test_empty_diff_for_db_users_not_in_guild(self): - """When the DB knows a user the guild doesn't, no difference is found.""" - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] - guild = self.get_guild(fake_user()) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) - - self.assertEqual(actual_diff, expected_diff) - - -class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): - """Tests for the API requests that sync users.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) - - async def test_sync_created_users(self): - """Only POST requests should be made with the correct payload.""" - users = [fake_user(id=111), fake_user(id=222)] - - user_tuples = {_User(**user) for user in users} - diff = _Diff(user_tuples, set(), None) - await self.syncer._sync(diff) - - calls = [mock.call("bot/users", json=user) for user in users] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(users)) - - self.bot.api_client.put.assert_not_called() - self.bot.api_client.delete.assert_not_called() - - async def test_sync_updated_users(self): - """Only PUT requests should be made with the correct payload.""" - users = [fake_user(id=111), fake_user(id=222)] - - user_tuples = {_User(**user) for user in users} - diff = _Diff(set(), user_tuples, None) - await self.syncer._sync(diff) - - calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(users)) - - self.bot.api_client.post.assert_not_called() - self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/backend/test_logging.py b/tests/bot/cogs/backend/test_logging.py deleted file mode 100644 index c867773e2..000000000 --- a/tests/bot/cogs/backend/test_logging.py +++ /dev/null @@ -1,32 +0,0 @@ -import unittest -from unittest.mock import patch - -from bot import constants -from bot.cogs.backend.logging import Logging -from tests.helpers import MockBot, MockTextChannel - - -class LoggingTests(unittest.IsolatedAsyncioTestCase): - """Test cases for connected login.""" - - def setUp(self): - self.bot = MockBot() - self.cog = Logging(self.bot) - self.dev_log = MockTextChannel(id=1234, name="dev-log") - - @patch("bot.cogs.backend.logging.DEBUG_MODE", False) - async def test_debug_mode_false(self): - """Should send connected message to dev-log.""" - self.bot.get_channel.return_value = self.dev_log - - await self.cog.startup_greeting() - self.bot.wait_until_guild_available.assert_awaited_once_with() - self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) - self.dev_log.send.assert_awaited_once() - - @patch("bot.cogs.backend.logging.DEBUG_MODE", True) - async def test_debug_mode_true(self): - """Should not send anything to dev-log.""" - await self.cog.startup_greeting() - self.bot.wait_until_guild_available.assert_awaited_once_with() - self.bot.get_channel.assert_not_called() diff --git a/tests/bot/cogs/filters/__init__.py b/tests/bot/cogs/filters/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/filters/test_antimalware.py b/tests/bot/cogs/filters/test_antimalware.py deleted file mode 100644 index b00211f47..000000000 --- a/tests/bot/cogs/filters/test_antimalware.py +++ /dev/null @@ -1,165 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, Mock - -from discord import NotFound - -from bot.cogs.filters import antimalware -from bot.constants import Channels, STAFF_ROLES -from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole - - -class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): - """Test the AntiMalware cog.""" - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = MockBot() - self.bot.filter_list_cache = { - "FILE_FORMAT.True": { - ".first": {}, - ".second": {}, - ".third": {}, - } - } - self.cog = antimalware.AntiMalware(self.bot) - self.message = MockMessage() - self.whitelist = [".first", ".second", ".third"] - - async def test_message_with_allowed_attachment(self): - """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_message_without_attachment(self): - """Messages without attachments should result in no action.""" - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_direct_message_with_attachment(self): - """Direct messages should have no action taken.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.guild = None - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_message_with_illegal_extension_gets_deleted(self): - """A message containing an illegal extension should send an embed.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_called_once() - - async def test_message_send_by_staff(self): - """A message send by a member of staff should be ignored.""" - staff_role = MockRole(id=STAFF_ROLES[0]) - self.message.author.roles.append(staff_role) - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_python_file_redirect_embed_description(self): - """A message containing a .py file should result in an embed redirecting the user to our paste site""" - attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - - self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - - async def test_txt_file_redirect_embed_description(self): - """A message containing a .txt file should result in the correct embed.""" - attachment = MockAttachment(filename="python.txt") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.TXT_EMBED_DESCRIPTION = Mock() - antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - cmd_channel = self.bot.get_channel(Channels.bot_commands) - - self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) - antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) - - async def test_other_disallowed_extension_embed_description(self): - """Test the description for a non .py/.txt disallowed extension.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - meta_channel = self.bot.get_channel(Channels.meta) - - self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( - joined_whitelist=", ".join(self.whitelist), - blocked_extensions_str=".disallowed", - meta_channel_mention=meta_channel.mention - ) - - async def test_removing_deleted_message_logs(self): - """Removing an already deleted message logs the correct message""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - self.message.delete.assert_called_once() - - async def test_message_with_illegal_attachment_logs(self): - """Deleting a message with an illegal attachment should result in a log.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - - async def test_get_disallowed_extensions(self): - """The return value should include all non-whitelisted extensions.""" - test_values = ( - ([], []), - (self.whitelist, []), - ([".first"], []), - ([".first", ".disallowed"], [".disallowed"]), - ([".disallowed"], [".disallowed"]), - ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), - ) - - for extensions, expected_disallowed_extensions in test_values: - with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog._get_disallowed_extensions(self.message) - self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) - - -class AntiMalwareSetupTests(unittest.TestCase): - """Tests setup of the `AntiMalware` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = MockBot() - antimalware.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/filters/test_antispam.py b/tests/bot/cogs/filters/test_antispam.py deleted file mode 100644 index 8a3d8d02e..000000000 --- a/tests/bot/cogs/filters/test_antispam.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from bot.cogs.filters import antispam - - -class AntispamConfigurationValidationTests(unittest.TestCase): - """Tests validation of the antispam cog configuration.""" - - def test_default_antispam_config_is_valid(self): - """The default antispam configuration is valid.""" - validation_errors = antispam.validate_config() - self.assertEqual(validation_errors, {}) - - def test_unknown_rule_returns_error(self): - """Configuring an unknown rule returns an error.""" - self.assertEqual( - antispam.validate_config({'invalid-rule': {}}), - {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} - ) - - def test_missing_keys_returns_error(self): - """Not configuring required keys returns an error.""" - keys = (('interval', 'max'), ('max', 'interval')) - for configured_key, unconfigured_key in keys: - with self.subTest( - configured_key=configured_key, - unconfigured_key=unconfigured_key - ): - config = {'burst': {configured_key: 10}} - error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" - - self.assertEqual( - antispam.validate_config(config), - {'burst': error} - ) diff --git a/tests/bot/cogs/filters/test_security.py b/tests/bot/cogs/filters/test_security.py deleted file mode 100644 index 82679f69c..000000000 --- a/tests/bot/cogs/filters/test_security.py +++ /dev/null @@ -1,54 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from discord.ext.commands import NoPrivateMessage - -from bot.cogs.filters import security -from tests.helpers import MockBot, MockContext - - -class SecurityCogTests(unittest.TestCase): - """Tests the `Security` cog.""" - - def setUp(self): - """Attach an instance of the cog to the class for tests.""" - self.bot = MockBot() - self.cog = security.Security(self.bot) - self.ctx = MockContext() - - def test_check_additions(self): - """The cog should add its checks after initialization.""" - self.bot.check.assert_any_call(self.cog.check_on_guild) - self.bot.check.assert_any_call(self.cog.check_not_bot) - - def test_check_not_bot_returns_false_for_humans(self): - """The bot check should return `True` when invoked with human authors.""" - self.ctx.author.bot = False - self.assertTrue(self.cog.check_not_bot(self.ctx)) - - def test_check_not_bot_returns_true_for_robots(self): - """The bot check should return `False` when invoked with robotic authors.""" - self.ctx.author.bot = True - self.assertFalse(self.cog.check_not_bot(self.ctx)) - - def test_check_on_guild_raises_when_outside_of_guild(self): - """When invoked outside of a guild, `check_on_guild` should cause an error.""" - self.ctx.guild = None - - with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): - self.cog.check_on_guild(self.ctx) - - def test_check_on_guild_returns_true_inside_of_guild(self): - """When invoked inside of a guild, `check_on_guild` should return `True`.""" - self.ctx.guild = "lemon's lemonade stand" - self.assertTrue(self.cog.check_on_guild(self.ctx)) - - -class SecurityCogLoadTests(unittest.TestCase): - """Tests loading the `Security` cog.""" - - def test_security_cog_load(self): - """Setup of the extension should call add_cog.""" - bot = MagicMock() - security.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/filters/test_token_remover.py b/tests/bot/cogs/filters/test_token_remover.py deleted file mode 100644 index 55b284ef9..000000000 --- a/tests/bot/cogs/filters/test_token_remover.py +++ /dev/null @@ -1,310 +0,0 @@ -import unittest -from re import Match -from unittest import mock -from unittest.mock import MagicMock - -from discord import Colour, NotFound - -from bot import constants -from bot.cogs.filters import token_remover -from bot.cogs.filters.token_remover import Token, TokenRemover -from bot.cogs.moderation.modlog import ModLog -from tests.helpers import MockBot, MockMessage, autospec - - -class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): - """Tests the `TokenRemover` cog.""" - - def setUp(self): - """Adds the cog, a bot, and a message to the instance for usage in tests.""" - self.bot = MockBot() - self.cog = TokenRemover(bot=self.bot) - - self.msg = MockMessage(id=555, content="hello world") - self.msg.channel.mention = "#lemonade-stand" - self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) - self.msg.author.avatar_url_as.return_value = "picture-lemon.png" - - def test_is_valid_user_id_valid(self): - """Should consider user IDs valid if they decode entirely to ASCII digits.""" - ids = ( - "NDcyMjY1OTQzMDYyNDEzMzMy", - "NDc1MDczNjI5Mzk5NTQ3OTA0", - "NDY3MjIzMjMwNjUwNzc3NjQx", - ) - - for user_id in ids: - with self.subTest(user_id=user_id): - result = TokenRemover.is_valid_user_id(user_id) - self.assertTrue(result) - - def test_is_valid_user_id_invalid(self): - """Should consider non-digit and non-ASCII IDs invalid.""" - ids = ( - ("SGVsbG8gd29ybGQ", "non-digit ASCII"), - ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), - ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), - ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), - ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), - ("{hello}[world]&(bye!)", "ASCII invalid Base64"), - ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), - ) - - for user_id, msg in ids: - with self.subTest(msg=msg): - result = TokenRemover.is_valid_user_id(user_id) - self.assertFalse(result) - - def test_is_valid_timestamp_valid(self): - """Should consider timestamps valid if they're greater than the Discord epoch.""" - timestamps = ( - "XsyRkw", - "Xrim9Q", - "XsyR-w", - "XsySD_", - "Dn9r_A", - ) - - for timestamp in timestamps: - with self.subTest(timestamp=timestamp): - result = TokenRemover.is_valid_timestamp(timestamp) - self.assertTrue(result) - - def test_is_valid_timestamp_invalid(self): - """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" - timestamps = ( - ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), - ("ew", "123"), - ("AoIKgA", "42076800"), - ("{hello}[world]&(bye!)", "ASCII invalid Base64"), - ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), - ) - - for timestamp, msg in timestamps: - with self.subTest(msg=msg): - result = TokenRemover.is_valid_timestamp(timestamp) - self.assertFalse(result) - - def test_mod_log_property(self): - """The `mod_log` property should ask the bot to return the `ModLog` cog.""" - self.bot.get_cog.return_value = 'lemon' - self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) - self.bot.get_cog.assert_called_once_with('ModLog') - - async def test_on_message_edit_uses_on_message(self): - """The edit listener should delegate handling of the message to the normal listener.""" - self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) - - await self.cog.on_message_edit(MockMessage(), self.msg) - self.cog.on_message.assert_awaited_once_with(self.msg) - - @autospec(TokenRemover, "find_token_in_message", "take_action") - async def test_on_message_takes_action(self, find_token_in_message, take_action): - """Should take action if a valid token is found when a message is sent.""" - cog = TokenRemover(self.bot) - found_token = "foobar" - find_token_in_message.return_value = found_token - - await cog.on_message(self.msg) - - find_token_in_message.assert_called_once_with(self.msg) - take_action.assert_awaited_once_with(cog, self.msg, found_token) - - @autospec(TokenRemover, "find_token_in_message", "take_action") - async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): - """Shouldn't take action if a valid token isn't found when a message is sent.""" - cog = TokenRemover(self.bot) - find_token_in_message.return_value = False - - await cog.on_message(self.msg) - - find_token_in_message.assert_called_once_with(self.msg) - take_action.assert_not_awaited() - - @autospec(TokenRemover, "find_token_in_message") - async def test_on_message_ignores_dms_bots(self, find_token_in_message): - """Shouldn't parse a message if it is a DM or authored by a bot.""" - cog = TokenRemover(self.bot) - dm_msg = MockMessage(guild=None) - bot_msg = MockMessage(author=MagicMock(bot=True)) - - for msg in (dm_msg, bot_msg): - await cog.on_message(msg) - find_token_in_message.assert_not_called() - - @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") - def test_find_token_no_matches(self, token_re): - """None should be returned if the regex matches no tokens in a message.""" - token_re.finditer.return_value = () - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - @autospec("bot.cogs.filters.token_remover", "Token") - @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") - def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): - """The first match with a valid user ID and timestamp should be returned as a `Token`.""" - matches = [ - mock.create_autospec(Match, spec_set=True, instance=True), - mock.create_autospec(Match, spec_set=True, instance=True), - ] - tokens = [ - mock.create_autospec(Token, spec_set=True, instance=True), - mock.create_autospec(Token, spec_set=True, instance=True), - ] - - token_re.finditer.return_value = matches - token_cls.side_effect = tokens - is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid. - is_valid_timestamp.return_value = True - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertEqual(tokens[1], return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - @autospec("bot.cogs.filters.token_remover", "Token") - @autospec("bot.cogs.filters.token_remover", "TOKEN_RE") - def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): - """None should be returned if no matches have valid user IDs or timestamps.""" - token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] - token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) - is_valid_id.return_value = False - is_valid_timestamp.return_value = False - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - def test_regex_invalid_tokens(self): - """Messages without anything looking like a token are not matched.""" - tokens = ( - "", - "lemon wins", - "..", - "x.y", - "x.y.", - ".y.z", - ".y.", - "..z", - "x..z", - " . . ", - "\n.\n.\n", - "hellö.world.bye", - "base64.nötbåse64.morebase64", - "19jd3J.dfkm3d.€víł§tüff", - ) - - for token in tokens: - with self.subTest(token=token): - results = token_remover.TOKEN_RE.findall(token) - self.assertEqual(len(results), 0) - - def test_regex_valid_tokens(self): - """Messages that look like tokens should be matched.""" - # Don't worry, these tokens have been invalidated. - tokens = ( - "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", - "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", - "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", - "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", - ) - - for token in tokens: - with self.subTest(token=token): - results = token_remover.TOKEN_RE.fullmatch(token) - self.assertIsNotNone(results, f"{token} was not matched by the regex") - - def test_regex_matches_multiple_valid(self): - """Should support multiple matches in the middle of a string.""" - token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" - token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" - message = f"garbage {token_1} hello {token_2} world" - - results = token_remover.TOKEN_RE.finditer(message) - results = [match[0] for match in results] - self.assertCountEqual((token_1, token_2), results) - - @autospec("bot.cogs.filters.token_remover", "LOG_MESSAGE") - def test_format_log_message(self, log_message): - """Should correctly format the log message with info from the message and token.""" - token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - log_message.format.return_value = "Howdy" - - return_value = TokenRemover.format_log_message(self.msg, token) - - self.assertEqual(return_value, log_message.format.return_value) - log_message.format.assert_called_once_with( - author=self.msg.author, - author_id=self.msg.author.id, - channel=self.msg.channel.mention, - user_id=token.user_id, - timestamp=token.timestamp, - hmac="x" * len(token.hmac), - ) - - @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - @autospec("bot.cogs.filters.token_remover", "log") - @autospec(TokenRemover, "format_log_message") - async def test_take_action(self, format_log_message, logger, mod_log_property): - """Should delete the message and send a mod log.""" - cog = TokenRemover(self.bot) - mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) - token = mock.create_autospec(Token, spec_set=True, instance=True) - log_msg = "testing123" - - mod_log_property.return_value = mod_log - format_log_message.return_value = log_msg - - await cog.take_action(self.msg, token) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_called_once_with( - token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) - ) - - format_log_message.assert_called_once_with(self.msg, token) - logger.debug.assert_called_with(log_msg) - self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") - - mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) - mod_log.send_log_message.assert_called_once_with( - icon_url=constants.Icons.token_removed, - colour=Colour(constants.Colours.soft_red), - title="Token removed!", - text=log_msg, - thumbnail=self.msg.author.avatar_url_as.return_value, - channel_id=constants.Channels.mod_alerts - ) - - @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - async def test_take_action_delete_failure(self, mod_log_property): - """Shouldn't send any messages if the token message can't be deleted.""" - cog = TokenRemover(self.bot) - mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) - self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) - - token = mock.create_autospec(Token, spec_set=True, instance=True) - await cog.take_action(self.msg, token) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_not_awaited() - - -class TokenRemoverExtensionTests(unittest.TestCase): - """Tests for the token_remover extension.""" - - @autospec("bot.cogs.filters.token_remover", "TokenRemover") - def test_extension_setup(self, cog): - """The TokenRemover cog should be added.""" - bot = MockBot() - token_remover.setup(bot) - - cog.assert_called_once_with(bot) - bot.add_cog.assert_called_once() - self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/cogs/info/__init__.py b/tests/bot/cogs/info/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/info/test_information.py b/tests/bot/cogs/info/test_information.py deleted file mode 100644 index 895a8328e..000000000 --- a/tests/bot/cogs/info/test_information.py +++ /dev/null @@ -1,584 +0,0 @@ -import asyncio -import textwrap -import unittest -import unittest.mock - -import discord - -from bot import constants -from bot.cogs.info import information -from bot.utils.checks import InWhitelistCheckFailure -from tests import helpers - -COG_PATH = "bot.cogs.info.information.Information" - - -class InformationCogTests(unittest.TestCase): - """Tests the Information cog.""" - - @classmethod - def setUpClass(cls): - cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = helpers.MockBot() - - self.cog = information.Information(self.bot) - - self.ctx = helpers.MockContext() - self.ctx.author.roles.append(self.moderator_role) - - def test_roles_command_command(self): - """Test if the `role_info` command correctly returns the `moderator_role`.""" - self.ctx.guild.roles.append(self.moderator_role) - - self.cog.roles_info.can_run = unittest.mock.AsyncMock() - self.cog.roles_info.can_run.return_value = True - - coroutine = self.cog.roles_info.callback(self.cog, self.ctx) - - self.assertIsNone(asyncio.run(coroutine)) - self.ctx.send.assert_called_once() - - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - - self.assertEqual(embed.title, "Role information (Total 1 role)") - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") - - def test_role_info_command(self): - """Tests the `role info` command.""" - dummy_role = helpers.MockRole( - name="Dummy", - id=112233445566778899, - colour=discord.Colour.blurple(), - position=10, - members=[self.ctx.author], - permissions=discord.Permissions(0) - ) - - admin_role = helpers.MockRole( - name="Admins", - id=998877665544332211, - colour=discord.Colour.red(), - position=3, - members=[self.ctx.author], - permissions=discord.Permissions(0), - ) - - self.ctx.guild.roles.append([dummy_role, admin_role]) - - self.cog.role_info.can_run = unittest.mock.AsyncMock() - self.cog.role_info.can_run.return_value = True - - coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) - - self.assertIsNone(asyncio.run(coroutine)) - - self.assertEqual(self.ctx.send.call_count, 2) - - (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list - - dummy_embed = dummy_kwargs["embed"] - admin_embed = admin_kwargs["embed"] - - self.assertEqual(dummy_embed.title, "Dummy info") - self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) - - self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) - self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") - self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") - self.assertEqual(dummy_embed.fields[3].value, "1") - self.assertEqual(dummy_embed.fields[4].value, "10") - self.assertEqual(dummy_embed.fields[5].value, "0") - - self.assertEqual(admin_embed.title, "Admins info") - self.assertEqual(admin_embed.colour, discord.Colour.red()) - - @unittest.mock.patch('bot.cogs.info.information.time_since') - def test_server_info_command(self, time_since_patch): - time_since_patch.return_value = '2 days ago' - - self.ctx.guild = helpers.MockGuild( - features=('lemons', 'apples'), - region="The Moon", - roles=[self.moderator_role], - channels=[ - discord.TextChannel( - state={}, - guild=self.ctx.guild, - data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} - ), - discord.CategoryChannel( - state={}, - guild=self.ctx.guild, - data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} - ), - discord.VoiceChannel( - state={}, - guild=self.ctx.guild, - data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} - ) - ], - members=[ - *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), - *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), - *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), - *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), - ], - member_count=1_234, - icon_url='a-lemon.jpg', - ) - - coroutine = self.cog.server_info.callback(self.cog, self.ctx) - self.assertIsNone(asyncio.run(coroutine)) - - time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual( - embed.description, - textwrap.dedent( - f""" - **Server information** - Created: {time_since_patch.return_value} - Voice region: {self.ctx.guild.region} - Features: {', '.join(self.ctx.guild.features)} - - **Channel counts** - Category channels: 1 - Text channels: 1 - Voice channels: 1 - Staff channels: 0 - - **Member counts** - Members: {self.ctx.guild.member_count:,} - Staff members: 0 - Roles: {len(self.ctx.guild.roles)} - - **Member statuses** - {constants.Emojis.status_online} 2 - {constants.Emojis.status_idle} 1 - {constants.Emojis.status_dnd} 4 - {constants.Emojis.status_offline} 3 - """ - ) - ) - self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') - - -class UserInfractionHelperMethodTests(unittest.TestCase): - """Tests for the helper methods of the `!user` command.""" - - def setUp(self): - """Common set-up steps done before for each test.""" - self.bot = helpers.MockBot() - self.bot.api_client.get = unittest.mock.AsyncMock() - self.cog = information.Information(self.bot) - self.member = helpers.MockMember(id=1234) - - def test_user_command_helper_method_get_requests(self): - """The helper methods should form the correct get requests.""" - test_values = ( - { - "helper_method": self.cog.basic_user_infraction_counts, - "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}), - }, - { - "helper_method": self.cog.expanded_user_infraction_counts, - "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}), - }, - { - "helper_method": self.cog.user_nomination_counts, - "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}), - }, - ) - - for test_value in test_values: - helper_method = test_value["helper_method"] - endpoint, params = test_value["expected_args"] - - with self.subTest(method=helper_method, endpoint=endpoint, params=params): - asyncio.run(helper_method(self.member)) - self.bot.api_client.get.assert_called_once_with(endpoint, params=params) - self.bot.api_client.get.reset_mock() - - def _method_subtests(self, method, test_values, default_header): - """Helper method that runs the subtests for the different helper methods.""" - for test_value in test_values: - api_response = test_value["api response"] - expected_lines = test_value["expected_lines"] - - with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): - self.bot.api_client.get.return_value = api_response - - expected_output = "\n".join(default_header + expected_lines) - actual_output = asyncio.run(method(self.member)) - - self.assertEqual(expected_output, actual_output) - - def test_basic_user_infraction_counts_returns_correct_strings(self): - """The method should correctly list both the total and active number of non-hidden infractions.""" - test_values = ( - # No infractions means zero counts - { - "api response": [], - "expected_lines": ["Total: 0", "Active: 0"], - }, - # Simple, single-infraction dictionaries - { - "api response": [{"type": "ban", "active": True}], - "expected_lines": ["Total: 1", "Active: 1"], - }, - { - "api response": [{"type": "ban", "active": False}], - "expected_lines": ["Total: 1", "Active: 0"], - }, - # Multiple infractions with various `active` status - { - "api response": [ - {"type": "ban", "active": True}, - {"type": "kick", "active": False}, - {"type": "ban", "active": True}, - {"type": "ban", "active": False}, - ], - "expected_lines": ["Total: 4", "Active: 2"], - }, - ) - - header = ["**Infractions**"] - - self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) - - def test_expanded_user_infraction_counts_returns_correct_strings(self): - """The method should correctly list the total and active number of all infractions split by infraction type.""" - test_values = ( - { - "api response": [], - "expected_lines": ["This user has never received an infraction."], - }, - # Shows non-hidden inactive infraction as expected - { - "api response": [{"type": "kick", "active": False, "hidden": False}], - "expected_lines": ["Kicks: 1"], - }, - # Shows non-hidden active infraction as expected - { - "api response": [{"type": "mute", "active": True, "hidden": False}], - "expected_lines": ["Mutes: 1 (1 active)"], - }, - # Shows hidden inactive infraction as expected - { - "api response": [{"type": "superstar", "active": False, "hidden": True}], - "expected_lines": ["Superstars: 1"], - }, - # Shows hidden active infraction as expected - { - "api response": [{"type": "ban", "active": True, "hidden": True}], - "expected_lines": ["Bans: 1 (1 active)"], - }, - # Correctly displays tally of multiple infractions of mixed properties in alphabetical order - { - "api response": [ - {"type": "kick", "active": False, "hidden": True}, - {"type": "ban", "active": True, "hidden": True}, - {"type": "superstar", "active": True, "hidden": True}, - {"type": "mute", "active": True, "hidden": True}, - {"type": "ban", "active": False, "hidden": False}, - {"type": "note", "active": False, "hidden": True}, - {"type": "note", "active": False, "hidden": True}, - {"type": "warn", "active": False, "hidden": False}, - {"type": "note", "active": False, "hidden": True}, - ], - "expected_lines": [ - "Bans: 2 (1 active)", - "Kicks: 1", - "Mutes: 1 (1 active)", - "Notes: 3", - "Superstars: 1 (1 active)", - "Warns: 1", - ], - }, - ) - - header = ["**Infractions**"] - - self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) - - def test_user_nomination_counts_returns_correct_strings(self): - """The method should list the number of active and historical nominations for the user.""" - test_values = ( - { - "api response": [], - "expected_lines": ["This user has never been nominated."], - }, - { - "api response": [{'active': True}], - "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], - }, - { - "api response": [{'active': True}, {'active': False}], - "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], - }, - { - "api response": [{'active': False}], - "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."], - }, - { - "api response": [{'active': False}, {'active': False}], - "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."], - }, - - ) - - header = ["**Nominations**"] - - self._method_subtests(self.cog.user_nomination_counts, test_values, header) - - -@unittest.mock.patch("bot.cogs.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) -@unittest.mock.patch("bot.cogs.info.information.constants.MODERATION_CHANNELS", new=[50]) -class UserEmbedTests(unittest.TestCase): - """Tests for the creation of the `!user` embed.""" - - def setUp(self): - """Common set-up steps done before for each test.""" - self.bot = helpers.MockBot() - self.bot.api_client.get = unittest.mock.AsyncMock() - self.cog = information.Information(self.bot) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): - """The embed should use the string representation of the user if they don't have a nick.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - user = helpers.MockMember() - user.nick = None - user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.title, "Mr. Hemlock") - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_nick_in_title_if_available(self): - """The embed should use the nick if it's available.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - user = helpers.MockMember() - user.nick = "Cat lover" - user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_ignores_everyone_role(self): - """Created `!user` embeds should not contain mention of the @everyone-role.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) - admins_role = helpers.MockRole(name='Admins') - admins_role.colour = 100 - - # A `MockMember` has the @Everyone role by default; we add the Admins to that. - user = helpers.MockMember(roles=[admins_role], top_role=admins_role) - - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertIn("&Admins", embed.description) - self.assertNotIn("&Everyone", embed.description) - - @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): - """The embed should contain expanded infractions and nomination info in mod channels.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - infraction_counts.return_value = "expanded infractions info" - nomination_counts.return_value = "nomination info" - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - infraction_counts.assert_called_once_with(user) - nomination_counts.assert_called_once_with(user) - - self.assertEqual( - textwrap.dedent(f""" - **User Information** - Created: {"1 year ago"} - Profile: {user.mention} - ID: {user.id} - - **Member Information** - Joined: {"1 year ago"} - Roles: &Moderators - - expanded infractions info - - nomination info - """).strip(), - embed.description - ) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): - """The embed should contain only basic infraction data outside of mod channels.""" - ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - infraction_counts.return_value = "basic infractions info" - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - infraction_counts.assert_called_once_with(user) - - self.assertEqual( - textwrap.dedent(f""" - **User Information** - Created: {"1 year ago"} - Profile: {user.mention} - ID: {user.id} - - **Member Information** - Joined: {"1 year ago"} - Roles: &Moderators - - basic infractions info - """).strip(), - embed.description - ) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): - """The embed should be created with the colour of the top role, if a top role is available.""" - ctx = helpers.MockContext() - - moderators_role = helpers.MockRole(name='Moderators') - moderators_role.colour = 100 - - user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): - """The embed should be created with a blurple colour if the user has no assigned roles.""" - ctx = helpers.MockContext() - - user = helpers.MockMember(id=217) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - self.assertEqual(embed.colour, discord.Colour.blurple()) - - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) - def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): - """The embed thumbnail should be set to the user's avatar in `png` format.""" - ctx = helpers.MockContext() - - user = helpers.MockMember(id=217) - user.avatar_url_as.return_value = "avatar url" - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - - user.avatar_url_as.assert_called_once_with(static_format="png") - self.assertEqual(embed.thumbnail.url, "avatar url") - - -@unittest.mock.patch("bot.cogs.info.information.constants") -class UserCommandTests(unittest.TestCase): - """Tests for the `!user` command.""" - - def setUp(self): - """Set up steps executed before each test is run.""" - self.bot = helpers.MockBot() - self.cog = information.Information(self.bot) - - self.moderator_role = helpers.MockRole(name="Moderators", id=2, position=10) - self.flautist_role = helpers.MockRole(name="Flautists", id=3, position=2) - self.bassist_role = helpers.MockRole(name="Bassists", id=4, position=3) - - self.author = helpers.MockMember(id=1, name="syntaxaire") - self.moderator = helpers.MockMember(id=2, name="riffautae", roles=[self.moderator_role]) - self.target = helpers.MockMember(id=3, name="__fluzz__") - - def test_regular_member_cannot_target_another_member(self, constants): - """A regular user should not be able to use `!user` targeting another user.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - - ctx = helpers.MockContext(author=self.author) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) - - ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") - - def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): - """A regular user should not be able to use this command outside of bot-commands.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) - - msg = "Sorry, but you may only use this command within <#50>." - with self.assertRaises(InWhitelistCheckFailure, msg=msg): - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") - def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): - """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - create_embed.assert_called_once_with(ctx, self.author) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") - def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): - """A user should target itself with `!user` when a `user` argument was not provided.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) - - create_embed.assert_called_once_with(ctx, self.author) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") - def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): - """Staff members should be able to bypass the bot-commands channel restriction.""" - constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot_commands = 50 - - ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - - create_embed.assert_called_once_with(ctx, self.moderator) - ctx.send.assert_called_once() - - @unittest.mock.patch("bot.cogs.info.information.Information.create_user_embed") - def test_moderators_can_target_another_member(self, create_embed, constants): - """A moderator should be able to use `!user` targeting another user.""" - constants.MODERATION_ROLES = [self.moderator_role.id] - constants.STAFF_ROLES = [self.moderator_role.id] - - ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) - - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) - - create_embed.assert_called_once_with(ctx, self.target) - ctx.send.assert_called_once() diff --git a/tests/bot/cogs/moderation/__init__.py b/tests/bot/cogs/moderation/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/moderation/infraction/__init__.py b/tests/bot/cogs/moderation/infraction/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/moderation/infraction/test_infractions.py b/tests/bot/cogs/moderation/infraction/test_infractions.py deleted file mode 100644 index 2df61d431..000000000 --- a/tests/bot/cogs/moderation/infraction/test_infractions.py +++ /dev/null @@ -1,55 +0,0 @@ -import textwrap -import unittest -from unittest.mock import AsyncMock, Mock, patch - -from bot.cogs.moderation.infraction.infractions import Infractions -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole - - -class TruncationTests(unittest.IsolatedAsyncioTestCase): - """Tests for ban and kick command reason truncation.""" - - def setUp(self): - self.bot = MockBot() - self.cog = Infractions(self.bot) - self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) - self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) - self.guild = MockGuild(id=4567) - self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) - - @patch("bot.cogs.moderation.infraction._utils.get_active_infraction") - @patch("bot.cogs.moderation.infraction._utils.post_infraction") - async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): - """Should truncate reason for `ctx.guild.ban`.""" - get_active_mock.return_value = None - post_infraction_mock.return_value = {"foo": "bar"} - - self.cog.apply_infraction = AsyncMock() - self.bot.get_cog.return_value = AsyncMock() - self.cog.mod_log.ignore = Mock() - self.ctx.guild.ban = Mock() - - await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) - self.ctx.guild.ban.assert_called_once_with( - self.target, - reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), - delete_message_days=0 - ) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value - ) - - @patch("bot.cogs.moderation.infraction._utils.post_infraction") - async def test_apply_kick_reason_truncation(self, post_infraction_mock): - """Should truncate reason for `Member.kick`.""" - post_infraction_mock.return_value = {"foo": "bar"} - - self.cog.apply_infraction = AsyncMock() - self.cog.mod_log.ignore = Mock() - self.target.kick = Mock() - - await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) - self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value - ) diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py deleted file mode 100644 index 5e4d90251..000000000 --- a/tests/bot/cogs/moderation/test_incidents.py +++ /dev/null @@ -1,770 +0,0 @@ -import asyncio -import enum -import logging -import typing as t -import unittest -from unittest.mock import AsyncMock, MagicMock, call, patch - -import aiohttp -import discord - -from bot.cogs.moderation import incidents -from bot.constants import Colours -from tests.helpers import ( - MockAsyncWebhook, - MockAttachment, - MockBot, - MockMember, - MockMessage, - MockReaction, - MockRole, - MockTextChannel, - MockUser, -) - - -class MockAsyncIterable: - """ - Helper for mocking asynchronous for loops. - - It does not appear that the `unittest` library currently provides anything that would - allow us to simply mock an async iterator, such as `discord.TextChannel.history`. - - We therefore write our own helper to wrap a regular synchronous iterable, and feed - its values via `__anext__` rather than `__next__`. - - This class was written for the purposes of testing the `Incidents` cog - it may not - be generic enough to be placed in the `tests.helpers` module. - """ - - def __init__(self, messages: t.Iterable): - """Take a sync iterable to be wrapped.""" - self.iter_messages = iter(messages) - - def __aiter__(self): - """Return `self` as we provide the `__anext__` method.""" - return self - - async def __anext__(self): - """ - Feed the next item, or raise `StopAsyncIteration`. - - Since we're wrapping a sync iterator, it will communicate that it has been depleted - by raising a `StopIteration`. The `async for` construct does not expect it, and we - therefore need to substitute it for the appropriate exception type. - """ - try: - return next(self.iter_messages) - except StopIteration: - raise StopAsyncIteration - - -class MockSignal(enum.Enum): - A = "A" - B = "B" - - -mock_404 = discord.NotFound( - response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response - message="Not found", -) - - -class TestDownloadFile(unittest.IsolatedAsyncioTestCase): - """Collection of tests for the `download_file` helper function.""" - - async def test_download_file_success(self): - """If `to_file` succeeds, function returns the acquired `discord.File`.""" - file = MagicMock(discord.File, filename="bigbadlemon.jpg") - attachment = MockAttachment(to_file=AsyncMock(return_value=file)) - - acquired_file = await incidents.download_file(attachment) - self.assertIs(file, acquired_file) - - async def test_download_file_404(self): - """If `to_file` encounters a 404, function handles the exception & returns None.""" - attachment = MockAttachment(to_file=AsyncMock(side_effect=mock_404)) - - acquired_file = await incidents.download_file(attachment) - self.assertIsNone(acquired_file) - - async def test_download_file_fail(self): - """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" - arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") - attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) - - with self.assertLogs(logger=incidents.log, level=logging.ERROR): - acquired_file = await incidents.download_file(attachment) - - self.assertIsNone(acquired_file) - - -class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): - """Collection of tests for the `make_embed` helper function.""" - - async def test_make_embed_actioned(self): - """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) - - self.assertEqual(embed.colour.value, Colours.soft_green) - self.assertIn("Actioned", embed.footer.text) - - async def test_make_embed_not_actioned(self): - """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) - - self.assertEqual(embed.colour.value, Colours.soft_red) - self.assertIn("Rejected", embed.footer.text) - - async def test_make_embed_content(self): - """Incident content appears as embed description.""" - incident = MockMessage(content="this is an incident") - embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - - self.assertEqual(incident.content, embed.description) - - async def test_make_embed_with_attachment_succeeds(self): - """Incident's attachment is downloaded and displayed in the embed's image field.""" - file = MagicMock(discord.File, filename="bigbadjoe.jpg") - attachment = MockAttachment(filename="bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) - - # Patch `download_file` to return our `file` - with patch("bot.cogs.moderation.incidents.download_file", AsyncMock(return_value=file)): - embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - - self.assertIs(file, returned_file) - self.assertEqual("attachment://bigbadjoe.jpg", embed.image.url) - - async def test_make_embed_with_attachment_fails(self): - """Incident's attachment fails to download, proxy url is linked instead.""" - attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) - - # Patch `download_file` to return None as if the download failed - with patch("bot.cogs.moderation.incidents.download_file", AsyncMock(return_value=None)): - embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - - self.assertIsNone(returned_file) - - # The author name field is simply expected to have something in it, we do not assert the message - self.assertGreater(len(embed.author.name), 0) - self.assertEqual(embed.author.url, "discord.com/bigbadjoe.jpg") # However, it should link the exact url - - -@patch("bot.constants.Channels.incidents", 123) -class TestIsIncident(unittest.TestCase): - """ - Collection of tests for the `is_incident` helper function. - - In `setUp`, we will create a mock message which should qualify as an incident. Each - test case will then mutate this instance to make it **not** qualify, in various ways. - - Notice that we patch the #incidents channel id globally for this class. - """ - - def setUp(self) -> None: - """Prepare a mock message which should qualify as an incident.""" - self.incident = MockMessage( - channel=MockTextChannel(id=123), - content="this is an incident", - author=MockUser(bot=False), - pinned=False, - ) - - def test_is_incident_true(self): - """Message qualifies as an incident if unchanged.""" - self.assertTrue(incidents.is_incident(self.incident)) - - def check_false(self): - """Assert that `self.incident` does **not** qualify as an incident.""" - self.assertFalse(incidents.is_incident(self.incident)) - - def test_is_incident_false_channel(self): - """Message doesn't qualify if sent outside of #incidents.""" - self.incident.channel = MockTextChannel(id=456) - self.check_false() - - def test_is_incident_false_content(self): - """Message doesn't qualify if content begins with hash symbol.""" - self.incident.content = "# this is a comment message" - self.check_false() - - def test_is_incident_false_author(self): - """Message doesn't qualify if author is a bot.""" - self.incident.author = MockUser(bot=True) - self.check_false() - - def test_is_incident_false_pinned(self): - """Message doesn't qualify if it is pinned.""" - self.incident.pinned = True - self.check_false() - - -class TestOwnReactions(unittest.TestCase): - """Assertions for the `own_reactions` function.""" - - def test_own_reactions(self): - """Only bot's own emoji are extracted from the input incident.""" - reactions = ( - MockReaction(emoji="A", me=True), - MockReaction(emoji="B", me=True), - MockReaction(emoji="C", me=False), - ) - message = MockMessage(reactions=reactions) - self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) - - -@patch("bot.cogs.moderation.incidents.ALL_SIGNALS", {"A", "B"}) -class TestHasSignals(unittest.TestCase): - """ - Assertions for the `has_signals` function. - - We patch `ALL_SIGNALS` globally. Each test function then patches `own_reactions` - as appropriate. - """ - - def test_has_signals_true(self): - """True when `own_reactions` returns all emoji in `ALL_SIGNALS`.""" - message = MockMessage() - own_reactions = MagicMock(return_value={"A", "B"}) - - with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): - self.assertTrue(incidents.has_signals(message)) - - def test_has_signals_false(self): - """False when `own_reactions` does not return all emoji in `ALL_SIGNALS`.""" - message = MockMessage() - own_reactions = MagicMock(return_value={"A", "C"}) - - with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): - self.assertFalse(incidents.has_signals(message)) - - -@patch("bot.cogs.moderation.incidents.Signal", MockSignal) -class TestAddSignals(unittest.IsolatedAsyncioTestCase): - """ - Assertions for the `add_signals` coroutine. - - These are all fairly similar and could go into a single test function, but I found the - patching & sub-testing fairly awkward in that case and decided to split them up - to avoid unnecessary syntax noise. - """ - - def setUp(self): - """Prepare a mock incident message for tests to use.""" - self.incident = MockMessage() - - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value=set())) - async def test_add_signals_missing(self): - """All emoji are added when none are present.""" - await incidents.add_signals(self.incident) - self.incident.add_reaction.assert_has_calls([call("A"), call("B")]) - - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A"})) - async def test_add_signals_partial(self): - """Only missing emoji are added when some are present.""" - await incidents.add_signals(self.incident) - self.incident.add_reaction.assert_has_calls([call("B")]) - - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"})) - async def test_add_signals_present(self): - """No emoji are added when all are present.""" - await incidents.add_signals(self.incident) - self.incident.add_reaction.assert_not_called() - - -class TestIncidents(unittest.IsolatedAsyncioTestCase): - """ - Tests for bound methods of the `Incidents` cog. - - Use this as a base class for `Incidents` tests - it will prepare a fresh instance - for each test function, but not make any assertions on its own. Tests can mutate - the instance as they wish. - """ - - def setUp(self): - """ - Prepare a fresh `Incidents` instance for each test. - - Note that this will not schedule `crawl_incidents` in the background, as everything - is being mocked. The `crawl_task` attribute will end up being None. - """ - self.cog_instance = incidents.Incidents(MockBot()) - - -@patch("asyncio.sleep", AsyncMock()) # Prevent the coro from sleeping to speed up the test -class TestCrawlIncidents(TestIncidents): - """ - Tests for the `Incidents.crawl_incidents` coroutine. - - Apart from `test_crawl_incidents_waits_until_cache_ready`, all tests in this class - will patch the return values of `is_incident` and `has_signal` and then observe - whether the `AsyncMock` for `add_signals` was awaited or not. - - The `add_signals` mock is added by each test separately to ensure it is clean (has not - been awaited by another test yet). The mock can be reset, but this appears to be the - cleaner way. - - For each test, we inject a mock channel with a history of 1 message only (see: `setUp`). - """ - - def setUp(self): - """For each test, ensure `bot.get_channel` returns a channel with 1 arbitrary message.""" - super().setUp() # First ensure we get `cog_instance` from parent - - incidents_history = MagicMock(return_value=MockAsyncIterable([MockMessage()])) - self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(history=incidents_history)) - - async def test_crawl_incidents_waits_until_cache_ready(self): - """ - The coroutine will await the `wait_until_guild_available` event. - - Since this task is schedule in the `__init__`, it is critical that it waits for the - cache to be ready, so that it can safely get the #incidents channel. - """ - await self.cog_instance.crawl_incidents() - self.cog_instance.bot.wait_until_guild_available.assert_awaited() - - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)) # Message doesn't qualify - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=False)) - async def test_crawl_incidents_noop_if_is_not_incident(self): - """Signals are not added for a non-incident message.""" - await self.cog_instance.crawl_incidents() - incidents.add_signals.assert_not_awaited() - - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=True)) # But already has signals - async def test_crawl_incidents_noop_if_message_already_has_signals(self): - """Signals are not added for messages which already have them.""" - await self.cog_instance.crawl_incidents() - incidents.add_signals.assert_not_awaited() - - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=False)) # And doesn't have signals - async def test_crawl_incidents_add_signals_called(self): - """Message has signals added as it does not have them yet and qualifies as an incident.""" - await self.cog_instance.crawl_incidents() - incidents.add_signals.assert_awaited_once() - - -class TestArchive(TestIncidents): - """Tests for the `Incidents.archive` coroutine.""" - - async def test_archive_webhook_not_found(self): - """ - Method recovers and returns False when the webhook is not found. - - Implicitly, this also tests that the error is handled internally and doesn't - propagate out of the method, which is just as important. - """ - self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) - self.assertFalse( - await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) - ) - - async def test_archive_relays_incident(self): - """ - If webhook is found, method relays `incident` properly. - - This test will assert that the fetched webhook's `send` method is fed the correct arguments, - and that the `archive` method returns True. - """ - webhook = MockAsyncWebhook() - self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) # Patch in our webhook - - # Define our own `incident` to be archived - incident = MockMessage( - content="this is an incident", - author=MockUser(name="author_name", avatar_url="author_avatar"), - id=123, - ) - built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this - - with patch("bot.cogs.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): - archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) - - # Now we check that the webhook was given the correct args, and that `archive` returned True - webhook.send.assert_called_once_with( - embed=built_embed, - username="author_name", - avatar_url="author_avatar", - file=None, - ) - self.assertTrue(archive_return) - - async def test_archive_clyde_username(self): - """ - The archive webhook username is cleansed using `sub_clyde`. - - Discord will reject any webhook with "clyde" in the username field, as it impersonates - the official Clyde bot. Since we do not control what the username will be (the incident - author name is used), we must ensure the name is cleansed, otherwise the relay may fail. - - This test assumes the username is passed as a kwarg. If this test fails, please review - whether the passed argument is being retrieved correctly. - """ - webhook = MockAsyncWebhook() - self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - - message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) - await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) - - self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) - - -class TestMakeConfirmationTask(TestIncidents): - """ - Tests for the `Incidents.make_confirmation_task` method. - - Writing tests for this method is difficult, as it mostly just delegates the provided - information elsewhere. There is very little internal logic. Whether our approach - works conceptually is difficult to prove using unit tests. - """ - - def test_make_confirmation_task_check(self): - """ - The internal check will recognize the passed incident. - - This is a little tricky - we first pass a message with a specific `id` in, and then - retrieve the built check from the `call_args` of the `wait_for` method. This relies - on the check being passed as a kwarg. - - Once the check is retrieved, we assert that it gives True for our incident's `id`, - and False for any other. - - If this function begins to fail, first check that `created_check` is being retrieved - correctly. It should be the function that is built locally in the tested method. - """ - self.cog_instance.make_confirmation_task(MockMessage(id=123)) - - self.cog_instance.bot.wait_for.assert_called_once() - created_check = self.cog_instance.bot.wait_for.call_args.kwargs["check"] - - # The `message_id` matches the `id` of our incident - self.assertTrue(created_check(payload=MagicMock(message_id=123))) - - # This `message_id` does not match - self.assertFalse(created_check(payload=MagicMock(message_id=0))) - - -@patch("bot.cogs.moderation.incidents.ALLOWED_ROLES", {1, 2}) -@patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", AsyncMock()) # Generic awaitable -class TestProcessEvent(TestIncidents): - """Tests for the `Incidents.process_event` coroutine.""" - - async def test_process_event_bad_role(self): - """The reaction is removed when the author lacks all allowed roles.""" - incident = MockMessage() - member = MockMember(roles=[MockRole(id=0)]) # Must have role 1 or 2 - - await self.cog_instance.process_event("reaction", incident, member) - incident.remove_reaction.assert_called_once_with("reaction", member) - - async def test_process_event_bad_emoji(self): - """ - The reaction is removed when an invalid emoji is used. - - This requires that we pass in a `member` with valid roles, as we need the role check - to succeed. - """ - incident = MockMessage() - member = MockMember(roles=[MockRole(id=1)]) # Member has allowed role - - await self.cog_instance.process_event("invalid_signal", incident, member) - incident.remove_reaction.assert_called_once_with("invalid_signal", member) - - async def test_process_event_no_archive_on_investigating(self): - """Message is not archived on `Signal.INVESTIGATING`.""" - with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive: - await self.cog_instance.process_event( - reaction=incidents.Signal.INVESTIGATING.value, - incident=MockMessage(), - member=MockMember(roles=[MockRole(id=1)]), - ) - - mocked_archive.assert_not_called() - - async def test_process_event_no_delete_if_archive_fails(self): - """ - Original message is not deleted when `Incidents.archive` returns False. - - This is the way of signaling that the relay failed, and we should not remove the original, - as that would result in losing the incident record. - """ - incident = MockMessage() - - with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)): - await self.cog_instance.process_event( - reaction=incidents.Signal.ACTIONED.value, - incident=incident, - member=MockMember(roles=[MockRole(id=1)]) - ) - - incident.delete.assert_not_called() - - async def test_process_event_confirmation_task_is_awaited(self): - """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" - mock_task = AsyncMock() - - with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task): - await self.cog_instance.process_event( - reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(), - member=MockMember(roles=[MockRole(id=1)]) - ) - - mock_task.assert_awaited() - - async def test_process_event_confirmation_task_timeout_is_handled(self): - """ - Confirmation task `asyncio.TimeoutError` is handled gracefully. - - We have `make_confirmation_task` return a mock with a side effect, and then catch the - exception should it propagate out of `process_event`. This is so that we can then manually - fail the test with a more informative message than just the plain traceback. - """ - mock_task = AsyncMock(side_effect=asyncio.TimeoutError()) - - try: - with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task): - await self.cog_instance.process_event( - reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(), - member=MockMember(roles=[MockRole(id=1)]) - ) - except asyncio.TimeoutError: - self.fail("TimeoutError was not handled gracefully, and propagated out of `process_event`!") - - -class TestResolveMessage(TestIncidents): - """Tests for the `Incidents.resolve_message` coroutine.""" - - async def test_resolve_message_pass_message_id(self): - """Method will call `_get_message` with the passed `message_id`.""" - await self.cog_instance.resolve_message(123) - self.cog_instance.bot._connection._get_message.assert_called_once_with(123) - - async def test_resolve_message_in_cache(self): - """ - No API call is made if the queried message exists in the cache. - - We mock the `_get_message` return value regardless of input. Whether it finds the message - internally is considered d.py's responsibility, not ours. - """ - cached_message = MockMessage(id=123) - self.cog_instance.bot._connection._get_message = MagicMock(return_value=cached_message) - - return_value = await self.cog_instance.resolve_message(123) - - self.assertIs(return_value, cached_message) - self.cog_instance.bot.get_channel.assert_not_called() # The `fetch_message` line was never hit - - async def test_resolve_message_not_in_cache(self): - """ - The message is retrieved from the API if it isn't cached. - - This is desired behaviour for messages which exist, but were sent before the bot's - current session. - """ - self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - - # API returns our message - uncached_message = MockMessage() - fetch_message = AsyncMock(return_value=uncached_message) - self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) - - retrieved_message = await self.cog_instance.resolve_message(123) - self.assertIs(retrieved_message, uncached_message) - - async def test_resolve_message_doesnt_exist(self): - """ - If the API returns a 404, the function handles it gracefully and returns None. - - This is an edge-case happening with racing events - event A will relay the message - to the archive and delete the original. Once event B acquires the `event_lock`, - it will not find the message in the cache, and will ask the API. - """ - self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - - fetch_message = AsyncMock(side_effect=mock_404) - self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) - - self.assertIsNone(await self.cog_instance.resolve_message(123)) - - async def test_resolve_message_fetch_fails(self): - """ - Non-404 errors are handled, logged & None is returned. - - In contrast with a 404, this should make an error-level log. We assert that at least - one such log was made - we do not make any assertions about the log's message. - """ - self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - - arbitrary_error = discord.HTTPException( - response=MagicMock(aiohttp.ClientResponse), - message="Arbitrary error", - ) - fetch_message = AsyncMock(side_effect=arbitrary_error) - self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) - - with self.assertLogs(logger=incidents.log, level=logging.ERROR): - self.assertIsNone(await self.cog_instance.resolve_message(123)) - - -@patch("bot.constants.Channels.incidents", 123) -class TestOnRawReactionAdd(TestIncidents): - """ - Tests for the `Incidents.on_raw_reaction_add` listener. - - Writing tests for this listener comes with additional complexity due to the listener - awaiting the `crawl_task` task. See `asyncSetUp` for further details, which attempts - to make unit testing this function possible. - """ - - def setUp(self): - """ - Prepare & assign `payload` attribute. - - This attribute represents an *ideal* payload which will not be rejected by the - listener. As each test will receive a fresh instance, it can be mutated to - observe how the listener's behaviour changes with different attributes on - the passed payload. - """ - super().setUp() # Ensure `cog_instance` is assigned - - self.payload = MagicMock( - discord.RawReactionActionEvent, - channel_id=123, # Patched at class level - message_id=456, - member=MockMember(bot=False), - emoji="reaction", - ) - - async def asyncSetUp(self): # noqa: N802 - """ - Prepare an empty task and assign it as `crawl_task`. - - It appears that the `unittest` framework does not provide anything for mocking - asyncio tasks. An `AsyncMock` instance can be called and then awaited, however, - it does not provide the `done` method or any other parts of the `asyncio.Task` - interface. - - Although we do not need to make any assertions about the task itself while - testing the listener, the code will still await it and call the `done` method, - and so we must inject something that will not fail on either action. - - Note that this is done in an `asyncSetUp`, which runs after `setUp`. - The justification is that creating an actual task requires the event - loop to be ready, which is not the case in the `setUp`. - """ - mock_task = asyncio.create_task(AsyncMock()()) # Mock async func, then a coro - self.cog_instance.crawl_task = mock_task - - async def test_on_raw_reaction_add_wrong_channel(self): - """ - Events outside of #incidents will be ignored. - - We check this by asserting that `resolve_message` was never queried. - """ - self.payload.channel_id = 0 - self.cog_instance.resolve_message = AsyncMock() - - await self.cog_instance.on_raw_reaction_add(self.payload) - self.cog_instance.resolve_message.assert_not_called() - - async def test_on_raw_reaction_add_user_is_bot(self): - """ - Events dispatched by bot accounts will be ignored. - - We check this by asserting that `resolve_message` was never queried. - """ - self.payload.member = MockMember(bot=True) - self.cog_instance.resolve_message = AsyncMock() - - await self.cog_instance.on_raw_reaction_add(self.payload) - self.cog_instance.resolve_message.assert_not_called() - - async def test_on_raw_reaction_add_message_doesnt_exist(self): - """ - Listener gracefully handles the case where `resolve_message` gives None. - - We check this by asserting that `process_event` was never called. - """ - self.cog_instance.process_event = AsyncMock() - self.cog_instance.resolve_message = AsyncMock(return_value=None) - - await self.cog_instance.on_raw_reaction_add(self.payload) - self.cog_instance.process_event.assert_not_called() - - async def test_on_raw_reaction_add_message_is_not_an_incident(self): - """ - The event won't be processed if the related message is not an incident. - - This is an edge-case that can happen if someone manually leaves a reaction - on a pinned message, or a comment. - - We check this by asserting that `process_event` was never called. - """ - self.cog_instance.process_event = AsyncMock() - self.cog_instance.resolve_message = AsyncMock(return_value=MockMessage()) - - with patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)): - await self.cog_instance.on_raw_reaction_add(self.payload) - - self.cog_instance.process_event.assert_not_called() - - async def test_on_raw_reaction_add_valid_event_is_processed(self): - """ - If the reaction event is valid, it is passed to `process_event`. - - This is the case when everything goes right: - * The reaction was placed in #incidents, and not by a bot - * The message was found successfully - * The message qualifies as an incident - - Additionally, we check that all arguments were passed as expected. - """ - incident = MockMessage(id=1) - - self.cog_instance.process_event = AsyncMock() - self.cog_instance.resolve_message = AsyncMock(return_value=incident) - - with patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)): - await self.cog_instance.on_raw_reaction_add(self.payload) - - self.cog_instance.process_event.assert_called_with( - "reaction", # Defined in `self.payload` - incident, - self.payload.member, - ) - - -class TestOnMessage(TestIncidents): - """ - Tests for the `Incidents.on_message` listener. - - Notice the decorators mocking the `is_incident` return value. The `is_incidents` - function is tested in `TestIsIncident` - here we do not worry about it. - """ - - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) - async def test_on_message_incident(self): - """Messages qualifying as incidents are passed to `add_signals`.""" - incident = MockMessage() - - with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: - await self.cog_instance.on_message(incident) - - mock_add_signals.assert_called_once_with(incident) - - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)) - async def test_on_message_non_incident(self): - """Messages not qualifying as incidents are ignored.""" - with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: - await self.cog_instance.on_message(MockMessage()) - - mock_add_signals.assert_not_called() diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py deleted file mode 100644 index f2809f40a..000000000 --- a/tests/bot/cogs/moderation/test_modlog.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest - -import discord - -from bot.cogs.moderation.modlog import ModLog -from tests.helpers import MockBot, MockTextChannel - - -class ModLogTests(unittest.IsolatedAsyncioTestCase): - """Tests for moderation logs.""" - - def setUp(self): - self.bot = MockBot() - self.cog = ModLog(self.bot) - self.channel = MockTextChannel() - - async def test_log_entry_description_truncation(self): - """Test that embed description for ModLog entry is truncated.""" - self.bot.get_channel.return_value = self.channel - await self.cog.send_log_message( - icon_url="foo", - colour=discord.Colour.blue(), - title="bar", - text="foo bar" * 3000 - ) - embed = self.channel.send.call_args[1]["embed"] - self.assertEqual( - embed.description, ("foo bar" * 3000)[:2045] + "..." - ) diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py deleted file mode 100644 index ab3d0742a..000000000 --- a/tests/bot/cogs/moderation/test_silence.py +++ /dev/null @@ -1,261 +0,0 @@ -import unittest -from unittest import mock -from unittest.mock import MagicMock, Mock - -from discord import PermissionOverwrite - -from bot.cogs.moderation.silence import Silence, SilenceNotifier -from bot.constants import Channels, Emojis, Guild, Roles -from tests.helpers import MockBot, MockContext, MockTextChannel - - -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: - self.alert_channel = MockTextChannel() - self.notifier = SilenceNotifier(self.alert_channel) - self.notifier.stop = self.notifier_stop_mock = Mock() - self.notifier.start = self.notifier_start_mock = Mock() - - def test_add_channel_adds_channel(self): - """Channel in FirstHash with current loop is added to internal set.""" - channel = Mock() - with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: - self.notifier.add_channel(channel) - silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop) - - def test_add_channel_starts_loop(self): - """Loop is started if `_silenced_channels` was empty.""" - self.notifier.add_channel(Mock()) - self.notifier_start_mock.assert_called_once() - - def test_add_channel_skips_start_with_channels(self): - """Loop start is not called when `_silenced_channels` is not empty.""" - with mock.patch.object(self.notifier, "_silenced_channels"): - self.notifier.add_channel(Mock()) - self.notifier_start_mock.assert_not_called() - - def test_remove_channel_removes_channel(self): - """Channel in FirstHash is removed from `_silenced_channels`.""" - channel = Mock() - with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: - self.notifier.remove_channel(channel) - silenced_channels.__delitem__.assert_called_with(channel) - - def test_remove_channel_stops_loop(self): - """Notifier loop is stopped if `_silenced_channels` is empty after remove.""" - with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False): - self.notifier.remove_channel(Mock()) - self.notifier_stop_mock.assert_called_once() - - def test_remove_channel_skips_stop_with_channels(self): - """Notifier loop is not stopped if `_silenced_channels` is not empty after remove.""" - self.notifier.remove_channel(Mock()) - self.notifier_stop_mock.assert_not_called() - - async def test_notifier_private_sends_alert(self): - """Alert is sent on 15 min intervals.""" - test_cases = (900, 1800, 2700) - for current_loop in test_cases: - with self.subTest(current_loop=current_loop): - with mock.patch.object(self.notifier, "_current_loop", new=current_loop): - await self.notifier._notifier() - self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") - self.alert_channel.send.reset_mock() - - async def test_notifier_skips_alert(self): - """Alert is skipped on first loop or not an increment of 900.""" - test_cases = (0, 15, 5000) - for current_loop in test_cases: - with self.subTest(current_loop=current_loop): - with mock.patch.object(self.notifier, "_current_loop", new=current_loop): - await self.notifier._notifier() - self.alert_channel.send.assert_not_called() - - -class SilenceTests(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: - self.bot = MockBot() - self.cog = Silence(self.bot) - self.ctx = MockContext() - self.cog._verified_role = None - # Set event so command callbacks can continue. - self.cog._get_instance_vars_event.set() - - async def test_instance_vars_got_guild(self): - """Bot got guild after it became available.""" - await self.cog._get_instance_vars() - self.bot.wait_until_guild_available.assert_called_once() - self.bot.get_guild.assert_called_once_with(Guild.id) - - async def test_instance_vars_got_role(self): - """Got `Roles.verified` role from guild.""" - await self.cog._get_instance_vars() - guild = self.bot.get_guild() - guild.get_role.assert_called_once_with(Roles.verified) - - async def test_instance_vars_got_channels(self): - """Got channels from bot.""" - await self.cog._get_instance_vars() - self.bot.get_channel.called_once_with(Channels.mod_alerts) - self.bot.get_channel.called_once_with(Channels.mod_log) - - @mock.patch("bot.cogs.moderation.silence.SilenceNotifier") - async def test_instance_vars_got_notifier(self, notifier): - """Notifier was started with channel.""" - mod_log = MockTextChannel() - self.bot.get_channel.side_effect = (None, mod_log) - await self.cog._get_instance_vars() - notifier.assert_called_once_with(mod_log) - self.bot.get_channel.side_effect = None - - async def test_silence_sent_correct_discord_message(self): - """Check if proper message was sent when called with duration in channel with previous state.""" - test_cases = ( - (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), - (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), - (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), - ) - for duration, result_message, _silence_patch_return in test_cases: - with self.subTest( - silence_duration=duration, - result_message=result_message, - starting_unsilenced_state=_silence_patch_return - ): - with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): - await self.cog.silence.callback(self.cog, self.ctx, duration) - self.ctx.send.assert_called_once_with(result_message) - self.ctx.reset_mock() - - async def test_unsilence_sent_correct_discord_message(self): - """Check if proper message was sent when unsilencing channel.""" - test_cases = ( - (True, f"{Emojis.check_mark} unsilenced current channel."), - (False, f"{Emojis.cross_mark} current channel was not silenced.") - ) - for _unsilence_patch_return, result_message in test_cases: - with self.subTest( - starting_silenced_state=_unsilence_patch_return, - result_message=result_message - ): - with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): - await self.cog.unsilence.callback(self.cog, self.ctx) - self.ctx.send.assert_called_once_with(result_message) - self.ctx.reset_mock() - - async def test_silence_private_for_false(self): - """Permissions are not set and `False` is returned in an already silenced channel.""" - perm_overwrite = Mock(send_messages=False) - channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - - self.assertFalse(await self.cog._silence(channel, True, None)) - channel.set_permissions.assert_not_called() - - async def test_silence_private_silenced_channel(self): - """Channel had `send_message` permissions revoked.""" - channel = MockTextChannel() - self.assertTrue(await self.cog._silence(channel, False, None)) - channel.set_permissions.assert_called_once() - self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) - - async def test_silence_private_preserves_permissions(self): - """Previous permissions were preserved when channel was silenced.""" - channel = MockTextChannel() - # Set up mock channel permission state. - mock_permissions = PermissionOverwrite() - mock_permissions_dict = dict(mock_permissions) - channel.overwrites_for.return_value = mock_permissions - await self.cog._silence(channel, False, None) - new_permissions = channel.set_permissions.call_args.kwargs - # Remove 'send_messages' key because it got changed in the method. - del new_permissions['send_messages'] - del mock_permissions_dict['send_messages'] - self.assertDictEqual(mock_permissions_dict, new_permissions) - - async def test_silence_private_notifier(self): - """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" - channel = MockTextChannel() - with mock.patch.object(self.cog, "notifier", create=True): - with self.subTest(persistent=True): - await self.cog._silence(channel, True, None) - self.cog.notifier.add_channel.assert_called_once() - - with mock.patch.object(self.cog, "notifier", create=True): - with self.subTest(persistent=False): - await self.cog._silence(channel, False, None) - self.cog.notifier.add_channel.assert_not_called() - - async def test_silence_private_added_muted_channel(self): - """Channel was added to `muted_channels` on silence.""" - channel = MockTextChannel() - with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._silence(channel, False, None) - muted_channels.add.assert_called_once_with(channel) - - async def test_unsilence_private_for_false(self): - """Permissions are not set and `False` is returned in an unsilenced channel.""" - channel = Mock() - self.assertFalse(await self.cog._unsilence(channel)) - channel.set_permissions.assert_not_called() - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_unsilenced_channel(self, _): - """Channel had `send_message` permissions restored""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - self.assertTrue(await self.cog._unsilence(channel)) - channel.set_permissions.assert_called_once() - self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_removed_notifier(self, notifier): - """Channel was removed from `notifier` on unsilence.""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - await self.cog._unsilence(channel) - notifier.remove_channel.assert_called_once_with(channel) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_removed_muted_channel(self, _): - """Channel was removed from `muted_channels` on unsilence.""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._unsilence(channel) - muted_channels.discard.assert_called_once_with(channel) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_preserves_permissions(self, _): - """Previous permissions were preserved when channel was unsilenced.""" - channel = MockTextChannel() - # Set up mock channel permission state. - mock_permissions = PermissionOverwrite(send_messages=False) - mock_permissions_dict = dict(mock_permissions) - channel.overwrites_for.return_value = mock_permissions - await self.cog._unsilence(channel) - new_permissions = channel.set_permissions.call_args.kwargs - # Remove 'send_messages' key because it got changed in the method. - del new_permissions['send_messages'] - del mock_permissions_dict['send_messages'] - self.assertDictEqual(mock_permissions_dict, new_permissions) - - @mock.patch("bot.cogs.moderation.silence.asyncio") - @mock.patch.object(Silence, "_mod_alerts_channel", create=True) - def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): - """Task for sending an alert was created with present `muted_channels`.""" - with mock.patch.object(self.cog, "muted_channels"): - self.cog.cog_unload() - alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") - asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - - @mock.patch("bot.cogs.moderation.silence.asyncio") - def test_cog_unload_skips_task_start(self, asyncio_mock): - """No task created with no channels.""" - self.cog.cog_unload() - asyncio_mock.create_task.assert_not_called() - - @mock.patch("bot.cogs.moderation.silence.with_role_check") - @mock.patch("bot.cogs.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) - def test_cog_check(self, role_check): - """Role check is called with `MODERATION_ROLES`""" - self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/moderation/test_slowmode.py b/tests/bot/cogs/moderation/test_slowmode.py deleted file mode 100644 index f442814c8..000000000 --- a/tests/bot/cogs/moderation/test_slowmode.py +++ /dev/null @@ -1,111 +0,0 @@ -import unittest -from unittest import mock - -from dateutil.relativedelta import relativedelta - -from bot.cogs.moderation.slowmode import Slowmode -from bot.constants import Emojis -from tests.helpers import MockBot, MockContext, MockTextChannel - - -class SlowmodeTests(unittest.IsolatedAsyncioTestCase): - - def setUp(self) -> None: - self.bot = MockBot() - self.cog = Slowmode(self.bot) - self.ctx = MockContext() - - async def test_get_slowmode_no_channel(self) -> None: - """Get slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) - - await self.cog.get_slowmode(self.cog, self.ctx, None) - self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") - - async def test_get_slowmode_with_channel(self) -> None: - """Get slowmode with a given channel.""" - text_channel = MockTextChannel(name='python-language', slowmode_delay=2) - - await self.cog.get_slowmode(self.cog, self.ctx, text_channel) - self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') - - async def test_set_slowmode_no_channel(self) -> None: - """Set slowmode without a given channel.""" - test_cases = ( - ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), - ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), - ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') - ) - - for channel_name, seconds, edited, result_msg in test_cases: - with self.subTest( - channel_mention=channel_name, - seconds=seconds, - edited=edited, - result_msg=result_msg - ): - self.ctx.channel = MockTextChannel(name=channel_name) - - await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) - - if edited: - self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) - else: - self.ctx.channel.edit.assert_not_called() - - self.ctx.send.assert_called_once_with(result_msg) - - self.ctx.reset_mock() - - async def test_set_slowmode_with_channel(self) -> None: - """Set slowmode with a given channel.""" - test_cases = ( - ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), - ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), - ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') - ) - - for channel_name, seconds, edited, result_msg in test_cases: - with self.subTest( - channel_mention=channel_name, - seconds=seconds, - edited=edited, - result_msg=result_msg - ): - text_channel = MockTextChannel(name=channel_name) - - await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) - - if edited: - text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) - else: - text_channel.edit.assert_not_called() - - self.ctx.send.assert_called_once_with(result_msg) - - self.ctx.reset_mock() - - async def test_reset_slowmode_no_channel(self) -> None: - """Reset slowmode without a given channel.""" - self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) - - await self.cog.reset_slowmode(self.cog, self.ctx, None) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' - ) - - async def test_reset_slowmode_with_channel(self) -> None: - """Reset slowmode with a given channel.""" - text_channel = MockTextChannel(name='meta', slowmode_delay=1) - - await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) - self.ctx.send.assert_called_once_with( - f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' - ) - - @mock.patch("bot.cogs.moderation.slowmode.with_role_check") - @mock.patch("bot.cogs.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) - def test_cog_check(self, role_check): - """Role check is called with `MODERATION_ROLES`""" - self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py deleted file mode 100644 index fdda59a8f..000000000 --- a/tests/bot/cogs/test_cogs.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Test suite for general tests which apply to all cogs.""" - -import importlib -import pkgutil -import typing as t -import unittest -from collections import defaultdict -from types import ModuleType -from unittest import mock - -from discord.ext import commands - -from bot import cogs - - -class CommandNameTests(unittest.TestCase): - """Tests for shadowing command names and aliases.""" - - @staticmethod - def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]: - """An iterator that recursively walks through `cog`'s commands and subcommands.""" - # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods. - for command in cog.__cog_commands__: - if command.parent is None: - yield command - if isinstance(command, commands.GroupMixin): - # Annoyingly it returns duplicates for each alias so use a set to fix that - yield from set(command.walk_commands()) - - @staticmethod - def walk_modules() -> t.Iterator[ModuleType]: - """Yield imported modules from the bot.cogs subpackage.""" - def on_error(name: str) -> t.NoReturn: - raise ImportError(name=name) # pragma: no cover - - # The mock prevents asyncio.get_event_loop() from being called. - with mock.patch("discord.ext.tasks.loop"): - for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): - if not module.ispkg: - yield importlib.import_module(module.name) - - @staticmethod - def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: - """Yield all cogs defined in an extension.""" - for obj in module.__dict__.values(): - # Check if it's a class type cause otherwise issubclass() may raise a TypeError. - is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) - if is_cog and obj.__module__ == module.__name__: - yield obj - - @staticmethod - def get_qualified_names(command: commands.Command) -> t.List[str]: - """Return a list of all qualified names, including aliases, for the `command`.""" - names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] - names.append(command.qualified_name) - - return names - - def get_all_commands(self) -> t.Iterator[commands.Command]: - """Yield all commands for all cogs in all extensions.""" - for module in self.walk_modules(): - for cog in self.walk_cogs(module): - for cmd in self.walk_commands(cog): - yield cmd - - def test_names_dont_shadow(self): - """Names and aliases of commands should be unique.""" - all_names = defaultdict(list) - for cmd in self.get_all_commands(): - func_name = f"{cmd.module}.{cmd.callback.__qualname__}" - - for name in self.get_qualified_names(cmd): - with self.subTest(cmd=func_name, name=name): - if name in all_names: # pragma: no cover - conflicts = ", ".join(all_names.get(name, "")) - self.fail( - f"Name '{name}' of the command {func_name} conflicts with {conflicts}." - ) - - all_names[name].append(func_name) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py deleted file mode 100644 index cfe10aebf..000000000 --- a/tests/bot/cogs/test_duck_pond.py +++ /dev/null @@ -1,548 +0,0 @@ -import asyncio -import logging -import typing -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -import discord - -from bot import constants -from bot.cogs import duck_pond -from tests import base -from tests import helpers - -MODULE_PATH = "bot.cogs.duck_pond" - - -class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): - """Tests for DuckPond functionality.""" - - @classmethod - def setUpClass(cls): - """Sets up the objects that only have to be initialized once.""" - cls.nonstaff_member = helpers.MockMember(name="Non-staffer") - - cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) - cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) - - cls.checkmark_emoji = "\N{White Heavy Check Mark}" - cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" - cls.unicode_duck_emoji = "\N{Duck}" - cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) - cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) - - def setUp(self): - """Sets up the objects that need to be refreshed before each test.""" - self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) - self.cog = duck_pond.DuckPond(bot=self.bot) - - def test_duck_pond_correctly_initializes(self): - """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" - bot = helpers.MockBot() - cog = MagicMock() - - duck_pond.DuckPond.__init__(cog, bot) - - self.assertEqual(cog.bot, bot) - self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) - bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) - - def test_fetch_webhook_succeeds_without_connectivity_issues(self): - """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" - self.bot.fetch_webhook.return_value = "dummy webhook" - self.cog.webhook_id = 1 - - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - self.assertEqual(self.cog.webhook, "dummy webhook") - - def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): - """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" - self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") - self.cog.webhook_id = 1 - - log = logging.getLogger('bot.cogs.duck_pond') - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def test_is_staff_returns_correct_values_based_on_instance_passed(self): - """The `is_staff` method should return correct values based on the instance passed.""" - test_cases = ( - (helpers.MockUser(name="User instance"), False), - (helpers.MockMember(name="Member instance without staff role"), False), - (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) - ) - - for user, expected_return in test_cases: - actual_return = self.cog.is_staff(user) - with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): - """The `has_green_checkmark` method should only return `True` if one is present.""" - test_cases = ( - ( - "No reactions", helpers.MockMessage(), False - ), - ( - "No green check mark reactions", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) - ]), - False - ), - ( - "Green check mark reaction, but not from our bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) - ]), - False - ), - ( - "Green check mark reaction, with one from the bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) - ]), - True - ) - ) - - for description, message, expected_return in test_cases: - actual_return = await self.cog.has_green_checkmark(message) - with self.subTest( - test_case=description, - expected_return=expected_return, - actual_return=actual_return - ): - self.assertEqual(expected_return, actual_return) - - def _get_reaction( - self, - emoji: typing.Union[str, helpers.MockEmoji], - staff: int = 0, - nonstaff: int = 0 - ) -> helpers.MockReaction: - staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] - nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] - return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - - async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): - """The `count_ducks` method should return the number of unique staffers who gave a duck.""" - test_cases = ( - # Simple test cases - # A message without reactions should return 0 - ( - "No reactions", - helpers.MockMessage(), - 0 - ), - # A message with a non-duck reaction from a non-staffer should return 0 - ( - "Non-duck reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), - 0 - ), - # A message with a non-duck reaction from a staffer should return 0 - ( - "Non-duck reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), - 0 - ), - # A message with a non-duck reaction from a non-staffer and staffer should return 0 - ( - "Non-duck reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a non-staffer should return 0 - ( - "Unicode Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a staffer should return 1 - ( - "Unicode Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), - 1 - ), - # A message with a unicode duck reaction from a non-staffer and staffer should return 1 - ( - "Unicode Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer should return 0 - ( - "Duckpond Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), - 0 - ), - # A message with a duckpond duck reaction from a staffer should return 1 - ( - "Duckpond Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 - ( - "Duckpond Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), - 1 - ), - - # Complex test cases - # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), - 3 - ), - # A staffer with multiple duck reactions only counts once - ( - "Two different duck reactions from the same staffer", - helpers.MockMessage( - reactions=[ - helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), - ] - ), - 1 - ), - # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) - ( - "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), - 0 - ), - # We correctly sum when multiple reactions are provided. - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage( - reactions=[ - self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), - self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), - ] - ), - 3 + 4 - ), - ) - - for description, message, expected_count in test_cases: - actual_count = await self.cog.count_ducks(message) - with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): - self.assertEqual(expected_count, actual_count) - - async def test_relay_message_correctly_relays_content_and_attachments(self): - """The `relay_message` method should correctly relay message content and attachments.""" - send_webhook_path = f"{MODULE_PATH}.send_webhook" - send_attachments_path = f"{MODULE_PATH}.send_attachments" - author = MagicMock( - display_name="x", - avatar_url="https://" - ) - - self.cog.webhook = helpers.MockAsyncWebhook() - - test_values = ( - (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), - (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), - (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), - (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), - ) - - for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: - with self.subTest(clean_content=message.clean_content, attachments=message.attachments): - await self.cog.relay_message(message) - - self.assertEqual(expect_webhook_call, send_webhook.called) - self.assertEqual(expect_attachment_call, send_attachments.called) - - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.cogs.duck_pond") - - for side_effect in side_effects: # pragma: no cover - send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertNotLogs(logger=log, level=logging.ERROR): - await self.cog.relay_message(message) - - self.assertEqual(send_webhook.call_count, 2) - - @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.cogs.duck_pond") - - side_effect = discord.HTTPException(MagicMock(), "") - send_attachments.side_effect = side_effect - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - await self.cog.relay_message(message) - - send_webhook.assert_called_once_with( - webhook=self.cog.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): - """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" - payload = MagicMock(name=label) - payload.emoji.is_custom_emoji.return_value = is_custom_emoji - payload.emoji.id = id_ - payload.emoji.name = emoji_name - return payload - - async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): - """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" - test_values = ( - # Custom Emojis - ( - self._mock_payload( - label="Custom Duckpond Emoji", - is_custom_emoji=True, - id_=constants.DuckPond.custom_emojis[0], - emoji_name="" - ), - True - ), - ( - self._mock_payload( - label="Custom Non-Duckpond Emoji", - is_custom_emoji=True, - id_=123, - emoji_name="" - ), - False - ), - # Unicode Emojis - ( - self._mock_payload( - label="Unicode Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.unicode_duck_emoji - ), - True - ), - ( - self._mock_payload( - label="Unicode Non-Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.thumbs_up_emoji - ), - False - ), - ) - - for payload, expected_return in test_values: - actual_return = self.cog._payload_has_duckpond_emoji(payload) - with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - @patch(f"{MODULE_PATH}.discord.utils.get") - @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) - def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): - """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) - - # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check - utils_get.assert_not_called() - - def _raw_reaction_mocks(self, channel_id, message_id, user_id): - """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" - channel = helpers.MockTextChannel(id=channel_id) - self.bot.get_all_channels.return_value = (channel,) - - message = helpers.MockMessage(id=message_id) - - channel.fetch_message.return_value = message - - member = helpers.MockMember(id=user_id, roles=[self.staff_role]) - message.guild.members = (member,) - - payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) - - return channel, message, member, payload - - async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): - """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" - channel_id = 1234 - message_id = 2345 - user_id = 3456 - - channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - test_cases = ( - ("non-staff member", helpers.MockMember(id=user_id)), - ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), - ) - - payload.emoji = self.duck_pond_emoji - - for description, member in test_cases: - message.guild.members = (member, ) - with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: - checkmark.side_effect = AssertionError( - "Expected method to return before calling `self.has_green_checkmark`." - ) - self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) - - # Check that we did make it past the payload checks - channel.fetch_message.assert_called_once() - channel.fetch_message.reset_mock() - - @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) - def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): - """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" - channel_id = 31415926535 - message_id = 27182818284 - user_id = 16180339887 - - channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) - payload.emoji.is_custom_emoji.return_value = False - - message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] - - is_staff.return_value = True - count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) - - # Assert that we've made it past `self.is_staff` - is_staff.assert_called_once() - - async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): - """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - - channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) - - payload.emoji = self.duck_pond_emoji - - for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_relay=should_relay): - await self.cog.on_raw_reaction_add(payload) - - # Confirm that we've made it past counting - count_ducks.assert_called_once() - - # Did we relay a message? - has_relayed = relay_message.called - self.assertEqual(has_relayed, should_relay) - - if should_relay: - relay_message.assert_called_once_with(message) - - async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): - """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" - checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) - - message = helpers.MockMessage(id=1234) - - channel = helpers.MockTextChannel(id=98765) - channel.fetch_message.return_value = message - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) - - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): - await self.cog.on_raw_reaction_remove(payload) - - # Check if we fetched the message - channel.fetch_message.assert_called_once_with(message.id) - - # Check if we actually counted the number of ducks - count_ducks.assert_called_once_with(message) - - has_re_added_checkmark = message.add_reaction.called - self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) - - if should_re_add_checkmark: - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - message.add_reaction.reset_mock() - - # reset mocks - channel.fetch_message.reset_mock() - message.reset_mock() - - def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): - """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" - channel = helpers.MockTextChannel(id=98765) - - channel.fetch_message.side_effect = AssertionError( - "Expected method to return before calling `channel.fetch_message`" - ) - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) - - channel.fetch_message.assert_not_called() - - -class DuckPondSetupTests(unittest.TestCase): - """Tests setup of the `DuckPond` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = helpers.MockBot() - duck_pond.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/utils/__init__.py b/tests/bot/cogs/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/cogs/utils/test_jams.py b/tests/bot/cogs/utils/test_jams.py deleted file mode 100644 index 299f436ba..000000000 --- a/tests/bot/cogs/utils/test_jams.py +++ /dev/null @@ -1,173 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, MagicMock, create_autospec - -from discord import CategoryChannel - -from bot.cogs.utils import jams -from bot.constants import Roles -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel - - -def get_mock_category(channel_count: int, name: str) -> CategoryChannel: - """Return a mocked code jam category.""" - category = create_autospec(CategoryChannel, spec_set=True, instance=True) - category.name = name - category.channels = [MockTextChannel() for _ in range(channel_count)] - - return category - - -class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): - """Tests for `createteam` command.""" - - def setUp(self): - self.bot = MockBot() - self.admin_role = MockRole(name="Admins", id=Roles.admins) - self.command_user = MockMember([self.admin_role]) - self.guild = MockGuild([self.admin_role]) - self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) - self.cog = jams.CodeJams(self.bot) - - async def test_too_small_amount_of_team_members_passed(self): - """Should `ctx.send` and exit early when too small amount of members.""" - for case in (1, 2): - with self.subTest(amount_of_members=case): - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - self.ctx.reset_mock() - members = (MockMember() for _ in range(case)) - await self.cog.createteam(self.cog, self.ctx, "foo", members) - - self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_not_awaited() - self.cog.add_roles.assert_not_awaited() - - async def test_duplicate_members_provided(self): - """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - member = MockMember() - await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) - - self.ctx.send.assert_awaited_once() - self.cog.create_channels.assert_not_awaited() - self.cog.add_roles.assert_not_awaited() - - async def test_result_sending(self): - """Should call `ctx.send` when everything goes right.""" - self.cog.create_channels = AsyncMock() - self.cog.add_roles = AsyncMock() - - members = [MockMember() for _ in range(5)] - await self.cog.createteam(self.cog, self.ctx, "foo", members) - - self.cog.create_channels.assert_awaited_once() - self.cog.add_roles.assert_awaited_once() - self.ctx.send.assert_awaited_once() - - async def test_category_doesnt_exist(self): - """Should create a new code jam category.""" - subtests = ( - [], - [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], - [get_mock_category(jams.MAX_CHANNELS - 2, "other")], - ) - - for categories in subtests: - self.guild.reset_mock() - self.guild.categories = categories - - with self.subTest(categories=categories): - actual_category = await self.cog.get_category(self.guild) - - self.guild.create_category_channel.assert_awaited_once() - category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] - - self.assertFalse(category_overwrites[self.guild.default_role].read_messages) - self.assertTrue(category_overwrites[self.guild.me].read_messages) - self.assertEqual(self.guild.create_category_channel.return_value, actual_category) - - async def test_category_channel_exist(self): - """Should not try to create category channel.""" - expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) - self.guild.categories = [ - get_mock_category(jams.MAX_CHANNELS - 2, "other"), - expected_category, - get_mock_category(0, jams.CATEGORY_NAME), - ] - - actual_category = await self.cog.get_category(self.guild) - self.assertEqual(expected_category, actual_category) - - async def test_channel_overwrites(self): - """Should have correct permission overwrites for users and roles.""" - leader = MockMember() - members = [leader] + [MockMember() for _ in range(4)] - overwrites = self.cog.get_overwrites(members, self.guild) - - # Leader permission overwrites - self.assertTrue(overwrites[leader].manage_messages) - self.assertTrue(overwrites[leader].read_messages) - self.assertTrue(overwrites[leader].manage_webhooks) - self.assertTrue(overwrites[leader].connect) - - # Other members permission overwrites - for member in members[1:]: - self.assertTrue(overwrites[member].read_messages) - self.assertTrue(overwrites[member].connect) - - # Everyone and verified role overwrite - self.assertFalse(overwrites[self.guild.default_role].read_messages) - self.assertFalse(overwrites[self.guild.default_role].connect) - self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) - self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) - - async def test_team_channels_creation(self): - """Should create new voice and text channel for team.""" - members = [MockMember() for _ in range(5)] - - self.cog.get_overwrites = MagicMock() - self.cog.get_category = AsyncMock() - self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") - actual = await self.cog.create_channels(self.guild, "my-team", members) - - self.assertEqual("foobar-channel", actual) - self.cog.get_overwrites.assert_called_once_with(members, self.guild) - self.cog.get_category.assert_awaited_once_with(self.guild) - - self.guild.create_text_channel.assert_awaited_once_with( - "my-team", - overwrites=self.cog.get_overwrites.return_value, - category=self.cog.get_category.return_value - ) - self.guild.create_voice_channel.assert_awaited_once_with( - "My Team", - overwrites=self.cog.get_overwrites.return_value, - category=self.cog.get_category.return_value - ) - - async def test_jam_roles_adding(self): - """Should add team leader role to leader and jam role to every team member.""" - leader_role = MockRole(name="Team Leader") - jam_role = MockRole(name="Jammer") - self.guild.get_role.side_effect = [leader_role, jam_role] - - leader = MockMember() - members = [leader] + [MockMember() for _ in range(4)] - await self.cog.add_roles(self.guild, members) - - leader.add_roles.assert_any_await(leader_role) - for member in members: - member.add_roles.assert_any_await(jam_role) - - -class CodeJamSetup(unittest.TestCase): - """Test for `setup` function of `CodeJam` cog.""" - - def test_setup(self): - """Should call `bot.add_cog`.""" - bot = MockBot() - jams.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/utils/test_snekbox.py b/tests/bot/cogs/utils/test_snekbox.py deleted file mode 100644 index 3e447f319..000000000 --- a/tests/bot/cogs/utils/test_snekbox.py +++ /dev/null @@ -1,409 +0,0 @@ -import asyncio -import logging -import unittest -from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch - -from discord.ext import commands - -from bot import constants -from bot.cogs.utils import snekbox -from bot.cogs.utils.snekbox import Snekbox -from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser - - -class SnekboxTests(unittest.IsolatedAsyncioTestCase): - def setUp(self): - """Add mocked bot and cog to the instance.""" - self.bot = MockBot() - self.cog = Snekbox(bot=self.bot) - - async def test_post_eval(self): - """Post the eval code to the URLs.snekbox_eval_api endpoint.""" - resp = MagicMock() - resp.json = AsyncMock(return_value="return") - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual(await self.cog.post_eval("import random"), "return") - self.bot.http_session.post.assert_called_with( - constants.URLs.snekbox_eval_api, - json={"input": "import random"}, - raise_for_status=True - ) - resp.json.assert_awaited_once() - - async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LEN.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) - self.assertEqual(result, "too long to upload") - - async def test_upload_output(self): - """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - key = "MarkDiamond" - resp = MagicMock() - resp.json = AsyncMock(return_value={"key": key}) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual( - await self.cog.upload_output("My awesome output"), - constants.URLs.paste_service.format(key=key) - ) - self.bot.http_session.post.assert_called_with( - constants.URLs.paste_service.format(key="documents"), - data="My awesome output", - raise_for_status=True - ) - - async def test_upload_output_gracefully_fallback_if_exception_during_request(self): - """Output upload gracefully fallback if the upload fail.""" - resp = MagicMock() - resp.json = AsyncMock(side_effect=Exception) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - log = logging.getLogger("bot.cogs.utils.snekbox") - with self.assertLogs(logger=log, level='ERROR'): - await self.cog.upload_output('My awesome output!') - - async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): - """Output upload gracefully fallback if there is no key entry in the response body.""" - self.assertEqual((await self.cog.upload_output('My awesome output!')), None) - - def test_prepare_input(self): - cases = ( - ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), - ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), - ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), - ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), - ) - for case, expected, testname in cases: - with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) - - def test_get_results_message(self): - """Return error and message according to the eval result.""" - cases = ( - ('ERROR', None, ('Your eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), - ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) - - @patch('bot.cogs.utils.snekbox.Signals', side_effect=ValueError) - def test_get_results_message_invalid_signal(self, mock_signals: Mock): - self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127', '') - ) - - @patch('bot.cogs.utils.snekbox.Signals') - def test_get_results_message_valid_signal(self, mock_signals: Mock): - mock_signals.return_value.name = 'SIGTEST' - self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127 (SIGTEST)', '') - ) - - def test_get_status_emoji(self): - """Return emoji according to the eval result.""" - cases = ( - (' ', -1, ':warning:'), - ('Hello world!', 0, ':white_check_mark:'), - ('Invalid beard size', -1, ':x:') - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) - - async def test_format_output(self): - """Test output formatting.""" - self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') - - too_many_lines = ( - '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' - '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' - ) - too_long_too_many_lines = ( - "\n".join( - f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) - )[:1000] + "\n... (truncated - too long, too many lines)" - ) - - cases = ( - ('', ('[No output]', None), 'No output'), - ('My awesome output', ('My awesome output', None), 'One line output'), - ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), - ('" else mock_.__name__ + + with self.subTest(msg=subtest_msg): + _, mock_message = mock_() + await self.syncer._send_prompt(message_arg) + + calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] + mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): + """Tests for waiting for a sync confirmation reaction on the prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) + + @staticmethod + def get_message_reaction(emoji): + """Fixture to return a mock message an reaction from the given `emoji`.""" + message = helpers.MockMessage() + reaction = helpers.MockReaction(emoji=emoji, message=message) + + return message, reaction + + def test_reaction_check_for_valid_emoji_and_authors(self): + """Should return True if authors are identical or are a bot and a core dev, respectively.""" + user_subtests = ( + ( + helpers.MockMember(id=77), + helpers.MockMember(id=77), + "identical users", + ), + ( + helpers.MockMember(id=77, bot=True), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "bot author and core-dev reactor", + ), + ) + + for emoji in self.syncer._REACTION_EMOJIS: + for author, user, msg in user_subtests: + with self.subTest(author=author, user=user, emoji=emoji, msg=msg): + message, reaction = self.get_message_reaction(emoji) + ret_val = self.syncer._reaction_check(author, message, reaction, user) + + self.assertTrue(ret_val) + + def test_reaction_check_for_invalid_reactions(self): + """Should return False for invalid reaction events.""" + valid_emoji = self.syncer._REACTION_EMOJIS[0] + subtests = ( + ( + helpers.MockMember(id=77), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "users are not identical", + ), + ( + helpers.MockMember(id=77, bot=True), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43), + "reactor lacks the core-dev role", + ), + ( + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + "reactor is a bot", + ), + ( + helpers.MockMember(id=77), + helpers.MockMessage(id=95), + helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), + helpers.MockMember(id=77), + "messages are not identical", + ), + ( + helpers.MockMember(id=77), + *self.get_message_reaction("InVaLiD"), + helpers.MockMember(id=77), + "emoji is invalid", + ), + ) + + for *args, msg in subtests: + kwargs = dict(zip(("author", "message", "reaction", "user"), args)) + with self.subTest(**kwargs, msg=msg): + ret_val = self.syncer._reaction_check(*args) + self.assertFalse(ret_val) + + async def test_wait_for_confirmation(self): + """The message should always be edited and only return True if the emoji is a check mark.""" + subtests = ( + (constants.Emojis.check_mark, True, None), + ("InVaLiD", False, None), + (None, False, asyncio.TimeoutError), + ) + + for emoji, ret_val, side_effect in subtests: + for bot in (True, False): + with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): + # Set up mocks + message = helpers.MockMessage() + member = helpers.MockMember(bot=bot) + + self.bot.wait_for.reset_mock() + self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) + self.bot.wait_for.side_effect = side_effect + + # Call the function + actual_return = await self.syncer._wait_for_confirmation(member, message) + + # Perform assertions + self.bot.wait_for.assert_called_once() + self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + + message.edit.assert_called_once() + kwargs = message.edit.call_args[1] + self.assertIn("content", kwargs) + + # Core devs should only be mentioned if the author is a bot. + if bot: + self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + else: + self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + + self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for main function orchestrating the sync.""" + + def setUp(self): + self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) + self.syncer = TestSyncer(self.bot) + + async def test_sync_respects_confirmation_result(self): + """The sync should abort if confirmation fails and continue if confirmed.""" + mock_message = helpers.MockMessage() + subtests = ( + (True, mock_message), + (False, None), + ) + + for confirmed, message in subtests: + with self.subTest(confirmed=confirmed): + self.syncer._sync.reset_mock() + self.syncer._get_diff.reset_mock() + + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(confirmed, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + + if confirmed: + self.syncer._sync.assert_called_once_with(diff) + else: + self.syncer._sync.assert_not_called() + + async def test_sync_diff_size(self): + """The diff size should be correctly calculated.""" + subtests = ( + (6, _Diff({1, 2}, {3, 4}, {5, 6})), + (5, _Diff({1, 2, 3}, None, {4, 5})), + (0, _Diff(None, None, None)), + (0, _Diff(set(), set(), set())), + ) + + for size, diff in subtests: + with self.subTest(size=size, diff=diff): + self.syncer._get_diff.reset_mock() + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + + async def test_sync_message_edited(self): + """The message should be edited if one was sent, even if the sync has an API error.""" + subtests = ( + (None, None, False), + (helpers.MockMessage(), None, True), + (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), + ) + + for message, side_effect, should_edit in subtests: + with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): + self.syncer._sync.side_effect = side_effect + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(True, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + if should_edit: + message.edit.assert_called_once() + self.assertIn("content", message.edit.call_args[1]) + + async def test_sync_confirmation_context_redirect(self): + """If ctx is given, a new message should be sent and author should be ctx's author.""" + mock_member = helpers.MockMember() + subtests = ( + (None, self.bot.user, None), + (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + ) + + for ctx, author, message in subtests: + with self.subTest(ctx=ctx, author=author, message=message): + if ctx is not None: + ctx.send.return_value = message + + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock + self.syncer._get_diff.return_value = mock.MagicMock() + + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild, ctx) + + if ctx is not None: + ctx.send.assert_called_once() + + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_small_diff(self): + """Should always return True and the given message if the diff size is too small.""" + author = helpers.MockMember() + expected_message = helpers.MockMessage() + + for size in (3, 2): # pragma: no cover + with self.subTest(size=size): + self.syncer._send_prompt = mock.AsyncMock() + self.syncer._wait_for_confirmation = mock.AsyncMock() + + coro = self.syncer._get_confirmation_result(size, author, expected_message) + result, actual_message = await coro + + self.assertTrue(result) + self.assertEqual(actual_message, expected_message) + self.syncer._send_prompt.assert_not_called() + self.syncer._wait_for_confirmation.assert_not_called() + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_large_diff(self): + """Should return True if confirmed and False if _send_prompt fails or aborted.""" + author = helpers.MockMember() + mock_message = helpers.MockMessage() + + subtests = ( + (True, mock_message, True, "confirmed"), + (False, None, False, "_send_prompt failed"), + (False, mock_message, False, "aborted"), + ) + + for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover + with self.subTest(msg=msg): + self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) + + coro = self.syncer._get_confirmation_result(4, author) + actual_result, actual_message = await coro + + self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None + self.assertIs(actual_result, expected_result) + self.assertEqual(actual_message, expected_message) + + if expected_message: + self.syncer._wait_for_confirmation.assert_called_once_with( + author, expected_message + ) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py new file mode 100644 index 000000000..1b89564f2 --- /dev/null +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -0,0 +1,416 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.exts.backend import sync +from bot.exts.backend.sync._cog import Sync +from bot.exts.backend.sync._syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): + """Tests for the sync extension.""" + + @staticmethod + def test_extension_setup(): + """The Sync cog should be added.""" + bot = helpers.MockBot() + sync.setup(bot) + bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): + """Base class for Sync cog tests. Sets up patches for syncers.""" + + def setUp(self): + self.bot = helpers.MockBot() + + self.role_syncer_patcher = mock.patch( + "bot.exts.backend.sync._syncers.RoleSyncer", + autospec=Syncer, + spec_set=True + ) + self.user_syncer_patcher = mock.patch( + "bot.exts.backend.sync._syncers.UserSyncer", + autospec=Syncer, + spec_set=True + ) + self.RoleSyncer = self.role_syncer_patcher.start() + self.UserSyncer = self.user_syncer_patcher.start() + + self.cog = Sync(self.bot) + + def tearDown(self): + self.role_syncer_patcher.stop() + self.user_syncer_patcher.stop() + + @staticmethod + def response_error(status: int) -> ResponseCodeError: + """Fixture to return a ResponseCodeError with the given status code.""" + response = mock.MagicMock() + response.status = status + + return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): + """Tests for the Sync cog.""" + + @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) + def test_sync_cog_init(self, sync_guild): + """Should instantiate syncers and run a sync for the guild.""" + # Reset because a Sync cog was already instantiated in setUp. + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + self.bot.loop.create_task = mock.MagicMock() + + mock_sync_guild_coro = mock.MagicMock() + sync_guild.return_value = mock_sync_guild_coro + + Sync(self.bot) + + self.RoleSyncer.assert_called_once_with(self.bot) + self.UserSyncer.assert_called_once_with(self.bot) + sync_guild.assert_called_once_with() + self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + + async def test_sync_cog_sync_guild(self): + """Roles and users should be synced only if a guild is successfully retrieved.""" + for guild in (helpers.MockGuild(), None): + with self.subTest(guild=guild): + self.bot.reset_mock() + self.cog.role_syncer.reset_mock() + self.cog.user_syncer.reset_mock() + + self.bot.get_guild = mock.MagicMock(return_value=guild) + + await self.cog.sync_guild() + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(constants.Guild.id) + + if guild is None: + self.cog.role_syncer.sync.assert_not_called() + self.cog.user_syncer.sync.assert_not_called() + else: + self.cog.role_syncer.sync.assert_called_once_with(guild) + self.cog.user_syncer.sync.assert_called_once_with(guild) + + async def patch_user_helper(self, side_effect: BaseException) -> None: + """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" + self.bot.api_client.patch.reset_mock(side_effect=True) + self.bot.api_client.patch.side_effect = side_effect + + user_id, updated_information = 5, {"key": 123} + await self.cog.patch_user(user_id, updated_information) + + self.bot.api_client.patch.assert_called_once_with( + f"bot/users/{user_id}", + json=updated_information, + ) + + async def test_sync_cog_patch_user(self): + """A PATCH request should be sent and 404 errors ignored.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + await self.patch_user_helper(side_effect) + + async def test_sync_cog_patch_user_non_404(self): + """A PATCH request should be sent and the error raised if it's not a 404.""" + with self.assertRaises(ResponseCodeError): + await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): + """Tests for the listeners of the Sync cog.""" + + def setUp(self): + super().setUp() + self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) + + self.guild_id_patcher = mock.patch("bot.exts.backend.sync._cog.constants.Guild.id", 5) + self.guild_id = self.guild_id_patcher.start() + + self.guild = helpers.MockGuild(id=self.guild_id) + self.other_guild = helpers.MockGuild(id=0) + + def tearDown(self): + self.guild_id_patcher.stop() + + async def test_sync_cog_on_guild_role_create(self): + """A POST request should be sent with the new role's data.""" + self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + role = helpers.MockRole(**role_data, guild=self.guild) + await self.cog.on_guild_role_create(role) + + self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + + async def test_sync_cog_on_guild_role_create_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_create(role) + self.bot.api_client.post.assert_not_awaited() + + async def test_sync_cog_on_guild_role_delete(self): + """A DELETE request should be sent.""" + self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + + role = helpers.MockRole(id=99, guild=self.guild) + await self.cog.on_guild_role_delete(role) + + self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + + async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_delete(role) + self.bot.api_client.delete.assert_not_awaited() + + async def test_sync_cog_on_guild_role_update(self): + """A PUT request should be sent if the colour, name, permissions, or position changes.""" + self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + subtests = ( + (True, ("colour", "name", "permissions", "position")), + (False, ("hoist", "mentionable")), + ) + + for should_put, attributes in subtests: + for attribute in attributes: + with self.subTest(should_put=should_put, changed_attribute=attribute): + self.bot.api_client.put.reset_mock() + + after_role_data = role_data.copy() + after_role_data[attribute] = 876 + + before_role = helpers.MockRole(**role_data, guild=self.guild) + after_role = helpers.MockRole(**after_role_data, guild=self.guild) + + await self.cog.on_guild_role_update(before_role, after_role) + + if should_put: + self.bot.api_client.put.assert_called_once_with( + f"bot/roles/{after_role.id}", + json=after_role_data + ) + else: + self.bot.api_client.put.assert_not_called() + + async def test_sync_cog_on_guild_role_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + role = helpers.MockRole(guild=self.other_guild) + await self.cog.on_guild_role_update(role, role) + self.bot.api_client.put.assert_not_awaited() + + async def test_sync_cog_on_member_remove(self): + """Member should be patched to set in_guild as False.""" + self.assertTrue(self.cog.on_member_remove.__cog_listener__) + + member = helpers.MockMember(guild=self.guild) + await self.cog.on_member_remove(member) + + self.cog.patch_user.assert_called_once_with( + member.id, + json={"in_guild": False} + ) + + async def test_sync_cog_on_member_remove_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_remove(member) + self.cog.patch_user.assert_not_awaited() + + async def test_sync_cog_on_member_update_roles(self): + """Members should be patched if their roles have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + # Roles are intentionally unsorted. + before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] + before_member = helpers.MockMember(roles=before_roles, guild=self.guild) + after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild) + + await self.cog.on_member_update(before_member, after_member) + + data = {"roles": sorted(role.id for role in after_member.roles)} + self.cog.patch_user.assert_called_once_with(after_member.id, json=data) + + async def test_sync_cog_on_member_update_other(self): + """Members should not be patched if other attributes have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + subtests = ( + ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("nick", "old nick", "new nick"), + ("status", discord.Status.online, discord.Status.offline), + ) + + for attribute, old_value, new_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) + after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild) + + await self.cog.on_member_update(before_member, after_member) + + self.cog.patch_user.assert_not_called() + + async def test_sync_cog_on_member_update_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_update(member, member) + self.cog.patch_user.assert_not_awaited() + + async def test_sync_cog_on_user_update(self): + """A user should be patched only if the name, discriminator, or avatar changes.""" + self.assertTrue(self.cog.on_user_update.__cog_listener__) + + before_data = { + "name": "old name", + "discriminator": "1234", + "bot": False, + } + + subtests = ( + (True, "name", "name", "new name", "new name"), + (True, "discriminator", "discriminator", "8765", 8765), + (False, "bot", "bot", True, True), + ) + + for should_patch, attribute, api_field, value, api_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + after_data = before_data.copy() + after_data[attribute] = value + before_user = helpers.MockUser(**before_data) + after_user = helpers.MockUser(**after_data) + + await self.cog.on_user_update(before_user, after_user) + + if should_patch: + self.cog.patch_user.assert_called_once() + + # Don't care if *all* keys are present; only the changed one is required + call_args = self.cog.patch_user.call_args + self.assertEqual(call_args.args[0], after_user.id) + self.assertIn("json", call_args.kwargs) + + self.assertIn("ignore_404", call_args.kwargs) + self.assertTrue(call_args.kwargs["ignore_404"]) + + json = call_args.kwargs["json"] + self.assertIn(api_field, json) + self.assertEqual(json[api_field], api_value) + else: + self.cog.patch_user.assert_not_called() + + async def on_member_join_helper(self, side_effect: Exception) -> dict: + """ + Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + + The request data for the mock member is returned. All exceptions will be re-raised. + """ + member = helpers.MockMember( + discriminator="1234", + roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + guild=self.guild, + ) + + data = { + "discriminator": int(member.discriminator), + "id": member.id, + "in_guild": True, + "name": member.name, + "roles": sorted(role.id for role in member.roles) + } + + self.bot.api_client.put.reset_mock(side_effect=True) + self.bot.api_client.put.side_effect = side_effect + + try: + await self.cog.on_member_join(member) + except Exception: + raise + finally: + self.bot.api_client.put.assert_called_once_with( + f"bot/users/{member.id}", + json=data + ) + + return data + + async def test_sync_cog_on_member_join(self): + """Should PUT user's data or POST it if the user doesn't exist.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + self.bot.api_client.post.reset_mock() + data = await self.on_member_join_helper(side_effect) + + if side_effect: + self.bot.api_client.post.assert_called_once_with("bot/users", json=data) + else: + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_non_404(self): + """ResponseCodeError should be re-raised if status code isn't a 404.""" + with self.assertRaises(ResponseCodeError): + await self.on_member_join_helper(self.response_error(500)) + + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_ignores_guilds(self): + """Events from other guilds should be ignored.""" + member = helpers.MockMember(guild=self.other_guild) + await self.cog.on_member_join(member) + self.bot.api_client.post.assert_not_awaited() + self.bot.api_client.put.assert_not_awaited() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): + """Tests for the commands in the Sync cog.""" + + async def test_sync_roles_command(self): + """sync() should be called on the RoleSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_roles_command.callback(self.cog, ctx) + + self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_sync_users_command(self): + """sync() should be called on the UserSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_users_command.callback(self.cog, ctx) + + self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_commands_require_admin(self): + """The sync commands should only run if the author has the administrator permission.""" + cmds = ( + self.cog.sync_group, + self.cog.sync_roles_command, + self.cog.sync_users_command, + ) + + for cmd in cmds: + with self.subTest(cmd=cmd): + await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py new file mode 100644 index 000000000..7b9f40cad --- /dev/null +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -0,0 +1,157 @@ +import unittest +from unittest import mock + +import discord + +from bot.exts.backend.sync._syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): + """Fixture to return a dictionary representing a role with default values set.""" + kwargs.setdefault("id", 9) + kwargs.setdefault("name", "fake role") + kwargs.setdefault("colour", 7) + kwargs.setdefault("permissions", 0) + kwargs.setdefault("position", 55) + + return kwargs + + +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @staticmethod + def get_guild(*roles): + """Fixture to return a guild object with the given roles.""" + guild = helpers.MockGuild() + guild.roles = [] + + for role in roles: + mock_role = helpers.MockRole(**role) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) + guild.roles.append(mock_role) + + return guild + + async def test_empty_diff_for_identical_roles(self): + """No differences should be found if the roles in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_roles(self): + """Only updated roles should be added to the 'updated' set of the diff.""" + updated_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] + guild = self.get_guild(updated_role, fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_Role(**updated_role)}, set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_roles(self): + """Only new roles should be added to the 'created' set of the diff.""" + new_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role(), new_role) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new_role)}, set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_deleted_roles(self): + """Only deleted roles should be added to the 'deleted' set of the diff.""" + deleted_role = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [fake_role(), deleted_role] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), {_Role(**deleted_role)}) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_deleted_roles(self): + """When roles are added, updated, and removed, all of them are returned properly.""" + new = fake_role(id=41, name="new") + updated = fake_role(id=71, name="updated") + deleted = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [ + fake_role(), + fake_role(id=71, name="updated name"), + deleted, + ] + guild = self.get_guild(fake_role(), new, updated) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + + self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync roles.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + async def test_sync_created_roles(self): + """Only POST requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(role_tuples, set(), set()) + await self.syncer._sync(diff) + + calls = [mock.call("bot/roles", json=role) for role in roles] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_roles(self): + """Only PUT requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), role_tuples, set()) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_deleted_roles(self): + """Only DELETE requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), set(), role_tuples) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] + self.bot.api_client.delete.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py new file mode 100644 index 000000000..c0a1da35c --- /dev/null +++ b/tests/bot/exts/backend/sync/test_users.py @@ -0,0 +1,158 @@ +import unittest +from unittest import mock + +from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User +from tests import helpers + + +def fake_user(**kwargs): + """Fixture to return a dictionary representing a user with default values set.""" + kwargs.setdefault("id", 43) + kwargs.setdefault("name", "bob the test man") + kwargs.setdefault("discriminator", 1337) + kwargs.setdefault("roles", (666,)) + kwargs.setdefault("in_guild", True) + + return kwargs + + +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between users in the DB and users in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @staticmethod + def get_guild(*members): + """Fixture to return a guild object with the given members.""" + guild = helpers.MockGuild() + guild.members = [] + + for member in members: + member = member.copy() + del member["in_guild"] + + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + + guild.members.append(mock_member) + + return guild + + async def test_empty_diff_for_no_users(self): + """When no users are given, an empty diff should be returned.""" + guild = self.get_guild() + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_identical_users(self): + """No differences should be found if the users in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_users(self): + """Only updated users should be added to the 'updated' set of the diff.""" + updated_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + guild = self.get_guild(updated_user, fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**updated_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_users(self): + """Only new users should be added to the 'created' set of the diff.""" + new_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user(), new_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_sets_in_guild_false_for_leaving_users(self): + """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_leaving_users(self): + """When users are added, updated, and removed, all of them are returned properly.""" + new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + guild = self.get_guild(fake_user(), new_user, updated_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_db_users_not_in_guild(self): + """When the DB knows a user the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync users.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + async def test_sync_created_users(self): + """Only POST requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(user_tuples, set(), None) + await self.syncer._sync(diff) + + calls = [mock.call("bot/users", json=user) for user in users] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(users)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_users(self): + """Only PUT requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(set(), user_tuples, None) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(users)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/exts/backend/test_logging.py b/tests/bot/exts/backend/test_logging.py new file mode 100644 index 000000000..466f207d9 --- /dev/null +++ b/tests/bot/exts/backend/test_logging.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import patch + +from bot import constants +from bot.exts.backend.logging import Logging +from tests.helpers import MockBot, MockTextChannel + + +class LoggingTests(unittest.IsolatedAsyncioTestCase): + """Test cases for connected login.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Logging(self.bot) + self.dev_log = MockTextChannel(id=1234, name="dev-log") + + @patch("bot.exts.backend.logging.DEBUG_MODE", False) + async def test_debug_mode_false(self): + """Should send connected message to dev-log.""" + self.bot.get_channel.return_value = self.dev_log + + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) + self.dev_log.send.assert_awaited_once() + + @patch("bot.exts.backend.logging.DEBUG_MODE", True) + async def test_debug_mode_true(self): + """Should not send anything to dev-log.""" + await self.cog.startup_greeting() + self.bot.wait_until_guild_available.assert_awaited_once_with() + self.bot.get_channel.assert_not_called() diff --git a/tests/bot/exts/filters/__init__.py b/tests/bot/exts/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py new file mode 100644 index 000000000..960894e5c --- /dev/null +++ b/tests/bot/exts/filters/test_antimalware.py @@ -0,0 +1,165 @@ +import unittest +from unittest.mock import AsyncMock, Mock + +from discord import NotFound + +from bot.constants import Channels, STAFF_ROLES +from bot.exts.filters import antimalware +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + + +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): + """Test the AntiMalware cog.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = MockBot() + self.bot.filter_list_cache = { + "FILE_FORMAT.True": { + ".first": {}, + ".second": {}, + ".third": {}, + } + } + self.cog = antimalware.AntiMalware(self.bot) + self.message = MockMessage() + self.whitelist = [".first", ".second", ".third"] + + async def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should not be deleted""" + attachment = MockAttachment(filename="python.first") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_message_without_attachment(self): + """Messages without attachments should result in no action.""" + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_direct_message_with_attachment(self): + """Direct messages should have no action taken.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.guild = None + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_message_with_illegal_extension_gets_deleted(self): + """A message containing an illegal extension should send an embed.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_called_once() + + async def test_message_send_by_staff(self): + """A message send by a member of staff should be ignored.""" + staff_role = MockRole(id=STAFF_ROLES[0]) + self.message.author.roles.append(staff_role) + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_python_file_redirect_embed_description(self): + """A message containing a .py file should result in an embed redirecting the user to our paste site""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + + self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt file should result in the correct embed.""" + attachment = MockAttachment(filename="python.txt") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.TXT_EMBED_DESCRIPTION = Mock() + antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + cmd_channel = self.bot.get_channel(Channels.bot_commands) + + self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) + antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + + async def test_other_disallowed_extension_embed_description(self): + """Test the description for a non .py/.txt disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + meta_channel = self.bot.get_channel(Channels.meta) + + self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + joined_whitelist=", ".join(self.whitelist), + blocked_extensions_str=".disallowed", + meta_channel_mention=meta_channel.mention + ) + + async def test_removing_deleted_message_logs(self): + """Removing an already deleted message logs the correct message""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + self.message.delete.assert_called_once() + + async def test_message_with_illegal_attachment_logs(self): + """Deleting a message with an illegal attachment should result in a log.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + ([], []), + (self.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], [".disallowed"]), + ([".disallowed"], [".disallowed"]), + ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] + disallowed_extensions = self.cog._get_disallowed_extensions(self.message) + self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): + """Tests setup of the `AntiMalware` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + antimalware.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/filters/test_antispam.py b/tests/bot/exts/filters/test_antispam.py new file mode 100644 index 000000000..6a0e4fded --- /dev/null +++ b/tests/bot/exts/filters/test_antispam.py @@ -0,0 +1,35 @@ +import unittest + +from bot.exts.filters import antispam + + +class AntispamConfigurationValidationTests(unittest.TestCase): + """Tests validation of the antispam cog configuration.""" + + def test_default_antispam_config_is_valid(self): + """The default antispam configuration is valid.""" + validation_errors = antispam.validate_config() + self.assertEqual(validation_errors, {}) + + def test_unknown_rule_returns_error(self): + """Configuring an unknown rule returns an error.""" + self.assertEqual( + antispam.validate_config({'invalid-rule': {}}), + {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} + ) + + def test_missing_keys_returns_error(self): + """Not configuring required keys returns an error.""" + keys = (('interval', 'max'), ('max', 'interval')) + for configured_key, unconfigured_key in keys: + with self.subTest( + configured_key=configured_key, + unconfigured_key=unconfigured_key + ): + config = {'burst': {configured_key: 10}} + error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" + + self.assertEqual( + antispam.validate_config(config), + {'burst': error} + ) diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py new file mode 100644 index 000000000..c0c3baa42 --- /dev/null +++ b/tests/bot/exts/filters/test_security.py @@ -0,0 +1,54 @@ +import unittest +from unittest.mock import MagicMock + +from discord.ext.commands import NoPrivateMessage + +from bot.exts.filters import security +from tests.helpers import MockBot, MockContext + + +class SecurityCogTests(unittest.TestCase): + """Tests the `Security` cog.""" + + def setUp(self): + """Attach an instance of the cog to the class for tests.""" + self.bot = MockBot() + self.cog = security.Security(self.bot) + self.ctx = MockContext() + + def test_check_additions(self): + """The cog should add its checks after initialization.""" + self.bot.check.assert_any_call(self.cog.check_on_guild) + self.bot.check.assert_any_call(self.cog.check_not_bot) + + def test_check_not_bot_returns_false_for_humans(self): + """The bot check should return `True` when invoked with human authors.""" + self.ctx.author.bot = False + self.assertTrue(self.cog.check_not_bot(self.ctx)) + + def test_check_not_bot_returns_true_for_robots(self): + """The bot check should return `False` when invoked with robotic authors.""" + self.ctx.author.bot = True + self.assertFalse(self.cog.check_not_bot(self.ctx)) + + def test_check_on_guild_raises_when_outside_of_guild(self): + """When invoked outside of a guild, `check_on_guild` should cause an error.""" + self.ctx.guild = None + + with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): + self.cog.check_on_guild(self.ctx) + + def test_check_on_guild_returns_true_inside_of_guild(self): + """When invoked inside of a guild, `check_on_guild` should return `True`.""" + self.ctx.guild = "lemon's lemonade stand" + self.assertTrue(self.cog.check_on_guild(self.ctx)) + + +class SecurityCogLoadTests(unittest.TestCase): + """Tests loading the `Security` cog.""" + + def test_security_cog_load(self): + """Setup of the extension should call add_cog.""" + bot = MagicMock() + security.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py new file mode 100644 index 000000000..a0ff8a877 --- /dev/null +++ b/tests/bot/exts/filters/test_token_remover.py @@ -0,0 +1,310 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock + +from discord import Colour, NotFound + +from bot import constants +from bot.exts.filters import token_remover +from bot.exts.filters.token_remover import Token, TokenRemover +from bot.exts.moderation.modlog import ModLog +from tests.helpers import MockBot, MockMessage, autospec + + +class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): + """Tests the `TokenRemover` cog.""" + + def setUp(self): + """Adds the cog, a bot, and a message to the instance for usage in tests.""" + self.bot = MockBot() + self.cog = TokenRemover(bot=self.bot) + + self.msg = MockMessage(id=555, content="hello world") + self.msg.channel.mention = "#lemonade-stand" + self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + self.msg.author.avatar_url_as.return_value = "picture-lemon.png" + + def test_is_valid_user_id_valid(self): + """Should consider user IDs valid if they decode entirely to ASCII digits.""" + ids = ( + "NDcyMjY1OTQzMDYyNDEzMzMy", + "NDc1MDczNjI5Mzk5NTQ3OTA0", + "NDY3MjIzMjMwNjUwNzc3NjQx", + ) + + for user_id in ids: + with self.subTest(user_id=user_id): + result = TokenRemover.is_valid_user_id(user_id) + self.assertTrue(result) + + def test_is_valid_user_id_invalid(self): + """Should consider non-digit and non-ASCII IDs invalid.""" + ids = ( + ("SGVsbG8gd29ybGQ", "non-digit ASCII"), + ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), + ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), + ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), + ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for user_id, msg in ids: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_user_id(user_id) + self.assertFalse(result) + + def test_is_valid_timestamp_valid(self): + """Should consider timestamps valid if they're greater than the Discord epoch.""" + timestamps = ( + "XsyRkw", + "Xrim9Q", + "XsyR-w", + "XsySD_", + "Dn9r_A", + ) + + for timestamp in timestamps: + with self.subTest(timestamp=timestamp): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertTrue(result) + + def test_is_valid_timestamp_invalid(self): + """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" + timestamps = ( + ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), + ("ew", "123"), + ("AoIKgA", "42076800"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for timestamp, msg in timestamps: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertFalse(result) + + def test_mod_log_property(self): + """The `mod_log` property should ask the bot to return the `ModLog` cog.""" + self.bot.get_cog.return_value = 'lemon' + self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) + self.bot.get_cog.assert_called_once_with('ModLog') + + async def test_on_message_edit_uses_on_message(self): + """The edit listener should delegate handling of the message to the normal listener.""" + self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) + + await self.cog.on_message_edit(MockMessage(), self.msg) + self.cog.on_message.assert_awaited_once_with(self.msg) + + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_takes_action(self, find_token_in_message, take_action): + """Should take action if a valid token is found when a message is sent.""" + cog = TokenRemover(self.bot) + found_token = "foobar" + find_token_in_message.return_value = found_token + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_awaited_once_with(cog, self.msg, found_token) + + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): + """Shouldn't take action if a valid token isn't found when a message is sent.""" + cog = TokenRemover(self.bot) + find_token_in_message.return_value = False + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_not_awaited() + + @autospec(TokenRemover, "find_token_in_message") + async def test_on_message_ignores_dms_bots(self, find_token_in_message): + """Shouldn't parse a message if it is a DM or authored by a bot.""" + cog = TokenRemover(self.bot) + dm_msg = MockMessage(guild=None) + bot_msg = MockMessage(author=MagicMock(bot=True)) + + for msg in (dm_msg, bot_msg): + await cog.on_message(msg) + find_token_in_message.assert_not_called() + + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") + def test_find_token_no_matches(self, token_re): + """None should be returned if the regex matches no tokens in a message.""" + token_re.finditer.return_value = () + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.exts.filters.token_remover", "Token") + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") + def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """The first match with a valid user ID and timestamp should be returned as a `Token`.""" + matches = [ + mock.create_autospec(Match, spec_set=True, instance=True), + mock.create_autospec(Match, spec_set=True, instance=True), + ] + tokens = [ + mock.create_autospec(Token, spec_set=True, instance=True), + mock.create_autospec(Token, spec_set=True, instance=True), + ] + + token_re.finditer.return_value = matches + token_cls.side_effect = tokens + is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid. + is_valid_timestamp.return_value = True + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertEqual(tokens[1], return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.exts.filters.token_remover", "Token") + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") + def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """None should be returned if no matches have valid user IDs or timestamps.""" + token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] + token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) + is_valid_id.return_value = False + is_valid_timestamp.return_value = False + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + + def test_regex_invalid_tokens(self): + """Messages without anything looking like a token are not matched.""" + tokens = ( + "", + "lemon wins", + "..", + "x.y", + "x.y.", + ".y.z", + ".y.", + "..z", + "x..z", + " . . ", + "\n.\n.\n", + "hellö.world.bye", + "base64.nötbåse64.morebase64", + "19jd3J.dfkm3d.€víł§tüff", + ) + + for token in tokens: + with self.subTest(token=token): + results = token_remover.TOKEN_RE.findall(token) + self.assertEqual(len(results), 0) + + def test_regex_valid_tokens(self): + """Messages that look like tokens should be matched.""" + # Don't worry, these tokens have been invalidated. + tokens = ( + "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", + "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", + "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", + ) + + for token in tokens: + with self.subTest(token=token): + results = token_remover.TOKEN_RE.fullmatch(token) + self.assertIsNotNone(results, f"{token} was not matched by the regex") + + def test_regex_matches_multiple_valid(self): + """Should support multiple matches in the middle of a string.""" + token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" + token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" + message = f"garbage {token_1} hello {token_2} world" + + results = token_remover.TOKEN_RE.finditer(message) + results = [match[0] for match in results] + self.assertCountEqual((token_1, token_2), results) + + @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE") + def test_format_log_message(self, log_message): + """Should correctly format the log message with info from the message and token.""" + token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + log_message.format.return_value = "Howdy" + + return_value = TokenRemover.format_log_message(self.msg, token) + + self.assertEqual(return_value, log_message.format.return_value) + log_message.format.assert_called_once_with( + author=self.msg.author, + author_id=self.msg.author.id, + channel=self.msg.channel.mention, + user_id=token.user_id, + timestamp=token.timestamp, + hmac="x" * len(token.hmac), + ) + + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + @autospec("bot.exts.filters.token_remover", "log") + @autospec(TokenRemover, "format_log_message") + async def test_take_action(self, format_log_message, logger, mod_log_property): + """Should delete the message and send a mod log.""" + cog = TokenRemover(self.bot) + mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) + token = mock.create_autospec(Token, spec_set=True, instance=True) + log_msg = "testing123" + + mod_log_property.return_value = mod_log + format_log_message.return_value = log_msg + + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_called_once_with( + token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + ) + + format_log_message.assert_called_once_with(self.msg, token) + logger.debug.assert_called_with(log_msg) + self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") + + mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) + mod_log.send_log_message.assert_called_once_with( + icon_url=constants.Icons.token_removed, + colour=Colour(constants.Colours.soft_red), + title="Token removed!", + text=log_msg, + thumbnail=self.msg.author.avatar_url_as.return_value, + channel_id=constants.Channels.mod_alerts + ) + + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + async def test_take_action_delete_failure(self, mod_log_property): + """Shouldn't send any messages if the token message can't be deleted.""" + cog = TokenRemover(self.bot) + mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) + self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) + + token = mock.create_autospec(Token, spec_set=True, instance=True) + await cog.take_action(self.msg, token) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_not_awaited() + + +class TokenRemoverExtensionTests(unittest.TestCase): + """Tests for the token_remover extension.""" + + @autospec("bot.exts.filters.token_remover", "TokenRemover") + def test_extension_setup(self, cog): + """The TokenRemover cog should be added.""" + bot = MockBot() + token_remover.setup(bot) + + cog.assert_called_once_with(bot) + bot.add_cog.assert_called_once() + self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/info/__init__.py b/tests/bot/exts/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py new file mode 100644 index 000000000..be47d42ef --- /dev/null +++ b/tests/bot/exts/info/test_information.py @@ -0,0 +1,584 @@ +import asyncio +import textwrap +import unittest +import unittest.mock + +import discord + +from bot import constants +from bot.exts.info import information +from bot.utils.checks import InWhitelistCheckFailure +from tests import helpers + +COG_PATH = "bot.exts.info.information.Information" + + +class InformationCogTests(unittest.TestCase): + """Tests the Information cog.""" + + @classmethod + def setUpClass(cls): + cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = helpers.MockBot() + + self.cog = information.Information(self.bot) + + self.ctx = helpers.MockContext() + self.ctx.author.roles.append(self.moderator_role) + + def test_roles_command_command(self): + """Test if the `role_info` command correctly returns the `moderator_role`.""" + self.ctx.guild.roles.append(self.moderator_role) + + self.cog.roles_info.can_run = unittest.mock.AsyncMock() + self.cog.roles_info.can_run.return_value = True + + coroutine = self.cog.roles_info.callback(self.cog, self.ctx) + + self.assertIsNone(asyncio.run(coroutine)) + self.ctx.send.assert_called_once() + + _, kwargs = self.ctx.send.call_args + embed = kwargs.pop('embed') + + self.assertEqual(embed.title, "Role information (Total 1 role)") + self.assertEqual(embed.colour, discord.Colour.blurple()) + self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") + + def test_role_info_command(self): + """Tests the `role info` command.""" + dummy_role = helpers.MockRole( + name="Dummy", + id=112233445566778899, + colour=discord.Colour.blurple(), + position=10, + members=[self.ctx.author], + permissions=discord.Permissions(0) + ) + + admin_role = helpers.MockRole( + name="Admins", + id=998877665544332211, + colour=discord.Colour.red(), + position=3, + members=[self.ctx.author], + permissions=discord.Permissions(0), + ) + + self.ctx.guild.roles.append([dummy_role, admin_role]) + + self.cog.role_info.can_run = unittest.mock.AsyncMock() + self.cog.role_info.can_run.return_value = True + + coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) + + self.assertIsNone(asyncio.run(coroutine)) + + self.assertEqual(self.ctx.send.call_count, 2) + + (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list + + dummy_embed = dummy_kwargs["embed"] + admin_embed = admin_kwargs["embed"] + + self.assertEqual(dummy_embed.title, "Dummy info") + self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) + + self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) + self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") + self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") + self.assertEqual(dummy_embed.fields[3].value, "1") + self.assertEqual(dummy_embed.fields[4].value, "10") + self.assertEqual(dummy_embed.fields[5].value, "0") + + self.assertEqual(admin_embed.title, "Admins info") + self.assertEqual(admin_embed.colour, discord.Colour.red()) + + @unittest.mock.patch('bot.exts.info.information.time_since') + def test_server_info_command(self, time_since_patch): + time_since_patch.return_value = '2 days ago' + + self.ctx.guild = helpers.MockGuild( + features=('lemons', 'apples'), + region="The Moon", + roles=[self.moderator_role], + channels=[ + discord.TextChannel( + state={}, + guild=self.ctx.guild, + data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} + ), + discord.CategoryChannel( + state={}, + guild=self.ctx.guild, + data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} + ), + discord.VoiceChannel( + state={}, + guild=self.ctx.guild, + data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} + ) + ], + members=[ + *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), + *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), + *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), + *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), + ], + member_count=1_234, + icon_url='a-lemon.jpg', + ) + + coroutine = self.cog.server_info.callback(self.cog, self.ctx) + self.assertIsNone(asyncio.run(coroutine)) + + time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') + _, kwargs = self.ctx.send.call_args + embed = kwargs.pop('embed') + self.assertEqual(embed.colour, discord.Colour.blurple()) + self.assertEqual( + embed.description, + textwrap.dedent( + f""" + **Server information** + Created: {time_since_patch.return_value} + Voice region: {self.ctx.guild.region} + Features: {', '.join(self.ctx.guild.features)} + + **Channel counts** + Category channels: 1 + Text channels: 1 + Voice channels: 1 + Staff channels: 0 + + **Member counts** + Members: {self.ctx.guild.member_count:,} + Staff members: 0 + Roles: {len(self.ctx.guild.roles)} + + **Member statuses** + {constants.Emojis.status_online} 2 + {constants.Emojis.status_idle} 1 + {constants.Emojis.status_dnd} 4 + {constants.Emojis.status_offline} 3 + """ + ) + ) + self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') + + +class UserInfractionHelperMethodTests(unittest.TestCase): + """Tests for the helper methods of the `!user` command.""" + + def setUp(self): + """Common set-up steps done before for each test.""" + self.bot = helpers.MockBot() + self.bot.api_client.get = unittest.mock.AsyncMock() + self.cog = information.Information(self.bot) + self.member = helpers.MockMember(id=1234) + + def test_user_command_helper_method_get_requests(self): + """The helper methods should form the correct get requests.""" + test_values = ( + { + "helper_method": self.cog.basic_user_infraction_counts, + "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}), + }, + { + "helper_method": self.cog.expanded_user_infraction_counts, + "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}), + }, + { + "helper_method": self.cog.user_nomination_counts, + "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}), + }, + ) + + for test_value in test_values: + helper_method = test_value["helper_method"] + endpoint, params = test_value["expected_args"] + + with self.subTest(method=helper_method, endpoint=endpoint, params=params): + asyncio.run(helper_method(self.member)) + self.bot.api_client.get.assert_called_once_with(endpoint, params=params) + self.bot.api_client.get.reset_mock() + + def _method_subtests(self, method, test_values, default_header): + """Helper method that runs the subtests for the different helper methods.""" + for test_value in test_values: + api_response = test_value["api response"] + expected_lines = test_value["expected_lines"] + + with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): + self.bot.api_client.get.return_value = api_response + + expected_output = "\n".join(default_header + expected_lines) + actual_output = asyncio.run(method(self.member)) + + self.assertEqual(expected_output, actual_output) + + def test_basic_user_infraction_counts_returns_correct_strings(self): + """The method should correctly list both the total and active number of non-hidden infractions.""" + test_values = ( + # No infractions means zero counts + { + "api response": [], + "expected_lines": ["Total: 0", "Active: 0"], + }, + # Simple, single-infraction dictionaries + { + "api response": [{"type": "ban", "active": True}], + "expected_lines": ["Total: 1", "Active: 1"], + }, + { + "api response": [{"type": "ban", "active": False}], + "expected_lines": ["Total: 1", "Active: 0"], + }, + # Multiple infractions with various `active` status + { + "api response": [ + {"type": "ban", "active": True}, + {"type": "kick", "active": False}, + {"type": "ban", "active": True}, + {"type": "ban", "active": False}, + ], + "expected_lines": ["Total: 4", "Active: 2"], + }, + ) + + header = ["**Infractions**"] + + self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) + + def test_expanded_user_infraction_counts_returns_correct_strings(self): + """The method should correctly list the total and active number of all infractions split by infraction type.""" + test_values = ( + { + "api response": [], + "expected_lines": ["This user has never received an infraction."], + }, + # Shows non-hidden inactive infraction as expected + { + "api response": [{"type": "kick", "active": False, "hidden": False}], + "expected_lines": ["Kicks: 1"], + }, + # Shows non-hidden active infraction as expected + { + "api response": [{"type": "mute", "active": True, "hidden": False}], + "expected_lines": ["Mutes: 1 (1 active)"], + }, + # Shows hidden inactive infraction as expected + { + "api response": [{"type": "superstar", "active": False, "hidden": True}], + "expected_lines": ["Superstars: 1"], + }, + # Shows hidden active infraction as expected + { + "api response": [{"type": "ban", "active": True, "hidden": True}], + "expected_lines": ["Bans: 1 (1 active)"], + }, + # Correctly displays tally of multiple infractions of mixed properties in alphabetical order + { + "api response": [ + {"type": "kick", "active": False, "hidden": True}, + {"type": "ban", "active": True, "hidden": True}, + {"type": "superstar", "active": True, "hidden": True}, + {"type": "mute", "active": True, "hidden": True}, + {"type": "ban", "active": False, "hidden": False}, + {"type": "note", "active": False, "hidden": True}, + {"type": "note", "active": False, "hidden": True}, + {"type": "warn", "active": False, "hidden": False}, + {"type": "note", "active": False, "hidden": True}, + ], + "expected_lines": [ + "Bans: 2 (1 active)", + "Kicks: 1", + "Mutes: 1 (1 active)", + "Notes: 3", + "Superstars: 1 (1 active)", + "Warns: 1", + ], + }, + ) + + header = ["**Infractions**"] + + self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) + + def test_user_nomination_counts_returns_correct_strings(self): + """The method should list the number of active and historical nominations for the user.""" + test_values = ( + { + "api response": [], + "expected_lines": ["This user has never been nominated."], + }, + { + "api response": [{'active': True}], + "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], + }, + { + "api response": [{'active': True}, {'active': False}], + "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], + }, + { + "api response": [{'active': False}], + "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."], + }, + { + "api response": [{'active': False}, {'active': False}], + "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."], + }, + + ) + + header = ["**Nominations**"] + + self._method_subtests(self.cog.user_nomination_counts, test_values, header) + + +@unittest.mock.patch("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) +@unittest.mock.patch("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50]) +class UserEmbedTests(unittest.TestCase): + """Tests for the creation of the `!user` embed.""" + + def setUp(self): + """Common set-up steps done before for each test.""" + self.bot = helpers.MockBot() + self.bot.api_client.get = unittest.mock.AsyncMock() + self.cog = information.Information(self.bot) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): + """The embed should use the string representation of the user if they don't have a nick.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + user = helpers.MockMember() + user.nick = None + user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.title, "Mr. Hemlock") + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_nick_in_title_if_available(self): + """The embed should use the nick if it's available.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + user = helpers.MockMember() + user.nick = "Cat lover" + user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_ignores_everyone_role(self): + """Created `!user` embeds should not contain mention of the @everyone-role.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) + admins_role = helpers.MockRole(name='Admins') + admins_role.colour = 100 + + # A `MockMember` has the @Everyone role by default; we add the Admins to that. + user = helpers.MockMember(roles=[admins_role], top_role=admins_role) + + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertIn("&Admins", embed.description) + self.assertNotIn("&Everyone", embed.description) + + @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) + def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): + """The embed should contain expanded infractions and nomination info in mod channels.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + infraction_counts.return_value = "expanded infractions info" + nomination_counts.return_value = "nomination info" + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + infraction_counts.assert_called_once_with(user) + nomination_counts.assert_called_once_with(user) + + self.assertEqual( + textwrap.dedent(f""" + **User Information** + Created: {"1 year ago"} + Profile: {user.mention} + ID: {user.id} + + **Member Information** + Joined: {"1 year ago"} + Roles: &Moderators + + expanded infractions info + + nomination info + """).strip(), + embed.description + ) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): + """The embed should contain only basic infraction data outside of mod channels.""" + ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + infraction_counts.return_value = "basic infractions info" + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + infraction_counts.assert_called_once_with(user) + + self.assertEqual( + textwrap.dedent(f""" + **User Information** + Created: {"1 year ago"} + Profile: {user.mention} + ID: {user.id} + + **Member Information** + Joined: {"1 year ago"} + Roles: &Moderators + + basic infractions info + """).strip(), + embed.description + ) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): + """The embed should be created with the colour of the top role, if a top role is available.""" + ctx = helpers.MockContext() + + moderators_role = helpers.MockRole(name='Moderators') + moderators_role.colour = 100 + + user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): + """The embed should be created with a blurple colour if the user has no assigned roles.""" + ctx = helpers.MockContext() + + user = helpers.MockMember(id=217) + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + self.assertEqual(embed.colour, discord.Colour.blurple()) + + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): + """The embed thumbnail should be set to the user's avatar in `png` format.""" + ctx = helpers.MockContext() + + user = helpers.MockMember(id=217) + user.avatar_url_as.return_value = "avatar url" + embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + + user.avatar_url_as.assert_called_once_with(static_format="png") + self.assertEqual(embed.thumbnail.url, "avatar url") + + +@unittest.mock.patch("bot.exts.info.information.constants") +class UserCommandTests(unittest.TestCase): + """Tests for the `!user` command.""" + + def setUp(self): + """Set up steps executed before each test is run.""" + self.bot = helpers.MockBot() + self.cog = information.Information(self.bot) + + self.moderator_role = helpers.MockRole(name="Moderators", id=2, position=10) + self.flautist_role = helpers.MockRole(name="Flautists", id=3, position=2) + self.bassist_role = helpers.MockRole(name="Bassists", id=4, position=3) + + self.author = helpers.MockMember(id=1, name="syntaxaire") + self.moderator = helpers.MockMember(id=2, name="riffautae", roles=[self.moderator_role]) + self.target = helpers.MockMember(id=3, name="__fluzz__") + + def test_regular_member_cannot_target_another_member(self, constants): + """A regular user should not be able to use `!user` targeting another user.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + + ctx = helpers.MockContext(author=self.author) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + + ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") + + def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): + """A regular user should not be able to use this command outside of bot-commands.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) + + msg = "Sorry, but you may only use this command within <#50>." + with self.assertRaises(InWhitelistCheckFailure, msg=msg): + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") + def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): + """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + create_embed.assert_called_once_with(ctx, self.author) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") + def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): + """A user should target itself with `!user` when a `user` argument was not provided.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) + + create_embed.assert_called_once_with(ctx, self.author) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") + def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): + """Staff members should be able to bypass the bot-commands channel restriction.""" + constants.STAFF_ROLES = [self.moderator_role.id] + constants.Channels.bot_commands = 50 + + ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + + create_embed.assert_called_once_with(ctx, self.moderator) + ctx.send.assert_called_once() + + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") + def test_moderators_can_target_another_member(self, create_embed, constants): + """A moderator should be able to use `!user` targeting another user.""" + constants.MODERATION_ROLES = [self.moderator_role.id] + constants.STAFF_ROLES = [self.moderator_role.id] + + ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) + + asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + + create_embed.assert_called_once_with(ctx, self.target) + ctx.send.assert_called_once() diff --git a/tests/bot/exts/moderation/__init__.py b/tests/bot/exts/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/moderation/infraction/__init__.py b/tests/bot/exts/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py new file mode 100644 index 000000000..be1b649e1 --- /dev/null +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.exts.moderation.infraction.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): + """Tests for ban and kick command reason truncation.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Infractions(self.bot) + self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) + self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) + self.guild = MockGuild(id=4567) + self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + + @patch("bot.exts.moderation.infraction._utils.get_active_infraction") + @patch("bot.exts.moderation.infraction._utils.post_infraction") + async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): + """Should truncate reason for `ctx.guild.ban`.""" + get_active_mock.return_value = None + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.bot.get_cog.return_value = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.ctx.guild.ban = Mock() + + await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) + self.ctx.guild.ban.assert_called_once_with( + self.target, + reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), + delete_message_days=0 + ) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value + ) + + @patch("bot.exts.moderation.infraction._utils.post_infraction") + async def test_apply_kick_reason_truncation(self, post_infraction_mock): + """Should truncate reason for `Member.kick`.""" + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.target.kick = Mock() + + await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) + self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value + ) diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py new file mode 100644 index 000000000..cbf7f7bcf --- /dev/null +++ b/tests/bot/exts/moderation/test_incidents.py @@ -0,0 +1,770 @@ +import asyncio +import enum +import logging +import typing as t +import unittest +from unittest.mock import AsyncMock, MagicMock, call, patch + +import aiohttp +import discord + +from bot.constants import Colours +from bot.exts.moderation import incidents +from tests.helpers import ( + MockAsyncWebhook, + MockAttachment, + MockBot, + MockMember, + MockMessage, + MockReaction, + MockRole, + MockTextChannel, + MockUser, +) + + +class MockAsyncIterable: + """ + Helper for mocking asynchronous for loops. + + It does not appear that the `unittest` library currently provides anything that would + allow us to simply mock an async iterator, such as `discord.TextChannel.history`. + + We therefore write our own helper to wrap a regular synchronous iterable, and feed + its values via `__anext__` rather than `__next__`. + + This class was written for the purposes of testing the `Incidents` cog - it may not + be generic enough to be placed in the `tests.helpers` module. + """ + + def __init__(self, messages: t.Iterable): + """Take a sync iterable to be wrapped.""" + self.iter_messages = iter(messages) + + def __aiter__(self): + """Return `self` as we provide the `__anext__` method.""" + return self + + async def __anext__(self): + """ + Feed the next item, or raise `StopAsyncIteration`. + + Since we're wrapping a sync iterator, it will communicate that it has been depleted + by raising a `StopIteration`. The `async for` construct does not expect it, and we + therefore need to substitute it for the appropriate exception type. + """ + try: + return next(self.iter_messages) + except StopIteration: + raise StopAsyncIteration + + +class MockSignal(enum.Enum): + A = "A" + B = "B" + + +mock_404 = discord.NotFound( + response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response + message="Not found", +) + + +class TestDownloadFile(unittest.IsolatedAsyncioTestCase): + """Collection of tests for the `download_file` helper function.""" + + async def test_download_file_success(self): + """If `to_file` succeeds, function returns the acquired `discord.File`.""" + file = MagicMock(discord.File, filename="bigbadlemon.jpg") + attachment = MockAttachment(to_file=AsyncMock(return_value=file)) + + acquired_file = await incidents.download_file(attachment) + self.assertIs(file, acquired_file) + + async def test_download_file_404(self): + """If `to_file` encounters a 404, function handles the exception & returns None.""" + attachment = MockAttachment(to_file=AsyncMock(side_effect=mock_404)) + + acquired_file = await incidents.download_file(attachment) + self.assertIsNone(acquired_file) + + async def test_download_file_fail(self): + """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" + arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") + attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) + + with self.assertLogs(logger=incidents.log, level=logging.ERROR): + acquired_file = await incidents.download_file(attachment) + + self.assertIsNone(acquired_file) + + +class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): + """Collection of tests for the `make_embed` helper function.""" + + async def test_make_embed_actioned(self): + """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" + embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + + self.assertEqual(embed.colour.value, Colours.soft_green) + self.assertIn("Actioned", embed.footer.text) + + async def test_make_embed_not_actioned(self): + """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" + embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + + self.assertEqual(embed.colour.value, Colours.soft_red) + self.assertIn("Rejected", embed.footer.text) + + async def test_make_embed_content(self): + """Incident content appears as embed description.""" + incident = MockMessage(content="this is an incident") + embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertEqual(incident.content, embed.description) + + async def test_make_embed_with_attachment_succeeds(self): + """Incident's attachment is downloaded and displayed in the embed's image field.""" + file = MagicMock(discord.File, filename="bigbadjoe.jpg") + attachment = MockAttachment(filename="bigbadjoe.jpg") + incident = MockMessage(content="this is an incident", attachments=[attachment]) + + # Patch `download_file` to return our `file` + with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): + embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertIs(file, returned_file) + self.assertEqual("attachment://bigbadjoe.jpg", embed.image.url) + + async def test_make_embed_with_attachment_fails(self): + """Incident's attachment fails to download, proxy url is linked instead.""" + attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") + incident = MockMessage(content="this is an incident", attachments=[attachment]) + + # Patch `download_file` to return None as if the download failed + with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): + embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + + self.assertIsNone(returned_file) + + # The author name field is simply expected to have something in it, we do not assert the message + self.assertGreater(len(embed.author.name), 0) + self.assertEqual(embed.author.url, "discord.com/bigbadjoe.jpg") # However, it should link the exact url + + +@patch("bot.constants.Channels.incidents", 123) +class TestIsIncident(unittest.TestCase): + """ + Collection of tests for the `is_incident` helper function. + + In `setUp`, we will create a mock message which should qualify as an incident. Each + test case will then mutate this instance to make it **not** qualify, in various ways. + + Notice that we patch the #incidents channel id globally for this class. + """ + + def setUp(self) -> None: + """Prepare a mock message which should qualify as an incident.""" + self.incident = MockMessage( + channel=MockTextChannel(id=123), + content="this is an incident", + author=MockUser(bot=False), + pinned=False, + ) + + def test_is_incident_true(self): + """Message qualifies as an incident if unchanged.""" + self.assertTrue(incidents.is_incident(self.incident)) + + def check_false(self): + """Assert that `self.incident` does **not** qualify as an incident.""" + self.assertFalse(incidents.is_incident(self.incident)) + + def test_is_incident_false_channel(self): + """Message doesn't qualify if sent outside of #incidents.""" + self.incident.channel = MockTextChannel(id=456) + self.check_false() + + def test_is_incident_false_content(self): + """Message doesn't qualify if content begins with hash symbol.""" + self.incident.content = "# this is a comment message" + self.check_false() + + def test_is_incident_false_author(self): + """Message doesn't qualify if author is a bot.""" + self.incident.author = MockUser(bot=True) + self.check_false() + + def test_is_incident_false_pinned(self): + """Message doesn't qualify if it is pinned.""" + self.incident.pinned = True + self.check_false() + + +class TestOwnReactions(unittest.TestCase): + """Assertions for the `own_reactions` function.""" + + def test_own_reactions(self): + """Only bot's own emoji are extracted from the input incident.""" + reactions = ( + MockReaction(emoji="A", me=True), + MockReaction(emoji="B", me=True), + MockReaction(emoji="C", me=False), + ) + message = MockMessage(reactions=reactions) + self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) + + +@patch("bot.exts.moderation.incidents.ALL_SIGNALS", {"A", "B"}) +class TestHasSignals(unittest.TestCase): + """ + Assertions for the `has_signals` function. + + We patch `ALL_SIGNALS` globally. Each test function then patches `own_reactions` + as appropriate. + """ + + def test_has_signals_true(self): + """True when `own_reactions` returns all emoji in `ALL_SIGNALS`.""" + message = MockMessage() + own_reactions = MagicMock(return_value={"A", "B"}) + + with patch("bot.exts.moderation.incidents.own_reactions", own_reactions): + self.assertTrue(incidents.has_signals(message)) + + def test_has_signals_false(self): + """False when `own_reactions` does not return all emoji in `ALL_SIGNALS`.""" + message = MockMessage() + own_reactions = MagicMock(return_value={"A", "C"}) + + with patch("bot.exts.moderation.incidents.own_reactions", own_reactions): + self.assertFalse(incidents.has_signals(message)) + + +@patch("bot.exts.moderation.incidents.Signal", MockSignal) +class TestAddSignals(unittest.IsolatedAsyncioTestCase): + """ + Assertions for the `add_signals` coroutine. + + These are all fairly similar and could go into a single test function, but I found the + patching & sub-testing fairly awkward in that case and decided to split them up + to avoid unnecessary syntax noise. + """ + + def setUp(self): + """Prepare a mock incident message for tests to use.""" + self.incident = MockMessage() + + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value=set())) + async def test_add_signals_missing(self): + """All emoji are added when none are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_has_calls([call("A"), call("B")]) + + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value={"A"})) + async def test_add_signals_partial(self): + """Only missing emoji are added when some are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_has_calls([call("B")]) + + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"})) + async def test_add_signals_present(self): + """No emoji are added when all are present.""" + await incidents.add_signals(self.incident) + self.incident.add_reaction.assert_not_called() + + +class TestIncidents(unittest.IsolatedAsyncioTestCase): + """ + Tests for bound methods of the `Incidents` cog. + + Use this as a base class for `Incidents` tests - it will prepare a fresh instance + for each test function, but not make any assertions on its own. Tests can mutate + the instance as they wish. + """ + + def setUp(self): + """ + Prepare a fresh `Incidents` instance for each test. + + Note that this will not schedule `crawl_incidents` in the background, as everything + is being mocked. The `crawl_task` attribute will end up being None. + """ + self.cog_instance = incidents.Incidents(MockBot()) + + +@patch("asyncio.sleep", AsyncMock()) # Prevent the coro from sleeping to speed up the test +class TestCrawlIncidents(TestIncidents): + """ + Tests for the `Incidents.crawl_incidents` coroutine. + + Apart from `test_crawl_incidents_waits_until_cache_ready`, all tests in this class + will patch the return values of `is_incident` and `has_signal` and then observe + whether the `AsyncMock` for `add_signals` was awaited or not. + + The `add_signals` mock is added by each test separately to ensure it is clean (has not + been awaited by another test yet). The mock can be reset, but this appears to be the + cleaner way. + + For each test, we inject a mock channel with a history of 1 message only (see: `setUp`). + """ + + def setUp(self): + """For each test, ensure `bot.get_channel` returns a channel with 1 arbitrary message.""" + super().setUp() # First ensure we get `cog_instance` from parent + + incidents_history = MagicMock(return_value=MockAsyncIterable([MockMessage()])) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(history=incidents_history)) + + async def test_crawl_incidents_waits_until_cache_ready(self): + """ + The coroutine will await the `wait_until_guild_available` event. + + Since this task is schedule in the `__init__`, it is critical that it waits for the + cache to be ready, so that it can safely get the #incidents channel. + """ + await self.cog_instance.crawl_incidents() + self.cog_instance.bot.wait_until_guild_available.assert_awaited() + + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)) # Message doesn't qualify + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=False)) + async def test_crawl_incidents_noop_if_is_not_incident(self): + """Signals are not added for a non-incident message.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_not_awaited() + + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=True)) # But already has signals + async def test_crawl_incidents_noop_if_message_already_has_signals(self): + """Signals are not added for messages which already have them.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_not_awaited() + + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=False)) # And doesn't have signals + async def test_crawl_incidents_add_signals_called(self): + """Message has signals added as it does not have them yet and qualifies as an incident.""" + await self.cog_instance.crawl_incidents() + incidents.add_signals.assert_awaited_once() + + +class TestArchive(TestIncidents): + """Tests for the `Incidents.archive` coroutine.""" + + async def test_archive_webhook_not_found(self): + """ + Method recovers and returns False when the webhook is not found. + + Implicitly, this also tests that the error is handled internally and doesn't + propagate out of the method, which is just as important. + """ + self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) + self.assertFalse( + await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) + ) + + async def test_archive_relays_incident(self): + """ + If webhook is found, method relays `incident` properly. + + This test will assert that the fetched webhook's `send` method is fed the correct arguments, + and that the `archive` method returns True. + """ + webhook = MockAsyncWebhook() + self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) # Patch in our webhook + + # Define our own `incident` to be archived + incident = MockMessage( + content="this is an incident", + author=MockUser(name="author_name", avatar_url="author_avatar"), + id=123, + ) + built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this + + with patch("bot.exts.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): + archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) + + # Now we check that the webhook was given the correct args, and that `archive` returned True + webhook.send.assert_called_once_with( + embed=built_embed, + username="author_name", + avatar_url="author_avatar", + file=None, + ) + self.assertTrue(archive_return) + + async def test_archive_clyde_username(self): + """ + The archive webhook username is cleansed using `sub_clyde`. + + Discord will reject any webhook with "clyde" in the username field, as it impersonates + the official Clyde bot. Since we do not control what the username will be (the incident + author name is used), we must ensure the name is cleansed, otherwise the relay may fail. + + This test assumes the username is passed as a kwarg. If this test fails, please review + whether the passed argument is being retrieved correctly. + """ + webhook = MockAsyncWebhook() + self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) + + message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) + await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) + + self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) + + +class TestMakeConfirmationTask(TestIncidents): + """ + Tests for the `Incidents.make_confirmation_task` method. + + Writing tests for this method is difficult, as it mostly just delegates the provided + information elsewhere. There is very little internal logic. Whether our approach + works conceptually is difficult to prove using unit tests. + """ + + def test_make_confirmation_task_check(self): + """ + The internal check will recognize the passed incident. + + This is a little tricky - we first pass a message with a specific `id` in, and then + retrieve the built check from the `call_args` of the `wait_for` method. This relies + on the check being passed as a kwarg. + + Once the check is retrieved, we assert that it gives True for our incident's `id`, + and False for any other. + + If this function begins to fail, first check that `created_check` is being retrieved + correctly. It should be the function that is built locally in the tested method. + """ + self.cog_instance.make_confirmation_task(MockMessage(id=123)) + + self.cog_instance.bot.wait_for.assert_called_once() + created_check = self.cog_instance.bot.wait_for.call_args.kwargs["check"] + + # The `message_id` matches the `id` of our incident + self.assertTrue(created_check(payload=MagicMock(message_id=123))) + + # This `message_id` does not match + self.assertFalse(created_check(payload=MagicMock(message_id=0))) + + +@patch("bot.exts.moderation.incidents.ALLOWED_ROLES", {1, 2}) +@patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", AsyncMock()) # Generic awaitable +class TestProcessEvent(TestIncidents): + """Tests for the `Incidents.process_event` coroutine.""" + + async def test_process_event_bad_role(self): + """The reaction is removed when the author lacks all allowed roles.""" + incident = MockMessage() + member = MockMember(roles=[MockRole(id=0)]) # Must have role 1 or 2 + + await self.cog_instance.process_event("reaction", incident, member) + incident.remove_reaction.assert_called_once_with("reaction", member) + + async def test_process_event_bad_emoji(self): + """ + The reaction is removed when an invalid emoji is used. + + This requires that we pass in a `member` with valid roles, as we need the role check + to succeed. + """ + incident = MockMessage() + member = MockMember(roles=[MockRole(id=1)]) # Member has allowed role + + await self.cog_instance.process_event("invalid_signal", incident, member) + incident.remove_reaction.assert_called_once_with("invalid_signal", member) + + async def test_process_event_no_archive_on_investigating(self): + """Message is not archived on `Signal.INVESTIGATING`.""" + with patch("bot.exts.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive: + await self.cog_instance.process_event( + reaction=incidents.Signal.INVESTIGATING.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]), + ) + + mocked_archive.assert_not_called() + + async def test_process_event_no_delete_if_archive_fails(self): + """ + Original message is not deleted when `Incidents.archive` returns False. + + This is the way of signaling that the relay failed, and we should not remove the original, + as that would result in losing the incident record. + """ + incident = MockMessage() + + with patch("bot.exts.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=incident, + member=MockMember(roles=[MockRole(id=1)]) + ) + + incident.delete.assert_not_called() + + async def test_process_event_confirmation_task_is_awaited(self): + """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" + mock_task = AsyncMock() + + with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]) + ) + + mock_task.assert_awaited() + + async def test_process_event_confirmation_task_timeout_is_handled(self): + """ + Confirmation task `asyncio.TimeoutError` is handled gracefully. + + We have `make_confirmation_task` return a mock with a side effect, and then catch the + exception should it propagate out of `process_event`. This is so that we can then manually + fail the test with a more informative message than just the plain traceback. + """ + mock_task = AsyncMock(side_effect=asyncio.TimeoutError()) + + try: + with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): + await self.cog_instance.process_event( + reaction=incidents.Signal.ACTIONED.value, + incident=MockMessage(), + member=MockMember(roles=[MockRole(id=1)]) + ) + except asyncio.TimeoutError: + self.fail("TimeoutError was not handled gracefully, and propagated out of `process_event`!") + + +class TestResolveMessage(TestIncidents): + """Tests for the `Incidents.resolve_message` coroutine.""" + + async def test_resolve_message_pass_message_id(self): + """Method will call `_get_message` with the passed `message_id`.""" + await self.cog_instance.resolve_message(123) + self.cog_instance.bot._connection._get_message.assert_called_once_with(123) + + async def test_resolve_message_in_cache(self): + """ + No API call is made if the queried message exists in the cache. + + We mock the `_get_message` return value regardless of input. Whether it finds the message + internally is considered d.py's responsibility, not ours. + """ + cached_message = MockMessage(id=123) + self.cog_instance.bot._connection._get_message = MagicMock(return_value=cached_message) + + return_value = await self.cog_instance.resolve_message(123) + + self.assertIs(return_value, cached_message) + self.cog_instance.bot.get_channel.assert_not_called() # The `fetch_message` line was never hit + + async def test_resolve_message_not_in_cache(self): + """ + The message is retrieved from the API if it isn't cached. + + This is desired behaviour for messages which exist, but were sent before the bot's + current session. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + # API returns our message + uncached_message = MockMessage() + fetch_message = AsyncMock(return_value=uncached_message) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + retrieved_message = await self.cog_instance.resolve_message(123) + self.assertIs(retrieved_message, uncached_message) + + async def test_resolve_message_doesnt_exist(self): + """ + If the API returns a 404, the function handles it gracefully and returns None. + + This is an edge-case happening with racing events - event A will relay the message + to the archive and delete the original. Once event B acquires the `event_lock`, + it will not find the message in the cache, and will ask the API. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + fetch_message = AsyncMock(side_effect=mock_404) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + self.assertIsNone(await self.cog_instance.resolve_message(123)) + + async def test_resolve_message_fetch_fails(self): + """ + Non-404 errors are handled, logged & None is returned. + + In contrast with a 404, this should make an error-level log. We assert that at least + one such log was made - we do not make any assertions about the log's message. + """ + self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None + + arbitrary_error = discord.HTTPException( + response=MagicMock(aiohttp.ClientResponse), + message="Arbitrary error", + ) + fetch_message = AsyncMock(side_effect=arbitrary_error) + self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + + with self.assertLogs(logger=incidents.log, level=logging.ERROR): + self.assertIsNone(await self.cog_instance.resolve_message(123)) + + +@patch("bot.constants.Channels.incidents", 123) +class TestOnRawReactionAdd(TestIncidents): + """ + Tests for the `Incidents.on_raw_reaction_add` listener. + + Writing tests for this listener comes with additional complexity due to the listener + awaiting the `crawl_task` task. See `asyncSetUp` for further details, which attempts + to make unit testing this function possible. + """ + + def setUp(self): + """ + Prepare & assign `payload` attribute. + + This attribute represents an *ideal* payload which will not be rejected by the + listener. As each test will receive a fresh instance, it can be mutated to + observe how the listener's behaviour changes with different attributes on + the passed payload. + """ + super().setUp() # Ensure `cog_instance` is assigned + + self.payload = MagicMock( + discord.RawReactionActionEvent, + channel_id=123, # Patched at class level + message_id=456, + member=MockMember(bot=False), + emoji="reaction", + ) + + async def asyncSetUp(self): # noqa: N802 + """ + Prepare an empty task and assign it as `crawl_task`. + + It appears that the `unittest` framework does not provide anything for mocking + asyncio tasks. An `AsyncMock` instance can be called and then awaited, however, + it does not provide the `done` method or any other parts of the `asyncio.Task` + interface. + + Although we do not need to make any assertions about the task itself while + testing the listener, the code will still await it and call the `done` method, + and so we must inject something that will not fail on either action. + + Note that this is done in an `asyncSetUp`, which runs after `setUp`. + The justification is that creating an actual task requires the event + loop to be ready, which is not the case in the `setUp`. + """ + mock_task = asyncio.create_task(AsyncMock()()) # Mock async func, then a coro + self.cog_instance.crawl_task = mock_task + + async def test_on_raw_reaction_add_wrong_channel(self): + """ + Events outside of #incidents will be ignored. + + We check this by asserting that `resolve_message` was never queried. + """ + self.payload.channel_id = 0 + self.cog_instance.resolve_message = AsyncMock() + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.resolve_message.assert_not_called() + + async def test_on_raw_reaction_add_user_is_bot(self): + """ + Events dispatched by bot accounts will be ignored. + + We check this by asserting that `resolve_message` was never queried. + """ + self.payload.member = MockMember(bot=True) + self.cog_instance.resolve_message = AsyncMock() + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.resolve_message.assert_not_called() + + async def test_on_raw_reaction_add_message_doesnt_exist(self): + """ + Listener gracefully handles the case where `resolve_message` gives None. + + We check this by asserting that `process_event` was never called. + """ + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=None) + + await self.cog_instance.on_raw_reaction_add(self.payload) + self.cog_instance.process_event.assert_not_called() + + async def test_on_raw_reaction_add_message_is_not_an_incident(self): + """ + The event won't be processed if the related message is not an incident. + + This is an edge-case that can happen if someone manually leaves a reaction + on a pinned message, or a comment. + + We check this by asserting that `process_event` was never called. + """ + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=MockMessage()) + + with patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)): + await self.cog_instance.on_raw_reaction_add(self.payload) + + self.cog_instance.process_event.assert_not_called() + + async def test_on_raw_reaction_add_valid_event_is_processed(self): + """ + If the reaction event is valid, it is passed to `process_event`. + + This is the case when everything goes right: + * The reaction was placed in #incidents, and not by a bot + * The message was found successfully + * The message qualifies as an incident + + Additionally, we check that all arguments were passed as expected. + """ + incident = MockMessage(id=1) + + self.cog_instance.process_event = AsyncMock() + self.cog_instance.resolve_message = AsyncMock(return_value=incident) + + with patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)): + await self.cog_instance.on_raw_reaction_add(self.payload) + + self.cog_instance.process_event.assert_called_with( + "reaction", # Defined in `self.payload` + incident, + self.payload.member, + ) + + +class TestOnMessage(TestIncidents): + """ + Tests for the `Incidents.on_message` listener. + + Notice the decorators mocking the `is_incident` return value. The `is_incidents` + function is tested in `TestIsIncident` - here we do not worry about it. + """ + + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) + async def test_on_message_incident(self): + """Messages qualifying as incidents are passed to `add_signals`.""" + incident = MockMessage() + + with patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: + await self.cog_instance.on_message(incident) + + mock_add_signals.assert_called_once_with(incident) + + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)) + async def test_on_message_non_incident(self): + """Messages not qualifying as incidents are ignored.""" + with patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: + await self.cog_instance.on_message(MockMessage()) + + mock_add_signals.assert_not_called() diff --git a/tests/bot/exts/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py new file mode 100644 index 000000000..f8f142484 --- /dev/null +++ b/tests/bot/exts/moderation/test_modlog.py @@ -0,0 +1,29 @@ +import unittest + +import discord + +from bot.exts.moderation.modlog import ModLog +from tests.helpers import MockBot, MockTextChannel + + +class ModLogTests(unittest.IsolatedAsyncioTestCase): + """Tests for moderation logs.""" + + def setUp(self): + self.bot = MockBot() + self.cog = ModLog(self.bot) + self.channel = MockTextChannel() + + async def test_log_entry_description_truncation(self): + """Test that embed description for ModLog entry is truncated.""" + self.bot.get_channel.return_value = self.channel + await self.cog.send_log_message( + icon_url="foo", + colour=discord.Colour.blue(), + title="bar", + text="foo bar" * 3000 + ) + embed = self.channel.send.call_args[1]["embed"] + self.assertEqual( + embed.description, ("foo bar" * 3000)[:2045] + "..." + ) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py new file mode 100644 index 000000000..8c4fb764a --- /dev/null +++ b/tests/bot/exts/moderation/test_silence.py @@ -0,0 +1,261 @@ +import unittest +from unittest import mock +from unittest.mock import MagicMock, Mock + +from discord import PermissionOverwrite + +from bot.constants import Channels, Emojis, Guild, Roles +from bot.exts.moderation.silence import Silence, SilenceNotifier +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.alert_channel = MockTextChannel() + self.notifier = SilenceNotifier(self.alert_channel) + self.notifier.stop = self.notifier_stop_mock = Mock() + self.notifier.start = self.notifier_start_mock = Mock() + + def test_add_channel_adds_channel(self): + """Channel in FirstHash with current loop is added to internal set.""" + channel = Mock() + with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: + self.notifier.add_channel(channel) + silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop) + + def test_add_channel_starts_loop(self): + """Loop is started if `_silenced_channels` was empty.""" + self.notifier.add_channel(Mock()) + self.notifier_start_mock.assert_called_once() + + def test_add_channel_skips_start_with_channels(self): + """Loop start is not called when `_silenced_channels` is not empty.""" + with mock.patch.object(self.notifier, "_silenced_channels"): + self.notifier.add_channel(Mock()) + self.notifier_start_mock.assert_not_called() + + def test_remove_channel_removes_channel(self): + """Channel in FirstHash is removed from `_silenced_channels`.""" + channel = Mock() + with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: + self.notifier.remove_channel(channel) + silenced_channels.__delitem__.assert_called_with(channel) + + def test_remove_channel_stops_loop(self): + """Notifier loop is stopped if `_silenced_channels` is empty after remove.""" + with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False): + self.notifier.remove_channel(Mock()) + self.notifier_stop_mock.assert_called_once() + + def test_remove_channel_skips_stop_with_channels(self): + """Notifier loop is not stopped if `_silenced_channels` is not empty after remove.""" + self.notifier.remove_channel(Mock()) + self.notifier_stop_mock.assert_not_called() + + async def test_notifier_private_sends_alert(self): + """Alert is sent on 15 min intervals.""" + test_cases = (900, 1800, 2700) + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with mock.patch.object(self.notifier, "_current_loop", new=current_loop): + await self.notifier._notifier() + self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") + self.alert_channel.send.reset_mock() + + async def test_notifier_skips_alert(self): + """Alert is skipped on first loop or not an increment of 900.""" + test_cases = (0, 15, 5000) + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with mock.patch.object(self.notifier, "_current_loop", new=current_loop): + await self.notifier._notifier() + self.alert_channel.send.assert_not_called() + + +class SilenceTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.bot = MockBot() + self.cog = Silence(self.bot) + self.ctx = MockContext() + self.cog._verified_role = None + # Set event so command callbacks can continue. + self.cog._get_instance_vars_event.set() + + async def test_instance_vars_got_guild(self): + """Bot got guild after it became available.""" + await self.cog._get_instance_vars() + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(Guild.id) + + async def test_instance_vars_got_role(self): + """Got `Roles.verified` role from guild.""" + await self.cog._get_instance_vars() + guild = self.bot.get_guild() + guild.get_role.assert_called_once_with(Roles.verified) + + async def test_instance_vars_got_channels(self): + """Got channels from bot.""" + await self.cog._get_instance_vars() + self.bot.get_channel.called_once_with(Channels.mod_alerts) + self.bot.get_channel.called_once_with(Channels.mod_log) + + @mock.patch("bot.exts.moderation.silence.SilenceNotifier") + async def test_instance_vars_got_notifier(self, notifier): + """Notifier was started with channel.""" + mod_log = MockTextChannel() + self.bot.get_channel.side_effect = (None, mod_log) + await self.cog._get_instance_vars() + notifier.assert_called_once_with(mod_log) + self.bot.get_channel.side_effect = None + + async def test_silence_sent_correct_discord_message(self): + """Check if proper message was sent when called with duration in channel with previous state.""" + test_cases = ( + (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), + (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), + (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), + ) + for duration, result_message, _silence_patch_return in test_cases: + with self.subTest( + silence_duration=duration, + result_message=result_message, + starting_unsilenced_state=_silence_patch_return + ): + with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): + await self.cog.silence.callback(self.cog, self.ctx, duration) + self.ctx.send.assert_called_once_with(result_message) + self.ctx.reset_mock() + + async def test_unsilence_sent_correct_discord_message(self): + """Check if proper message was sent when unsilencing channel.""" + test_cases = ( + (True, f"{Emojis.check_mark} unsilenced current channel."), + (False, f"{Emojis.cross_mark} current channel was not silenced.") + ) + for _unsilence_patch_return, result_message in test_cases: + with self.subTest( + starting_silenced_state=_unsilence_patch_return, + result_message=result_message + ): + with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): + await self.cog.unsilence.callback(self.cog, self.ctx) + self.ctx.send.assert_called_once_with(result_message) + self.ctx.reset_mock() + + async def test_silence_private_for_false(self): + """Permissions are not set and `False` is returned in an already silenced channel.""" + perm_overwrite = Mock(send_messages=False) + channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) + + self.assertFalse(await self.cog._silence(channel, True, None)) + channel.set_permissions.assert_not_called() + + async def test_silence_private_silenced_channel(self): + """Channel had `send_message` permissions revoked.""" + channel = MockTextChannel() + self.assertTrue(await self.cog._silence(channel, False, None)) + channel.set_permissions.assert_called_once() + self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) + + async def test_silence_private_preserves_permissions(self): + """Previous permissions were preserved when channel was silenced.""" + channel = MockTextChannel() + # Set up mock channel permission state. + mock_permissions = PermissionOverwrite() + mock_permissions_dict = dict(mock_permissions) + channel.overwrites_for.return_value = mock_permissions + await self.cog._silence(channel, False, None) + new_permissions = channel.set_permissions.call_args.kwargs + # Remove 'send_messages' key because it got changed in the method. + del new_permissions['send_messages'] + del mock_permissions_dict['send_messages'] + self.assertDictEqual(mock_permissions_dict, new_permissions) + + async def test_silence_private_notifier(self): + """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" + channel = MockTextChannel() + with mock.patch.object(self.cog, "notifier", create=True): + with self.subTest(persistent=True): + await self.cog._silence(channel, True, None) + self.cog.notifier.add_channel.assert_called_once() + + with mock.patch.object(self.cog, "notifier", create=True): + with self.subTest(persistent=False): + await self.cog._silence(channel, False, None) + self.cog.notifier.add_channel.assert_not_called() + + async def test_silence_private_added_muted_channel(self): + """Channel was added to `muted_channels` on silence.""" + channel = MockTextChannel() + with mock.patch.object(self.cog, "muted_channels") as muted_channels: + await self.cog._silence(channel, False, None) + muted_channels.add.assert_called_once_with(channel) + + async def test_unsilence_private_for_false(self): + """Permissions are not set and `False` is returned in an unsilenced channel.""" + channel = Mock() + self.assertFalse(await self.cog._unsilence(channel)) + channel.set_permissions.assert_not_called() + + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_unsilenced_channel(self, _): + """Channel had `send_message` permissions restored""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + self.assertTrue(await self.cog._unsilence(channel)) + channel.set_permissions.assert_called_once() + self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) + + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_removed_notifier(self, notifier): + """Channel was removed from `notifier` on unsilence.""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + await self.cog._unsilence(channel) + notifier.remove_channel.assert_called_once_with(channel) + + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_removed_muted_channel(self, _): + """Channel was removed from `muted_channels` on unsilence.""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + with mock.patch.object(self.cog, "muted_channels") as muted_channels: + await self.cog._unsilence(channel) + muted_channels.discard.assert_called_once_with(channel) + + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_preserves_permissions(self, _): + """Previous permissions were preserved when channel was unsilenced.""" + channel = MockTextChannel() + # Set up mock channel permission state. + mock_permissions = PermissionOverwrite(send_messages=False) + mock_permissions_dict = dict(mock_permissions) + channel.overwrites_for.return_value = mock_permissions + await self.cog._unsilence(channel) + new_permissions = channel.set_permissions.call_args.kwargs + # Remove 'send_messages' key because it got changed in the method. + del new_permissions['send_messages'] + del mock_permissions_dict['send_messages'] + self.assertDictEqual(mock_permissions_dict, new_permissions) + + @mock.patch("bot.exts.moderation.silence.asyncio") + @mock.patch.object(Silence, "_mod_alerts_channel", create=True) + def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): + """Task for sending an alert was created with present `muted_channels`.""" + with mock.patch.object(self.cog, "muted_channels"): + self.cog.cog_unload() + alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") + asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) + + @mock.patch("bot.exts.moderation.silence.asyncio") + def test_cog_unload_skips_task_start(self, asyncio_mock): + """No task created with no channels.""" + self.cog.cog_unload() + asyncio_mock.create_task.assert_not_called() + + @mock.patch("bot.exts.moderation.silence.with_role_check") + @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) + def test_cog_check(self, role_check): + """Role check is called with `MODERATION_ROLES`""" + self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/exts/moderation/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py new file mode 100644 index 000000000..e90394ab9 --- /dev/null +++ b/tests/bot/exts/moderation/test_slowmode.py @@ -0,0 +1,111 @@ +import unittest +from unittest import mock + +from dateutil.relativedelta import relativedelta + +from bot.constants import Emojis +from bot.exts.moderation.slowmode import Slowmode +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SlowmodeTests(unittest.IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.bot = MockBot() + self.cog = Slowmode(self.bot) + self.ctx = MockContext() + + async def test_get_slowmode_no_channel(self) -> None: + """Get slowmode without a given channel.""" + self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) + + await self.cog.get_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") + + async def test_get_slowmode_with_channel(self) -> None: + """Get slowmode with a given channel.""" + text_channel = MockTextChannel(name='python-language', slowmode_delay=2) + + await self.cog.get_slowmode(self.cog, self.ctx, text_channel) + self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') + + async def test_set_slowmode_no_channel(self) -> None: + """Set slowmode without a given channel.""" + test_cases = ( + ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), + ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), + ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') + ) + + for channel_name, seconds, edited, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + edited=edited, + result_msg=result_msg + ): + self.ctx.channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) + + if edited: + self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + self.ctx.channel.edit.assert_not_called() + + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + + async def test_set_slowmode_with_channel(self) -> None: + """Set slowmode with a given channel.""" + test_cases = ( + ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), + ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), + ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') + ) + + for channel_name, seconds, edited, result_msg in test_cases: + with self.subTest( + channel_mention=channel_name, + seconds=seconds, + edited=edited, + result_msg=result_msg + ): + text_channel = MockTextChannel(name=channel_name) + + await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) + + if edited: + text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) + else: + text_channel.edit.assert_not_called() + + self.ctx.send.assert_called_once_with(result_msg) + + self.ctx.reset_mock() + + async def test_reset_slowmode_no_channel(self) -> None: + """Reset slowmode without a given channel.""" + self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) + + await self.cog.reset_slowmode(self.cog, self.ctx, None) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' + ) + + async def test_reset_slowmode_with_channel(self) -> None: + """Reset slowmode with a given channel.""" + text_channel = MockTextChannel(name='meta', slowmode_delay=1) + + await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) + self.ctx.send.assert_called_once_with( + f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' + ) + + @mock.patch("bot.exts.moderation.slowmode.with_role_check") + @mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) + def test_cog_check(self, role_check): + """Role check is called with `MODERATION_ROLES`""" + self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/exts/test_cogs.py b/tests/bot/exts/test_cogs.py new file mode 100644 index 000000000..775c40722 --- /dev/null +++ b/tests/bot/exts/test_cogs.py @@ -0,0 +1,81 @@ +"""Test suite for general tests which apply to all cogs.""" + +import importlib +import pkgutil +import typing as t +import unittest +from collections import defaultdict +from types import ModuleType +from unittest import mock + +from discord.ext import commands + +from bot import exts + + +class CommandNameTests(unittest.TestCase): + """Tests for shadowing command names and aliases.""" + + @staticmethod + def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]: + """An iterator that recursively walks through `cog`'s commands and subcommands.""" + # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods. + for command in cog.__cog_commands__: + if command.parent is None: + yield command + if isinstance(command, commands.GroupMixin): + # Annoyingly it returns duplicates for each alias so use a set to fix that + yield from set(command.walk_commands()) + + @staticmethod + def walk_modules() -> t.Iterator[ModuleType]: + """Yield imported modules from the bot.exts subpackage.""" + def on_error(name: str) -> t.NoReturn: + raise ImportError(name=name) # pragma: no cover + + # The mock prevents asyncio.get_event_loop() from being called. + with mock.patch("discord.ext.tasks.loop"): + prefix = f"{exts.__name__}." + for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error): + if not module.ispkg: + yield importlib.import_module(module.name) + + @staticmethod + def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: + """Yield all cogs defined in an extension.""" + for obj in module.__dict__.values(): + # Check if it's a class type cause otherwise issubclass() may raise a TypeError. + is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) + if is_cog and obj.__module__ == module.__name__: + yield obj + + @staticmethod + def get_qualified_names(command: commands.Command) -> t.List[str]: + """Return a list of all qualified names, including aliases, for the `command`.""" + names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] + names.append(command.qualified_name) + + return names + + def get_all_commands(self) -> t.Iterator[commands.Command]: + """Yield all commands for all cogs in all extensions.""" + for module in self.walk_modules(): + for cog in self.walk_cogs(module): + for cmd in self.walk_commands(cog): + yield cmd + + def test_names_dont_shadow(self): + """Names and aliases of commands should be unique.""" + all_names = defaultdict(list) + for cmd in self.get_all_commands(): + func_name = f"{cmd.module}.{cmd.callback.__qualname__}" + + for name in self.get_qualified_names(cmd): + with self.subTest(cmd=func_name, name=name): + if name in all_names: # pragma: no cover + conflicts = ", ".join(all_names.get(name, "")) + self.fail( + f"Name '{name}' of the command {func_name} conflicts with {conflicts}." + ) + + all_names[name].append(func_name) diff --git a/tests/bot/exts/test_duck_pond.py b/tests/bot/exts/test_duck_pond.py new file mode 100644 index 000000000..f6d977482 --- /dev/null +++ b/tests/bot/exts/test_duck_pond.py @@ -0,0 +1,548 @@ +import asyncio +import logging +import typing +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +import discord + +from bot import constants +from bot.exts import duck_pond +from tests import base +from tests import helpers + +MODULE_PATH = "bot.exts.duck_pond" + + +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): + """Tests for DuckPond functionality.""" + + @classmethod + def setUpClass(cls): + """Sets up the objects that only have to be initialized once.""" + cls.nonstaff_member = helpers.MockMember(name="Non-staffer") + + cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) + cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) + + cls.checkmark_emoji = "\N{White Heavy Check Mark}" + cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" + cls.unicode_duck_emoji = "\N{Duck}" + cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) + cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) + + def setUp(self): + """Sets up the objects that need to be refreshed before each test.""" + self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) + self.cog = duck_pond.DuckPond(bot=self.bot) + + def test_duck_pond_correctly_initializes(self): + """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" + bot = helpers.MockBot() + cog = MagicMock() + + duck_pond.DuckPond.__init__(cog, bot) + + self.assertEqual(cog.bot, bot) + self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) + bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) + + def test_fetch_webhook_succeeds_without_connectivity_issues(self): + """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" + self.bot.fetch_webhook.return_value = "dummy webhook" + self.cog.webhook_id = 1 + + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + self.assertEqual(self.cog.webhook, "dummy webhook") + + def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): + """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" + self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") + self.cog.webhook_id = 1 + + log = logging.getLogger('bot.exts.duck_pond') + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def test_is_staff_returns_correct_values_based_on_instance_passed(self): + """The `is_staff` method should return correct values based on the instance passed.""" + test_cases = ( + (helpers.MockUser(name="User instance"), False), + (helpers.MockMember(name="Member instance without staff role"), False), + (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) + ) + + for user, expected_return in test_cases: + actual_return = self.cog.is_staff(user) + with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): + """The `has_green_checkmark` method should only return `True` if one is present.""" + test_cases = ( + ( + "No reactions", helpers.MockMessage(), False + ), + ( + "No green check mark reactions", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) + ]), + False + ), + ( + "Green check mark reaction, but not from our bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) + ]), + False + ), + ( + "Green check mark reaction, with one from the bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) + ]), + True + ) + ) + + for description, message, expected_return in test_cases: + actual_return = await self.cog.has_green_checkmark(message) + with self.subTest( + test_case=description, + expected_return=expected_return, + actual_return=actual_return + ): + self.assertEqual(expected_return, actual_return) + + def _get_reaction( + self, + emoji: typing.Union[str, helpers.MockEmoji], + staff: int = 0, + nonstaff: int = 0 + ) -> helpers.MockReaction: + staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] + nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] + return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) + + async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): + """The `count_ducks` method should return the number of unique staffers who gave a duck.""" + test_cases = ( + # Simple test cases + # A message without reactions should return 0 + ( + "No reactions", + helpers.MockMessage(), + 0 + ), + # A message with a non-duck reaction from a non-staffer should return 0 + ( + "Non-duck reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), + 0 + ), + # A message with a non-duck reaction from a staffer should return 0 + ( + "Non-duck reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), + 0 + ), + # A message with a non-duck reaction from a non-staffer and staffer should return 0 + ( + "Non-duck reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a non-staffer should return 0 + ( + "Unicode Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a staffer should return 1 + ( + "Unicode Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), + 1 + ), + # A message with a unicode duck reaction from a non-staffer and staffer should return 1 + ( + "Unicode Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer should return 0 + ( + "Duckpond Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), + 0 + ), + # A message with a duckpond duck reaction from a staffer should return 1 + ( + "Duckpond Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 + ( + "Duckpond Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), + 1 + ), + + # Complex test cases + # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), + 3 + ), + # A staffer with multiple duck reactions only counts once + ( + "Two different duck reactions from the same staffer", + helpers.MockMessage( + reactions=[ + helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), + ] + ), + 1 + ), + # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) + ( + "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), + 0 + ), + # We correctly sum when multiple reactions are provided. + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage( + reactions=[ + self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), + self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), + ] + ), + 3 + 4 + ), + ) + + for description, message, expected_count in test_cases: + actual_count = await self.cog.count_ducks(message) + with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): + self.assertEqual(expected_count, actual_count) + + async def test_relay_message_correctly_relays_content_and_attachments(self): + """The `relay_message` method should correctly relay message content and attachments.""" + send_webhook_path = f"{MODULE_PATH}.send_webhook" + send_attachments_path = f"{MODULE_PATH}.send_attachments" + author = MagicMock( + display_name="x", + avatar_url="https://" + ) + + self.cog.webhook = helpers.MockAsyncWebhook() + + test_values = ( + (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), + (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), + (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), + (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), + ) + + for message, expect_webhook_call, expect_attachment_call in test_values: + with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: + with self.subTest(clean_content=message.clean_content, attachments=message.attachments): + await self.cog.relay_message(message) + + self.assertEqual(expect_webhook_call, send_webhook.called) + self.assertEqual(expect_attachment_call, send_attachments.called) + + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) + async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.exts.duck_pond") + + for side_effect in side_effects: # pragma: no cover + send_attachments.side_effect = side_effect + with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertNotLogs(logger=log, level=logging.ERROR): + await self.cog.relay_message(message) + + self.assertEqual(send_webhook.call_count, 2) + + @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) + async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.exts.duck_pond") + + side_effect = discord.HTTPException(MagicMock(), "") + send_attachments.side_effect = side_effect + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + await self.cog.relay_message(message) + + send_webhook.assert_called_once_with( + webhook=self.cog.webhook, + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): + """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" + payload = MagicMock(name=label) + payload.emoji.is_custom_emoji.return_value = is_custom_emoji + payload.emoji.id = id_ + payload.emoji.name = emoji_name + return payload + + async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): + """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" + test_values = ( + # Custom Emojis + ( + self._mock_payload( + label="Custom Duckpond Emoji", + is_custom_emoji=True, + id_=constants.DuckPond.custom_emojis[0], + emoji_name="" + ), + True + ), + ( + self._mock_payload( + label="Custom Non-Duckpond Emoji", + is_custom_emoji=True, + id_=123, + emoji_name="" + ), + False + ), + # Unicode Emojis + ( + self._mock_payload( + label="Unicode Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.unicode_duck_emoji + ), + True + ), + ( + self._mock_payload( + label="Unicode Non-Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.thumbs_up_emoji + ), + False + ), + ) + + for payload, expected_return in test_values: + actual_return = self.cog._payload_has_duckpond_emoji(payload) + with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @patch(f"{MODULE_PATH}.discord.utils.get") + @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) + def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): + """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) + + # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check + utils_get.assert_not_called() + + def _raw_reaction_mocks(self, channel_id, message_id, user_id): + """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" + channel = helpers.MockTextChannel(id=channel_id) + self.bot.get_all_channels.return_value = (channel,) + + message = helpers.MockMessage(id=message_id) + + channel.fetch_message.return_value = message + + member = helpers.MockMember(id=user_id, roles=[self.staff_role]) + message.guild.members = (member,) + + payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) + + return channel, message, member, payload + + async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): + """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" + channel_id = 1234 + message_id = 2345 + user_id = 3456 + + channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + test_cases = ( + ("non-staff member", helpers.MockMember(id=user_id)), + ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), + ) + + payload.emoji = self.duck_pond_emoji + + for description, member in test_cases: + message.guild.members = (member, ) + with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: + checkmark.side_effect = AssertionError( + "Expected method to return before calling `self.has_green_checkmark`." + ) + self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) + + # Check that we did make it past the payload checks + channel.fetch_message.assert_called_once() + channel.fetch_message.reset_mock() + + @patch(f"{MODULE_PATH}.DuckPond.is_staff") + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) + def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): + """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" + channel_id = 31415926535 + message_id = 27182818284 + user_id = 16180339887 + + channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) + payload.emoji.is_custom_emoji.return_value = False + + message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] + + is_staff.return_value = True + count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) + + # Assert that we've made it past `self.is_staff` + is_staff.assert_called_once() + + async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): + """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + + channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) + + payload.emoji = self.duck_pond_emoji + + for duck_count, should_relay in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_relay=should_relay): + await self.cog.on_raw_reaction_add(payload) + + # Confirm that we've made it past counting + count_ducks.assert_called_once() + + # Did we relay a message? + has_relayed = relay_message.called + self.assertEqual(has_relayed, should_relay) + + if should_relay: + relay_message.assert_called_once_with(message) + + async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): + """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" + checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) + + message = helpers.MockMessage(id=1234) + + channel = helpers.MockTextChannel(id=98765) + channel.fetch_message.return_value = message + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) + + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + for duck_count, should_re_add_checkmark in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): + await self.cog.on_raw_reaction_remove(payload) + + # Check if we fetched the message + channel.fetch_message.assert_called_once_with(message.id) + + # Check if we actually counted the number of ducks + count_ducks.assert_called_once_with(message) + + has_re_added_checkmark = message.add_reaction.called + self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) + + if should_re_add_checkmark: + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + message.add_reaction.reset_mock() + + # reset mocks + channel.fetch_message.reset_mock() + message.reset_mock() + + def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): + """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" + channel = helpers.MockTextChannel(id=98765) + + channel.fetch_message.side_effect = AssertionError( + "Expected method to return before calling `channel.fetch_message`" + ) + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) + + channel.fetch_message.assert_not_called() + + +class DuckPondSetupTests(unittest.TestCase): + """Tests setup of the `DuckPond` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = helpers.MockBot() + duck_pond.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/utils/__init__.py b/tests/bot/exts/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/utils/test_jams.py b/tests/bot/exts/utils/test_jams.py new file mode 100644 index 000000000..45e7b5b51 --- /dev/null +++ b/tests/bot/exts/utils/test_jams.py @@ -0,0 +1,173 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, create_autospec + +from discord import CategoryChannel + +from bot.constants import Roles +from bot.exts.utils import jams +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel + + +def get_mock_category(channel_count: int, name: str) -> CategoryChannel: + """Return a mocked code jam category.""" + category = create_autospec(CategoryChannel, spec_set=True, instance=True) + category.name = name + category.channels = [MockTextChannel() for _ in range(channel_count)] + + return category + + +class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): + """Tests for `createteam` command.""" + + def setUp(self): + self.bot = MockBot() + self.admin_role = MockRole(name="Admins", id=Roles.admins) + self.command_user = MockMember([self.admin_role]) + self.guild = MockGuild([self.admin_role]) + self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) + self.cog = jams.CodeJams(self.bot) + + async def test_too_small_amount_of_team_members_passed(self): + """Should `ctx.send` and exit early when too small amount of members.""" + for case in (1, 2): + with self.subTest(amount_of_members=case): + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + self.ctx.reset_mock() + members = (MockMember() for _ in range(case)) + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + self.ctx.send.assert_awaited_once() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() + + async def test_duplicate_members_provided(self): + """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + member = MockMember() + await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) + + self.ctx.send.assert_awaited_once() + self.cog.create_channels.assert_not_awaited() + self.cog.add_roles.assert_not_awaited() + + async def test_result_sending(self): + """Should call `ctx.send` when everything goes right.""" + self.cog.create_channels = AsyncMock() + self.cog.add_roles = AsyncMock() + + members = [MockMember() for _ in range(5)] + await self.cog.createteam(self.cog, self.ctx, "foo", members) + + self.cog.create_channels.assert_awaited_once() + self.cog.add_roles.assert_awaited_once() + self.ctx.send.assert_awaited_once() + + async def test_category_doesnt_exist(self): + """Should create a new code jam category.""" + subtests = ( + [], + [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], + [get_mock_category(jams.MAX_CHANNELS - 2, "other")], + ) + + for categories in subtests: + self.guild.reset_mock() + self.guild.categories = categories + + with self.subTest(categories=categories): + actual_category = await self.cog.get_category(self.guild) + + self.guild.create_category_channel.assert_awaited_once() + category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] + + self.assertFalse(category_overwrites[self.guild.default_role].read_messages) + self.assertTrue(category_overwrites[self.guild.me].read_messages) + self.assertEqual(self.guild.create_category_channel.return_value, actual_category) + + async def test_category_channel_exist(self): + """Should not try to create category channel.""" + expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) + self.guild.categories = [ + get_mock_category(jams.MAX_CHANNELS - 2, "other"), + expected_category, + get_mock_category(0, jams.CATEGORY_NAME), + ] + + actual_category = await self.cog.get_category(self.guild) + self.assertEqual(expected_category, actual_category) + + async def test_channel_overwrites(self): + """Should have correct permission overwrites for users and roles.""" + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + overwrites = self.cog.get_overwrites(members, self.guild) + + # Leader permission overwrites + self.assertTrue(overwrites[leader].manage_messages) + self.assertTrue(overwrites[leader].read_messages) + self.assertTrue(overwrites[leader].manage_webhooks) + self.assertTrue(overwrites[leader].connect) + + # Other members permission overwrites + for member in members[1:]: + self.assertTrue(overwrites[member].read_messages) + self.assertTrue(overwrites[member].connect) + + # Everyone and verified role overwrite + self.assertFalse(overwrites[self.guild.default_role].read_messages) + self.assertFalse(overwrites[self.guild.default_role].connect) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) + self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) + + async def test_team_channels_creation(self): + """Should create new voice and text channel for team.""" + members = [MockMember() for _ in range(5)] + + self.cog.get_overwrites = MagicMock() + self.cog.get_category = AsyncMock() + self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") + actual = await self.cog.create_channels(self.guild, "my-team", members) + + self.assertEqual("foobar-channel", actual) + self.cog.get_overwrites.assert_called_once_with(members, self.guild) + self.cog.get_category.assert_awaited_once_with(self.guild) + + self.guild.create_text_channel.assert_awaited_once_with( + "my-team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + self.guild.create_voice_channel.assert_awaited_once_with( + "My Team", + overwrites=self.cog.get_overwrites.return_value, + category=self.cog.get_category.return_value + ) + + async def test_jam_roles_adding(self): + """Should add team leader role to leader and jam role to every team member.""" + leader_role = MockRole(name="Team Leader") + jam_role = MockRole(name="Jammer") + self.guild.get_role.side_effect = [leader_role, jam_role] + + leader = MockMember() + members = [leader] + [MockMember() for _ in range(4)] + await self.cog.add_roles(self.guild, members) + + leader.add_roles.assert_any_await(leader_role) + for member in members: + member.add_roles.assert_any_await(jam_role) + + +class CodeJamSetup(unittest.TestCase): + """Test for `setup` function of `CodeJam` cog.""" + + def test_setup(self): + """Should call `bot.add_cog`.""" + bot = MockBot() + jams.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py new file mode 100644 index 000000000..f7b861035 --- /dev/null +++ b/tests/bot/exts/utils/test_snekbox.py @@ -0,0 +1,409 @@ +import asyncio +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch + +from discord.ext import commands + +from bot import constants +from bot.exts.utils import snekbox +from bot.exts.utils.snekbox import Snekbox +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser + + +class SnekboxTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Add mocked bot and cog to the instance.""" + self.bot = MockBot() + self.cog = Snekbox(bot=self.bot) + + async def test_post_eval(self): + """Post the eval code to the URLs.snekbox_eval_api endpoint.""" + resp = MagicMock() + resp.json = AsyncMock(return_value="return") + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + self.assertEqual(await self.cog.post_eval("import random"), "return") + self.bot.http_session.post.assert_called_with( + constants.URLs.snekbox_eval_api, + json={"input": "import random"}, + raise_for_status=True + ) + resp.json.assert_awaited_once() + + async def test_upload_output_reject_too_long(self): + """Reject output longer than MAX_PASTE_LEN.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + self.assertEqual(result, "too long to upload") + + async def test_upload_output(self): + """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" + key = "MarkDiamond" + resp = MagicMock() + resp.json = AsyncMock(return_value={"key": key}) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + self.assertEqual( + await self.cog.upload_output("My awesome output"), + constants.URLs.paste_service.format(key=key) + ) + self.bot.http_session.post.assert_called_with( + constants.URLs.paste_service.format(key="documents"), + data="My awesome output", + raise_for_status=True + ) + + async def test_upload_output_gracefully_fallback_if_exception_during_request(self): + """Output upload gracefully fallback if the upload fail.""" + resp = MagicMock() + resp.json = AsyncMock(side_effect=Exception) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + log = logging.getLogger("bot.exts.utils.snekbox") + with self.assertLogs(logger=log, level='ERROR'): + await self.cog.upload_output('My awesome output!') + + async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): + """Output upload gracefully fallback if there is no key entry in the response body.""" + self.assertEqual((await self.cog.upload_output('My awesome output!')), None) + + def test_prepare_input(self): + cases = ( + ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), + ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), + ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), + ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), + ) + for case, expected, testname in cases: + with self.subTest(msg=f'Extract code from {testname}.'): + self.assertEqual(self.cog.prepare_input(case), expected) + + def test_get_results_message(self): + """Return error and message according to the eval result.""" + cases = ( + ('ERROR', None, ('Your eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), + ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) + def test_get_results_message_invalid_signal(self, mock_signals: Mock): + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127', '') + ) + + @patch('bot.exts.utils.snekbox.Signals') + def test_get_results_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = 'SIGTEST' + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127 (SIGTEST)', '') + ) + + def test_get_status_emoji(self): + """Return emoji according to the eval result.""" + cases = ( + (' ', -1, ':warning:'), + ('Hello world!', 0, ':white_check_mark:'), + ('Invalid beard size', -1, ':x:') + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + async def test_format_output(self): + """Test output formatting.""" + self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + + too_many_lines = ( + '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' + '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' + ) + too_long_too_many_lines = ( + "\n".join( + f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) + )[:1000] + "\n... (truncated - too long, too many lines)" + ) + + cases = ( + ('', ('[No output]', None), 'No output'), + ('My awesome output', ('My awesome output', None), 'One line output'), + ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), + (' Date: Fri, 14 Aug 2020 21:19:04 +0100 Subject: Update tests for user commands --- tests/bot/cogs/test_information.py | 87 ++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 32 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 79c0e0ad3..77b0ddf17 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -215,10 +215,10 @@ class UserInfractionHelperMethodTests(unittest.TestCase): with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): self.bot.api_client.get.return_value = api_response - expected_output = "\n".join(default_header + expected_lines) + expected_output = "\n".join(expected_lines) actual_output = asyncio.run(method(self.member)) - self.assertEqual(expected_output, actual_output) + self.assertEqual((default_header, expected_output), actual_output) def test_basic_user_infraction_counts_returns_correct_strings(self): """The method should correctly list both the total and active number of non-hidden infractions.""" @@ -249,7 +249,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): }, ) - header = ["**Infractions**"] + header = "Infractions" self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) @@ -258,7 +258,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): test_values = ( { "api response": [], - "expected_lines": ["This user has never received an infraction."], + "expected_lines": ["No infractions"], }, # Shows non-hidden inactive infraction as expected { @@ -304,7 +304,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): }, ) - header = ["**Infractions**"] + header = "Infractions" self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) @@ -313,15 +313,15 @@ class UserInfractionHelperMethodTests(unittest.TestCase): test_values = ( { "api response": [], - "expected_lines": ["This user has never been nominated."], + "expected_lines": ["No nominations"], }, { "api response": [{'active': True}], - "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], + "expected_lines": ["This user is **currently** nominated", "(1 nomination in total)"], }, { "api response": [{'active': True}, {'active': False}], - "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], + "expected_lines": ["This user is **currently** nominated", "(2 nominations in total)"], }, { "api response": [{'active': False}], @@ -334,7 +334,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): ) - header = ["**Nominations**"] + header = "Nominations" self._method_subtests(self.cog.user_nomination_counts, test_values, header) @@ -350,7 +350,10 @@ class UserEmbedTests(unittest.TestCase): self.bot.api_client.get = unittest.mock.AsyncMock() self.cog = information.Information(self.bot) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): """The embed should use the string representation of the user if they don't have a nick.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -362,7 +365,10 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Mr. Hemlock") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_nick_in_title_if_available(self): """The embed should use the nick if it's available.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -374,7 +380,10 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_ignores_everyone_role(self): """Created `!user` embeds should not contain mention of the @everyone-role.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -386,8 +395,8 @@ class UserEmbedTests(unittest.TestCase): embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - self.assertIn("&Admins", embed.description) - self.assertNotIn("&Everyone", embed.description) + self.assertIn("&Admins", embed.fields[1].value) + self.assertNotIn("&Everyone", embed.fields[1].value) @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) @@ -398,8 +407,8 @@ class UserEmbedTests(unittest.TestCase): moderators_role = helpers.MockRole(name='Moderators') moderators_role.colour = 100 - infraction_counts.return_value = "expanded infractions info" - nomination_counts.return_value = "nomination info" + infraction_counts.return_value = ("Infractions", "expanded infractions info") + nomination_counts.return_value = ("Nominations", "nomination info") user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) embed = asyncio.run(self.cog.create_user_embed(ctx, user)) @@ -409,20 +418,19 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual( textwrap.dedent(f""" - **User Information** Created: {"1 year ago"} Profile: {user.mention} ID: {user.id} + """).strip(), + embed.fields[0].value + ) - **Member Information** + self.assertEqual( + textwrap.dedent(f""" Joined: {"1 year ago"} Roles: &Moderators - - expanded infractions info - - nomination info """).strip(), - embed.description + embed.fields[1].value ) @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) @@ -433,7 +441,7 @@ class UserEmbedTests(unittest.TestCase): moderators_role = helpers.MockRole(name='Moderators') moderators_role.colour = 100 - infraction_counts.return_value = "basic infractions info" + infraction_counts.return_value = ("Infractions", "basic infractions info") user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) embed = asyncio.run(self.cog.create_user_embed(ctx, user)) @@ -442,21 +450,30 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual( textwrap.dedent(f""" - **User Information** Created: {"1 year ago"} Profile: {user.mention} ID: {user.id} + """).strip(), + embed.fields[0].value + ) - **Member Information** + self.assertEqual( + textwrap.dedent(f""" Joined: {"1 year ago"} Roles: &Moderators - - basic infractions info """).strip(), - embed.description + embed.fields[1].value + ) + + self.assertEqual( + "basic infractions info", + embed.fields[3].value ) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): """The embed should be created with the colour of the top role, if a top role is available.""" ctx = helpers.MockContext() @@ -469,7 +486,10 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): """The embed should be created with a blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() @@ -479,7 +499,10 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour.blurple()) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): """The embed thumbnail should be set to the user's avatar in `png` format.""" ctx = helpers.MockContext() -- cgit v1.2.3 From 07084103cabb95f1af25890e0059a93244088010 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 19 Aug 2020 13:34:34 -0700 Subject: Categorise most of the uncategorised extensions --- bot/exts/alias.py | 153 ---------- bot/exts/backend/alias.py | 153 ++++++++++ bot/exts/dm_relay.py | 124 -------- bot/exts/duck_pond.py | 166 ----------- bot/exts/fun/__init__.py | 0 bot/exts/fun/duck_pond.py | 166 +++++++++++ bot/exts/fun/off_topic_names.py | 162 +++++++++++ bot/exts/moderation/dm_relay.py | 124 ++++++++ bot/exts/off_topic_names.py | 162 ----------- tests/bot/exts/fun/__init__.py | 0 tests/bot/exts/fun/test_duck_pond.py | 548 +++++++++++++++++++++++++++++++++++ tests/bot/exts/test_duck_pond.py | 548 ----------------------------------- 12 files changed, 1153 insertions(+), 1153 deletions(-) delete mode 100644 bot/exts/alias.py create mode 100644 bot/exts/backend/alias.py delete mode 100644 bot/exts/dm_relay.py delete mode 100644 bot/exts/duck_pond.py create mode 100644 bot/exts/fun/__init__.py create mode 100644 bot/exts/fun/duck_pond.py create mode 100644 bot/exts/fun/off_topic_names.py create mode 100644 bot/exts/moderation/dm_relay.py delete mode 100644 bot/exts/off_topic_names.py create mode 100644 tests/bot/exts/fun/__init__.py create mode 100644 tests/bot/exts/fun/test_duck_pond.py delete mode 100644 tests/bot/exts/test_duck_pond.py (limited to 'tests') diff --git a/bot/exts/alias.py b/bot/exts/alias.py deleted file mode 100644 index 77867b933..000000000 --- a/bot/exts/alias.py +++ /dev/null @@ -1,153 +0,0 @@ -import inspect -import logging - -from discord import Colour, Embed -from discord.ext.commands import ( - Cog, Command, Context, Greedy, - clean_content, command, group, -) - -from bot.bot import Bot -from bot.converters import FetchedMember, TagNameConverter -from bot.exts.utils.extensions import Extension -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - - -class Alias (Cog): - """Aliases for commonly used commands.""" - - def __init__(self, bot: Bot): - self.bot = bot - - async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: - """Invokes a command with args and kwargs.""" - log.debug(f"{cmd_name} was invoked through an alias") - cmd = self.bot.get_command(cmd_name) - if not cmd: - return log.info(f'Did not find command "{cmd_name}" to invoke.') - elif not await cmd.can_run(ctx): - return log.info( - f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' - ) - - await ctx.invoke(cmd, *args, **kwargs) - - @command(name='aliases') - async def aliases_command(self, ctx: Context) -> None: - """Show configured aliases on the bot.""" - embed = Embed( - title='Configured aliases', - colour=Colour.blue() - ) - await LinePaginator.paginate( - ( - f"• `{ctx.prefix}{value.name}` " - f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" - for name, value in inspect.getmembers(self) - if isinstance(value, Command) and name.endswith('_alias') - ), - ctx, embed, empty=False, max_lines=20 - ) - - @command(name="resources", aliases=("resource",), hidden=True) - async def site_resources_alias(self, ctx: Context) -> None: - """Alias for invoking site resources.""" - await self.invoke(ctx, "site resources") - - @command(name="tools", hidden=True) - async def site_tools_alias(self, ctx: Context) -> None: - """Alias for invoking site tools.""" - await self.invoke(ctx, "site tools") - - @command(name="watch", hidden=True) - async def bigbrother_watch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking bigbrother watch [user] [reason].""" - await self.invoke(ctx, "bigbrother watch", user, reason=reason) - - @command(name="unwatch", hidden=True) - async def bigbrother_unwatch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking bigbrother unwatch [user] [reason].""" - await self.invoke(ctx, "bigbrother unwatch", user, reason=reason) - - @command(name="home", hidden=True) - async def site_home_alias(self, ctx: Context) -> None: - """Alias for invoking site home.""" - await self.invoke(ctx, "site home") - - @command(name="faq", hidden=True) - async def site_faq_alias(self, ctx: Context) -> None: - """Alias for invoking site faq.""" - await self.invoke(ctx, "site faq") - - @command(name="rules", aliases=("rule",), hidden=True) - async def site_rules_alias(self, ctx: Context, rules: Greedy[int], *_: str) -> None: - """Alias for invoking site rules.""" - await self.invoke(ctx, "site rules", *rules) - - @command(name="reload", hidden=True) - async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: - """Alias for invoking extensions reload [extensions...].""" - await self.invoke(ctx, "extensions reload", *extensions) - - @command(name="defon", hidden=True) - async def defcon_enable_alias(self, ctx: Context) -> None: - """Alias for invoking defcon enable.""" - await self.invoke(ctx, "defcon enable") - - @command(name="defoff", hidden=True) - async def defcon_disable_alias(self, ctx: Context) -> None: - """Alias for invoking defcon disable.""" - await self.invoke(ctx, "defcon disable") - - @command(name="exception", hidden=True) - async def tags_get_traceback_alias(self, ctx: Context) -> None: - """Alias for invoking tags get traceback.""" - await self.invoke(ctx, "tags get", tag_name="traceback") - - @group(name="get", - aliases=("show", "g"), - hidden=True, - invoke_without_command=True) - async def get_group_alias(self, ctx: Context) -> None: - """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" - pass - - @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) - async def tags_get_alias( - self, ctx: Context, *, tag_name: TagNameConverter = None - ) -> None: - """ - Alias for invoking tags get [tag_name]. - - tag_name: str - tag to be viewed. - """ - await self.invoke(ctx, "tags get", tag_name=tag_name) - - @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) - async def docs_get_alias( - self, ctx: Context, symbol: clean_content = None - ) -> None: - """Alias for invoking docs get [symbol].""" - await self.invoke(ctx, "docs get", symbol) - - @command(name="nominate", hidden=True) - async def nomination_add_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking talentpool add [user] [reason].""" - await self.invoke(ctx, "talentpool add", user, reason=reason) - - @command(name="unnominate", hidden=True) - async def nomination_end_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: - """Alias for invoking nomination end [user] [reason].""" - await self.invoke(ctx, "nomination end", user, reason=reason) - - @command(name="nominees", hidden=True) - async def nominees_alias(self, ctx: Context) -> None: - """Alias for invoking tp watched.""" - await self.invoke(ctx, "talentpool watched") - - -def setup(bot: Bot) -> None: - """Load the Alias cog.""" - bot.add_cog(Alias(bot)) diff --git a/bot/exts/backend/alias.py b/bot/exts/backend/alias.py new file mode 100644 index 000000000..77867b933 --- /dev/null +++ b/bot/exts/backend/alias.py @@ -0,0 +1,153 @@ +import inspect +import logging + +from discord import Colour, Embed +from discord.ext.commands import ( + Cog, Command, Context, Greedy, + clean_content, command, group, +) + +from bot.bot import Bot +from bot.converters import FetchedMember, TagNameConverter +from bot.exts.utils.extensions import Extension +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + + +class Alias (Cog): + """Aliases for commonly used commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: + """Invokes a command with args and kwargs.""" + log.debug(f"{cmd_name} was invoked through an alias") + cmd = self.bot.get_command(cmd_name) + if not cmd: + return log.info(f'Did not find command "{cmd_name}" to invoke.') + elif not await cmd.can_run(ctx): + return log.info( + f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' + ) + + await ctx.invoke(cmd, *args, **kwargs) + + @command(name='aliases') + async def aliases_command(self, ctx: Context) -> None: + """Show configured aliases on the bot.""" + embed = Embed( + title='Configured aliases', + colour=Colour.blue() + ) + await LinePaginator.paginate( + ( + f"• `{ctx.prefix}{value.name}` " + f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" + for name, value in inspect.getmembers(self) + if isinstance(value, Command) and name.endswith('_alias') + ), + ctx, embed, empty=False, max_lines=20 + ) + + @command(name="resources", aliases=("resource",), hidden=True) + async def site_resources_alias(self, ctx: Context) -> None: + """Alias for invoking site resources.""" + await self.invoke(ctx, "site resources") + + @command(name="tools", hidden=True) + async def site_tools_alias(self, ctx: Context) -> None: + """Alias for invoking site tools.""" + await self.invoke(ctx, "site tools") + + @command(name="watch", hidden=True) + async def bigbrother_watch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking bigbrother watch [user] [reason].""" + await self.invoke(ctx, "bigbrother watch", user, reason=reason) + + @command(name="unwatch", hidden=True) + async def bigbrother_unwatch_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking bigbrother unwatch [user] [reason].""" + await self.invoke(ctx, "bigbrother unwatch", user, reason=reason) + + @command(name="home", hidden=True) + async def site_home_alias(self, ctx: Context) -> None: + """Alias for invoking site home.""" + await self.invoke(ctx, "site home") + + @command(name="faq", hidden=True) + async def site_faq_alias(self, ctx: Context) -> None: + """Alias for invoking site faq.""" + await self.invoke(ctx, "site faq") + + @command(name="rules", aliases=("rule",), hidden=True) + async def site_rules_alias(self, ctx: Context, rules: Greedy[int], *_: str) -> None: + """Alias for invoking site rules.""" + await self.invoke(ctx, "site rules", *rules) + + @command(name="reload", hidden=True) + async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: + """Alias for invoking extensions reload [extensions...].""" + await self.invoke(ctx, "extensions reload", *extensions) + + @command(name="defon", hidden=True) + async def defcon_enable_alias(self, ctx: Context) -> None: + """Alias for invoking defcon enable.""" + await self.invoke(ctx, "defcon enable") + + @command(name="defoff", hidden=True) + async def defcon_disable_alias(self, ctx: Context) -> None: + """Alias for invoking defcon disable.""" + await self.invoke(ctx, "defcon disable") + + @command(name="exception", hidden=True) + async def tags_get_traceback_alias(self, ctx: Context) -> None: + """Alias for invoking tags get traceback.""" + await self.invoke(ctx, "tags get", tag_name="traceback") + + @group(name="get", + aliases=("show", "g"), + hidden=True, + invoke_without_command=True) + async def get_group_alias(self, ctx: Context) -> None: + """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" + pass + + @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) + async def tags_get_alias( + self, ctx: Context, *, tag_name: TagNameConverter = None + ) -> None: + """ + Alias for invoking tags get [tag_name]. + + tag_name: str - tag to be viewed. + """ + await self.invoke(ctx, "tags get", tag_name=tag_name) + + @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) + async def docs_get_alias( + self, ctx: Context, symbol: clean_content = None + ) -> None: + """Alias for invoking docs get [symbol].""" + await self.invoke(ctx, "docs get", symbol) + + @command(name="nominate", hidden=True) + async def nomination_add_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking talentpool add [user] [reason].""" + await self.invoke(ctx, "talentpool add", user, reason=reason) + + @command(name="unnominate", hidden=True) + async def nomination_end_alias(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: + """Alias for invoking nomination end [user] [reason].""" + await self.invoke(ctx, "nomination end", user, reason=reason) + + @command(name="nominees", hidden=True) + async def nominees_alias(self, ctx: Context) -> None: + """Alias for invoking tp watched.""" + await self.invoke(ctx, "talentpool watched") + + +def setup(bot: Bot) -> None: + """Load the Alias cog.""" + bot.add_cog(Alias(bot)) diff --git a/bot/exts/dm_relay.py b/bot/exts/dm_relay.py deleted file mode 100644 index 0d8f340b4..000000000 --- a/bot/exts/dm_relay.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging -from typing import Optional - -import discord -from discord import Color -from discord.ext import commands -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot -from bot.converters import UserMentionOrID -from bot.utils import RedisCache -from bot.utils.checks import in_whitelist_check, with_role_check -from bot.utils.messages import send_attachments -from bot.utils.webhooks import send_webhook - -log = logging.getLogger(__name__) - - -class DMRelay(Cog): - """Relay direct messages to and from the bot.""" - - # RedisCache[str, t.Union[discord.User.id, discord.Member.id]] - dm_cache = RedisCache() - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_id = constants.Webhooks.dm_log - self.webhook = None - self.bot.loop.create_task(self.fetch_webhook()) - - @commands.command(aliases=("reply",)) - async def send_dm(self, ctx: commands.Context, member: Optional[UserMentionOrID], *, message: str) -> None: - """ - Allows you to send a DM to a user from the bot. - - If `member` is not provided, it will send to the last user who DM'd the bot. - - This feature should be used extremely sparingly. Use ModMail if you need to have a serious - conversation with a user. This is just for responding to extraordinary DMs, having a little - fun with users, and telling people they are DMing the wrong bot. - - NOTE: This feature will be removed if it is overused. - """ - if not member: - user_id = await self.dm_cache.get("last_user") - member = ctx.guild.get_member(user_id) if user_id else None - - # If we still don't have a Member at this point, give up - if not member: - log.debug("This bot has never gotten a DM, or the RedisCache has been cleared.") - await ctx.message.add_reaction("❌") - return - - try: - await member.send(message) - except discord.errors.Forbidden: - log.debug("User has disabled DMs.") - await ctx.message.add_reaction("❌") - else: - await ctx.message.add_reaction("✅") - self.bot.stats.incr("dm_relay.dm_sent") - - async def fetch_webhook(self) -> None: - """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_guild_available() - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - @Cog.listener() - async def on_message(self, message: discord.Message) -> None: - """Relays the message's content and attachments to the dm_log channel.""" - # Only relay DMs from humans - if message.author.bot or message.guild or self.webhook is None: - return - - if message.clean_content: - await send_webhook( - webhook=self.webhook, - content=message.clean_content, - username=f"{message.author.display_name} ({message.author.id})", - avatar_url=message.author.avatar_url - ) - await self.dm_cache.set("last_user", message.author.id) - self.bot.stats.incr("dm_relay.dm_received") - - # Handle any attachments - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (discord.errors.Forbidden, discord.errors.NotFound): - e = discord.Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await send_webhook( - webhook=self.webhook, - embed=e, - username=f"{message.author.display_name} ({message.author.id})", - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception("Failed to send an attachment to the webhook") - - def cog_check(self, ctx: commands.Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - checks = [ - with_role_check(ctx, *constants.MODERATION_ROLES), - in_whitelist_check( - ctx, - channels=[constants.Channels.dm_log], - redirect=None, - fail_silently=True, - ) - ] - return all(checks) - - -def setup(bot: Bot) -> None: - """Load the DMRelay cog.""" - bot.add_cog(DMRelay(bot)) diff --git a/bot/exts/duck_pond.py b/bot/exts/duck_pond.py deleted file mode 100644 index 7021069fa..000000000 --- a/bot/exts/duck_pond.py +++ /dev/null @@ -1,166 +0,0 @@ -import logging -from typing import Union - -import discord -from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors -from discord.ext.commands import Cog - -from bot import constants -from bot.bot import Bot -from bot.utils.messages import send_attachments -from bot.utils.webhooks import send_webhook - -log = logging.getLogger(__name__) - - -class DuckPond(Cog): - """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.webhook_id = constants.Webhooks.duck_pond - self.webhook = None - self.bot.loop.create_task(self.fetch_webhook()) - - async def fetch_webhook(self) -> None: - """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_guild_available() - - try: - self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: - log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") - - @staticmethod - def is_staff(member: Union[User, Member]) -> bool: - """Check if a specific member or user is staff.""" - if hasattr(member, "roles"): - for role in member.roles: - if role.id in constants.STAFF_ROLES: - return True - return False - - async def has_green_checkmark(self, message: Message) -> bool: - """Check if the message has a green checkmark reaction.""" - for reaction in message.reactions: - if reaction.emoji == "✅": - async for user in reaction.users(): - if user == self.bot.user: - return True - return False - - async def count_ducks(self, message: Message) -> int: - """ - Count the number of ducks in the reactions of a specific message. - - Only counts ducks added by staff members. - """ - duck_count = 0 - duck_reactors = [] - - for reaction in message.reactions: - async for user in reaction.users(): - - # Is the user a staff member and not already counted as reactor? - if not self.is_staff(user) or user.id in duck_reactors: - continue - - # Is the emoji a duck? - if hasattr(reaction.emoji, "id"): - if reaction.emoji.id in constants.DuckPond.custom_emojis: - duck_count += 1 - duck_reactors.append(user.id) - elif isinstance(reaction.emoji, str): - if reaction.emoji == "🦆": - duck_count += 1 - duck_reactors.append(user.id) - return duck_count - - async def relay_message(self, message: Message) -> None: - """Relays the message's content and attachments to the duck pond channel.""" - if message.clean_content: - await send_webhook( - webhook=self.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await send_webhook( - webhook=self.webhook, - embed=e, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception("Failed to send an attachment to the webhook") - - await message.add_reaction("✅") - - @staticmethod - def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: - """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" - if payload.emoji.is_custom_emoji(): - if payload.emoji.id in constants.DuckPond.custom_emojis: - return True - elif payload.emoji.name == "🦆": - return True - - return False - - @Cog.listener() - async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: - """ - Determine if a message should be sent to the duck pond. - - This will count the number of duck reactions on the message, and if this amount meets the - amount of ducks specified in the config under duck_pond/threshold, it will - send the message off to the duck pond. - """ - # Is the emoji in the reaction a duck? - if not self._payload_has_duckpond_emoji(payload): - return - - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - message = await channel.fetch_message(payload.message_id) - member = discord.utils.get(message.guild.members, id=payload.user_id) - - # Is the member a human and a staff member? - if not self.is_staff(member) or member.bot: - return - - # Does the message already have a green checkmark? - if await self.has_green_checkmark(message): - return - - # Time to count our ducks! - duck_count = await self.count_ducks(message) - - # If we've got more than the required amount of ducks, send the message to the duck_pond. - if duck_count >= constants.DuckPond.threshold: - await self.relay_message(message) - - @Cog.listener() - async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: - """Ensure that people don't remove the green checkmark from duck ponded messages.""" - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - - # Prevent the green checkmark from being removed - if payload.emoji.name == "✅": - message = await channel.fetch_message(payload.message_id) - duck_count = await self.count_ducks(message) - if duck_count >= constants.DuckPond.threshold: - await message.add_reaction("✅") - - -def setup(bot: Bot) -> None: - """Load the DuckPond cog.""" - bot.add_cog(DuckPond(bot)) diff --git a/bot/exts/fun/__init__.py b/bot/exts/fun/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py new file mode 100644 index 000000000..7021069fa --- /dev/null +++ b/bot/exts/fun/duck_pond.py @@ -0,0 +1,166 @@ +import logging +from typing import Union + +import discord +from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.utils.messages import send_attachments +from bot.utils.webhooks import send_webhook + +log = logging.getLogger(__name__) + + +class DuckPond(Cog): + """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.duck_pond + self.webhook = None + self.bot.loop.create_task(self.fetch_webhook()) + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_guild_available() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @staticmethod + def is_staff(member: Union[User, Member]) -> bool: + """Check if a specific member or user is staff.""" + if hasattr(member, "roles"): + for role in member.roles: + if role.id in constants.STAFF_ROLES: + return True + return False + + async def has_green_checkmark(self, message: Message) -> bool: + """Check if the message has a green checkmark reaction.""" + for reaction in message.reactions: + if reaction.emoji == "✅": + async for user in reaction.users(): + if user == self.bot.user: + return True + return False + + async def count_ducks(self, message: Message) -> int: + """ + Count the number of ducks in the reactions of a specific message. + + Only counts ducks added by staff members. + """ + duck_count = 0 + duck_reactors = [] + + for reaction in message.reactions: + async for user in reaction.users(): + + # Is the user a staff member and not already counted as reactor? + if not self.is_staff(user) or user.id in duck_reactors: + continue + + # Is the emoji a duck? + if hasattr(reaction.emoji, "id"): + if reaction.emoji.id in constants.DuckPond.custom_emojis: + duck_count += 1 + duck_reactors.append(user.id) + elif isinstance(reaction.emoji, str): + if reaction.emoji == "🦆": + duck_count += 1 + duck_reactors.append(user.id) + return duck_count + + async def relay_message(self, message: Message) -> None: + """Relays the message's content and attachments to the duck pond channel.""" + if message.clean_content: + await send_webhook( + webhook=self.webhook, + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await send_webhook( + webhook=self.webhook, + embed=e, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception("Failed to send an attachment to the webhook") + + await message.add_reaction("✅") + + @staticmethod + def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: + """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" + if payload.emoji.is_custom_emoji(): + if payload.emoji.id in constants.DuckPond.custom_emojis: + return True + elif payload.emoji.name == "🦆": + return True + + return False + + @Cog.listener() + async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: + """ + Determine if a message should be sent to the duck pond. + + This will count the number of duck reactions on the message, and if this amount meets the + amount of ducks specified in the config under duck_pond/threshold, it will + send the message off to the duck pond. + """ + # Is the emoji in the reaction a duck? + if not self._payload_has_duckpond_emoji(payload): + return + + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + message = await channel.fetch_message(payload.message_id) + member = discord.utils.get(message.guild.members, id=payload.user_id) + + # Is the member a human and a staff member? + if not self.is_staff(member) or member.bot: + return + + # Does the message already have a green checkmark? + if await self.has_green_checkmark(message): + return + + # Time to count our ducks! + duck_count = await self.count_ducks(message) + + # If we've got more than the required amount of ducks, send the message to the duck_pond. + if duck_count >= constants.DuckPond.threshold: + await self.relay_message(message) + + @Cog.listener() + async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: + """Ensure that people don't remove the green checkmark from duck ponded messages.""" + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + + # Prevent the green checkmark from being removed + if payload.emoji.name == "✅": + message = await channel.fetch_message(payload.message_id) + duck_count = await self.count_ducks(message) + if duck_count >= constants.DuckPond.threshold: + await message.add_reaction("✅") + + +def setup(bot: Bot) -> None: + """Load the DuckPond cog.""" + bot.add_cog(DuckPond(bot)) diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py new file mode 100644 index 000000000..ce95450e0 --- /dev/null +++ b/bot/exts/fun/off_topic_names.py @@ -0,0 +1,162 @@ +import asyncio +import difflib +import logging +from datetime import datetime, timedelta + +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES +from bot.converters import OffTopicName +from bot.decorators import with_role +from bot.pagination import LinePaginator + +CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) +log = logging.getLogger(__name__) + + +async def update_names(bot: Bot) -> None: + """Background updater task that performs the daily channel name update.""" + while True: + # Since we truncate the compute timedelta to seconds, we add one second to ensure + # we go past midnight in the `seconds_to_sleep` set below. + today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0) + next_midnight = today_at_midnight + timedelta(days=1) + seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1 + await asyncio.sleep(seconds_to_sleep) + + try: + channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get( + 'bot/off-topic-channel-names', params={'random_items': 3} + ) + except ResponseCodeError as e: + log.error(f"Failed to get new off topic channel names: code {e.response.status}") + continue + channel_0, channel_1, channel_2 = (bot.get_channel(channel_id) for channel_id in CHANNELS) + + await channel_0.edit(name=f'ot0-{channel_0_name}') + await channel_1.edit(name=f'ot1-{channel_1_name}') + await channel_2.edit(name=f'ot2-{channel_2_name}') + log.debug( + "Updated off-topic channel names to" + f" {channel_0_name}, {channel_1_name} and {channel_2_name}" + ) + + +class OffTopicNames(Cog): + """Commands related to managing the off-topic category channel names.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.updater_task = None + + self.bot.loop.create_task(self.init_offtopic_updater()) + + def cog_unload(self) -> None: + """Cancel any running updater tasks on cog unload.""" + if self.updater_task is not None: + self.updater_task.cancel() + + async def init_offtopic_updater(self) -> None: + """Start off-topic channel updating event loop if it hasn't already started.""" + await self.bot.wait_until_guild_available() + if self.updater_task is None: + coro = update_names(self.bot) + self.updater_task = self.bot.loop.create_task(coro) + + @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) + @with_role(*MODERATION_ROLES) + async def otname_group(self, ctx: Context) -> None: + """Add or list items from the off-topic channel name rotation.""" + await ctx.send_help(ctx.command) + + @otname_group.command(name='add', aliases=('a',)) + @with_role(*MODERATION_ROLES) + async def add_command(self, ctx: Context, *, name: OffTopicName) -> None: + """ + Adds a new off-topic name to the rotation. + + The name is not added if it is too similar to an existing name. + """ + existing_names = await self.bot.api_client.get('bot/off-topic-channel-names') + close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) + + if close_match: + match = close_match[0] + log.info( + f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'" + ) + await ctx.send( + f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. " + "Use `!otn forceadd` to override this check." + ) + else: + await self._add_name(ctx, name) + + @otname_group.command(name='forceadd', aliases=('fa',)) + @with_role(*MODERATION_ROLES) + async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None: + """Forcefully adds a new off-topic name to the rotation.""" + await self._add_name(ctx, name) + + async def _add_name(self, ctx: Context, name: str) -> None: + """Adds an off-topic channel name to the site storage.""" + await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name}) + + log.info(f"{ctx.author} added the off-topic channel name '{name}'") + await ctx.send(f":ok_hand: Added `{name}` to the names list.") + + @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) + @with_role(*MODERATION_ROLES) + async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None: + """Removes a off-topic name from the rotation.""" + await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') + + log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") + await ctx.send(f":ok_hand: Removed `{name}` from the names list.") + + @otname_group.command(name='list', aliases=('l',)) + @with_role(*MODERATION_ROLES) + async def list_command(self, ctx: Context) -> None: + """ + Lists all currently known off-topic channel names in a paginator. + + Restricted to Moderator and above to not spoil the surprise. + """ + result = await self.bot.api_client.get('bot/off-topic-channel-names') + lines = sorted(f"• {name}" for name in result) + embed = Embed( + title=f"Known off-topic names (`{len(result)}` total)", + colour=Colour.blue() + ) + if result: + await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + + @otname_group.command(name='search', aliases=('s',)) + @with_role(*MODERATION_ROLES) + async def search_command(self, ctx: Context, *, query: OffTopicName) -> None: + """Search for an off-topic name.""" + result = await self.bot.api_client.get('bot/off-topic-channel-names') + in_matches = {name for name in result if query in name} + close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70) + lines = sorted(f"• {name}" for name in in_matches.union(close_matches)) + embed = Embed( + title="Query results", + colour=Colour.blue() + ) + + if lines: + await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) + else: + embed.description = "Nothing found." + await ctx.send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the OffTopicNames cog.""" + bot.add_cog(OffTopicNames(bot)) diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py new file mode 100644 index 000000000..0d8f340b4 --- /dev/null +++ b/bot/exts/moderation/dm_relay.py @@ -0,0 +1,124 @@ +import logging +from typing import Optional + +import discord +from discord import Color +from discord.ext import commands +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.converters import UserMentionOrID +from bot.utils import RedisCache +from bot.utils.checks import in_whitelist_check, with_role_check +from bot.utils.messages import send_attachments +from bot.utils.webhooks import send_webhook + +log = logging.getLogger(__name__) + + +class DMRelay(Cog): + """Relay direct messages to and from the bot.""" + + # RedisCache[str, t.Union[discord.User.id, discord.Member.id]] + dm_cache = RedisCache() + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.dm_log + self.webhook = None + self.bot.loop.create_task(self.fetch_webhook()) + + @commands.command(aliases=("reply",)) + async def send_dm(self, ctx: commands.Context, member: Optional[UserMentionOrID], *, message: str) -> None: + """ + Allows you to send a DM to a user from the bot. + + If `member` is not provided, it will send to the last user who DM'd the bot. + + This feature should be used extremely sparingly. Use ModMail if you need to have a serious + conversation with a user. This is just for responding to extraordinary DMs, having a little + fun with users, and telling people they are DMing the wrong bot. + + NOTE: This feature will be removed if it is overused. + """ + if not member: + user_id = await self.dm_cache.get("last_user") + member = ctx.guild.get_member(user_id) if user_id else None + + # If we still don't have a Member at this point, give up + if not member: + log.debug("This bot has never gotten a DM, or the RedisCache has been cleared.") + await ctx.message.add_reaction("❌") + return + + try: + await member.send(message) + except discord.errors.Forbidden: + log.debug("User has disabled DMs.") + await ctx.message.add_reaction("❌") + else: + await ctx.message.add_reaction("✅") + self.bot.stats.incr("dm_relay.dm_sent") + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_guild_available() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Relays the message's content and attachments to the dm_log channel.""" + # Only relay DMs from humans + if message.author.bot or message.guild or self.webhook is None: + return + + if message.clean_content: + await send_webhook( + webhook=self.webhook, + content=message.clean_content, + username=f"{message.author.display_name} ({message.author.id})", + avatar_url=message.author.avatar_url + ) + await self.dm_cache.set("last_user", message.author.id) + self.bot.stats.incr("dm_relay.dm_received") + + # Handle any attachments + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (discord.errors.Forbidden, discord.errors.NotFound): + e = discord.Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await send_webhook( + webhook=self.webhook, + embed=e, + username=f"{message.author.display_name} ({message.author.id})", + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception("Failed to send an attachment to the webhook") + + def cog_check(self, ctx: commands.Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_whitelist_check( + ctx, + channels=[constants.Channels.dm_log], + redirect=None, + fail_silently=True, + ) + ] + return all(checks) + + +def setup(bot: Bot) -> None: + """Load the DMRelay cog.""" + bot.add_cog(DMRelay(bot)) diff --git a/bot/exts/off_topic_names.py b/bot/exts/off_topic_names.py deleted file mode 100644 index ce95450e0..000000000 --- a/bot/exts/off_topic_names.py +++ /dev/null @@ -1,162 +0,0 @@ -import asyncio -import difflib -import logging -from datetime import datetime, timedelta - -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group - -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES -from bot.converters import OffTopicName -from bot.decorators import with_role -from bot.pagination import LinePaginator - -CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) -log = logging.getLogger(__name__) - - -async def update_names(bot: Bot) -> None: - """Background updater task that performs the daily channel name update.""" - while True: - # Since we truncate the compute timedelta to seconds, we add one second to ensure - # we go past midnight in the `seconds_to_sleep` set below. - today_at_midnight = datetime.utcnow().replace(microsecond=0, second=0, minute=0, hour=0) - next_midnight = today_at_midnight + timedelta(days=1) - seconds_to_sleep = (next_midnight - datetime.utcnow()).seconds + 1 - await asyncio.sleep(seconds_to_sleep) - - try: - channel_0_name, channel_1_name, channel_2_name = await bot.api_client.get( - 'bot/off-topic-channel-names', params={'random_items': 3} - ) - except ResponseCodeError as e: - log.error(f"Failed to get new off topic channel names: code {e.response.status}") - continue - channel_0, channel_1, channel_2 = (bot.get_channel(channel_id) for channel_id in CHANNELS) - - await channel_0.edit(name=f'ot0-{channel_0_name}') - await channel_1.edit(name=f'ot1-{channel_1_name}') - await channel_2.edit(name=f'ot2-{channel_2_name}') - log.debug( - "Updated off-topic channel names to" - f" {channel_0_name}, {channel_1_name} and {channel_2_name}" - ) - - -class OffTopicNames(Cog): - """Commands related to managing the off-topic category channel names.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.updater_task = None - - self.bot.loop.create_task(self.init_offtopic_updater()) - - def cog_unload(self) -> None: - """Cancel any running updater tasks on cog unload.""" - if self.updater_task is not None: - self.updater_task.cancel() - - async def init_offtopic_updater(self) -> None: - """Start off-topic channel updating event loop if it hasn't already started.""" - await self.bot.wait_until_guild_available() - if self.updater_task is None: - coro = update_names(self.bot) - self.updater_task = self.bot.loop.create_task(coro) - - @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) - @with_role(*MODERATION_ROLES) - async def otname_group(self, ctx: Context) -> None: - """Add or list items from the off-topic channel name rotation.""" - await ctx.send_help(ctx.command) - - @otname_group.command(name='add', aliases=('a',)) - @with_role(*MODERATION_ROLES) - async def add_command(self, ctx: Context, *, name: OffTopicName) -> None: - """ - Adds a new off-topic name to the rotation. - - The name is not added if it is too similar to an existing name. - """ - existing_names = await self.bot.api_client.get('bot/off-topic-channel-names') - close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) - - if close_match: - match = close_match[0] - log.info( - f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'" - ) - await ctx.send( - f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. " - "Use `!otn forceadd` to override this check." - ) - else: - await self._add_name(ctx, name) - - @otname_group.command(name='forceadd', aliases=('fa',)) - @with_role(*MODERATION_ROLES) - async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None: - """Forcefully adds a new off-topic name to the rotation.""" - await self._add_name(ctx, name) - - async def _add_name(self, ctx: Context, name: str) -> None: - """Adds an off-topic channel name to the site storage.""" - await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name}) - - log.info(f"{ctx.author} added the off-topic channel name '{name}'") - await ctx.send(f":ok_hand: Added `{name}` to the names list.") - - @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) - @with_role(*MODERATION_ROLES) - async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None: - """Removes a off-topic name from the rotation.""" - await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') - - log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") - await ctx.send(f":ok_hand: Removed `{name}` from the names list.") - - @otname_group.command(name='list', aliases=('l',)) - @with_role(*MODERATION_ROLES) - async def list_command(self, ctx: Context) -> None: - """ - Lists all currently known off-topic channel names in a paginator. - - Restricted to Moderator and above to not spoil the surprise. - """ - result = await self.bot.api_client.get('bot/off-topic-channel-names') - lines = sorted(f"• {name}" for name in result) - embed = Embed( - title=f"Known off-topic names (`{len(result)}` total)", - colour=Colour.blue() - ) - if result: - await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - - @otname_group.command(name='search', aliases=('s',)) - @with_role(*MODERATION_ROLES) - async def search_command(self, ctx: Context, *, query: OffTopicName) -> None: - """Search for an off-topic name.""" - result = await self.bot.api_client.get('bot/off-topic-channel-names') - in_matches = {name for name in result if query in name} - close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70) - lines = sorted(f"• {name}" for name in in_matches.union(close_matches)) - embed = Embed( - title="Query results", - colour=Colour.blue() - ) - - if lines: - await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) - else: - embed.description = "Nothing found." - await ctx.send(embed=embed) - - -def setup(bot: Bot) -> None: - """Load the OffTopicNames cog.""" - bot.add_cog(OffTopicNames(bot)) diff --git a/tests/bot/exts/fun/__init__.py b/tests/bot/exts/fun/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/fun/test_duck_pond.py b/tests/bot/exts/fun/test_duck_pond.py new file mode 100644 index 000000000..704b08066 --- /dev/null +++ b/tests/bot/exts/fun/test_duck_pond.py @@ -0,0 +1,548 @@ +import asyncio +import logging +import typing +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +import discord + +from bot import constants +from bot.exts.fun import duck_pond +from tests import base +from tests import helpers + +MODULE_PATH = "bot.exts.fun.duck_pond" + + +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): + """Tests for DuckPond functionality.""" + + @classmethod + def setUpClass(cls): + """Sets up the objects that only have to be initialized once.""" + cls.nonstaff_member = helpers.MockMember(name="Non-staffer") + + cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) + cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) + + cls.checkmark_emoji = "\N{White Heavy Check Mark}" + cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" + cls.unicode_duck_emoji = "\N{Duck}" + cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) + cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) + + def setUp(self): + """Sets up the objects that need to be refreshed before each test.""" + self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) + self.cog = duck_pond.DuckPond(bot=self.bot) + + def test_duck_pond_correctly_initializes(self): + """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" + bot = helpers.MockBot() + cog = MagicMock() + + duck_pond.DuckPond.__init__(cog, bot) + + self.assertEqual(cog.bot, bot) + self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) + bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) + + def test_fetch_webhook_succeeds_without_connectivity_issues(self): + """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" + self.bot.fetch_webhook.return_value = "dummy webhook" + self.cog.webhook_id = 1 + + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + self.assertEqual(self.cog.webhook, "dummy webhook") + + def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): + """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" + self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") + self.cog.webhook_id = 1 + + log = logging.getLogger(MODULE_PATH) + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def test_is_staff_returns_correct_values_based_on_instance_passed(self): + """The `is_staff` method should return correct values based on the instance passed.""" + test_cases = ( + (helpers.MockUser(name="User instance"), False), + (helpers.MockMember(name="Member instance without staff role"), False), + (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) + ) + + for user, expected_return in test_cases: + actual_return = self.cog.is_staff(user) + with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): + """The `has_green_checkmark` method should only return `True` if one is present.""" + test_cases = ( + ( + "No reactions", helpers.MockMessage(), False + ), + ( + "No green check mark reactions", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) + ]), + False + ), + ( + "Green check mark reaction, but not from our bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) + ]), + False + ), + ( + "Green check mark reaction, with one from the bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) + ]), + True + ) + ) + + for description, message, expected_return in test_cases: + actual_return = await self.cog.has_green_checkmark(message) + with self.subTest( + test_case=description, + expected_return=expected_return, + actual_return=actual_return + ): + self.assertEqual(expected_return, actual_return) + + def _get_reaction( + self, + emoji: typing.Union[str, helpers.MockEmoji], + staff: int = 0, + nonstaff: int = 0 + ) -> helpers.MockReaction: + staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] + nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] + return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) + + async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): + """The `count_ducks` method should return the number of unique staffers who gave a duck.""" + test_cases = ( + # Simple test cases + # A message without reactions should return 0 + ( + "No reactions", + helpers.MockMessage(), + 0 + ), + # A message with a non-duck reaction from a non-staffer should return 0 + ( + "Non-duck reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), + 0 + ), + # A message with a non-duck reaction from a staffer should return 0 + ( + "Non-duck reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), + 0 + ), + # A message with a non-duck reaction from a non-staffer and staffer should return 0 + ( + "Non-duck reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a non-staffer should return 0 + ( + "Unicode Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a staffer should return 1 + ( + "Unicode Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), + 1 + ), + # A message with a unicode duck reaction from a non-staffer and staffer should return 1 + ( + "Unicode Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer should return 0 + ( + "Duckpond Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), + 0 + ), + # A message with a duckpond duck reaction from a staffer should return 1 + ( + "Duckpond Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 + ( + "Duckpond Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), + 1 + ), + + # Complex test cases + # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), + 3 + ), + # A staffer with multiple duck reactions only counts once + ( + "Two different duck reactions from the same staffer", + helpers.MockMessage( + reactions=[ + helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), + ] + ), + 1 + ), + # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) + ( + "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), + 0 + ), + # We correctly sum when multiple reactions are provided. + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage( + reactions=[ + self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), + self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), + ] + ), + 3 + 4 + ), + ) + + for description, message, expected_count in test_cases: + actual_count = await self.cog.count_ducks(message) + with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): + self.assertEqual(expected_count, actual_count) + + async def test_relay_message_correctly_relays_content_and_attachments(self): + """The `relay_message` method should correctly relay message content and attachments.""" + send_webhook_path = f"{MODULE_PATH}.send_webhook" + send_attachments_path = f"{MODULE_PATH}.send_attachments" + author = MagicMock( + display_name="x", + avatar_url="https://" + ) + + self.cog.webhook = helpers.MockAsyncWebhook() + + test_values = ( + (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), + (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), + (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), + (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), + ) + + for message, expect_webhook_call, expect_attachment_call in test_values: + with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: + with self.subTest(clean_content=message.clean_content, attachments=message.attachments): + await self.cog.relay_message(message) + + self.assertEqual(expect_webhook_call, send_webhook.called) + self.assertEqual(expect_attachment_call, send_attachments.called) + + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) + async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger(MODULE_PATH) + + for side_effect in side_effects: # pragma: no cover + send_attachments.side_effect = side_effect + with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertNotLogs(logger=log, level=logging.ERROR): + await self.cog.relay_message(message) + + self.assertEqual(send_webhook.call_count, 2) + + @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) + async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger(MODULE_PATH) + + side_effect = discord.HTTPException(MagicMock(), "") + send_attachments.side_effect = side_effect + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + await self.cog.relay_message(message) + + send_webhook.assert_called_once_with( + webhook=self.cog.webhook, + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): + """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" + payload = MagicMock(name=label) + payload.emoji.is_custom_emoji.return_value = is_custom_emoji + payload.emoji.id = id_ + payload.emoji.name = emoji_name + return payload + + async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): + """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" + test_values = ( + # Custom Emojis + ( + self._mock_payload( + label="Custom Duckpond Emoji", + is_custom_emoji=True, + id_=constants.DuckPond.custom_emojis[0], + emoji_name="" + ), + True + ), + ( + self._mock_payload( + label="Custom Non-Duckpond Emoji", + is_custom_emoji=True, + id_=123, + emoji_name="" + ), + False + ), + # Unicode Emojis + ( + self._mock_payload( + label="Unicode Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.unicode_duck_emoji + ), + True + ), + ( + self._mock_payload( + label="Unicode Non-Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.thumbs_up_emoji + ), + False + ), + ) + + for payload, expected_return in test_values: + actual_return = self.cog._payload_has_duckpond_emoji(payload) + with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @patch(f"{MODULE_PATH}.discord.utils.get") + @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) + def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): + """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) + + # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check + utils_get.assert_not_called() + + def _raw_reaction_mocks(self, channel_id, message_id, user_id): + """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" + channel = helpers.MockTextChannel(id=channel_id) + self.bot.get_all_channels.return_value = (channel,) + + message = helpers.MockMessage(id=message_id) + + channel.fetch_message.return_value = message + + member = helpers.MockMember(id=user_id, roles=[self.staff_role]) + message.guild.members = (member,) + + payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) + + return channel, message, member, payload + + async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): + """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" + channel_id = 1234 + message_id = 2345 + user_id = 3456 + + channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + test_cases = ( + ("non-staff member", helpers.MockMember(id=user_id)), + ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), + ) + + payload.emoji = self.duck_pond_emoji + + for description, member in test_cases: + message.guild.members = (member, ) + with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: + checkmark.side_effect = AssertionError( + "Expected method to return before calling `self.has_green_checkmark`." + ) + self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) + + # Check that we did make it past the payload checks + channel.fetch_message.assert_called_once() + channel.fetch_message.reset_mock() + + @patch(f"{MODULE_PATH}.DuckPond.is_staff") + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) + def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): + """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" + channel_id = 31415926535 + message_id = 27182818284 + user_id = 16180339887 + + channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) + payload.emoji.is_custom_emoji.return_value = False + + message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] + + is_staff.return_value = True + count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) + + # Assert that we've made it past `self.is_staff` + is_staff.assert_called_once() + + async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): + """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + + channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) + + payload.emoji = self.duck_pond_emoji + + for duck_count, should_relay in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_relay=should_relay): + await self.cog.on_raw_reaction_add(payload) + + # Confirm that we've made it past counting + count_ducks.assert_called_once() + + # Did we relay a message? + has_relayed = relay_message.called + self.assertEqual(has_relayed, should_relay) + + if should_relay: + relay_message.assert_called_once_with(message) + + async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): + """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" + checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) + + message = helpers.MockMessage(id=1234) + + channel = helpers.MockTextChannel(id=98765) + channel.fetch_message.return_value = message + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) + + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + for duck_count, should_re_add_checkmark in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): + await self.cog.on_raw_reaction_remove(payload) + + # Check if we fetched the message + channel.fetch_message.assert_called_once_with(message.id) + + # Check if we actually counted the number of ducks + count_ducks.assert_called_once_with(message) + + has_re_added_checkmark = message.add_reaction.called + self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) + + if should_re_add_checkmark: + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + message.add_reaction.reset_mock() + + # reset mocks + channel.fetch_message.reset_mock() + message.reset_mock() + + def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): + """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" + channel = helpers.MockTextChannel(id=98765) + + channel.fetch_message.side_effect = AssertionError( + "Expected method to return before calling `channel.fetch_message`" + ) + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) + + channel.fetch_message.assert_not_called() + + +class DuckPondSetupTests(unittest.TestCase): + """Tests setup of the `DuckPond` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = helpers.MockBot() + duck_pond.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/test_duck_pond.py b/tests/bot/exts/test_duck_pond.py deleted file mode 100644 index f6d977482..000000000 --- a/tests/bot/exts/test_duck_pond.py +++ /dev/null @@ -1,548 +0,0 @@ -import asyncio -import logging -import typing -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -import discord - -from bot import constants -from bot.exts import duck_pond -from tests import base -from tests import helpers - -MODULE_PATH = "bot.exts.duck_pond" - - -class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): - """Tests for DuckPond functionality.""" - - @classmethod - def setUpClass(cls): - """Sets up the objects that only have to be initialized once.""" - cls.nonstaff_member = helpers.MockMember(name="Non-staffer") - - cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) - cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) - - cls.checkmark_emoji = "\N{White Heavy Check Mark}" - cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" - cls.unicode_duck_emoji = "\N{Duck}" - cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) - cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) - - def setUp(self): - """Sets up the objects that need to be refreshed before each test.""" - self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) - self.cog = duck_pond.DuckPond(bot=self.bot) - - def test_duck_pond_correctly_initializes(self): - """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" - bot = helpers.MockBot() - cog = MagicMock() - - duck_pond.DuckPond.__init__(cog, bot) - - self.assertEqual(cog.bot, bot) - self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) - bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) - - def test_fetch_webhook_succeeds_without_connectivity_issues(self): - """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" - self.bot.fetch_webhook.return_value = "dummy webhook" - self.cog.webhook_id = 1 - - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - self.assertEqual(self.cog.webhook, "dummy webhook") - - def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): - """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" - self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") - self.cog.webhook_id = 1 - - log = logging.getLogger('bot.exts.duck_pond') - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def test_is_staff_returns_correct_values_based_on_instance_passed(self): - """The `is_staff` method should return correct values based on the instance passed.""" - test_cases = ( - (helpers.MockUser(name="User instance"), False), - (helpers.MockMember(name="Member instance without staff role"), False), - (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) - ) - - for user, expected_return in test_cases: - actual_return = self.cog.is_staff(user) - with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): - """The `has_green_checkmark` method should only return `True` if one is present.""" - test_cases = ( - ( - "No reactions", helpers.MockMessage(), False - ), - ( - "No green check mark reactions", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) - ]), - False - ), - ( - "Green check mark reaction, but not from our bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) - ]), - False - ), - ( - "Green check mark reaction, with one from the bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) - ]), - True - ) - ) - - for description, message, expected_return in test_cases: - actual_return = await self.cog.has_green_checkmark(message) - with self.subTest( - test_case=description, - expected_return=expected_return, - actual_return=actual_return - ): - self.assertEqual(expected_return, actual_return) - - def _get_reaction( - self, - emoji: typing.Union[str, helpers.MockEmoji], - staff: int = 0, - nonstaff: int = 0 - ) -> helpers.MockReaction: - staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] - nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] - return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - - async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): - """The `count_ducks` method should return the number of unique staffers who gave a duck.""" - test_cases = ( - # Simple test cases - # A message without reactions should return 0 - ( - "No reactions", - helpers.MockMessage(), - 0 - ), - # A message with a non-duck reaction from a non-staffer should return 0 - ( - "Non-duck reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), - 0 - ), - # A message with a non-duck reaction from a staffer should return 0 - ( - "Non-duck reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), - 0 - ), - # A message with a non-duck reaction from a non-staffer and staffer should return 0 - ( - "Non-duck reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a non-staffer should return 0 - ( - "Unicode Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a staffer should return 1 - ( - "Unicode Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), - 1 - ), - # A message with a unicode duck reaction from a non-staffer and staffer should return 1 - ( - "Unicode Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer should return 0 - ( - "Duckpond Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), - 0 - ), - # A message with a duckpond duck reaction from a staffer should return 1 - ( - "Duckpond Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 - ( - "Duckpond Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), - 1 - ), - - # Complex test cases - # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), - 3 - ), - # A staffer with multiple duck reactions only counts once - ( - "Two different duck reactions from the same staffer", - helpers.MockMessage( - reactions=[ - helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), - ] - ), - 1 - ), - # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) - ( - "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), - 0 - ), - # We correctly sum when multiple reactions are provided. - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage( - reactions=[ - self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), - self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), - ] - ), - 3 + 4 - ), - ) - - for description, message, expected_count in test_cases: - actual_count = await self.cog.count_ducks(message) - with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): - self.assertEqual(expected_count, actual_count) - - async def test_relay_message_correctly_relays_content_and_attachments(self): - """The `relay_message` method should correctly relay message content and attachments.""" - send_webhook_path = f"{MODULE_PATH}.send_webhook" - send_attachments_path = f"{MODULE_PATH}.send_attachments" - author = MagicMock( - display_name="x", - avatar_url="https://" - ) - - self.cog.webhook = helpers.MockAsyncWebhook() - - test_values = ( - (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), - (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), - (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), - (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), - ) - - for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: - with self.subTest(clean_content=message.clean_content, attachments=message.attachments): - await self.cog.relay_message(message) - - self.assertEqual(expect_webhook_call, send_webhook.called) - self.assertEqual(expect_attachment_call, send_attachments.called) - - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.exts.duck_pond") - - for side_effect in side_effects: # pragma: no cover - send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertNotLogs(logger=log, level=logging.ERROR): - await self.cog.relay_message(message) - - self.assertEqual(send_webhook.call_count, 2) - - @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.exts.duck_pond") - - side_effect = discord.HTTPException(MagicMock(), "") - send_attachments.side_effect = side_effect - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - await self.cog.relay_message(message) - - send_webhook.assert_called_once_with( - webhook=self.cog.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): - """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" - payload = MagicMock(name=label) - payload.emoji.is_custom_emoji.return_value = is_custom_emoji - payload.emoji.id = id_ - payload.emoji.name = emoji_name - return payload - - async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): - """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" - test_values = ( - # Custom Emojis - ( - self._mock_payload( - label="Custom Duckpond Emoji", - is_custom_emoji=True, - id_=constants.DuckPond.custom_emojis[0], - emoji_name="" - ), - True - ), - ( - self._mock_payload( - label="Custom Non-Duckpond Emoji", - is_custom_emoji=True, - id_=123, - emoji_name="" - ), - False - ), - # Unicode Emojis - ( - self._mock_payload( - label="Unicode Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.unicode_duck_emoji - ), - True - ), - ( - self._mock_payload( - label="Unicode Non-Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.thumbs_up_emoji - ), - False - ), - ) - - for payload, expected_return in test_values: - actual_return = self.cog._payload_has_duckpond_emoji(payload) - with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - @patch(f"{MODULE_PATH}.discord.utils.get") - @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) - def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): - """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) - - # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check - utils_get.assert_not_called() - - def _raw_reaction_mocks(self, channel_id, message_id, user_id): - """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" - channel = helpers.MockTextChannel(id=channel_id) - self.bot.get_all_channels.return_value = (channel,) - - message = helpers.MockMessage(id=message_id) - - channel.fetch_message.return_value = message - - member = helpers.MockMember(id=user_id, roles=[self.staff_role]) - message.guild.members = (member,) - - payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) - - return channel, message, member, payload - - async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): - """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" - channel_id = 1234 - message_id = 2345 - user_id = 3456 - - channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - test_cases = ( - ("non-staff member", helpers.MockMember(id=user_id)), - ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), - ) - - payload.emoji = self.duck_pond_emoji - - for description, member in test_cases: - message.guild.members = (member, ) - with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: - checkmark.side_effect = AssertionError( - "Expected method to return before calling `self.has_green_checkmark`." - ) - self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) - - # Check that we did make it past the payload checks - channel.fetch_message.assert_called_once() - channel.fetch_message.reset_mock() - - @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) - def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): - """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" - channel_id = 31415926535 - message_id = 27182818284 - user_id = 16180339887 - - channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) - payload.emoji.is_custom_emoji.return_value = False - - message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] - - is_staff.return_value = True - count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) - - # Assert that we've made it past `self.is_staff` - is_staff.assert_called_once() - - async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): - """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - - channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) - - payload.emoji = self.duck_pond_emoji - - for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_relay=should_relay): - await self.cog.on_raw_reaction_add(payload) - - # Confirm that we've made it past counting - count_ducks.assert_called_once() - - # Did we relay a message? - has_relayed = relay_message.called - self.assertEqual(has_relayed, should_relay) - - if should_relay: - relay_message.assert_called_once_with(message) - - async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): - """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" - checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) - - message = helpers.MockMessage(id=1234) - - channel = helpers.MockTextChannel(id=98765) - channel.fetch_message.return_value = message - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) - - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): - await self.cog.on_raw_reaction_remove(payload) - - # Check if we fetched the message - channel.fetch_message.assert_called_once_with(message.id) - - # Check if we actually counted the number of ducks - count_ducks.assert_called_once_with(message) - - has_re_added_checkmark = message.add_reaction.called - self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) - - if should_re_add_checkmark: - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - message.add_reaction.reset_mock() - - # reset mocks - channel.fetch_message.reset_mock() - message.reset_mock() - - def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): - """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" - channel = helpers.MockTextChannel(id=98765) - - channel.fetch_message.side_effect = AssertionError( - "Expected method to return before calling `channel.fetch_message`" - ) - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) - - channel.fetch_message.assert_not_called() - - -class DuckPondSetupTests(unittest.TestCase): - """Tests setup of the `DuckPond` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = helpers.MockBot() - duck_pond.setup(bot) - bot.add_cog.assert_called_once() -- cgit v1.2.3 From 520ac0f9871bf6775d76eea753ed2a940704e92d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 22 Aug 2020 20:44:48 -0700 Subject: Include root aliases in the command name conflict test --- tests/bot/cogs/test_cogs.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index fdda59a8f..30a04422a 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -53,6 +53,7 @@ class CommandNameTests(unittest.TestCase): """Return a list of all qualified names, including aliases, for the `command`.""" names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] names.append(command.qualified_name) + names += getattr(command, "root_aliases", []) return names -- cgit v1.2.3 From b7644aa822def549e2591b53c69af3cf44355ac9 Mon Sep 17 00:00:00 2001 From: Xithrius Date: Mon, 31 Aug 2020 19:56:24 -0700 Subject: Removed ImagePaginator testing. --- tests/bot/test_pagination.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'tests') diff --git a/tests/bot/test_pagination.py b/tests/bot/test_pagination.py index ce880d457..630f2516d 100644 --- a/tests/bot/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -44,18 +44,3 @@ class LinePaginatorTests(TestCase): self.paginator.add_line('x' * (self.paginator.scale_to_size + 1)) # Note: item at index 1 is the truncated line, index 0 is prefix self.assertEqual(self.paginator._current_page[1], 'x' * self.paginator.scale_to_size) - - -class ImagePaginatorTests(TestCase): - """Tests functionality of the `ImagePaginator`.""" - - def setUp(self): - """Create a paginator for the test method.""" - self.paginator = pagination.ImagePaginator() - - def test_add_image_appends_image(self): - """`add_image` appends the image to the image list.""" - image = 'lemon' - self.paginator.add_image(image) - - assert self.paginator.images == [image] -- cgit v1.2.3 From 1a47f5d80f2f91c3da5a9626e9a6694381d49cd0 Mon Sep 17 00:00:00 2001 From: wookie184 Date: Tue, 1 Sep 2020 12:22:43 +0100 Subject: Fixed old tests and added 2 new ones --- tests/bot/cogs/test_antimalware.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index ecb7abf00..f50c0492d 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -23,6 +23,8 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): } self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() + self.message.webhook_id = None + self.message.author.bot = None self.whitelist = [".first", ".second", ".third"] async def test_message_with_allowed_attachment(self): @@ -48,6 +50,26 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.message.delete.assert_not_called() + async def test_webhook_message_with_illegal_extension(self): + """A webhook message containing an illegal extension should be ignored.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.webhook_id = 697140105563078727 + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_bot_message_with_illegal_extension(self): + """A bot message containing an illegal extension should be ignored.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.author.bot = 409107086526644234 + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + async def test_message_with_illegal_extension_gets_deleted(self): """A message containing an illegal extension should send an embed.""" attachment = MockAttachment(filename="python.disallowed") -- cgit v1.2.3