From 3f04611ddfc2e6d750d4c4e0a19d3cf154e7c5a9 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Thu, 20 May 2021 16:09:39 -0400 Subject: chore: Update tests to correspond with the timeit command --- tests/bot/exts/utils/test_snekbox.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 321a92445..1b3d61094 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -161,7 +161,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') + self.cog.send_eval.assert_called_once_with( + ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output + ) self.cog.continue_eval.assert_called_once_with(ctx, response) async def test_eval_command_evaluate_twice(self): @@ -171,11 +173,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) + self.cog.continue_eval.side_effect = ('MyAwesomeFormattedCode', None) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') + self.cog.send_eval.assert_called_with( + ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output + ) self.cog.continue_eval.assert_called_with(ctx, response) async def test_eval_command_reject_two_eval_at_the_same_time(self): @@ -190,12 +194,6 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) - async def test_eval_command_call_help(self): - """Test if the eval command call the help command if no code is provided.""" - ctx = MockContext(command="sentinel") - await self.cog.eval_command(self.cog, ctx=ctx, code='') - ctx.send_help.assert_called_once_with(ctx.command) - async def test_send_eval(self): """Test the send_eval function.""" ctx = MockContext() @@ -212,11 +210,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.format_output.assert_called_once_with('') @@ -237,12 +235,12 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :yay!: Return code 0.' '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.format_output.assert_called_once_with('Way too long beard') @@ -262,11 +260,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.format_output.assert_not_called() @@ -282,7 +280,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) actual = await self.cog.continue_eval(ctx, response) - self.cog.get_code.assert_awaited_once_with(new_msg) + self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) self.assertEqual(actual, expected) self.bot.wait_for.assert_has_awaits( ( @@ -327,7 +325,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.return_value = MockContext(command=command) message = MockMessage(content=content) - actual_code = await self.cog.get_code(message) + actual_code = await self.cog.get_code(message, self.cog.eval_command) self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) -- cgit v1.2.3 From 1e0c0cfe37eb9a868454508fcb813d7cf19e12cc Mon Sep 17 00:00:00 2001 From: Izan Date: Wed, 29 Dec 2021 15:09:22 +0000 Subject: Fix tests --- tests/bot/exts/moderation/test_incidents.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..ef33aa62b 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -1,4 +1,5 @@ import asyncio +import datetime import enum import logging import typing as t @@ -13,6 +14,7 @@ from async_rediscache import RedisSession from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user +from bot.utils.time import TimestampFormats, discord_timestamp from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser @@ -114,10 +116,19 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_content(self): """Incident content appears as embed description.""" - incident = MockMessage(content="this is an incident") + current_time = datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) + incident = MockMessage(content="this is an incident", created_at=current_time) + + day_timestamp = discord_timestamp(current_time, TimestampFormats.DATE) + time_timestamp = discord_timestamp(current_time, TimestampFormats.TIME) + relative_timestamp = discord_timestamp(current_time, TimestampFormats.RELATIVE) + embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - self.assertEqual(incident.content, embed.description) + self.assertEqual( + f"{incident.content}\n\n__*Reported {day_timestamp} at {time_timestamp} ({relative_timestamp}).*__", + embed.description + ) async def test_make_embed_with_attachment_succeeds(self): """Incident's attachment is downloaded and displayed in the embed's image field.""" @@ -391,7 +402,7 @@ class TestArchive(TestIncidents): # Define our own `incident` to be archived incident = MockMessage( content="this is an incident", - author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), + author=MockUser(display_name="author_name", display_avatar=Mock(url="author_avatar")), id=123, ) built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this @@ -422,7 +433,7 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) + message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great")) await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) @@ -521,12 +532,13 @@ class TestProcessEvent(TestIncidents): async def test_process_event_confirmation_task_is_awaited(self): """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" mock_task = AsyncMock() + mock_member = MockMember(display_name="Bobby Johnson", roles=[MockRole(id=1)]) with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(id=123), - member=MockMember(roles=[MockRole(id=1)]) + incident=MockMessage(author=mock_member, id=123), + member=mock_member ) mock_task.assert_awaited() -- cgit v1.2.3 From b7e49a5fb1adb541db2cf5632a460a37ddda6d0a Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Thu, 13 Jan 2022 21:26:58 -0500 Subject: chore: Suppress output in the setup code, not the code that gets timed. If multiple formatted codeblocks are passed to the command, the first one will be used as the setup code that does not get timed. --- bot/exts/utils/snekbox.py | 70 +++++++++++++++++++++++++++--------- tests/bot/exts/utils/test_snekbox.py | 6 ++-- 2 files changed, 56 insertions(+), 20 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index bd521a4ee..0d8da5e56 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -2,9 +2,9 @@ import asyncio import contextlib import datetime import re -import textwrap from functools import partial from signal import Signals +from textwrap import dedent from typing import Awaitable, Callable, Optional, Tuple from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User @@ -36,13 +36,35 @@ RAW_CODE_REGEX = re.compile( re.DOTALL # "." also matches newlines ) -TIMEIT_EVAL_WRAPPER = """ -from contextlib import redirect_stdout -from io import StringIO +TIMEIT_SETUP_WRAPPER = """ +import atexit +import sys +from collections import deque -with redirect_stdout(StringIO()): - del redirect_stdout, StringIO -{code} +if not hasattr(sys, "_setup_finished"): + class Writer(deque): + def __init__(self): + super().__init__(maxlen=1) + + def write(self, string): + if string.strip(): + self.append(string) + + def read(self): + return self.pop() + + def flush(self): + pass + + sys.stdout = Writer() + + def print_last_line(): + if sys.stdout: + print(sys.stdout.read(), file=sys.__stdout__) + + atexit.register(print_last_line) + sys._setup_finished = None +{setup} """ TIMEIT_OUTPUT_REGEX = re.compile(r"\d+ loops, best of \d+: \d(?:\.\d\d?)? [mnu]?sec per loop") @@ -90,34 +112,37 @@ class Snekbox(Cog): return await send_to_paste_service(output, extension="txt") @staticmethod - def prepare_input(code: str) -> str: + def prepare_input(code: str) -> list[str]: """ Extract code from the Markdown, format it, and insert it into the code template. If there is any code block, ignore text outside the code block. Use the first code block, but prefer a fenced code block. If there are several fenced code blocks, concatenate only the fenced code blocks. + + Retrun a list of code blocks if any, otherwise return a list with a single string of code. """ if match := list(FORMATTED_CODE_REGEX.finditer(code)): blocks = [block for block in match if block.group("block")] if len(blocks) > 1: - code = '\n'.join(block.group("code") for block in blocks) + codeblocks = [block.group("code") for block in blocks] info = "several code blocks" else: match = match[0] if len(blocks) == 0 else blocks[0] code, block, lang, delim = match.group("code", "block", "lang", "delim") + codeblocks = [dedent(code)] if block: info = (f"'{lang}' highlighted" if lang else "plain") + " code block" else: info = f"{delim}-enclosed inline code" else: - code = RAW_CODE_REGEX.fullmatch(code).group("code") + codeblocks = [dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))] info = "unformatted or badly formatted code" - code = textwrap.dedent(code) + code = "\n".join(codeblocks) log.trace(f"Extracted {info} for evaluation:\n{code}") - return code + return codeblocks @staticmethod def get_results_message(results: dict) -> Tuple[str, str]: @@ -248,7 +273,7 @@ class Snekbox(Cog): log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") return response - async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + async def continue_eval(self, ctx: Context, response: Message) -> Optional[list[str]]: """ Check if the eval session should continue. @@ -380,7 +405,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = self.prepare_input(code) + code = "\n".join(self.prepare_input(code)) await self.run_eval(ctx, code, format_func=self.format_output) @command(name="timeit", aliases=("ti",)) @@ -400,13 +425,24 @@ class Snekbox(Cog): block. Code can be re-evaluated by editing the original message within 10 seconds and clicking the reaction that subsequently appears. + If multiple formatted codeblocks are provided, the first one will be the setup code, which will + not be timed. The remaining codeblocks will be joined together and timed. + We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = self.prepare_input(code) + args = ["-m", "timeit"] + setup = "" + codeblocks = self.prepare_input(code) + + if len(codeblocks) > 1: + setup = codeblocks.pop(0) + + code = "\n".join(codeblocks) + args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) + await self.run_eval( - ctx, TIMEIT_EVAL_WRAPPER.format(code=textwrap.indent(code, " ")), - format_func=self.format_timeit_output, args=["-m", "timeit"] + ctx, code=code, format_func=self.format_timeit_output, args=args ) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index cbffaa6b0..ebab71e71 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -61,7 +61,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) + self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected) def test_get_results_message(self): """Return error and message according to the eval result.""" @@ -156,7 +156,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test the eval command procedure.""" ctx = MockContext() response = MockMessage() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') + self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock(return_value=None) @@ -297,7 +297,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): actual = await self.cog.continue_eval(ctx, response) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) - self.assertEqual(actual, expected) + self.assertEqual(actual, [expected]) self.bot.wait_for.assert_has_awaits( ( call( -- cgit v1.2.3 From 8594a3535413e662ae519f31989459a953e8a726 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Mon, 17 Jan 2022 14:43:58 -0500 Subject: fix: Modify tests to correspond with Snekbox.continue_eval --- tests/bot/exts/utils/test_snekbox.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index ebab71e71..4245de8a3 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -156,32 +156,35 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test the eval command procedure.""" ctx = MockContext() response = MockMessage() + ctx.command = MagicMock() + self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock(return_value=None) + self.cog.continue_eval = AsyncMock(return_value=(None, None)) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') self.cog.send_eval.assert_called_once_with( ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output ) - self.cog.continue_eval.assert_called_once_with(ctx, response) + self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command) async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() response = MockMessage() + ctx.command = MagicMock() self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = ('MyAwesomeFormattedCode', None) + self.cog.continue_eval.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) self.cog.send_eval.assert_called_with( ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output ) - self.cog.continue_eval.assert_called_with(ctx, response) + self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -295,9 +298,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - actual = await self.cog.continue_eval(ctx, response) + actual = await self.cog.continue_eval(ctx, response, self.cog.eval_command) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) - self.assertEqual(actual, [expected]) + self.assertEqual(actual, (expected, None)) self.bot.wait_for.assert_has_awaits( ( call( @@ -316,8 +319,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError - actual = await self.cog.continue_eval(ctx, MockMessage()) - self.assertEqual(actual, None) + actual = await self.cog.continue_eval(ctx, MockMessage(), self.cog.eval_command) + self.assertEqual(actual, (None, None)) ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) async def test_get_code(self): -- cgit v1.2.3 From 54e4f3777372ef526667885f4392030bab1b5b07 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Mon, 17 Jan 2022 17:42:37 -0500 Subject: chore: Apply suggestions and adjust tests --- bot/exts/utils/snekbox.py | 52 +++++++++++++----------------------- tests/bot/exts/utils/test_snekbox.py | 24 ++++++++--------- 2 files changed, 29 insertions(+), 47 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 49f1be17b..1d9646113 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -36,6 +36,8 @@ RAW_CODE_REGEX = re.compile( re.DOTALL # "." also matches newlines ) +# The timeit command should only output the very last line, so all other output should be suppressed. +# This will be used as the setup code along with any setup code provided. TIMEIT_SETUP_WRAPPER = """ import atexit import sys @@ -163,19 +165,19 @@ class Snekbox(Cog): return code, args @staticmethod - def get_results_message(results: dict) -> Tuple[str, str]: + def get_results_message(results: dict, job_name: str) -> Tuple[str, str]: """Return a user-friendly message and error corresponding to the process's return code.""" stdout, returncode = results["stdout"], results["returncode"] - msg = f"Your eval job has completed with return code {returncode}" + msg = f"Your {job_name} job has completed with return code {returncode}" error = "" if returncode is None: - msg = "Your eval job has failed" + msg = f"Your {job_name} job has failed" error = stdout.strip() elif returncode == 128 + SIGKILL: - msg = "Your eval job timed out or ran out of memory" + msg = f"Your {job_name} job timed out or ran out of memory" elif returncode == 255: - msg = "Your eval job has failed" + msg = f"Your {job_name} job has failed" error = "A fatal NsJail error occurred" else: # Try to append signal's name if one exists @@ -249,7 +251,7 @@ class Snekbox(Cog): code: str, *, args: Optional[list[str]] = None, - format_func: FormatFunc + job_name: str ) -> Message: """ Evaluate code, format it, and send the output to the corresponding channel. @@ -258,13 +260,13 @@ class Snekbox(Cog): """ async with ctx.typing(): results = await self.post_eval(code, args=args) - msg, error = self.get_results_message(results) + msg, error = self.get_results_message(results, job_name) if error: output, paste_link = error, None else: log.trace("Formatting output...") - output, paste_link = await format_func(results["stdout"]) + output, paste_link = await self.format_output(results["stdout"]) icon = self.get_status_emoji(results) msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" @@ -288,12 +290,12 @@ class Snekbox(Cog): response = await ctx.send(msg, allowed_mentions=allowed_mentions) scheduling.create_task(wait_for_deletion(response, (ctx.author.id,)), event_loop=self.bot.loop) - log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") + log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") return response async def continue_eval( self, ctx: Context, response: Message, command: Command - ) -> Optional[tuple[str, Optional[list[str]]]]: + ) -> tuple[Optional[str], Optional[list[str]]]: """ Check if the eval session should continue. @@ -355,19 +357,15 @@ class Snekbox(Cog): return code - async def run_eval( + async def run_job( self, + job_name: str, ctx: Context, code: str, - format_func: FormatFunc, *, args: Optional[list[str]] = None, ) -> None: - """ - Handles checks, stats and re-evaluation of an eval. - - `format_func` is an async callable that takes a string (the output) and formats it to show to the user. - """ + """Handles checks, stats and re-evaluation of a snekbox job.""" if ctx.author.id in self.jobs: await ctx.send( f"{ctx.author.mention} You've already got a job running - " @@ -392,7 +390,7 @@ class Snekbox(Cog): while True: self.jobs[ctx.author.id] = datetime.datetime.now() try: - response = await self.send_eval(ctx, code, args=args, format_func=format_func) + response = await self.send_eval(ctx, code, args=args, job_name=job_name) finally: del self.jobs[ctx.author.id] @@ -401,18 +399,6 @@ class Snekbox(Cog): break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") - async def format_timeit_output(self, output: str) -> tuple[str, str]: - """ - Parses the time from the end of the output given by timeit. - - If an error happened, then it won't contain the time and instead proceed with regular formatting. - """ - split_output = output.rstrip("\n").rsplit("\n", 1) - if len(split_output) == 2 and TIMEIT_OUTPUT_REGEX.fullmatch(split_output[1]): - return split_output[1], None - - return await self.format_output(output) - @command(name="eval", aliases=("e",)) @guild_only() @redirect_output( @@ -434,7 +420,7 @@ class Snekbox(Cog): issue with it! """ code = "\n".join(self.prepare_input(code)) - await self.run_eval(ctx, code, format_func=self.format_output) + await self.run_job("eval", ctx, code) @command(name="timeit", aliases=("ti",)) @guild_only() @@ -462,9 +448,7 @@ class Snekbox(Cog): codeblocks = self.prepare_input(code) code, args = self.prepare_timeit_input(codeblocks) - await self.run_eval( - ctx, code=code, format_func=self.format_timeit_output, args=args - ) + await self.run_job("timeit", ctx, code=code, args=args) def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 4245de8a3..339cdaaa4 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -72,13 +72,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval') self.assertEqual(actual, expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), ('Your eval job has completed with return code 127', '') ) @@ -86,7 +86,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_get_results_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), ('Your eval job has completed with return code 127 (SIGTEST)', '') ) @@ -164,9 +164,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with( - ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output - ) + self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval') self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command) async def test_eval_command_evaluate_twice(self): @@ -182,7 +180,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) self.cog.send_eval.assert_called_with( - ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output + ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' ) self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) @@ -214,7 +212,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -227,7 +225,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('') async def test_send_eval_with_paste_link(self): @@ -246,7 +244,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -257,7 +255,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('Way too long beard') async def test_send_eval_with_non_zero_eval(self): @@ -275,7 +273,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -285,7 +283,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") -- cgit v1.2.3 From d0cf7f2f1573e883cf1f6aaf8c54b3701496722b Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Wed, 26 Jan 2022 16:12:15 -0500 Subject: chore: Remove the naming of 'eval' in certain places Since the !eval command is no longer the only snekbox command, make the naming more generic. --- bot/exts/filters/filtering.py | 4 +- bot/exts/utils/snekbox.py | 72 ++++++++++++++++++------------------ tests/bot/exts/utils/test_snekbox.py | 6 +-- 3 files changed, 41 insertions(+), 41 deletions(-) (limited to 'tests') diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index ad904d147..e49cf4f82 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -267,9 +267,9 @@ class Filtering(Cog): # Update time when alert sent await self.name_alerts.set(member.id, arrow.utcnow().timestamp()) - async def filter_eval(self, result: str, msg: Message) -> bool: + async def filter_snekbox_job(self, result: str, msg: Message) -> bool: """ - Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. + Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. Also requires the original message, to check whether to filter and for mod logs. Returns whether a filter was triggered or not. diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index e7712eee5..86993b7f1 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -71,15 +71,15 @@ if not hasattr(sys, "_setup_finished"): MAX_PASTE_LEN = 10000 -# `!eval` command whitelists and blacklists. -NO_EVAL_CHANNELS = (Channels.python_general,) -NO_EVAL_CATEGORIES = () -EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) +# The Snekbox commands' whitelists and blacklists. +NO_SNEKBOX_CHANNELS = (Channels.python_general,) +NO_SNEKBOX_CATEGORIES = () +SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) SIGKILL = 9 -REEVAL_EMOJI = '\U0001f501' # :repeat: -REEVAL_TIMEOUT = 30 +REDO_EMOJI = '\U0001f501' # :repeat: +REDO_TIMEOUT = 30 class Snekbox(Cog): @@ -89,7 +89,7 @@ class Snekbox(Cog): self.bot = bot self.jobs = {} - async def post_eval(self, code: str, *, args: Optional[list[str]] = None) -> dict: + async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: """Send a POST request to the Snekbox API to evaluate code and return the results.""" url = URLs.snekbox_eval_api data = {"input": code} @@ -101,7 +101,7 @@ class Snekbox(Cog): return await resp.json() async def upload_output(self, output: str) -> Optional[str]: - """Upload the eval output to a paste service and return a URL to it if successful.""" + """Upload the job's output to a paste service and return a URL to it if successful.""" log.trace("Uploading full output to paste service...") if len(output) > MAX_PASTE_LEN: @@ -241,7 +241,7 @@ class Snekbox(Cog): return output, paste_link - async def send_eval( + async def send_job( self, ctx: Context, code: str, @@ -255,7 +255,7 @@ class Snekbox(Cog): Return the bot response. """ async with ctx.typing(): - results = await self.post_eval(code, args=args) + results = await self.post_job(code, args=args) msg, error = self.get_results_message(results, job_name) if error: @@ -269,7 +269,7 @@ class Snekbox(Cog): if paste_link: msg = f"{msg}\nFull output: {paste_link}" - # Collect stats of eval fails + successes + # Collect stats of job fails + successes if icon == ":x:": self.bot.stats.incr("snekbox.python.fail") else: @@ -278,7 +278,7 @@ class Snekbox(Cog): filter_cog = self.bot.get_cog("Filtering") filter_triggered = False if filter_cog: - filter_triggered = await filter_cog.filter_eval(msg, ctx.message) + filter_triggered = await filter_cog.filter_snekbox_job(msg, ctx.message) if filter_triggered: response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") else: @@ -289,26 +289,26 @@ class Snekbox(Cog): log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") return response - async def continue_eval( + async def continue_job( self, ctx: Context, response: Message, command: Command ) -> tuple[Optional[str], Optional[list[str]]]: """ - Check if the eval session should continue. + Check if the job's session should continue. If the code is to be re-evaluated, return the new code, and the args if the command is the timeit command. - Otherwise return (None, None) if the eval session should be terminated. + Otherwise return (None, None) if the job's session should be terminated. """ - _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) - _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + _predicate_message_edit = partial(predicate_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx) with contextlib.suppress(NotFound): try: _, new_message = await self.bot.wait_for( 'message_edit', - check=_predicate_eval_message_edit, - timeout=REEVAL_TIMEOUT + check=_predicate_message_edit, + timeout=REDO_TIMEOUT ) - await ctx.message.add_reaction(REEVAL_EMOJI) + await ctx.message.add_reaction(REDO_EMOJI) await self.bot.wait_for( 'reaction_add', check=_predicate_emoji_reaction, @@ -316,7 +316,7 @@ class Snekbox(Cog): ) code = await self.get_code(new_message, ctx.command) - await ctx.message.clear_reaction(REEVAL_EMOJI) + await ctx.message.clear_reaction(REDO_EMOJI) with contextlib.suppress(HTTPException): await response.delete() @@ -324,7 +324,7 @@ class Snekbox(Cog): return None, None except asyncio.TimeoutError: - await ctx.message.clear_reaction(REEVAL_EMOJI) + await ctx.message.clear_reaction(REDO_EMOJI) return None, None codeblocks = self.prepare_input(code) @@ -347,11 +347,11 @@ class Snekbox(Cog): new_ctx = await self.bot.get_context(message) if new_ctx.command is command: - log.trace(f"Message {message.id} invokes eval command.") + log.trace(f"Message {message.id} invokes {command} command.") split = message.content.split(maxsplit=1) code = split[1] if len(split) > 1 else None else: - log.trace(f"Message {message.id} does not invoke eval command.") + log.trace(f"Message {message.id} does not invoke {command} command.") code = message.content return code @@ -389,11 +389,11 @@ class Snekbox(Cog): while True: self.jobs[ctx.author.id] = datetime.datetime.now() try: - response = await self.send_eval(ctx, code, args=args, job_name=job_name) + response = await self.send_job(ctx, code, args=args, job_name=job_name) finally: del self.jobs[ctx.author.id] - code, args = await self.continue_eval(ctx, response, ctx.command) + code, args = await self.continue_job(ctx, response, ctx.command) if not code: break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") @@ -402,9 +402,9 @@ class Snekbox(Cog): @guild_only() @redirect_output( destination_channel=Channels.bot_commands, - bypass_roles=EVAL_ROLES, - categories=NO_EVAL_CATEGORIES, - channels=NO_EVAL_CHANNELS, + bypass_roles=SNEKBOX_ROLES, + categories=NO_SNEKBOX_CATEGORIES, + channels=NO_SNEKBOX_CHANNELS, ping_user=False ) async def eval_command(self, ctx: Context, *, code: str) -> None: @@ -425,9 +425,9 @@ class Snekbox(Cog): @guild_only() @redirect_output( destination_channel=Channels.bot_commands, - bypass_roles=EVAL_ROLES, - categories=NO_EVAL_CATEGORIES, - channels=NO_EVAL_CHANNELS, + bypass_roles=SNEKBOX_ROLES, + categories=NO_SNEKBOX_CATEGORIES, + channels=NO_SNEKBOX_CHANNELS, ping_user=False ) async def timeit_command(self, ctx: Context, *, code: str) -> str: @@ -450,14 +450,14 @@ class Snekbox(Cog): await self.run_job("timeit", ctx, code=code, args=args) -def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: +def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: """Return True if the edited message is the context message and the content was indeed modified.""" return new_msg.id == ctx.message.id and old_msg.content != new_msg.content -def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: - """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI +def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI def setup(bot: Bot) -> None: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 339cdaaa4..5d213a883 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -209,7 +209,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('[No output]', None)) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') @@ -241,7 +241,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') @@ -270,7 +270,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') -- cgit v1.2.3 From 85a6f430aa68f59ce6958ecb6450eca0736628e4 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Thu, 27 Jan 2022 10:31:12 -0500 Subject: chore: Switch Snekbox.prepare_input with a CodeblockConverter As per @Numerlor's suggestion --- bot/exts/filters/filtering.py | 2 +- bot/exts/utils/snekbox.py | 74 ++++++++++++------------ tests/bot/exts/utils/test_snekbox.py | 105 +++++++++++++++++------------------ 3 files changed, 91 insertions(+), 90 deletions(-) (limited to 'tests') diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index 599302576..375e9dca8 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -267,7 +267,7 @@ class Filtering(Cog): # Update time when alert sent await self.name_alerts.set(member.id, arrow.utcnow().timestamp()) - async def filter_snekbox_job(self, result: str, msg: Message) -> bool: + async def filter_snekbox_output(self, result: str, msg: Message) -> bool: """ Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 15599208f..41f6bf8ad 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -9,7 +9,7 @@ from typing import Optional, Tuple from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Command, Context, command, guild_only +from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs @@ -68,35 +68,11 @@ REDO_EMOJI = '\U0001f501' # :repeat: REDO_TIMEOUT = 30 -class Snekbox(Cog): - """Safe evaluation of Python code using Snekbox.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.jobs = {} - - async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: - """Send a POST request to the Snekbox API to evaluate code and return the results.""" - url = URLs.snekbox_eval_api - data = {"input": code} - - if args is not None: - data["args"] = args +class CodeblockConverter(Converter): + """Attempts to extract code from a codeblock, if provided.""" - async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return await resp.json() - - async def upload_output(self, output: str) -> Optional[str]: - """Upload the job's output to a paste service and return a URL to it if successful.""" - log.trace("Uploading full output to paste service...") - - if len(output) > MAX_PASTE_LEN: - log.info("Full output is too long to upload") - return "too long to upload" - return await send_to_paste_service(output, extension="txt") - - @staticmethod - def prepare_input(code: str) -> list[str]: + @classmethod + async def convert(cls, ctx: Context, code: str) -> list[str]: """ Extract code from the Markdown, format it, and insert it into the code template. @@ -128,6 +104,34 @@ class Snekbox(Cog): log.trace(f"Extracted {info} for evaluation:\n{code}") return codeblocks + +class Snekbox(Cog): + """Safe evaluation of Python code using Snekbox.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.jobs = {} + + async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + url = URLs.snekbox_eval_api + data = {"input": code} + + if args is not None: + data["args"] = args + + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return await resp.json() + + async def upload_output(self, output: str) -> Optional[str]: + """Upload the job's output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") + + if len(output) > MAX_PASTE_LEN: + log.info("Full output is too long to upload") + return "too long to upload" + return await send_to_paste_service(output, extension="txt") + @staticmethod def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: """ @@ -313,7 +317,7 @@ class Snekbox(Cog): await ctx.message.clear_reaction(REDO_EMOJI) return None, None - codeblocks = self.prepare_input(code) + codeblocks = await CodeblockConverter.convert(ctx, code) if command is self.timeit_command: return self.prepare_timeit_input(codeblocks) @@ -393,7 +397,7 @@ class Snekbox(Cog): channels=NO_SNEKBOX_CHANNELS, ping_user=False ) - async def eval_command(self, ctx: Context, *, code: str) -> None: + async def eval_command(self, ctx: Context, *, code: CodeblockConverter) -> None: """ Run Python code and get the results. @@ -404,8 +408,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = "\n".join(self.prepare_input(code)) - await self.run_job("eval", ctx, code) + await self.run_job("eval", ctx, "\n".join(code)) @command(name="timeit", aliases=("ti",)) @guild_only() @@ -416,7 +419,7 @@ class Snekbox(Cog): channels=NO_SNEKBOX_CHANNELS, ping_user=False ) - async def timeit_command(self, ctx: Context, *, code: str) -> str: + async def timeit_command(self, ctx: Context, *, code: CodeblockConverter) -> str: """ Profile Python Code to find execution time. @@ -430,8 +433,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - codeblocks = self.prepare_input(code) - code, args = self.prepare_timeit_input(codeblocks) + code, args = self.prepare_timeit_input(code) await self.run_job("timeit", ctx, code=code, args=args) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 5d213a883..75da0c860 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -17,7 +17,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = Snekbox(bot=self.bot) - async def test_post_eval(self): + async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() resp.json = AsyncMock(return_value="return") @@ -26,7 +26,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_eval("import random"), "return") + self.assertEqual(await self.cog.post_job("import random"), "return") self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, json={"input": "import random"}, @@ -45,7 +45,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.upload_output("Test output.") mock_paste_util.assert_called_once_with("Test output.", extension="txt") - def test_prepare_input(self): + async def test_codeblock_converter(self): + ctx = MockContext() cases = ( ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), @@ -61,7 +62,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected) + self.assertEqual( + '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected + ) def test_get_results_message(self): """Return error and message according to the eval result.""" @@ -158,31 +161,27 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): response = MockMessage() ctx.command = MagicMock() - self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock(return_value=(None, None)) + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock(return_value=(None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval') - self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command) + await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) + self.cog.send_job.assert_called_once_with(ctx, 'MyAwesomeCode', args=None, job_name='eval') + self.cog.continue_job.assert_called_once_with(ctx, response, ctx.command) async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() response = MockMessage() ctx.command = MagicMock() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock() + self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with( + await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) + self.cog.send_job.assert_called_with( ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' ) - self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) + self.cog.continue_job.assert_called_with(ctx, response, ctx.command) async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -196,14 +195,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) - async def test_send_eval(self): - """Test the send_eval function.""" + async def test_send_job(self): + """Test the send_job function.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author = MockUser(mention='@LemonLemonishBeard#0042') - self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('[No output]', None)) @@ -212,7 +211,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -223,19 +222,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('') - async def test_send_eval_with_paste_link(self): - """Test the send_eval function with a too long output that generate a paste link.""" + async def test_send_job_with_paste_link(self): + """Test the send_job function with a too long output that generate a paste link.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) @@ -244,7 +243,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -253,18 +252,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('Way too long beard') - async def test_send_eval_with_non_zero_eval(self): - """Test the send_eval function with a code returning a non-zero code.""" + async def test_send_job_with_non_zero_eval(self): + """Test the send_job function with a code returning a non-zero code.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) self.cog.get_status_emoji = MagicMock(return_value=':nope!:') self.cog.format_output = AsyncMock() # This function isn't called @@ -273,7 +272,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -281,14 +280,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") - async def test_continue_eval_does_continue(self, partial_mock): - """Test that the continue_eval function does continue if required conditions are met.""" + async def test_continue_job_does_continue(self, partial_mock): + """Test that the continue_job function does continue if required conditions are met.""" ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) response = MockMessage(delete=AsyncMock()) new_msg = MockMessage() @@ -296,30 +295,30 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - actual = await self.cog.continue_eval(ctx, response, self.cog.eval_command) + actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) self.assertEqual(actual, (expected, None)) self.bot.wait_for.assert_has_awaits( ( call( 'message_edit', - check=partial_mock(snekbox.predicate_eval_message_edit, ctx), - timeout=snekbox.REEVAL_TIMEOUT, + check=partial_mock(snekbox.predicate_message_edit, ctx), + timeout=snekbox.REDO_TIMEOUT, ), - call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10) ) ) - ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) response.delete.assert_called_once() - async def test_continue_eval_does_not_continue(self): + async def test_continue_job_does_not_continue(self): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError - actual = await self.cog.continue_eval(ctx, MockMessage(), self.cog.eval_command) + actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) self.assertEqual(actual, (None, None)) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) async def test_get_code(self): """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -347,8 +346,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) - def test_predicate_eval_message_edit(self): - """Test the predicate_eval_message_edit function.""" + def test_predicate_message_edit(self): + """Test the predicate_message_edit function.""" msg0 = MockMessage(id=1, content='abc') msg1 = MockMessage(id=2, content='abcdef') msg2 = MockMessage(id=1, content='abcdef') @@ -361,18 +360,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): for ctx_msg, new_msg, expected, testname in cases: with self.subTest(msg=f'Messages with {testname} return {expected}'): ctx = MockContext(message=ctx_msg) - actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) + actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg) self.assertEqual(actual, expected) - def test_predicate_eval_emoji_reaction(self): - """Test the predicate_eval_emoji_reaction function.""" + def test_predicate_emoji_reaction(self): + """Test the predicate_emoji_reaction function.""" valid_reaction = MockReaction(message=MockMessage(id=1)) - valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI + valid_reaction.__str__.return_value = snekbox.REDO_EMOJI valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) valid_user = MockUser(id=2) invalid_reaction_id = MockReaction(message=MockMessage(id=42)) - invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI + invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI invalid_user_id = MockUser(id=42) invalid_reaction_str = MockReaction(message=MockMessage(id=1)) invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -385,7 +384,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for reaction, user, expected, testname in cases: with self.subTest(msg=f'Test with {testname} and expected return {expected}'): - actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) + actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user) self.assertEqual(actual, expected) -- cgit v1.2.3 From b689f059e45e68f75e3a97277ac3bf58263fa769 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Fri, 4 Feb 2022 20:54:49 -0500 Subject: fix: Use filter_snekbox_output rather than job --- bot/exts/utils/snekbox.py | 2 +- tests/bot/exts/utils/test_snekbox.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index ce3dd7c24..a932b96ff 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -268,7 +268,7 @@ class Snekbox(Cog): filter_cog = self.bot.get_cog("Filtering") filter_triggered = False if filter_cog: - filter_triggered = await filter_cog.filter_snekbox_job(msg, ctx.message) + filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message) if filter_triggered: response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") else: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 75da0c860..2eaed0446 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -208,7 +208,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('[No output]', None)) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') @@ -240,7 +240,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') @@ -269,7 +269,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') -- cgit v1.2.3 From 7ef9b2163e7f0da5a7abd17decac2b0b2d53defb Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Sat, 5 Feb 2022 23:21:21 -0500 Subject: tests: Add a test for timeit command codeblock preparation --- tests/bot/exts/utils/test_snekbox.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 2eaed0446..f68a20089 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -66,6 +66,21 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected ) + def test_prepare_timeit_input(self): + """Test the prepare_timeit_input codeblock detection.""" + base_args = ('-m', 'timeit', '-s') + cases = ( + (['print("Hello World")'], '', 'single block of code'), + (['x = 1', 'print(x)'], 'x = 1', 'two blocks of code'), + (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') + ) + + for case, setup_code, testname in cases: + setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) + expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup]) + with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + self.assertEqual(self.cog.prepare_timeit_input(case), expected) + def test_get_results_message(self): """Return error and message according to the eval result.""" cases = ( -- cgit v1.2.3 From 753e101df0ba4f51c85a426a7a9d9678f96e0fd7 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 14 Mar 2022 19:01:56 +0000 Subject: Revert "Update all references of discord.py to disnake" This reverts commit 960619c23300c56c8aaa454edc7241e2badf80ad. --- bot/__init__.py | 2 +- bot/bot.py | 20 ++-- bot/converters.py | 24 ++-- bot/decorators.py | 10 +- bot/errors.py | 2 +- bot/exts/backend/branding/_cog.py | 18 +-- bot/exts/backend/config_verifier.py | 2 +- bot/exts/backend/error_handler.py | 4 +- bot/exts/backend/logging.py | 4 +- bot/exts/backend/sync/_cog.py | 6 +- bot/exts/backend/sync/_syncers.py | 4 +- bot/exts/events/code_jams/_channels.py | 42 +++---- bot/exts/events/code_jams/_cog.py | 24 ++-- bot/exts/filters/antimalware.py | 4 +- bot/exts/filters/antispam.py | 8 +- bot/exts/filters/filter_lists.py | 4 +- bot/exts/filters/filtering.py | 26 ++--- bot/exts/filters/security.py | 2 +- bot/exts/filters/token_remover.py | 6 +- bot/exts/filters/webhook_remover.py | 4 +- bot/exts/fun/duck_pond.py | 22 ++-- bot/exts/fun/off_topic_names.py | 6 +- bot/exts/help_channels/_caches.py | 14 +-- bot/exts/help_channels/_channel.py | 18 +-- bot/exts/help_channels/_cog.py | 64 +++++------ bot/exts/help_channels/_message.py | 36 +++--- bot/exts/help_channels/_name.py | 6 +- bot/exts/info/code_snippets.py | 8 +- bot/exts/info/codeblock/_cog.py | 22 ++-- bot/exts/info/doc/_batch_parser.py | 4 +- bot/exts/info/doc/_cog.py | 18 +-- bot/exts/info/help.py | 4 +- bot/exts/info/information.py | 8 +- bot/exts/info/pep.py | 4 +- bot/exts/info/pypi.py | 6 +- bot/exts/info/python_news.py | 12 +- bot/exts/info/source.py | 4 +- bot/exts/info/stats.py | 6 +- bot/exts/info/subscribe.py | 26 ++--- bot/exts/info/tags.py | 16 +-- bot/exts/moderation/clean.py | 10 +- bot/exts/moderation/defcon.py | 6 +- bot/exts/moderation/dm_relay.py | 6 +- bot/exts/moderation/incidents.py | 100 ++++++++--------- bot/exts/moderation/infraction/_scheduler.py | 14 +-- bot/exts/moderation/infraction/_utils.py | 14 +-- bot/exts/moderation/infraction/infractions.py | 26 ++--- bot/exts/moderation/infraction/management.py | 36 +++--- bot/exts/moderation/infraction/superstarify.py | 6 +- bot/exts/moderation/metabase.py | 2 +- bot/exts/moderation/modlog.py | 72 ++++++------ bot/exts/moderation/modpings.py | 8 +- bot/exts/moderation/silence.py | 8 +- bot/exts/moderation/slowmode.py | 4 +- bot/exts/moderation/stream.py | 24 ++-- bot/exts/moderation/verification.py | 16 +-- bot/exts/moderation/voice_gate.py | 42 +++---- bot/exts/moderation/watchchannels/_watchchannel.py | 12 +- bot/exts/moderation/watchchannels/bigbrother.py | 4 +- bot/exts/recruitment/talentpool/_cog.py | 8 +- bot/exts/recruitment/talentpool/_review.py | 4 +- bot/exts/utils/bot.py | 4 +- bot/exts/utils/extensions.py | 6 +- bot/exts/utils/internal.py | 17 ++- bot/exts/utils/ping.py | 4 +- bot/exts/utils/reminders.py | 32 +++--- bot/exts/utils/snekbox.py | 4 +- bot/exts/utils/thread_bumper.py | 24 ++-- bot/exts/utils/utils.py | 6 +- bot/log.py | 2 +- bot/monkey_patches.py | 6 +- bot/pagination.py | 18 +-- bot/rules/attachments.py | 2 +- bot/rules/burst.py | 2 +- bot/rules/burst_shared.py | 2 +- bot/rules/chars.py | 2 +- bot/rules/discord_emojis.py | 2 +- bot/rules/duplicates.py | 2 +- bot/rules/links.py | 2 +- bot/rules/mentions.py | 2 +- bot/rules/newlines.py | 2 +- bot/rules/role_mentions.py | 2 +- bot/utils/channel.py | 16 +-- bot/utils/checks.py | 4 +- bot/utils/function.py | 6 +- bot/utils/helpers.py | 2 +- bot/utils/members.py | 18 +-- bot/utils/message_cache.py | 2 +- bot/utils/messages.py | 48 ++++---- bot/utils/webhooks.py | 10 +- poetry.lock | 12 +- pyproject.toml | 2 +- tests/README.md | 12 +- tests/base.py | 8 +- tests/bot/exts/backend/sync/test_cog.py | 6 +- tests/bot/exts/backend/sync/test_roles.py | 6 +- tests/bot/exts/backend/sync/test_users.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 2 +- tests/bot/exts/events/test_code_jams.py | 4 +- tests/bot/exts/filters/test_antimalware.py | 2 +- tests/bot/exts/filters/test_security.py | 2 +- tests/bot/exts/filters/test_token_remover.py | 2 +- tests/bot/exts/info/test_information.py | 22 ++-- .../exts/moderation/infraction/test_infractions.py | 2 +- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- tests/bot/exts/moderation/test_incidents.py | 20 ++-- tests/bot/exts/moderation/test_modlog.py | 4 +- tests/bot/exts/moderation/test_silence.py | 4 +- tests/bot/exts/test_cogs.py | 4 +- tests/bot/exts/utils/test_snekbox.py | 4 +- tests/bot/test_converters.py | 2 +- tests/bot/utils/test_checks.py | 2 +- tests/helpers.py | 124 ++++++++++----------- tests/test_helpers.py | 28 ++--- 114 files changed, 734 insertions(+), 735 deletions(-) (limited to 'tests') diff --git a/bot/__init__.py b/bot/__init__.py index b28513bff..17d99105a 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -3,7 +3,7 @@ import os from functools import partial, partialmethod from typing import TYPE_CHECKING -from disnake.ext import commands +from discord.ext import commands from bot import log, monkey_patches diff --git a/bot/bot.py b/bot/bot.py index 2769b7dda..94783a466 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -6,9 +6,9 @@ from contextlib import suppress from typing import Dict, List, Optional import aiohttp -import disnake +import discord from async_rediscache import RedisSession -from disnake.ext import commands +from discord.ext import commands from sentry_sdk import push_scope from bot import api, constants @@ -28,7 +28,7 @@ class StartupError(Exception): class Bot(commands.Bot): - """A subclass of `disnake.ext.commands.Bot` with an aiohttp session and an API client.""" + """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" def __init__(self, *args, redis_session: RedisSession, **kwargs): if "connector" in kwargs: @@ -109,9 +109,9 @@ class Bot(commands.Bot): def create(cls) -> "Bot": """Create and return an instance of a Bot.""" loop = asyncio.get_event_loop() - allowed_roles = list({disnake.Object(id_) for id_ in constants.MODERATION_ROLES}) + allowed_roles = list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}) - intents = disnake.Intents.all() + intents = discord.Intents.all() intents.presences = False intents.dm_typing = False intents.dm_reactions = False @@ -123,10 +123,10 @@ class Bot(commands.Bot): redis_session=_create_redis_session(loop), loop=loop, command_prefix=commands.when_mentioned_or(constants.Bot.prefix), - activity=disnake.Game(name=f"Commands: {constants.Bot.prefix}help"), + activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), case_insensitive=True, max_messages=10_000, - allowed_mentions=disnake.AllowedMentions(everyone=False, roles=allowed_roles), + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), intents=intents, ) @@ -258,7 +258,7 @@ class Bot(commands.Bot): await self.stats.create_socket() await super().login(*args, **kwargs) - async def on_guild_available(self, guild: disnake.Guild) -> None: + async def on_guild_available(self, guild: discord.Guild) -> None: """ Set the internal guild available event when constants.Guild.id becomes available. @@ -274,7 +274,7 @@ class Bot(commands.Bot): try: webhook = await self.fetch_webhook(constants.Webhooks.dev_log) - except disnake.HTTPException as e: + except discord.HTTPException as e: log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}") else: await webhook.send(f"<@&{constants.Roles.admin}> {msg}") @@ -283,7 +283,7 @@ class Bot(commands.Bot): self._guild_available.set() - async def on_guild_unavailable(self, guild: disnake.Guild) -> None: + async def on_guild_unavailable(self, guild: discord.Guild) -> None: """Clear the internal guild available event when constants.Guild.id becomes unavailable.""" if guild.id != constants.Guild.id: return diff --git a/bot/converters.py b/bot/converters.py index 9d93428ca..3522a32aa 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -6,12 +6,12 @@ from datetime import datetime, timezone from ssl import CertificateError import dateutil.parser -import disnake +import discord from aiohttp import ClientConnectorError from botcore.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta -from disnake.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter -from disnake.utils import escape_markdown, snowflake_time +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from discord.utils import escape_markdown, snowflake_time from bot import exts from bot.api import ResponseCodeError @@ -505,14 +505,14 @@ AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Us class UnambiguousUser(UserConverter): """ - Converts to a `disnake.User`, but only if a mention, userID or a username (name#discrim) is provided. + Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. Unlike the default `UserConverter`, it doesn't allow conversion from a name. This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - async def convert(self, ctx: Context, argument: str) -> disnake.User: - """Convert the `argument` to a `disnake.User`.""" + async def convert(self, ctx: Context, argument: str) -> discord.User: + """Convert the `argument` to a `discord.User`.""" if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: @@ -521,14 +521,14 @@ class UnambiguousUser(UserConverter): class UnambiguousMember(MemberConverter): """ - Converts to a `disnake.Member`, but only if a mention, userID or a username (name#discrim) is provided. + Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - async def convert(self, ctx: Context, argument: str) -> disnake.Member: - """Convert the `argument` to a `disnake.Member`.""" + async def convert(self, ctx: Context, argument: str) -> discord.Member: + """Convert the `argument` to a `discord.Member`.""" if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: @@ -588,10 +588,10 @@ if t.TYPE_CHECKING: OffTopicName = str # noqa: F811 ISODateTime = datetime # noqa: F811 HushDurationConverter = int # noqa: F811 - UnambiguousUser = disnake.User # noqa: F811 - UnambiguousMember = disnake.Member # noqa: F811 + UnambiguousUser = discord.User # noqa: F811 + UnambiguousMember = discord.Member # noqa: F811 Infraction = t.Optional[dict] # noqa: F811 Expiry = t.Union[Duration, ISODateTime] -MemberOrUser = t.Union[disnake.Member, disnake.User] +MemberOrUser = t.Union[discord.Member, discord.User] UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] diff --git a/bot/decorators.py b/bot/decorators.py index 9ae98442c..f4331264f 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -4,9 +4,9 @@ import types import typing as t from contextlib import suppress -from disnake import Member, NotFound -from disnake.ext import commands -from disnake.ext.commands import Cog, Context +from discord import Member, NotFound +from discord.ext import commands +from discord.ext.commands import Cog, Context from bot.constants import Channels, DEBUG_MODE, RedirectOutput from bot.log import get_logger @@ -179,7 +179,7 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: Ensure the highest role of the invoking member is greater than that of the target member. If the condition fails, a warning is sent to the invoking context. A target which is not an - instance of disnake.Member will always pass. + instance of discord.Member will always pass. `member_arg` is the keyword name or position index of the parameter of the decorated command whose value is the target member. @@ -195,7 +195,7 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: target = function.get_arg_value(member_arg, bound_args) if not isinstance(target, Member): - log.trace("The target is not a disnake.Member; skipping role hierarchy check.") + log.trace("The target is not a discord.Member; skipping role hierarchy check.") return await func(*args, **kwargs) ctx = function.get_arg_value(1, bound_args) diff --git a/bot/errors.py b/bot/errors.py index 298e7ac2d..078b645f1 100644 --- a/bot/errors.py +++ b/bot/errors.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Hashable, TYPE_CHECKING, Union -from disnake.ext.commands import ConversionError, Converter +from discord.ext.commands import ConversionError, Converter if TYPE_CHECKING: from bot.converters import MemberOrUser diff --git a/bot/exts/backend/branding/_cog.py b/bot/exts/backend/branding/_cog.py index a07e70d58..0c5839a7a 100644 --- a/bot/exts/backend/branding/_cog.py +++ b/bot/exts/backend/branding/_cog.py @@ -7,10 +7,10 @@ from enum import Enum from operator import attrgetter import async_timeout -import disnake +import discord from arrow import Arrow from async_rediscache import RedisCache -from disnake.ext import commands, tasks +from discord.ext import commands, tasks from bot.bot import Bot from bot.constants import Branding as BrandingConfig, Channels, Colours, Guild, MODERATION_ROLES @@ -42,7 +42,7 @@ def compound_hash(objects: t.Iterable[RemoteObject]) -> str: return "-".join(item.sha for item in objects) -def make_embed(title: str, description: str, *, success: bool) -> disnake.Embed: +def make_embed(title: str, description: str, *, success: bool) -> discord.Embed: """ Construct simple response embed. @@ -51,7 +51,7 @@ def make_embed(title: str, description: str, *, success: bool) -> disnake.Embed: For both `title` and `description`, empty string are valid values ~ fields will be empty. """ colour = Colours.soft_green if success else Colours.soft_red - return disnake.Embed(title=title[:256], description=description[:4096], colour=colour) + return discord.Embed(title=title[:256], description=description[:4096], colour=colour) def extract_event_duration(event: Event) -> str: @@ -147,13 +147,13 @@ class Branding(commands.Cog): return False await self.bot.wait_until_guild_available() - pydis: disnake.Guild = self.bot.get_guild(Guild.id) + pydis: discord.Guild = self.bot.get_guild(Guild.id) timeout = 10 # Seconds. try: with async_timeout.timeout(timeout): # Raise after `timeout` seconds. await pydis.edit(**{asset_type.value: file}) - except disnake.HTTPException: + except discord.HTTPException: log.exception("Asset upload to Discord failed.") return False except asyncio.TimeoutError: @@ -277,7 +277,7 @@ class Branding(commands.Cog): log.debug(f"Sending event information event to channel: {channel_id} ({is_notification=}).") await self.bot.wait_until_guild_available() - channel: t.Optional[disnake.TextChannel] = self.bot.get_channel(channel_id) + channel: t.Optional[discord.TextChannel] = self.bot.get_channel(channel_id) if channel is None: log.warning(f"Cannot send event information: channel {channel_id} not found!") @@ -294,7 +294,7 @@ class Branding(commands.Cog): else: content = "Python Discord is entering a new event!" if is_notification else None - embed = disnake.Embed(description=description[:4096], colour=disnake.Colour.og_blurple()) + embed = discord.Embed(description=description[:4096], colour=discord.Colour.og_blurple()) embed.set_footer(text=duration[:4096]) await channel.send(content=content, embed=embed) @@ -573,7 +573,7 @@ class Branding(commands.Cog): await ctx.send(embed=resp) return - embed = disnake.Embed(title="Current event calendar", colour=disnake.Colour.og_blurple()) + embed = discord.Embed(title="Current event calendar", colour=discord.Colour.og_blurple()) # Because Discord embeds can only contain up to 25 fields, we only show the first 25. first_25 = list(available_events.items())[:25] diff --git a/bot/exts/backend/config_verifier.py b/bot/exts/backend/config_verifier.py index 1ade2bce7..dc85a65a2 100644 --- a/bot/exts/backend/config_verifier.py +++ b/bot/exts/backend/config_verifier.py @@ -1,4 +1,4 @@ -from disnake.ext.commands import Cog +from discord.ext.commands import Cog from bot import constants from bot.bot import Bot diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 953843a77..c79c7b2a7 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,7 +1,7 @@ import difflib -from disnake import Embed -from disnake.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors +from discord import Embed +from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from sentry_sdk import push_scope from bot.api import ResponseCodeError diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py index 040fb5d37..2d03cd580 100644 --- a/bot/exts/backend/logging.py +++ b/bot/exts/backend/logging.py @@ -1,5 +1,5 @@ -from disnake import Embed -from disnake.ext.commands import Cog +from discord import Embed +from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index d08e56077..80f5750bc 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -1,8 +1,8 @@ from typing import Any, Dict -from disnake import Member, Role, User -from disnake.ext import commands -from disnake.ext.commands import Cog, Context +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context from bot import constants from bot.api import ResponseCodeError diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 48ee3c842..45301b098 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,8 +2,8 @@ import abc import typing as t from collections import namedtuple -from disnake import Guild -from disnake.ext.commands import Context +from discord import Guild +from discord.ext.commands import Context from more_itertools import chunked import bot diff --git a/bot/exts/events/code_jams/_channels.py b/bot/exts/events/code_jams/_channels.py index fc4693bd4..e8cf5f7bf 100644 --- a/bot/exts/events/code_jams/_channels.py +++ b/bot/exts/events/code_jams/_channels.py @@ -1,6 +1,6 @@ import typing as t -import disnake +import discord from bot.constants import Categories, Channels, Roles from bot.log import get_logger @@ -11,7 +11,7 @@ MAX_CHANNELS = 50 CATEGORY_NAME = "Code Jam" -async def _get_category(guild: disnake.Guild) -> disnake.CategoryChannel: +async def _get_category(guild: discord.Guild) -> discord.CategoryChannel: """ Return a code jam category. @@ -24,13 +24,13 @@ async def _get_category(guild: disnake.Guild) -> disnake.CategoryChannel: return await _create_category(guild) -async def _create_category(guild: disnake.Guild) -> disnake.CategoryChannel: +async def _create_category(guild: discord.Guild) -> discord.CategoryChannel: """Create a new code jam category and return it.""" log.info("Creating a new code jam category.") category_overwrites = { - guild.default_role: disnake.PermissionOverwrite(read_messages=False), - guild.me: disnake.PermissionOverwrite(read_messages=True) + guild.default_role: discord.PermissionOverwrite(read_messages=False), + guild.me: discord.PermissionOverwrite(read_messages=True) } category = await guild.create_category_channel( @@ -47,17 +47,17 @@ async def _create_category(guild: disnake.Guild) -> disnake.CategoryChannel: def _get_overwrites( - members: list[tuple[disnake.Member, bool]], - guild: disnake.Guild, -) -> dict[t.Union[disnake.Member, disnake.Role], disnake.PermissionOverwrite]: + members: list[tuple[discord.Member, bool]], + guild: discord.Guild, +) -> dict[t.Union[discord.Member, discord.Role], discord.PermissionOverwrite]: """Get code jam team channels permission overwrites.""" team_channel_overwrites = { - guild.default_role: disnake.PermissionOverwrite(read_messages=False), - guild.get_role(Roles.code_jam_event_team): disnake.PermissionOverwrite(read_messages=True) + guild.default_role: discord.PermissionOverwrite(read_messages=False), + guild.get_role(Roles.code_jam_event_team): discord.PermissionOverwrite(read_messages=True) } for member, _ in members: - team_channel_overwrites[member] = disnake.PermissionOverwrite( + team_channel_overwrites[member] = discord.PermissionOverwrite( read_messages=True ) @@ -65,10 +65,10 @@ def _get_overwrites( async def create_team_channel( - guild: disnake.Guild, + guild: discord.Guild, team_name: str, - members: list[tuple[disnake.Member, bool]], - team_leaders: disnake.Role + members: list[tuple[discord.Member, bool]], + team_leaders: discord.Role ) -> None: """Create the team's text channel.""" await _add_team_leader_roles(members, team_leaders) @@ -84,29 +84,29 @@ async def create_team_channel( ) -async def create_team_leader_channel(guild: disnake.Guild, team_leaders: disnake.Role) -> None: +async def create_team_leader_channel(guild: discord.Guild, team_leaders: discord.Role) -> None: """Create the Team Leader Chat channel for the Code Jam team leaders.""" - category: disnake.CategoryChannel = guild.get_channel(Categories.summer_code_jam) + category: discord.CategoryChannel = guild.get_channel(Categories.summer_code_jam) team_leaders_chat = await category.create_text_channel( name="team-leaders-chat", overwrites={ - guild.default_role: disnake.PermissionOverwrite(read_messages=False), - team_leaders: disnake.PermissionOverwrite(read_messages=True) + guild.default_role: discord.PermissionOverwrite(read_messages=False), + team_leaders: discord.PermissionOverwrite(read_messages=True) } ) await _send_status_update(guild, f"Created {team_leaders_chat.mention} in the {category} category.") -async def _send_status_update(guild: disnake.Guild, message: str) -> None: +async def _send_status_update(guild: discord.Guild, message: str) -> None: """Inform the events lead with a status update when the command is ran.""" - channel: disnake.TextChannel = guild.get_channel(Channels.code_jam_planning) + channel: discord.TextChannel = guild.get_channel(Channels.code_jam_planning) await channel.send(f"<@&{Roles.events_lead}>\n\n{message}") -async def _add_team_leader_roles(members: list[tuple[disnake.Member, bool]], team_leaders: disnake.Role) -> None: +async def _add_team_leader_roles(members: list[tuple[discord.Member, bool]], team_leaders: discord.Role) -> None: """Assign the team leader role to the team leaders.""" for member, is_leader in members: if is_leader: diff --git a/bot/exts/events/code_jams/_cog.py b/bot/exts/events/code_jams/_cog.py index 5cb11826d..452199f5f 100644 --- a/bot/exts/events/code_jams/_cog.py +++ b/bot/exts/events/code_jams/_cog.py @@ -3,9 +3,9 @@ import csv import typing as t from collections import defaultdict -import disnake -from disnake import Colour, Embed, Guild, Member -from disnake.ext import commands +import discord +from discord import Colour, Embed, Guild, Member +from discord.ext import commands from bot.bot import Bot from bot.constants import Emojis, Roles @@ -85,7 +85,7 @@ class CodeJams(commands.Cog): A confirmation message is displayed with the categories and channels to be deleted.. Pressing the added reaction deletes those channels. """ - def predicate_deletion_emoji_reaction(reaction: disnake.Reaction, user: disnake.User) -> bool: + def predicate_deletion_emoji_reaction(reaction: discord.Reaction, user: discord.User) -> bool: """Return True if the reaction :boom: was added by the context message author on this message.""" return ( reaction.message.id == message.id @@ -124,14 +124,14 @@ class CodeJams(commands.Cog): @staticmethod async def _build_confirmation_message( - categories: dict[disnake.CategoryChannel, list[disnake.abc.GuildChannel]] + categories: dict[discord.CategoryChannel, list[discord.abc.GuildChannel]] ) -> str: """Sends details of the channels to be deleted to the pasting service, and formats the confirmation message.""" - def channel_repr(channel: disnake.abc.GuildChannel) -> str: + def channel_repr(channel: discord.abc.GuildChannel) -> str: """Formats the channel name and ID and a readable format.""" return f"{channel.name} ({channel.id})" - def format_category_info(category: disnake.CategoryChannel, channels: list[disnake.abc.GuildChannel]) -> str: + def format_category_info(category: discord.CategoryChannel, channels: list[discord.abc.GuildChannel]) -> str: """Displays the category and the channels within it in a readable format.""" return f"{channel_repr(category)}:\n" + "\n".join(" - " + channel_repr(channel) for channel in channels) @@ -187,7 +187,7 @@ class CodeJams(commands.Cog): await old_team_channel.set_permissions(member, overwrite=None, reason=f"Participant moved to {new_team_name}") await new_team_channel.set_permissions( member, - overwrite=disnake.PermissionOverwrite(read_messages=True), + overwrite=discord.PermissionOverwrite(read_messages=True), reason=f"Participant moved from {old_team_channel.name}" ) @@ -212,16 +212,16 @@ class CodeJams(commands.Cog): await ctx.send(f"Removed the participant from `{self.team_name(channel)}`.") @staticmethod - def jam_categories(guild: Guild) -> list[disnake.CategoryChannel]: + def jam_categories(guild: Guild) -> list[discord.CategoryChannel]: """Get all the code jam team categories.""" return [category for category in guild.categories if category.name == _channels.CATEGORY_NAME] @staticmethod - def team_channel(guild: Guild, criterion: t.Union[str, Member]) -> t.Optional[disnake.TextChannel]: + def team_channel(guild: Guild, criterion: t.Union[str, Member]) -> t.Optional[discord.TextChannel]: """Get a team channel through either a participant or the team name.""" for category in CodeJams.jam_categories(guild): for channel in category.channels: - if isinstance(channel, disnake.TextChannel): + if isinstance(channel, discord.TextChannel): if ( # If it's a string. criterion == channel.name or criterion == CodeJams.team_name(channel) @@ -231,6 +231,6 @@ class CodeJams(commands.Cog): return channel @staticmethod - def team_name(channel: disnake.TextChannel) -> str: + def team_name(channel: discord.TextChannel) -> str: """Retrieves the team name from the given channel.""" return channel.name.replace("-", " ").title() diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py index e55ece910..6cccf3680 100644 --- a/bot/exts/filters/antimalware.py +++ b/bot/exts/filters/antimalware.py @@ -1,8 +1,8 @@ import typing as t from os.path import splitext -from disnake import Embed, Message, NotFound -from disnake.ext.commands import Cog +from discord import Embed, Message, NotFound +from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, Filter, URLs diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index c887cf5fc..bcd845a43 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -8,8 +8,8 @@ from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set import arrow -from disnake import Colour, Member, Message, NotFound, Object, TextChannel -from disnake.ext.commands import Cog +from discord import Colour, Member, Message, NotFound, Object, TextChannel +from discord.ext.commands import Cog from bot import rules from bot.bot import Bot @@ -195,7 +195,7 @@ class AntiSpam(Cog): result = await rule_function(message, messages_for_rule, rule_config) # If the rule returns `None`, that means the message didn't violate it. - # If it doesn't, it returns a tuple in the form `(str, Iterable[disnake.Member])` + # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` # which contains the reason for why the message violated the rule and # an iterable of all members that violated the rule. if result is not None: @@ -265,7 +265,7 @@ class AntiSpam(Cog): # In the rare case where we found messages matching the # spam filter across multiple channels, it is possible # that a single channel will only contain a single message - # to delete. If that should be the case, disnake will + # to delete. If that should be the case, discord.py will # use the "delete single message" endpoint instead of the # bulk delete endpoint, and the single message deletion # endpoint will complain if you give it that does not exist. diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py index 05910973a..a883ddf54 100644 --- a/bot/exts/filters/filter_lists.py +++ b/bot/exts/filters/filter_lists.py @@ -1,8 +1,8 @@ import re from typing import Optional -from disnake import Colour, Embed -from disnake.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role +from discord import Colour, Embed +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role from bot import constants from bot.api import ResponseCodeError diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index e8c9bab62..f44b28125 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -6,15 +6,15 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union import arrow import dateutil.parser -import disnake.errors +import discord.errors import regex import tldextract from async_rediscache import RedisCache from botcore.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta -from disnake import Colour, HTTPException, Member, Message, NotFound, TextChannel -from disnake.ext.commands import Cog -from disnake.utils import escape_markdown +from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel +from discord.ext.commands import Cog +from discord.utils import escape_markdown from bot.api import ResponseCodeError from bot.bot import Bot @@ -63,14 +63,14 @@ AUTO_BAN_REASON = ( ) AUTO_BAN_DURATION = timedelta(days=4) -FilterMatch = Union[re.Match, dict, bool, List[disnake.Embed]] +FilterMatch = Union[re.Match, dict, bool, List[discord.Embed]] class Stats(NamedTuple): """Additional stats on a triggered filter to append to a mod log.""" message_content: str - additional_embeds: Optional[List[disnake.Embed]] + additional_embeds: Optional[List[discord.Embed]] class Filtering(Cog): @@ -339,7 +339,7 @@ class Filtering(Cog): match = result if match: - is_private = msg.channel.type is disnake.ChannelType.private + is_private = msg.channel.type is discord.ChannelType.private # If this is a filter (not a watchlist) and not in a DM, delete the message. if _filter["type"] == "filter" and not is_private: @@ -354,7 +354,7 @@ class Filtering(Cog): # In addition, to avoid sending two notifications to the user, the # logs, and mod_alert, we return if the message no longer exists. await msg.delete() - except disnake.errors.NotFound: + except discord.errors.NotFound: return # Notify the user if the filter specifies @@ -409,14 +409,14 @@ class Filtering(Cog): self, filter_name: str, _filter: Dict[str, Any], - msg: disnake.Message, + msg: discord.Message, stats: Stats, reason: Optional[str] = None, *, is_eval: bool = False, ) -> None: """Send a mod log for a triggered filter.""" - if msg.channel.type is disnake.ChannelType.private: + if msg.channel.type is discord.ChannelType.private: channel_str = "via DM" ping_everyone = False else: @@ -478,7 +478,7 @@ class Filtering(Cog): additional_embeds = [] for _, data in match.items(): reason = f"Reason: {data['reason']} | " if data.get('reason') else "" - embed = disnake.Embed(description=( + embed = discord.Embed(description=( f"**Members:**\n{data['members']}\n" f"**Active:**\n{data['active']}" )) @@ -626,7 +626,7 @@ class Filtering(Cog): return invite_data if invite_data else False @staticmethod - async def _has_rich_embed(msg: Message) -> Union[bool, List[disnake.Embed]]: + async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" if msg.embeds: for embed in msg.embeds: @@ -662,7 +662,7 @@ class Filtering(Cog): """ try: await filtered_member.send(reason) - except disnake.errors.Forbidden: + except discord.errors.Forbidden: await channel.send(f"{filtered_member.mention} {reason}") def schedule_msg_delete(self, msg: dict) -> None: diff --git a/bot/exts/filters/security.py b/bot/exts/filters/security.py index bbb15542f..fe3918423 100644 --- a/bot/exts/filters/security.py +++ b/bot/exts/filters/security.py @@ -1,4 +1,4 @@ -from disnake.ext.commands import Cog, Context, NoPrivateMessage +from discord.ext.commands import Cog, Context, NoPrivateMessage from bot.bot import Bot from bot.log import get_logger diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filters/token_remover.py index da42bb0aa..520283ba3 100644 --- a/bot/exts/filters/token_remover.py +++ b/bot/exts/filters/token_remover.py @@ -3,8 +3,8 @@ import binascii import re import typing as t -from disnake import Colour, Message, NotFound -from disnake.ext.commands import Cog +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog from bot import utils from bot.bot import Bot @@ -53,7 +53,7 @@ class Token(t.NamedTuple): class TokenRemover(Cog): - """Scans messages for potential Discord bot tokens and removes them.""" + """Scans messages for potential discord.py bot tokens and removes them.""" def __init__(self, bot: Bot): self.bot = bot diff --git a/bot/exts/filters/webhook_remover.py b/bot/exts/filters/webhook_remover.py index a5d51700c..96334317c 100644 --- a/bot/exts/filters/webhook_remover.py +++ b/bot/exts/filters/webhook_remover.py @@ -1,7 +1,7 @@ import re -from disnake import Colour, Message, NotFound -from disnake.ext.commands import Cog +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, Colours, Event, Icons diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py index 55196cd65..c51656343 100644 --- a/bot/exts/fun/duck_pond.py +++ b/bot/exts/fun/duck_pond.py @@ -1,9 +1,9 @@ import asyncio from typing import Union -import disnake -from disnake import Color, Embed, Message, RawReactionActionEvent, TextChannel, errors -from disnake.ext.commands import Cog, Context, command +import discord +from discord import Color, Embed, Message, RawReactionActionEvent, TextChannel, errors +from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot @@ -34,7 +34,7 @@ class DuckPond(Cog): try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except disnake.HTTPException: + except discord.HTTPException: log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") @staticmethod @@ -67,7 +67,7 @@ class DuckPond(Cog): return False @staticmethod - def _is_duck_emoji(emoji: Union[str, disnake.PartialEmoji, disnake.Emoji]) -> bool: + def _is_duck_emoji(emoji: Union[str, discord.PartialEmoji, discord.Emoji]) -> bool: """Check if the emoji is a valid duck emoji.""" if isinstance(emoji, str): return emoji == "🦆" @@ -111,7 +111,7 @@ class DuckPond(Cog): username=message.author.display_name, avatar_url=message.author.display_avatar.url ) - except disnake.HTTPException: + except discord.HTTPException: log.exception("Failed to send an attachment to the webhook") async def locked_relay(self, message: Message) -> bool: @@ -133,7 +133,7 @@ class DuckPond(Cog): await message.add_reaction("✅") return True - def _payload_has_duckpond_emoji(self, emoji: disnake.PartialEmoji) -> bool: + def _payload_has_duckpond_emoji(self, emoji: discord.PartialEmoji) -> bool: """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" if emoji.is_unicode_emoji(): # For unicode PartialEmojis, the `name` attribute is just the string @@ -165,7 +165,7 @@ class DuckPond(Cog): if not self._payload_has_duckpond_emoji(payload.emoji): return - channel = disnake.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) if channel is None: return @@ -175,10 +175,10 @@ class DuckPond(Cog): try: message = await channel.fetch_message(payload.message_id) - except disnake.NotFound: + except discord.NotFound: return # Message was deleted. - member = disnake.utils.get(message.guild.members, id=payload.user_id) + member = discord.utils.get(message.guild.members, id=payload.user_id) if not member: return # Member left or wasn't in the cache. @@ -205,7 +205,7 @@ class DuckPond(Cog): if payload.guild_id != constants.Guild.id: return - channel = disnake.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) if channel is None: return diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index d49f71320..7df1d172d 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -2,9 +2,9 @@ import difflib from datetime import timedelta import arrow -from disnake import Colour, Embed -from disnake.ext.commands import Cog, Context, group, has_any_role -from disnake.utils import sleep_until +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group, has_any_role +from discord.utils import sleep_until from bot.api import ResponseCodeError from bot.bot import Bot diff --git a/bot/exts/help_channels/_caches.py b/bot/exts/help_channels/_caches.py index f4eaf3291..8d45c2466 100644 --- a/bot/exts/help_channels/_caches.py +++ b/bot/exts/help_channels/_caches.py @@ -1,24 +1,24 @@ from async_rediscache import RedisCache # This dictionary maps a help channel to the time it was claimed -# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] +# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] claim_times = RedisCache(namespace="HelpChannels.claim_times") # This cache tracks which channels are claimed by which members. -# RedisCache[disnake.TextChannel.id, t.Union[disnake.User.id, disnake.Member.id]] +# RedisCache[discord.TextChannel.id, t.Union[discord.User.id, discord.Member.id]] claimants = RedisCache(namespace="HelpChannels.help_channel_claimants") # Stores the timestamp of the last message from the claimant of a help channel -# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] +# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] claimant_last_message_times = RedisCache(namespace="HelpChannels.claimant_last_message_times") # This cache maps a help channel to the timestamp of the last non-claimant message. # This cache being empty for a given help channel indicates the question is unanswered. -# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] +# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] non_claimant_last_message_times = RedisCache(namespace="HelpChannels.non_claimant_last_message_times") # This cache maps a help channel to original question message in same channel. -# RedisCache[disnake.TextChannel.id, disnake.Message.id] +# RedisCache[discord.TextChannel.id, discord.Message.id] question_messages = RedisCache(namespace="HelpChannels.question_messages") # This cache keeps track of the dynamic message ID for @@ -26,10 +26,10 @@ question_messages = RedisCache(namespace="HelpChannels.question_messages") dynamic_message = RedisCache(namespace="HelpChannels.dynamic_message") # This cache keeps track of who has help-dms on. -# RedisCache[disnake.User.id, bool] +# RedisCache[discord.User.id, bool] help_dm = RedisCache(namespace="HelpChannels.help_dm") # This cache tracks member who are participating and opted in to help channel dms. # serialise the set as a comma separated string to allow usage with redis -# RedisCache[disnake.TextChannel.id, str[set[disnake.User.id]]] +# RedisCache[discord.TextChannel.id, str[set[discord.User.id]]] session_participants = RedisCache(namespace="HelpChannels.session_participants") diff --git a/bot/exts/help_channels/_channel.py b/bot/exts/help_channels/_channel.py index 3c4eaa2b2..d9cebf215 100644 --- a/bot/exts/help_channels/_channel.py +++ b/bot/exts/help_channels/_channel.py @@ -4,7 +4,7 @@ from datetime import timedelta from enum import Enum import arrow -import disnake +import discord from arrow import Arrow import bot @@ -31,7 +31,7 @@ class ClosingReason(Enum): CLEANUP = "auto.cleanup" -def get_category_channels(category: disnake.CategoryChannel) -> t.Iterable[disnake.TextChannel]: +def get_category_channels(category: discord.CategoryChannel) -> t.Iterable[discord.TextChannel]: """Yield the text channels of the `category` in an unsorted manner.""" log.trace(f"Getting text channels in the category '{category}' ({category.id}).") @@ -41,7 +41,7 @@ def get_category_channels(category: disnake.CategoryChannel) -> t.Iterable[disna yield channel -async def get_closing_time(channel: disnake.TextChannel, init_done: bool) -> t.Tuple[Arrow, ClosingReason]: +async def get_closing_time(channel: discord.TextChannel, init_done: bool) -> t.Tuple[Arrow, ClosingReason]: """ Return the time at which the given help `channel` should be closed along with the reason. @@ -116,12 +116,12 @@ async def get_in_use_time(channel_id: int) -> t.Optional[timedelta]: return arrow.utcnow() - claimed -def is_excluded_channel(channel: disnake.abc.GuildChannel) -> bool: +def is_excluded_channel(channel: discord.abc.GuildChannel) -> bool: """Check if a channel should be excluded from the help channel system.""" - return not isinstance(channel, disnake.TextChannel) or channel.id in EXCLUDED_CHANNELS + return not isinstance(channel, discord.TextChannel) or channel.id in EXCLUDED_CHANNELS -async def move_to_bottom(channel: disnake.TextChannel, category_id: int, **options) -> None: +async def move_to_bottom(channel: discord.TextChannel, category_id: int, **options) -> None: """ Move the `channel` to the bottom position of `category` and edit channel attributes. @@ -130,8 +130,8 @@ async def move_to_bottom(channel: disnake.TextChannel, category_id: int, **optio really ends up at the bottom of the category. If `options` are provided, the channel will be edited after the move is completed. This is the - same order of operations that `disnake.TextChannel.edit` uses. For information on available - options, see the documentation on `disnake.TextChannel.edit`. While possible, position-related + same order of operations that `discord.TextChannel.edit` uses. For information on available + options, see the documentation on `discord.TextChannel.edit`. While possible, position-related options should be avoided, as it may interfere with the category move we perform. """ # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. @@ -161,7 +161,7 @@ async def move_to_bottom(channel: disnake.TextChannel, category_id: int, **optio await channel.edit(**options) -async def ensure_cached_claimant(channel: disnake.TextChannel) -> None: +async def ensure_cached_claimant(channel: discord.TextChannel) -> None: """ Ensure there is a claimant cached for each help channel. diff --git a/bot/exts/help_channels/_cog.py b/bot/exts/help_channels/_cog.py index fc55fa1df..d3d70e252 100644 --- a/bot/exts/help_channels/_cog.py +++ b/bot/exts/help_channels/_cog.py @@ -5,9 +5,9 @@ from datetime import timedelta from operator import attrgetter import arrow -import disnake -import disnake.abc -from disnake.ext import commands +import discord +import discord.abc +from discord.ext import commands from bot import constants from bot.bot import Bot @@ -66,16 +66,16 @@ class HelpChannels(commands.Cog): self.bot = bot self.scheduler = scheduling.Scheduler(self.__class__.__name__) - self.guild: disnake.Guild = None - self.cooldown_role: disnake.Role = None + self.guild: discord.Guild = None + self.cooldown_role: discord.Role = None # Categories - self.available_category: disnake.CategoryChannel = None - self.in_use_category: disnake.CategoryChannel = None - self.dormant_category: disnake.CategoryChannel = None + self.available_category: discord.CategoryChannel = None + self.in_use_category: discord.CategoryChannel = None + self.dormant_category: discord.CategoryChannel = None # Queues - self.channel_queue: asyncio.Queue[disnake.TextChannel] = None + self.channel_queue: asyncio.Queue[discord.TextChannel] = None self.name_queue: t.Deque[str] = None # Notifications @@ -84,7 +84,7 @@ class HelpChannels(commands.Cog): self.last_running_low_notification = arrow.get('1815-12-10T18:00:00.00000+00:00') self.dynamic_message: t.Optional[int] = None - self.available_help_channels: t.Set[disnake.TextChannel] = set() + self.available_help_channels: t.Set[discord.TextChannel] = set() # Asyncio stuff self.queue_tasks: t.List[asyncio.Task] = [] @@ -104,7 +104,7 @@ class HelpChannels(commands.Cog): @lock.lock_arg(NAMESPACE, "message", attrgetter("channel.id")) @lock.lock_arg(NAMESPACE, "message", attrgetter("author.id")) @lock.lock_arg(f"{NAMESPACE}.unclaim", "message", attrgetter("author.id"), wait=True) - async def claim_channel(self, message: disnake.Message) -> None: + async def claim_channel(self, message: discord.Message) -> None: """ Claim the channel in which the question `message` was sent. @@ -116,7 +116,7 @@ class HelpChannels(commands.Cog): try: await self.move_to_in_use(message.channel) - except disnake.DiscordServerError: + except discord.DiscordServerError: try: await message.channel.send( "The bot encountered a Discord API error while trying to move this channel, please try again later." @@ -133,14 +133,14 @@ class HelpChannels(commands.Cog): self.bot.stats.incr("help.failed_claims.500_on_move") return - embed = disnake.Embed( + embed = discord.Embed( description=f"Channel claimed by {message.author.mention}.", color=constants.Colours.bright_green, ) await message.channel.send(embed=embed) - # Handle odd edge case of `message.author` not being a `disnake.Member` (see bot#1839) - if not isinstance(message.author, disnake.Member): + # Handle odd edge case of `message.author` not being a `discord.Member` (see bot#1839) + if not isinstance(message.author, discord.Member): log.debug(f"{message.author} ({message.author.id}) isn't a member. Not giving cooldown role or sending DM.") else: await members.handle_role_change(message.author, message.author.add_roles, self.cooldown_role) @@ -189,7 +189,7 @@ class HelpChannels(commands.Cog): return queue - async def create_dormant(self) -> t.Optional[disnake.TextChannel]: + async def create_dormant(self) -> t.Optional[discord.TextChannel]: """ Create and return a new channel in the Dormant category. @@ -234,12 +234,12 @@ class HelpChannels(commands.Cog): May only be invoked by the channel's claimant or by staff. """ - # Don't use a disnake check because the check needs to fail silently. + # Don't use a discord.py check because the check needs to fail silently. if await self.close_check(ctx): log.info(f"Close command invoked by {ctx.author} in #{ctx.channel}.") await self.unclaim_channel(ctx.channel, closed_on=_channel.ClosingReason.COMMAND) - async def get_available_candidate(self) -> disnake.TextChannel: + async def get_available_candidate(self) -> discord.TextChannel: """ Return a dormant channel to turn into an available channel. @@ -313,7 +313,7 @@ class HelpChannels(commands.Cog): self.dormant_category = await channel_utils.get_or_fetch_channel( constants.Categories.help_dormant ) - except disnake.HTTPException: + except discord.HTTPException: log.exception("Failed to get a category; cog will be removed") self.bot.remove_cog(self.qualified_name) @@ -355,7 +355,7 @@ class HelpChannels(commands.Cog): log.info("Cog is ready!") - async def move_idle_channel(self, channel: disnake.TextChannel, has_task: bool = True) -> None: + async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None: """ Make the `channel` dormant if idle or schedule the move if still active. @@ -416,7 +416,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() - async def move_to_dormant(self, channel: disnake.TextChannel) -> None: + async def move_to_dormant(self, channel: discord.TextChannel) -> None: """Make the `channel` dormant.""" log.info(f"Moving #{channel} ({channel.id}) to the Dormant category.") await _channel.move_to_bottom( @@ -425,7 +425,7 @@ class HelpChannels(commands.Cog): ) log.trace(f"Sending dormant message for #{channel} ({channel.id}).") - embed = disnake.Embed( + embed = discord.Embed( description=_message.DORMANT_MSG.format( dormant=self.dormant_category.name, available=self.available_category.name, @@ -439,7 +439,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() @lock.lock_arg(f"{NAMESPACE}.unclaim", "channel") - async def unclaim_channel(self, channel: disnake.TextChannel, *, closed_on: _channel.ClosingReason) -> None: + async def unclaim_channel(self, channel: discord.TextChannel, *, closed_on: _channel.ClosingReason) -> None: """ Unclaim an in-use help `channel` to make it dormant. @@ -462,7 +462,7 @@ class HelpChannels(commands.Cog): async def _unclaim_channel( self, - channel: disnake.TextChannel, + channel: discord.TextChannel, claimant_id: t.Optional[int], closed_on: _channel.ClosingReason ) -> None: @@ -488,7 +488,7 @@ class HelpChannels(commands.Cog): if closed_on == _channel.ClosingReason.COMMAND: self.scheduler.cancel(channel.id) - async def move_to_in_use(self, channel: disnake.TextChannel) -> None: + async def move_to_in_use(self, channel: discord.TextChannel) -> None: """Make a channel in-use and schedule it to be made dormant.""" log.info(f"Moving #{channel} ({channel.id}) to the In Use category.") @@ -504,7 +504,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() @commands.Cog.listener() - async def on_message(self, message: disnake.Message) -> None: + async def on_message(self, message: discord.Message) -> None: """Move an available channel to the In Use category and replace it with a dormant one.""" if message.author.bot: return # Ignore messages sent by bots. @@ -520,7 +520,7 @@ class HelpChannels(commands.Cog): await _message.update_message_caches(message) @commands.Cog.listener() - async def on_message_delete(self, msg: disnake.Message) -> None: + async def on_message_delete(self, msg: discord.Message) -> None: """ Reschedule an in-use channel to become dormant sooner if the channel is empty. @@ -542,7 +542,7 @@ class HelpChannels(commands.Cog): delay = constants.HelpChannels.deleted_idle_minutes * 60 self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel)) - async def wait_for_dormant_channel(self) -> disnake.TextChannel: + async def wait_for_dormant_channel(self) -> discord.TextChannel: """Wait for a dormant channel to become available in the queue and return it.""" log.trace("Waiting for a dormant channel.") @@ -569,7 +569,7 @@ class HelpChannels(commands.Cog): await self.bot.http.edit_message( constants.Channels.how_to_get_help, self.dynamic_message, content=available_channels, files=None ) - except disnake.NotFound: + except discord.NotFound: pass else: return @@ -593,7 +593,7 @@ class HelpChannels(commands.Cog): @lock.lock_arg(NAMESPACE, "message", attrgetter("channel.id")) @lock.lock_arg(NAMESPACE, "message", attrgetter("author.id")) - async def notify_session_participants(self, message: disnake.Message) -> None: + async def notify_session_participants(self, message: discord.Message) -> None: """ Check if the message author meets the requirements to be notified. @@ -615,7 +615,7 @@ class HelpChannels(commands.Cog): if message.author.id not in session_participants: session_participants.add(message.author.id) - embed = disnake.Embed( + embed = discord.Embed( title="Currently Helping", description=f"You're currently helping in {message.channel.mention}", color=constants.Colours.bright_green, @@ -625,7 +625,7 @@ class HelpChannels(commands.Cog): try: await message.author.send(embed=embed) - except disnake.Forbidden: + except discord.Forbidden: log.trace( f"Failed to send helpdm message to {message.author.id}. DMs Closed/Blocked. " "Removing user from helpdm." diff --git a/bot/exts/help_channels/_message.py b/bot/exts/help_channels/_message.py index e08043694..7ceed9b4d 100644 --- a/bot/exts/help_channels/_message.py +++ b/bot/exts/help_channels/_message.py @@ -2,7 +2,7 @@ import textwrap import typing as t import arrow -import disnake +import discord from arrow import Arrow import bot @@ -41,7 +41,7 @@ through our guide for **[asking a good question]({ASKING_GUIDE_URL})**. """ -async def update_message_caches(message: disnake.Message) -> None: +async def update_message_caches(message: discord.Message) -> None: """Checks the source of new content in a help channel and updates the appropriate cache.""" channel = message.channel @@ -62,18 +62,18 @@ async def update_message_caches(message: disnake.Message) -> None: await _caches.non_claimant_last_message_times.set(channel.id, timestamp) -async def get_last_message(channel: disnake.TextChannel) -> t.Optional[disnake.Message]: +async def get_last_message(channel: discord.TextChannel) -> t.Optional[discord.Message]: """Return the last message sent in the channel or None if no messages exist.""" log.trace(f"Getting the last message in #{channel} ({channel.id}).") try: return await channel.history(limit=1).next() # noqa: B305 - except disnake.NoMoreItems: + except discord.NoMoreItems: log.debug(f"No last message available; #{channel} ({channel.id}) has no messages.") return None -async def is_empty(channel: disnake.TextChannel) -> bool: +async def is_empty(channel: discord.TextChannel) -> bool: """Return True if there's an AVAILABLE_MSG and the messages leading up are bot messages.""" log.trace(f"Checking if #{channel} ({channel.id}) is empty.") @@ -92,13 +92,13 @@ async def is_empty(channel: disnake.TextChannel) -> bool: return False -async def dm_on_open(message: disnake.Message) -> None: +async def dm_on_open(message: discord.Message) -> None: """ DM claimant with a link to the claimed channel's first message, with a 100 letter preview of the message. Does nothing if the user has DMs disabled. """ - embed = disnake.Embed( + embed = discord.Embed( title="Help channel opened", description=f"You claimed {message.channel.mention}.", colour=bot.constants.Colours.bright_green, @@ -118,7 +118,7 @@ async def dm_on_open(message: disnake.Message) -> None: try: await message.author.send(embed=embed) log.trace(f"Sent DM to {message.author.id} after claiming help channel.") - except disnake.errors.Forbidden: + except discord.errors.Forbidden: log.trace( f"Ignoring to send DM to {message.author.id} after claiming help channel: DMs disabled." ) @@ -146,7 +146,7 @@ async def notify_none_remaining(last_notification: Arrow) -> t.Optional[Arrow]: log.trace("Notifying about lack of channels.") mentions = " ".join(f"<@&{role}>" for role in constants.HelpChannels.notify_none_remaining_roles) - allowed_roles = [disnake.Object(id_) for id_ in constants.HelpChannels.notify_none_remaining_roles] + allowed_roles = [discord.Object(id_) for id_ in constants.HelpChannels.notify_none_remaining_roles] channel = bot.instance.get_channel(constants.HelpChannels.notify_channel) if channel is None: @@ -157,7 +157,7 @@ async def notify_none_remaining(last_notification: Arrow) -> t.Optional[Arrow]: f"{mentions} A new available help channel is needed but there " "are no more dormant ones. Consider freeing up some in-use channels manually by " f"using the `{constants.Bot.prefix}dormant` command within the channels.", - allowed_mentions=disnake.AllowedMentions(everyone=False, roles=allowed_roles) + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) ) except Exception: # Handle it here cause this feature isn't critical for the functionality of the system. @@ -213,18 +213,18 @@ async def notify_running_low(number_of_channels_left: int, last_notification: Ar return arrow.utcnow() -async def pin(message: disnake.Message) -> None: +async def pin(message: discord.Message) -> None: """Pin an initial question `message` and store it in a cache.""" if await pin_wrapper(message.id, message.channel, pin=True): await _caches.question_messages.set(message.channel.id, message.id) -async def send_available_message(channel: disnake.TextChannel) -> None: +async def send_available_message(channel: discord.TextChannel) -> None: """Send the available message by editing a dormant message or sending a new message.""" channel_info = f"#{channel} ({channel.id})" log.trace(f"Sending available message in {channel_info}.") - embed = disnake.Embed( + embed = discord.Embed( color=constants.Colours.bright_green, description=AVAILABLE_MSG, ) @@ -240,7 +240,7 @@ async def send_available_message(channel: disnake.TextChannel) -> None: await channel.send(embed=embed) -async def unpin(channel: disnake.TextChannel) -> None: +async def unpin(channel: discord.TextChannel) -> None: """Unpin the initial question message sent in `channel`.""" msg_id = await _caches.question_messages.pop(channel.id) if msg_id is None: @@ -249,19 +249,19 @@ async def unpin(channel: disnake.TextChannel) -> None: await pin_wrapper(msg_id, channel, pin=False) -def _match_bot_embed(message: t.Optional[disnake.Message], description: str) -> bool: +def _match_bot_embed(message: t.Optional[discord.Message], description: str) -> bool: """Return `True` if the bot's `message`'s embed description matches `description`.""" if not message or not message.embeds: return False bot_msg_desc = message.embeds[0].description - if bot_msg_desc is disnake.Embed.Empty: + if bot_msg_desc is discord.Embed.Empty: log.trace("Last message was a bot embed but it was empty.") return False return message.author == bot.instance.user and bot_msg_desc.strip() == description.strip() -async def pin_wrapper(msg_id: int, channel: disnake.TextChannel, *, pin: bool) -> bool: +async def pin_wrapper(msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool: """ Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. @@ -277,7 +277,7 @@ async def pin_wrapper(msg_id: int, channel: disnake.TextChannel, *, pin: bool) - try: await func(channel.id, msg_id) - except disnake.HTTPException as e: + except discord.HTTPException as e: if e.code == 10008: log.debug(f"Message {msg_id} in {channel_str} doesn't exist; can't {verb}.") else: diff --git a/bot/exts/help_channels/_name.py b/bot/exts/help_channels/_name.py index 50b250cb5..a9d9b2df1 100644 --- a/bot/exts/help_channels/_name.py +++ b/bot/exts/help_channels/_name.py @@ -3,7 +3,7 @@ import typing as t from collections import deque from pathlib import Path -import disnake +import discord from bot import constants from bot.exts.help_channels._channel import MAX_CHANNELS_PER_CATEGORY, get_category_channels @@ -12,7 +12,7 @@ from bot.log import get_logger log = get_logger(__name__) -def create_name_queue(*categories: disnake.CategoryChannel) -> deque: +def create_name_queue(*categories: discord.CategoryChannel) -> deque: """ Return a queue of food names to use for creating new channels. @@ -50,7 +50,7 @@ def _get_names() -> t.List[str]: return all_names[:count] -def _get_used_names(*categories: disnake.CategoryChannel) -> t.Set[str]: +def _get_used_names(*categories: discord.CategoryChannel) -> t.Set[str]: """Return names which are already being used by channels in `categories`.""" log.trace("Getting channel names which are already being used.") diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py index 68eb52a59..f2f29020f 100644 --- a/bot/exts/info/code_snippets.py +++ b/bot/exts/info/code_snippets.py @@ -4,9 +4,9 @@ import textwrap from typing import Any from urllib.parse import quote_plus -import disnake +import discord from aiohttp import ClientResponseError -from disnake.ext.commands import Cog +from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels @@ -241,7 +241,7 @@ class CodeSnippets(Cog): return '\n'.join(map(lambda x: x[1], sorted(all_snippets))) @Cog.listener() - async def on_message(self, message: disnake.Message) -> None: + async def on_message(self, message: discord.Message) -> None: """Checks if the message has a snippet link, removes the embed, then sends the snippet contents.""" if message.author.bot: return @@ -255,7 +255,7 @@ class CodeSnippets(Cog): if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: try: await message.edit(suppress=True) - except disnake.NotFound: + except discord.NotFound: # Don't send snippets if the original message was deleted. return diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index cf8c7d0be..a859d8cef 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -1,9 +1,9 @@ import time from typing import Optional -import disnake -from disnake import Message, RawMessageUpdateEvent -from disnake.ext.commands import Cog +import discord +from discord import Message, RawMessageUpdateEvent +from discord.ext.commands import Cog from bot import constants from bot.bot import Bot @@ -62,9 +62,9 @@ class CodeBlockCog(Cog, name="Code Block"): self.codeblock_message_ids = {} @staticmethod - def create_embed(instructions: str) -> disnake.Embed: + def create_embed(instructions: str) -> discord.Embed: """Return an embed which displays code block formatting `instructions`.""" - return disnake.Embed(description=instructions) + return discord.Embed(description=instructions) async def get_sent_instructions(self, payload: RawMessageUpdateEvent) -> Optional[Message]: """ @@ -78,11 +78,11 @@ class CodeBlockCog(Cog, name="Code Block"): try: return await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) - except disnake.NotFound: + except discord.NotFound: log.debug("Could not find instructions message; it was probably deleted.") return None - def is_on_cooldown(self, channel: disnake.TextChannel) -> bool: + def is_on_cooldown(self, channel: discord.TextChannel) -> bool: """ Return True if an embed was sent too recently for `channel`. @@ -93,7 +93,7 @@ class CodeBlockCog(Cog, name="Code Block"): cooldown = constants.CodeBlock.cooldown_seconds return (time.time() - self.channel_cooldowns.get(channel.id, 0)) < cooldown - def is_valid_channel(self, channel: disnake.TextChannel) -> bool: + def is_valid_channel(self, channel: discord.TextChannel) -> bool: """Return True if `channel` is a help channel, may be on a cooldown, or is whitelisted.""" log.trace(f"Checking if #{channel} qualifies for code block detection.") return ( @@ -102,7 +102,7 @@ class CodeBlockCog(Cog, name="Code Block"): or channel.id in constants.CodeBlock.channel_whitelist ) - async def send_instructions(self, message: disnake.Message, instructions: str) -> None: + async def send_instructions(self, message: discord.Message, instructions: str) -> None: """ Send an embed with `instructions` on fixing an incorrect code block in a `message`. @@ -119,7 +119,7 @@ class CodeBlockCog(Cog, name="Code Block"): # Increase amount of codeblock correction in stats self.bot.stats.incr("codeblock_corrections") - def should_parse(self, message: disnake.Message) -> bool: + def should_parse(self, message: discord.Message) -> bool: """ Return True if `message` should be parsed. @@ -185,5 +185,5 @@ class CodeBlockCog(Cog, name="Code Block"): else: log.info("Message edited but still has invalid code blocks; editing instructions.") await bot_message.edit(embed=self.create_embed(instructions)) - except disnake.NotFound: + except discord.NotFound: log.debug("Could not find instructions message; it was probably deleted.") diff --git a/bot/exts/info/doc/_batch_parser.py b/bot/exts/info/doc/_batch_parser.py index 487a0fd21..c27f28eac 100644 --- a/bot/exts/info/doc/_batch_parser.py +++ b/bot/exts/info/doc/_batch_parser.py @@ -7,7 +7,7 @@ from contextlib import suppress from operator import attrgetter from typing import Deque, Dict, List, NamedTuple, Optional, Union -import disnake +import discord from bs4 import BeautifulSoup import bot @@ -48,7 +48,7 @@ class StaleInventoryNotifier: if await self.symbol_counter.increment_for(doc_item) < 3: self._warned_urls.add(doc_item.url) await self._init_task - embed = disnake.Embed( + embed = discord.Embed( description=f"Doc item `{doc_item.symbol_id=}` present in loaded documentation inventories " f"not found on [site]({doc_item.url}), inventories may need to be refreshed." ) diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 77fc61389..4dc5276d9 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -9,8 +9,8 @@ from types import SimpleNamespace from typing import Dict, NamedTuple, Optional, Tuple, Union import aiohttp -import disnake -from disnake.ext import commands +import discord +from discord.ext import commands from bot.api import ResponseCodeError from bot.bot import Bot @@ -275,7 +275,7 @@ class DocCog(commands.Cog): return "Unable to parse the requested symbol." return markdown - async def create_symbol_embed(self, symbol_name: str) -> Optional[disnake.Embed]: + async def create_symbol_embed(self, symbol_name: str) -> Optional[discord.Embed]: """ Attempt to scrape and fetch the data for the given `symbol_name`, and build an embed from its contents. @@ -304,8 +304,8 @@ class DocCog(commands.Cog): else: footer_text = "" - embed = disnake.Embed( - title=disnake.utils.escape_markdown(symbol_name), + embed = discord.Embed( + title=discord.utils.escape_markdown(symbol_name), url=f"{doc_item.url}#{doc_item.symbol_id}", description=await self.get_symbol_markdown(doc_item) ) @@ -331,9 +331,9 @@ class DocCog(commands.Cog): !docs getdoc aiohttp.ClientSession """ if not symbol_name: - inventory_embed = disnake.Embed( + inventory_embed = discord.Embed( title=f"All inventories (`{len(self.base_urls)}` total)", - colour=disnake.Colour.blue() + colour=discord.Colour.blue() ) lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) @@ -355,7 +355,7 @@ class DocCog(commands.Cog): # Make sure that we won't cause a ghost-ping by deleting the message if not (ctx.message.mentions or ctx.message.role_mentions): - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await ctx.message.delete() await error_message.delete() @@ -449,7 +449,7 @@ class DocCog(commands.Cog): if removed := ", ".join(old_inventories - new_inventories): removed = "- " + removed - embed = disnake.Embed( + embed = discord.Embed( title="Inventories refreshed", description=f"```diff\n{added}\n{removed}```" if added or removed else "" ) diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py index 29d73c564..864e7edd2 100644 --- a/bot/exts/info/help.py +++ b/bot/exts/info/help.py @@ -6,8 +6,8 @@ from collections import namedtuple from contextlib import suppress from typing import List, Optional, Union -from disnake import ButtonStyle, Colour, Embed, Emoji, Interaction, PartialEmoji, ui -from disnake.ext.commands import Bot, Cog, Command, CommandError, Context, DisabledCommand, Group, HelpCommand +from discord import ButtonStyle, Colour, Embed, Emoji, Interaction, PartialEmoji, ui +from discord.ext.commands import Bot, Cog, Command, CommandError, Context, DisabledCommand, Group, HelpCommand from rapidfuzz import fuzz, process from rapidfuzz.utils import default_process diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 44a9b8f1a..e616b9208 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -6,9 +6,9 @@ from textwrap import shorten from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union import rapidfuzz -from disnake import AllowedMentions, Colour, Embed, Guild, Message, Role -from disnake.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role -from disnake.utils import escape_markdown +from discord import AllowedMentions, Colour, Embed, Guild, Message, Role +from discord.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role +from discord.utils import escape_markdown from bot import constants from bot.api import ResponseCodeError @@ -466,7 +466,7 @@ class Information(Cog): async def send_raw_content(self, ctx: Context, message: Message, json: bool = False) -> None: """ - Send information about the raw API response for a `disnake.Message`. + Send information about the raw API response for a `discord.Message`. If `json` is True, send the information in a copy-pasteable Python format. """ diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 08c693581..67866620b 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -3,8 +3,8 @@ from email.parser import HeaderParser from io import StringIO from typing import Dict, Optional, Tuple -from disnake import Colour, Embed -from disnake.ext.commands import Cog, Context, command +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, command from bot.bot import Bot from bot.constants import Keys diff --git a/bot/exts/info/pypi.py b/bot/exts/info/pypi.py index 0a7705eb0..dacf7bc12 100644 --- a/bot/exts/info/pypi.py +++ b/bot/exts/info/pypi.py @@ -3,9 +3,9 @@ import random import re from contextlib import suppress -from disnake import Embed, NotFound -from disnake.ext.commands import Cog, Context, command -from disnake.utils import escape_markdown +from discord import Embed, NotFound +from discord.ext.commands import Cog, Context, command +from discord.utils import escape_markdown from bot.bot import Bot from bot.constants import Colours, NEGATIVE_REPLIES, RedirectOutput diff --git a/bot/exts/info/python_news.py b/bot/exts/info/python_news.py index 7603b402b..2fad9d2ab 100644 --- a/bot/exts/info/python_news.py +++ b/bot/exts/info/python_news.py @@ -2,11 +2,11 @@ import re import typing as t from datetime import date, datetime -import disnake +import discord import feedparser from bs4 import BeautifulSoup -from disnake.ext.commands import Cog -from disnake.ext.tasks import loop +from discord.ext.commands import Cog +from discord.ext.tasks import loop from bot import constants from bot.bot import Bot @@ -40,7 +40,7 @@ class PythonNews(Cog): def __init__(self, bot: Bot): self.bot = bot self.webhook_names = {} - self.webhook: t.Optional[disnake.Webhook] = None + self.webhook: t.Optional[discord.Webhook] = None scheduling.create_task(self.get_webhook_names(), event_loop=self.bot.loop) scheduling.create_task(self.get_webhook_and_channel(), event_loop=self.bot.loop) @@ -119,7 +119,7 @@ class PythonNews(Cog): continue # Build an embed and send a webhook - embed = disnake.Embed( + embed = discord.Embed( title=self.escape_markdown(new["title"]), description=self.escape_markdown(new["summary"]), timestamp=new_datetime, @@ -189,7 +189,7 @@ class PythonNews(Cog): link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) # Build an embed and send a message to the webhook - embed = disnake.Embed( + embed = discord.Embed( title=self.escape_markdown(thread_information["subject"]), description=content[:1000] + f"... [continue reading]({link})" if len(content) > 1000 else content, timestamp=new_date, diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index 6305a9842..e3e7029ca 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -2,8 +2,8 @@ import inspect from pathlib import Path from typing import Optional, Tuple, Union -from disnake import Embed -from disnake.ext import commands +from discord import Embed +from discord.ext import commands from bot.bot import Bot from bot.constants import URLs diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py index 08422b38e..4d8bb645e 100644 --- a/bot/exts/info/stats.py +++ b/bot/exts/info/stats.py @@ -1,8 +1,8 @@ import string -from disnake import Member, Message -from disnake.ext.commands import Cog, Context -from disnake.ext.tasks import loop +from discord import Member, Message +from discord.ext.commands import Cog, Context +from discord.ext.tasks import loop from bot.bot import Bot from bot.constants import Categories, Channels, Guild diff --git a/bot/exts/info/subscribe.py b/bot/exts/info/subscribe.py index 0f285e0cb..eff0c13b8 100644 --- a/bot/exts/info/subscribe.py +++ b/bot/exts/info/subscribe.py @@ -4,9 +4,9 @@ import typing as t from dataclasses import dataclass import arrow -import disnake -from disnake.ext import commands -from disnake.interactions import Interaction +import discord +from discord.ext import commands +from discord.interactions import Interaction from bot import constants from bot.bot import Bot @@ -58,10 +58,10 @@ DELETE_MESSAGE_AFTER = 300 # Seconds log = get_logger(__name__) -class RoleButtonView(disnake.ui.View): +class RoleButtonView(discord.ui.View): """A list of SingleRoleButtons to show to the member.""" - def __init__(self, member: disnake.Member): + def __init__(self, member: discord.Member): super().__init__() self.interaction_owner = member @@ -76,12 +76,12 @@ class RoleButtonView(disnake.ui.View): return True -class SingleRoleButton(disnake.ui.Button): +class SingleRoleButton(discord.ui.Button): """A button that adds or removes a role from the member depending on it's current state.""" - ADD_STYLE = disnake.ButtonStyle.success - REMOVE_STYLE = disnake.ButtonStyle.red - UNAVAILABLE_STYLE = disnake.ButtonStyle.secondary + ADD_STYLE = discord.ButtonStyle.success + REMOVE_STYLE = discord.ButtonStyle.red + UNAVAILABLE_STYLE = discord.ButtonStyle.secondary LABEL_FORMAT = "{action} role {role_name}." CUSTOM_ID_FORMAT = "subscribe-{role_id}" @@ -104,7 +104,7 @@ class SingleRoleButton(disnake.ui.Button): async def callback(self, interaction: Interaction) -> None: """Update the member's role and change button text to reflect current text.""" - if isinstance(interaction.user, disnake.User): + if isinstance(interaction.user, discord.User): log.trace("User %s is not a member", interaction.user) await interaction.message.delete() self.view.stop() @@ -117,7 +117,7 @@ class SingleRoleButton(disnake.ui.Button): await members.handle_role_change( interaction.user, interaction.user.remove_roles if self.assigned else interaction.user.add_roles, - disnake.Object(self.role.role_id), + discord.Object(self.role.role_id), ) self.assigned = not self.assigned @@ -133,7 +133,7 @@ class SingleRoleButton(disnake.ui.Button): self.label = self.LABEL_FORMAT.format(action="Remove" if self.assigned else "Add", role_name=self.role.name) try: await interaction.message.edit(view=self.view) - except disnake.NotFound: + except discord.NotFound: log.debug("Subscribe message for %s removed before buttons could be updated", interaction.user) self.view.stop() @@ -145,7 +145,7 @@ class Subscribe(commands.Cog): self.bot = bot self.init_task = scheduling.create_task(self.init_cog(), event_loop=self.bot.loop) self.assignable_roles: list[AssignableRole] = [] - self.guild: disnake.Guild = None + self.guild: discord.Guild = None async def init_cog(self) -> None: """Initialise the cog by resolving the role IDs in ASSIGNABLE_ROLES to role names.""" diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index baeb21adb..f66237c8e 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -6,10 +6,10 @@ import time from pathlib import Path from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union -import disnake +import discord import frontmatter -from disnake import Embed, Member -from disnake.ext.commands import Cog, Context, group +from discord import Embed, Member +from discord.ext.commands import Cog, Context, group from bot import constants from bot.bot import Bot @@ -81,7 +81,7 @@ class Tag: self.content = post.content self.metadata = post.metadata self._restricted_to: set[int] = set(self.metadata.get("restricted_to", ())) - self._cooldowns: dict[disnake.TextChannel, float] = {} + self._cooldowns: dict[discord.TextChannel, float] = {} @property def embed(self) -> Embed: @@ -90,18 +90,18 @@ class Tag: embed.description = self.content return embed - def accessible_by(self, member: disnake.Member) -> bool: + def accessible_by(self, member: discord.Member) -> bool: """Check whether `member` can access the tag.""" return bool( not self._restricted_to or self._restricted_to & {role.id for role in member.roles} ) - def on_cooldown_in(self, channel: disnake.TextChannel) -> bool: + def on_cooldown_in(self, channel: discord.TextChannel) -> bool: """Check whether the tag is on cooldown in `channel`.""" return self._cooldowns.get(channel, float("-inf")) > time.time() - def set_cooldown_for(self, channel: disnake.TextChannel) -> None: + def set_cooldown_for(self, channel: discord.TextChannel) -> None: """Set the tag to be on cooldown in `channel` for `constants.Cooldowns.tags` seconds.""" self._cooldowns[channel] = time.time() + constants.Cooldowns.tags @@ -344,7 +344,7 @@ class Tags(Cog): return result_lines - def accessible_tags_in_group(self, group: str, user: disnake.Member) -> list[str]: + def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]: """Return a formatted list of tags in `group`, that are accessible by `user`.""" return sorted( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py index 2e274b23b..cb6836258 100644 --- a/bot/exts/moderation/clean.py +++ b/bot/exts/moderation/clean.py @@ -7,10 +7,10 @@ from datetime import datetime from itertools import takewhile from typing import Callable, Iterable, Literal, Optional, TYPE_CHECKING, Union -from disnake import Colour, Message, NotFound, TextChannel, User, errors -from disnake.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role -from disnake.ext.commands.converter import TextChannelConverter -from disnake.ext.commands.errors import BadArgument +from discord import Colour, Message, NotFound, TextChannel, User, errors +from discord.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role +from discord.ext.commands.converter import TextChannelConverter +from discord.ext.commands.errors import BadArgument from bot.bot import Bot from bot.constants import Channels, CleanMessages, Colours, Emojis, Event, Icons, MODERATION_ROLES @@ -459,7 +459,7 @@ class Clean(Cog): regex: Optional[Regex] = None, bots_only: Optional[bool] = False, *, - channels: CleanChannels = None # "Optional" with disnake silently ignores incorrect input. + channels: CleanChannels = None # "Optional" with discord.py silently ignores incorrect input. ) -> None: """ Commands for cleaning messages in channels. diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py index 58e049d4f..178be734d 100644 --- a/bot/exts/moderation/defcon.py +++ b/bot/exts/moderation/defcon.py @@ -8,9 +8,9 @@ import arrow from aioredis import RedisError from async_rediscache import RedisCache from dateutil.relativedelta import relativedelta -from disnake import Colour, Embed, Forbidden, Member, TextChannel, User -from disnake.ext import tasks -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord import Colour, Embed, Forbidden, Member, TextChannel, User +from discord.ext import tasks +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Event, Icons, MODERATION_ROLES, Roles diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py index 28e131eb4..566422e29 100644 --- a/bot/exts/moderation/dm_relay.py +++ b/bot/exts/moderation/dm_relay.py @@ -1,5 +1,5 @@ -import disnake -from disnake.ext.commands import Cog, Context, command, has_any_role +import discord +from discord.ext.commands import Cog, Context, command, has_any_role from bot.bot import Bot from bot.constants import Emojis, MODERATION_ROLES @@ -17,7 +17,7 @@ class DMRelay(Cog): self.bot = bot @command(aliases=("relay", "dr")) - async def dmrelay(self, ctx: Context, user: disnake.User, limit: int = 100) -> None: + async def dmrelay(self, ctx: Context, user: discord.User, limit: int = 100) -> None: """Relays the direct message history between the bot and given user.""" log.trace(f"Relaying DMs with {user.name} ({user.id})") diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index c4c03e546..b579416a6 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -4,9 +4,9 @@ from datetime import datetime from enum import Enum from typing import Optional -import disnake +import discord from async_rediscache import RedisCache -from disnake.ext.commands import Cog, Context, MessageConverter, MessageNotFound +from discord.ext.commands import Cog, Context, MessageConverter, MessageNotFound from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Guild, Roles, Webhooks @@ -52,10 +52,10 @@ ALL_SIGNALS: set[str] = {signal.value for signal in Signal} # An embed coupled with an optional file to be dispatched # If the file is not None, the embed attempts to show it in its body -FileEmbed = tuple[disnake.Embed, Optional[disnake.File]] +FileEmbed = tuple[discord.Embed, Optional[discord.File]] -async def download_file(attachment: disnake.Attachment) -> Optional[disnake.File]: +async def download_file(attachment: discord.Attachment) -> Optional[discord.File]: """ Download & return `attachment` file. @@ -65,13 +65,13 @@ async def download_file(attachment: disnake.Attachment) -> Optional[disnake.File log.debug(f"Attempting to download attachment: {attachment.filename}") try: return await attachment.to_file() - except (disnake.NotFound, disnake.Forbidden) as exc: + except (discord.NotFound, discord.Forbidden) as exc: log.debug(f"Failed to download attachment: {exc}") except Exception: log.exception("Failed to download attachment") -async def make_embed(incident: disnake.Message, outcome: Signal, actioned_by: disnake.Member) -> FileEmbed: +async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> FileEmbed: """ Create an embed representation of `incident` for the #incidents-archive channel. @@ -97,7 +97,7 @@ async def make_embed(incident: disnake.Message, outcome: Signal, actioned_by: di colour = Colours.soft_red footer = f"Rejected by {actioned_by}" - embed = disnake.Embed( + embed = discord.Embed( description=incident.content, timestamp=datetime.utcnow(), colour=colour, @@ -113,12 +113,12 @@ async def make_embed(incident: disnake.Message, outcome: Signal, actioned_by: di else: embed.set_author(name="[Failed to relay attachment]", url=attachment.proxy_url) # Embed links the file else: - file = disnake.utils.MISSING + file = discord.utils.MISSING return embed, file -def is_incident(message: disnake.Message) -> bool: +def is_incident(message: discord.Message) -> bool: """True if `message` qualifies as an incident, False otherwise.""" conditions = ( message.channel.id == Channels.incidents, # Message sent in #incidents @@ -129,12 +129,12 @@ def is_incident(message: disnake.Message) -> bool: return all(conditions) -def own_reactions(message: disnake.Message) -> set[str]: +def own_reactions(message: discord.Message) -> set[str]: """Get the set of reactions placed on `message` by the bot itself.""" return {str(reaction.emoji) for reaction in message.reactions if reaction.me} -def has_signals(message: disnake.Message) -> bool: +def has_signals(message: discord.Message) -> bool: """True if `message` already has all `Signal` reactions, False otherwise.""" return ALL_SIGNALS.issubset(own_reactions(message)) @@ -167,9 +167,9 @@ def shorten_text(text: str) -> str: return text -async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[disnake.Embed]: +async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[discord.Embed]: """ - Create an embedded representation of the Discord message link contained in the incident report. + Create an embedded representation of the discord message link contained in the incident report. The Embed would contain the following information --> Author: @Jason Terror ♦ (736234578745884682) @@ -179,23 +179,23 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d embed = None try: - message: disnake.Message = await MessageConverter().convert(ctx, message_link) + message: discord.Message = await MessageConverter().convert(ctx, message_link) except MessageNotFound: mod_logs_channel = ctx.bot.get_channel(Channels.mod_log) - last_100_logs: list[disnake.Message] = await mod_logs_channel.history(limit=100).flatten() + last_100_logs: list[discord.Message] = await mod_logs_channel.history(limit=100).flatten() for log_entry in last_100_logs: if not log_entry.embeds: continue - log_embed: disnake.Embed = log_entry.embeds[0] + log_embed: discord.Embed = log_entry.embeds[0] if ( log_embed.author.name == "Message deleted" and f"[Jump to message]({message_link})" in log_embed.description ): - embed = disnake.Embed( - colour=disnake.Colour.dark_gold(), + embed = discord.Embed( + colour=discord.Colour.dark_gold(), title="Deleted Message Link", description=( f"Found <#{Channels.mod_log}> entry for deleted message: " @@ -203,12 +203,12 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d ) ) if not embed: - embed = disnake.Embed( - colour=disnake.Colour.red(), + embed = discord.Embed( + colour=discord.Colour.red(), title="Bad Message Link", description=f"Message {message_link} not found." ) - except disnake.DiscordException as e: + except discord.DiscordException as e: log.exception(f"Failed to make message link embed for '{message_link}', raised exception: {e}") else: channel = message.channel @@ -219,12 +219,12 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d ) return - embed = disnake.Embed( - colour=disnake.Colour.gold(), + embed = discord.Embed( + colour=discord.Colour.gold(), description=( f"**Author:** {format_user(message.author)}\n" f"**Channel:** {channel.mention} ({channel.category}" - f"{f'/#{channel.parent.name} - ' if isinstance(channel, disnake.Thread) else '/#'}" + f"{f'/#{channel.parent.name} - ' if isinstance(channel, discord.Thread) else '/#'}" f"{channel.name})\n" ), timestamp=message.created_at @@ -242,7 +242,7 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d return embed -async def add_signals(incident: disnake.Message) -> None: +async def add_signals(incident: discord.Message) -> None: """ Add `Signal` member emoji to `incident` as reactions. @@ -257,7 +257,7 @@ async def add_signals(incident: disnake.Message) -> None: log.trace(f"Adding reaction: {signal_emoji}") try: await incident.add_reaction(signal_emoji.value) - except disnake.NotFound as e: + except discord.NotFound as e: if e.code != 10008: raise @@ -300,7 +300,7 @@ class Incidents(Cog): """ # This dictionary maps an incident report message to the message link embed's ID - # RedisCache[disnake.Message.id, disnake.Message.id] + # RedisCache[discord.Message.id, discord.Message.id] message_link_embeds_cache = RedisCache() def __init__(self, bot: Bot) -> None: @@ -319,7 +319,7 @@ class Incidents(Cog): try: self.incidents_webhook = await self.bot.fetch_webhook(Webhooks.incidents) - except disnake.HTTPException: + except discord.HTTPException: log.error(f"Failed to fetch incidents webhook with id `{Webhooks.incidents}`.") async def crawl_incidents(self) -> None: @@ -335,7 +335,7 @@ class Incidents(Cog): Behaviour is configured by: `CRAWL_LIMIT`, `CRAWL_SLEEP`. """ await self.bot.wait_until_guild_available() - incidents: disnake.TextChannel = self.bot.get_channel(Channels.incidents) + incidents: discord.TextChannel = self.bot.get_channel(Channels.incidents) log.debug(f"Crawling messages in #incidents: {CRAWL_LIMIT=}, {CRAWL_SLEEP=}") async for message in incidents.history(limit=CRAWL_LIMIT): @@ -353,7 +353,7 @@ class Incidents(Cog): log.debug("Crawl task finished!") - async def archive(self, incident: disnake.Message, outcome: Signal, actioned_by: disnake.Member) -> bool: + async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: """ Relay an embed representation of `incident` to the #incidents-archive channel. @@ -392,7 +392,7 @@ class Incidents(Cog): log.trace("Message archived successfully!") return True - def make_confirmation_task(self, incident: disnake.Message, timeout: int = 5) -> asyncio.Task: + def make_confirmation_task(self, incident: discord.Message, timeout: int = 5) -> asyncio.Task: """ Create a task to wait `timeout` seconds for `incident` to be deleted. @@ -401,13 +401,13 @@ class Incidents(Cog): """ log.trace(f"Confirmation task will wait {timeout=} seconds for {incident.id=} to be deleted") - def check(payload: disnake.RawReactionActionEvent) -> bool: + def check(payload: discord.RawReactionActionEvent) -> bool: return payload.message_id == incident.id coroutine = self.bot.wait_for(event="raw_message_delete", check=check, timeout=timeout) return scheduling.create_task(coroutine, event_loop=self.bot.loop) - async def process_event(self, reaction: str, incident: disnake.Message, member: disnake.Member) -> None: + async def process_event(self, reaction: str, incident: discord.Message, member: discord.Member) -> None: """ Process a `reaction_add` event in #incidents. @@ -430,7 +430,7 @@ class Incidents(Cog): log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals") try: await incident.remove_reaction(reaction, member) - except disnake.NotFound: + except discord.NotFound: log.trace("Couldn't remove reaction because the reaction or its message was deleted") return @@ -440,7 +440,7 @@ class Incidents(Cog): log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") try: await incident.remove_reaction(reaction, member) - except disnake.NotFound: + except discord.NotFound: log.trace("Couldn't remove reaction because the reaction or its message was deleted") return @@ -461,7 +461,7 @@ class Incidents(Cog): log.trace("Deleting original message") try: await incident.delete() - except disnake.NotFound: + except discord.NotFound: log.trace("Couldn't delete message because it was already deleted") log.trace(f"Awaiting deletion confirmation: {timeout=} seconds") @@ -476,9 +476,9 @@ class Incidents(Cog): # Deletes the message link embeds found in cache from the channel and cache. await self.delete_msg_link_embed(incident.id) - async def resolve_message(self, message_id: int) -> Optional[disnake.Message]: + async def resolve_message(self, message_id: int) -> Optional[discord.Message]: """ - Get `disnake.Message` for `message_id` from cache, or API. + Get `discord.Message` for `message_id` from cache, or API. We first look into the local cache to see if the message is present. @@ -491,7 +491,7 @@ class Incidents(Cog): """ await self.bot.wait_until_guild_available() # First make sure that the cache is ready log.trace(f"Resolving message for: {message_id=}") - message: Optional[disnake.Message] = self.bot._connection._get_message(message_id) + message: Optional[discord.Message] = self.bot._connection._get_message(message_id) if message is not None: log.trace("Message was found in cache") @@ -500,7 +500,7 @@ class Incidents(Cog): log.trace("Message not found, attempting to fetch") try: message = await self.bot.get_channel(Channels.incidents).fetch_message(message_id) - except disnake.NotFound: + except discord.NotFound: log.trace("Message doesn't exist, it was likely already relayed") except Exception: log.exception(f"Failed to fetch message {message_id}!") @@ -509,7 +509,7 @@ class Incidents(Cog): return message @Cog.listener() - async def on_raw_reaction_add(self, payload: disnake.RawReactionActionEvent) -> None: + async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None: """ Pre-process `payload` and pass it to `process_event` if appropriate. @@ -521,11 +521,11 @@ class Incidents(Cog): Next, we acquire `event_lock` - to prevent racing, events are processed one at a time. - Once we have the lock, the `disnake.Message` object for this event must be resolved. + Once we have the lock, the `discord.Message` object for this event must be resolved. If the lock was previously held by an event which successfully relayed the incident, this will fail and we abort the current event. - Finally, with both the lock and the `disnake.Message` instance in our hands, we delegate + Finally, with both the lock and the `discord.Message` instance in our hands, we delegate to `process_event` to handle the event. The justification for using a raw listener is the need to receive events for messages @@ -554,7 +554,7 @@ class Incidents(Cog): log.trace("Releasing event lock") @Cog.listener() - async def on_message(self, message: disnake.Message) -> None: + async def on_message(self, message: discord.Message) -> None: """ Pass `message` to `add_signals` and `extract_message_links` if it satisfies `is_incident`. @@ -575,7 +575,7 @@ class Incidents(Cog): await self.send_message_link_embeds(embed_list, message, self.incidents_webhook) @Cog.listener() - async def on_raw_message_delete(self, payload: disnake.RawMessageDeleteEvent) -> None: + async def on_raw_message_delete(self, payload: discord.RawMessageDeleteEvent) -> None: """ Delete message link embeds for `payload.message_id`. @@ -584,7 +584,7 @@ class Incidents(Cog): if self.incidents_webhook: await self.delete_msg_link_embed(payload.message_id) - async def extract_message_links(self, message: disnake.Message) -> Optional[list[disnake.Embed]]: + async def extract_message_links(self, message: discord.Message) -> Optional[list[discord.Embed]]: """ Check if there's any message links in the text content. @@ -615,8 +615,8 @@ class Incidents(Cog): async def send_message_link_embeds( self, webhook_embed_list: list, - message: disnake.Message, - webhook: disnake.Webhook, + message: discord.Message, + webhook: discord.Webhook, ) -> Optional[int]: """ Send message link embeds to #incidents channel. @@ -634,7 +634,7 @@ class Incidents(Cog): avatar_url=message.author.display_avatar.url, wait=True, ) - except disnake.DiscordException: + except discord.DiscordException: log.exception( f"Failed to send message link embed {message.id} to #incidents." ) @@ -651,7 +651,7 @@ class Incidents(Cog): if webhook_msg_id: try: await self.incidents_webhook.delete_message(webhook_msg_id) - except disnake.errors.NotFound: + except discord.errors.NotFound: log.trace(f"Incidents message link embed (`{webhook_msg_id}`) has already been deleted, skipping.") await self.message_link_embeds_cache.delete(message_id) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 8107b502a..2fc54856f 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -5,8 +5,8 @@ from gettext import ngettext import arrow import dateutil.parser -import disnake -from disnake.ext.commands import Context +import discord +from discord.ext.commands import Context from bot import constants from bot.api import ResponseCodeError @@ -101,7 +101,7 @@ class InfractionScheduler: # Allowing mod log since this is a passive action that should be logged. try: await apply_coro - except disnake.HTTPException as e: + except discord.HTTPException as e: # When user joined and then right after this left again before action completed, this can't apply roles if e.code == 10007 or e.status == 404: log.info( @@ -200,7 +200,7 @@ class InfractionScheduler: if expiry: # Schedule the expiration of the infraction. self.schedule_expiration(infraction) - except disnake.HTTPException as e: + except discord.HTTPException as e: # Accordingly display that applying the infraction failed. # Don't use ctx.message.author; antispam only patches ctx.author. confirm_msg = ":x: failed to apply" @@ -209,7 +209,7 @@ class InfractionScheduler: log_title = "failed to apply" log_msg = f"Failed to apply {' '.join(infr_type.split('_'))} infraction #{id_} to {user}" - if isinstance(e, disnake.Forbidden): + if isinstance(e, discord.Forbidden): log.warning(f"{log_msg}: bot lacks permissions.") elif e.code == 10007 or e.status == 404: log.info( @@ -396,11 +396,11 @@ class InfractionScheduler: raise ValueError( f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" ) - except disnake.Forbidden: + except discord.Forbidden: log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" log_content = mod_role.mention - except disnake.HTTPException as e: + except discord.HTTPException as e: if e.code == 10007 or e.status == 404: log.info( f"Can't pardon {infraction['type']} for user {infraction['user']} because user left the guild." diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 36e818ec6..c1be18362 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -2,8 +2,8 @@ import typing as t from datetime import datetime import arrow -import disnake -from disnake.ext.commands import Context +import discord +from discord.ext.commands import Context import bot from bot.api import ResponseCodeError @@ -86,7 +86,7 @@ async def post_infraction( dm_sent: bool = False, ) -> t.Optional[dict]: """Posts an infraction to the API.""" - if isinstance(user, (disnake.Member, disnake.User)) and user.bot: + if isinstance(user, (discord.Member, discord.User)) and user.bot: log.trace(f"Posting of {infr_type} infraction for {user} to the API aborted. User is a bot.") raise InvalidInfractedUserError(user) @@ -209,7 +209,7 @@ async def notify_infraction( text += INFRACTION_APPEAL_SERVER_FOOTER if infraction["type"] == 'ban' else INFRACTION_APPEAL_MODMAIL_FOOTER - embed = disnake.Embed( + embed = discord.Embed( description=text, colour=Colours.soft_red ) @@ -238,7 +238,7 @@ async def notify_pardon( """DM a user about their pardoned infraction and return True if the DM is successful.""" log.trace(f"Sending {user} a DM about their pardoned infraction.") - embed = disnake.Embed( + embed = discord.Embed( description=content, colour=Colours.soft_green ) @@ -248,7 +248,7 @@ async def notify_pardon( return await send_private_embed(user, embed) -async def send_private_embed(user: MemberOrUser, embed: disnake.Embed) -> bool: +async def send_private_embed(user: MemberOrUser, embed: discord.Embed) -> bool: """ A helper method for sending an embed to a user's DMs. @@ -257,7 +257,7 @@ async def send_private_embed(user: MemberOrUser, embed: disnake.Embed) -> bool: try: await user.send(embed=embed) return True - except (disnake.HTTPException, disnake.Forbidden, disnake.NotFound): + except (discord.HTTPException, discord.Forbidden, discord.NotFound): log.debug( f"Infraction-related information could not be sent to user {user} ({user.id}). " "The user either could not be retrieved or probably disabled their DMs." diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 5ff56abde..af42ab1b8 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -1,10 +1,10 @@ import textwrap import typing as t -import disnake -from disnake import Member -from disnake.ext import commands -from disnake.ext.commands import Context, command +import discord +from discord import Member +from discord.ext import commands +from discord.ext.commands import Context, command from bot import constants from bot.bot import Bot @@ -35,8 +35,8 @@ class Infractions(InfractionScheduler, commands.Cog): super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning", "voice_mute"}) self.category = "Moderation" - self._muted_role = disnake.Object(constants.Roles.muted) - self._voice_verified_role = disnake.Object(constants.Roles.voice_verified) + self._muted_role = discord.Object(constants.Roles.muted) + self._voice_verified_role = discord.Object(constants.Roles.voice_verified) @commands.Cog.listener() async def on_member_join(self, member: Member) -> None: @@ -123,7 +123,7 @@ class Infractions(InfractionScheduler, commands.Cog): log.error("Failed to apply ban to user %d", user.id) return - # Calling commands directly skips disnake's convertors, so we need to convert args manually. + # Calling commands directly skips Discord.py's convertors, so we need to convert args manually. clean_time = await Age().convert(ctx, "1h") log_url = await clean_cog._clean_messages( @@ -494,7 +494,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def pardon_mute( self, user_id: int, - guild: disnake.Guild, + guild: discord.Guild, reason: t.Optional[str], *, notify: bool = True @@ -525,16 +525,16 @@ class Infractions(InfractionScheduler, commands.Cog): return log_text - async def pardon_ban(self, user_id: int, guild: disnake.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: """Remove a user's ban on the Discord guild and return a log dict.""" - user = disnake.Object(user_id) + user = discord.Object(user_id) log_text = {} self.mod_log.ignore(Event.member_unban, user_id) try: await guild.unban(user, reason=reason) - except disnake.NotFound: + except discord.NotFound: log.info(f"Failed to unban user {user_id}: no active ban found on Discord") log_text["Note"] = "No active ban found on Discord." @@ -543,7 +543,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def pardon_voice_mute( self, user_id: int, - guild: disnake.Guild, + guild: discord.Guild, *, notify: bool = True ) -> t.Dict[str, str]: @@ -597,7 +597,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def cog_command_error(self, ctx: Context, error: Exception) -> None: """Send a notification to the invoking context on a Union failure.""" if isinstance(error, commands.BadUnionArgument): - if disnake.User in error.converters or Member in error.converters: + if discord.User in error.converters or Member in error.converters: await ctx.send(str(error.errors[0])) error.handled = True diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index 25420cd7a..c12dff928 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -1,10 +1,10 @@ import textwrap import typing as t -import disnake -from disnake.ext import commands -from disnake.ext.commands import Context -from disnake.utils import escape_markdown +import discord +from discord.ext import commands +from discord.ext.commands import Context +from discord.utils import escape_markdown from bot import constants from bot.bot import Bot @@ -52,9 +52,9 @@ class ModManagement(commands.Cog): await ctx.send_help(ctx.command) return - embed = disnake.Embed( + embed = discord.Embed( title=f"Infraction #{infraction['id']}", - colour=disnake.Colour.orange() + colour=discord.Colour.orange() ) await self.send_infraction_list(ctx, embed, [infraction]) @@ -222,7 +222,7 @@ class ModManagement(commands.Cog): await self.mod_log.send_log_message( icon_url=constants.Icons.pencil, - colour=disnake.Colour.og_blurple(), + colour=discord.Colour.og_blurple(), title="Infraction edited", thumbnail=thumbnail, text=textwrap.dedent(f""" @@ -240,21 +240,21 @@ class ModManagement(commands.Cog): async def infraction_search_group(self, ctx: Context, query: t.Union[UnambiguousUser, Snowflake, str]) -> None: """Searches for infractions in the database.""" if isinstance(query, int): - await self.search_user(ctx, disnake.Object(query)) + await self.search_user(ctx, discord.Object(query)) elif isinstance(query, str): await self.search_reason(ctx, query) else: await self.search_user(ctx, query) @infraction_search_group.command(name="user", aliases=("member", "userid")) - async def search_user(self, ctx: Context, user: t.Union[MemberOrUser, disnake.Object]) -> None: + async def search_user(self, ctx: Context, user: t.Union[MemberOrUser, discord.Object]) -> None: """Search for infractions by member.""" infraction_list = await self.bot.api_client.get( 'bot/infractions/expanded', params={'user__id': str(user.id)} ) - if isinstance(user, (disnake.Member, disnake.User)): + if isinstance(user, (discord.Member, discord.User)): user_str = escape_markdown(str(user)) else: if infraction_list: @@ -264,9 +264,9 @@ class ModManagement(commands.Cog): user_str = str(user.id) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = disnake.Embed( + embed = discord.Embed( title=f"Infractions for {user_str} ({formatted_infraction_count} total)", - colour=disnake.Colour.orange() + colour=discord.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -279,9 +279,9 @@ class ModManagement(commands.Cog): ) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = disnake.Embed( + embed = discord.Embed( title=f"Infractions matching `{reason}` ({formatted_infraction_count} total)", - colour=disnake.Colour.orange() + colour=discord.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -319,9 +319,9 @@ class ModManagement(commands.Cog): ) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = disnake.Embed( + embed = discord.Embed( title=f"Infractions by {actor} ({formatted_infraction_count} total)", - colour=disnake.Colour.orange() + colour=discord.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -344,7 +344,7 @@ class ModManagement(commands.Cog): async def send_infraction_list( self, ctx: Context, - embed: disnake.Embed, + embed: discord.Embed, infractions: t.Iterable[t.Dict[str, t.Any]] ) -> None: """Send a paginated embed of infractions for the specified user.""" @@ -433,7 +433,7 @@ class ModManagement(commands.Cog): async def cog_command_error(self, ctx: Context, error: commands.CommandError) -> None: """Handles errors for commands within this cog.""" if isinstance(error, commands.BadUnionArgument): - if disnake.User in error.converters: + if discord.User in error.converters: await ctx.send(str(error.errors[0])) error.handled = True diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 41ba52580..b91a5edba 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -4,9 +4,9 @@ import textwrap import typing as t from pathlib import Path -from disnake import Embed, Member -from disnake.ext.commands import Cog, Context, command, has_any_role -from disnake.utils import escape_markdown +from discord import Embed, Member +from discord.ext.commands import Cog, Context, command, has_any_role +from discord.utils import escape_markdown from bot import constants from bot.bot import Bot diff --git a/bot/exts/moderation/metabase.py b/bot/exts/moderation/metabase.py index 482d49b83..ce9c220b3 100644 --- a/bot/exts/moderation/metabase.py +++ b/bot/exts/moderation/metabase.py @@ -8,7 +8,7 @@ import arrow from aiohttp.client_exceptions import ClientResponseError from arrow import Arrow from async_rediscache import RedisCache -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Metabase as MetabaseConfig, Roles diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index a96638e53..32ea0dc6a 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -5,13 +5,13 @@ import typing as t from datetime import datetime, timezone from itertools import zip_longest -import disnake +import discord from dateutil.relativedelta import relativedelta from deepdiff import DeepDiff -from disnake import Colour, Message, Thread -from disnake.abc import GuildChannel -from disnake.ext.commands import Cog, Context -from disnake.utils import escape_markdown +from discord import Colour, Message, Thread +from discord.abc import GuildChannel +from discord.ext.commands import Cog, Context +from discord.utils import escape_markdown from bot.bot import Bot from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs @@ -21,7 +21,7 @@ from bot.utils.messages import format_user log = get_logger(__name__) -GUILD_CHANNEL = t.Union[disnake.CategoryChannel, disnake.TextChannel, disnake.VoiceChannel] +GUILD_CHANNEL = t.Union[discord.CategoryChannel, discord.TextChannel, discord.VoiceChannel] CHANNEL_CHANGES_UNSUPPORTED = ("permissions",) CHANNEL_CHANGES_SUPPRESSED = ("_overwrites", "position") @@ -45,7 +45,7 @@ class ModLog(Cog, name="ModLog"): async def upload_log( self, - messages: t.Iterable[disnake.Message], + messages: t.Iterable[discord.Message], actor_id: int, attachments: t.Iterable[t.List[str]] = None ) -> str: @@ -83,22 +83,22 @@ class ModLog(Cog, name="ModLog"): async def send_log_message( self, icon_url: t.Optional[str], - colour: t.Union[disnake.Colour, int], + colour: t.Union[discord.Colour, int], title: t.Optional[str], text: str, - thumbnail: t.Optional[t.Union[str, disnake.Asset]] = None, + thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, channel_id: int = Channels.mod_log, ping_everyone: bool = False, - files: t.Optional[t.List[disnake.File]] = None, + files: t.Optional[t.List[discord.File]] = None, content: t.Optional[str] = None, - additional_embeds: t.Optional[t.List[disnake.Embed]] = None, + additional_embeds: t.Optional[t.List[discord.Embed]] = None, timestamp_override: t.Optional[datetime] = None, footer: t.Optional[str] = None, ) -> Context: """Generate log embed and send to logging channel.""" await self.bot.wait_until_guild_available() # Truncate string directly here to avoid removing newlines - embed = disnake.Embed( + embed = discord.Embed( description=text[:4093] + "..." if len(text) > 4096 else text ) @@ -143,10 +143,10 @@ class ModLog(Cog, name="ModLog"): if channel.guild.id != GuildConstant.id: return - if isinstance(channel, disnake.CategoryChannel): + if isinstance(channel, discord.CategoryChannel): title = "Category created" message = f"{channel.name} (`{channel.id}`)" - elif isinstance(channel, disnake.VoiceChannel): + elif isinstance(channel, discord.VoiceChannel): title = "Voice channel created" if channel.category: @@ -169,14 +169,14 @@ class ModLog(Cog, name="ModLog"): if channel.guild.id != GuildConstant.id: return - if isinstance(channel, disnake.CategoryChannel): + if isinstance(channel, discord.CategoryChannel): title = "Category deleted" - elif isinstance(channel, disnake.VoiceChannel): + elif isinstance(channel, discord.VoiceChannel): title = "Voice channel deleted" else: title = "Text channel deleted" - if channel.category and not isinstance(channel, disnake.CategoryChannel): + if channel.category and not isinstance(channel, discord.CategoryChannel): message = f"{channel.category}/{channel.name} (`{channel.id}`)" else: message = f"{channel.name} (`{channel.id}`)" @@ -256,7 +256,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_create(self, role: disnake.Role) -> None: + async def on_guild_role_create(self, role: discord.Role) -> None: """Log role create event to mod log.""" if role.guild.id != GuildConstant.id: return @@ -267,7 +267,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_delete(self, role: disnake.Role) -> None: + async def on_guild_role_delete(self, role: discord.Role) -> None: """Log role delete event to mod log.""" if role.guild.id != GuildConstant.id: return @@ -278,7 +278,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_update(self, before: disnake.Role, after: disnake.Role) -> None: + async def on_guild_role_update(self, before: discord.Role, after: discord.Role) -> None: """Log role update event to mod log.""" if before.guild.id != GuildConstant.id: return @@ -331,7 +331,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_update(self, before: disnake.Guild, after: disnake.Guild) -> None: + async def on_guild_update(self, before: discord.Guild, after: discord.Guild) -> None: """Log guild update event to mod log.""" if before.id != GuildConstant.id: return @@ -382,7 +382,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_ban(self, guild: disnake.Guild, member: disnake.Member) -> None: + async def on_member_ban(self, guild: discord.Guild, member: discord.Member) -> None: """Log ban event to user log.""" if guild.id != GuildConstant.id: return @@ -399,7 +399,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_join(self, member: disnake.Member) -> None: + async def on_member_join(self, member: discord.Member) -> None: """Log member join event to user log.""" if member.guild.id != GuildConstant.id: return @@ -420,7 +420,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_remove(self, member: disnake.Member) -> None: + async def on_member_remove(self, member: discord.Member) -> None: """Log member leave event to user log.""" if member.guild.id != GuildConstant.id: return @@ -437,7 +437,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_unban(self, guild: disnake.Guild, member: disnake.User) -> None: + async def on_member_unban(self, guild: discord.Guild, member: discord.User) -> None: """Log member unban event to mod log.""" if guild.id != GuildConstant.id: return @@ -454,7 +454,7 @@ class ModLog(Cog, name="ModLog"): ) @staticmethod - def get_role_diff(before: t.List[disnake.Role], after: t.List[disnake.Role]) -> t.List[str]: + def get_role_diff(before: t.List[discord.Role], after: t.List[discord.Role]) -> t.List[str]: """Return a list of strings describing the roles added and removed.""" changes = [] before_roles = set(before) @@ -469,7 +469,7 @@ class ModLog(Cog, name="ModLog"): return changes @Cog.listener() - async def on_member_update(self, before: disnake.Member, after: disnake.Member) -> None: + async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: """Log member update event to user log.""" if before.guild.id != GuildConstant.id: return @@ -552,7 +552,7 @@ class ModLog(Cog, name="ModLog"): return channel.id in GuildConstant.modlog_blacklist - async def log_cached_deleted_message(self, message: disnake.Message) -> None: + async def log_cached_deleted_message(self, message: discord.Message) -> None: """ Log the message's details to message change log. @@ -608,7 +608,7 @@ class ModLog(Cog, name="ModLog"): channel_id=Channels.message_log ) - async def log_uncached_deleted_message(self, event: disnake.RawMessageDeleteEvent) -> None: + async def log_uncached_deleted_message(self, event: discord.RawMessageDeleteEvent) -> None: """ Log the message's details to message change log. @@ -648,7 +648,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_raw_message_delete(self, event: disnake.RawMessageDeleteEvent) -> None: + async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: """Log message deletions to message change log.""" if event.cached_message is not None: await self.log_cached_deleted_message(event.cached_message) @@ -656,7 +656,7 @@ class ModLog(Cog, name="ModLog"): await self.log_uncached_deleted_message(event) @Cog.listener() - async def on_message_edit(self, msg_before: disnake.Message, msg_after: disnake.Message) -> None: + async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: """Log message edit event to message change log.""" if self.is_message_blacklisted(msg_before): return @@ -727,7 +727,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_raw_message_edit(self, event: disnake.RawMessageUpdateEvent) -> None: + async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: """Log raw message edit event to message change log.""" if event.guild_id is None: return # ignore DM edits @@ -736,7 +736,7 @@ class ModLog(Cog, name="ModLog"): try: channel = self.bot.get_channel(int(event.data["channel_id"])) message = await channel.fetch_message(event.message_id) - except disnake.NotFound: # Was deleted before we got the event + except discord.NotFound: # Was deleted before we got the event return if self.is_message_blacklisted(message): @@ -860,9 +860,9 @@ class ModLog(Cog, name="ModLog"): @Cog.listener() async def on_voice_state_update( self, - member: disnake.Member, - before: disnake.VoiceState, - after: disnake.VoiceState + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState ) -> None: """Log member voice state changes to the voice log channel.""" if ( diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index 51d161d84..b5cd29b12 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -4,8 +4,8 @@ import datetime import arrow from async_rediscache import RedisCache from dateutil.parser import isoparse, parse as dateutil_parse -from disnake import Embed, Member -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord import Embed, Member +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles @@ -22,12 +22,12 @@ MAXIMUM_WORK_LIMIT = 16 class ModPings(Cog): """Commands for a moderator to turn moderator pings on and off.""" - # RedisCache[disnake.Member.id, 'Naïve ISO 8601 string'] + # RedisCache[discord.Member.id, 'Naïve ISO 8601 string'] # The cache's keys are mods who have pings off. # The cache's values are the times when the role should be re-applied to them, stored in ISO format. pings_off_mods = RedisCache() - # RedisCache[disnake.Member.id, 'start timestamp|total worktime in seconds'] + # RedisCache[discord.Member.id, 'start timestamp|total worktime in seconds'] # The cache's keys are mod's ID # The cache's values are their pings on schedule timestamp and the total seconds (work time) until pings off modpings_schedule = RedisCache() diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 0b677dddb..511520252 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -5,10 +5,10 @@ from datetime import datetime, timedelta, timezone from typing import Optional, OrderedDict, Union from async_rediscache import RedisCache -from disnake import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel -from disnake.ext import commands, tasks -from disnake.ext.commands import Context -from disnake.utils import MISSING +from discord import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel +from discord.ext import commands, tasks +from discord.ext.commands import Context +from discord.utils import MISSING from bot import constants from bot.bot import Bot diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py index 7fcafc01c..b6a771441 100644 --- a/bot/exts/moderation/slowmode.py +++ b/bot/exts/moderation/slowmode.py @@ -1,8 +1,8 @@ from typing import Optional from dateutil.relativedelta import relativedelta -from disnake import TextChannel -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord import TextChannel +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, Emojis, MODERATION_ROLES diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 7afd9f71d..985cc6eb1 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -2,10 +2,10 @@ from datetime import timedelta, timezone from operator import itemgetter import arrow -import disnake +import discord from arrow import Arrow from async_rediscache import RedisCache -from disnake.ext import commands +from discord.ext import commands from bot.bot import Bot from bot.constants import ( @@ -24,7 +24,7 @@ class Stream(commands.Cog): """Grant and revoke streaming permissions from members.""" # Stores tasks to remove streaming permission - # RedisCache[disnake.Member.id, UtcPosixTimestamp] + # RedisCache[discord.Member.id, UtcPosixTimestamp] task_cache = RedisCache() def __init__(self, bot: Bot): @@ -37,10 +37,10 @@ class Stream(commands.Cog): self.reload_task.cancel() self.reload_task.add_done_callback(lambda _: self.scheduler.cancel_all()) - async def _revoke_streaming_permission(self, member: disnake.Member) -> None: + async def _revoke_streaming_permission(self, member: discord.Member) -> None: """Remove the streaming permission from the given Member.""" await self.task_cache.delete(member.id) - await member.remove_roles(disnake.Object(Roles.video), reason="Streaming access revoked") + await member.remove_roles(discord.Object(Roles.video), reason="Streaming access revoked") async def _reload_tasks_from_redis(self) -> None: """Reload outstanding tasks from redis on startup, delete the task if the member has since left the server.""" @@ -66,7 +66,7 @@ class Stream(commands.Cog): self._revoke_streaming_permission(member) ) - async def _suspend_stream(self, ctx: commands.Context, member: disnake.Member) -> None: + async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None: """Suspend a member's stream.""" await self.bot.wait_until_guild_available() voice_state = member.voice @@ -90,7 +90,7 @@ class Stream(commands.Cog): @commands.command(aliases=("streaming",)) @commands.has_any_role(*MODERATION_ROLES) - async def stream(self, ctx: commands.Context, member: disnake.Member, duration: Expiry = None) -> None: + async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None: """ Temporarily grant streaming permissions to a member for a given duration. @@ -128,7 +128,7 @@ class Stream(commands.Cog): self.scheduler.schedule_at(duration, member.id, self._revoke_streaming_permission(member)) await self.task_cache.set(member.id, duration.timestamp()) - await member.add_roles(disnake.Object(Roles.video), reason="Temporary streaming access granted") + await member.add_roles(discord.Object(Roles.video), reason="Temporary streaming access granted") await ctx.send(f"{Emojis.check_mark} {member.mention} can now stream until {time.discord_timestamp(duration)}.") @@ -142,7 +142,7 @@ class Stream(commands.Cog): @commands.command(aliases=("pstream",)) @commands.has_any_role(*MODERATION_ROLES) - async def permanentstream(self, ctx: commands.Context, member: disnake.Member) -> None: + async def permanentstream(self, ctx: commands.Context, member: discord.Member) -> None: """Permanently grants the given member the permission to stream.""" log.trace(f"Attempting to give permanent streaming permission to {member} ({member.id}).") @@ -163,13 +163,13 @@ class Stream(commands.Cog): log.debug(f"{member} ({member.id}) already had permanent streaming permission.") return - await member.add_roles(disnake.Object(Roles.video), reason="Permanent streaming access granted") + await member.add_roles(discord.Object(Roles.video), reason="Permanent streaming access granted") await ctx.send(f"{Emojis.check_mark} Permanently granted {member.mention} the permission to stream.") log.debug(f"Successfully gave {member} ({member.id}) permanent streaming permission.") @commands.command(aliases=("unstream", "rstream")) @commands.has_any_role(*MODERATION_ROLES) - async def revokestream(self, ctx: commands.Context, member: disnake.Member) -> None: + async def revokestream(self, ctx: commands.Context, member: discord.Member) -> None: """Revoke the permission to stream from the given member.""" log.trace(f"Attempting to remove streaming permission from {member} ({member.id}).") @@ -222,7 +222,7 @@ class Stream(commands.Cog): # Only output the message in the pagination lines = [line[1] for line in streamer_info] - embed = disnake.Embed( + embed = discord.Embed( title=f"Members with streaming permission (`{len(lines)}` total)", colour=Colours.soft_green ) diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py index c958aa160..37338d19c 100644 --- a/bot/exts/moderation/verification.py +++ b/bot/exts/moderation/verification.py @@ -1,7 +1,7 @@ import typing as t -import disnake -from disnake.ext.commands import Cog, Context, command, has_any_role +import discord +from discord.ext.commands import Cog, Context, command, has_any_role from bot import constants from bot.bot import Bot @@ -51,7 +51,7 @@ async def safe_dm(coro: t.Coroutine) -> None: """ try: await coro - except disnake.HTTPException as discord_exc: + except discord.HTTPException as discord_exc: log.trace(f"DM dispatch failed on status {discord_exc.status} with code: {discord_exc.code}") if discord_exc.code != 50_007: # If any reason other than disabled DMs raise @@ -72,7 +72,7 @@ class Verification(Cog): # region: listeners @Cog.listener() - async def on_member_join(self, member: disnake.Member) -> None: + async def on_member_join(self, member: discord.Member) -> None: """Attempt to send initial direct message to each new member.""" if member.guild.id != constants.Guild.id: return # Only listen for PyDis events @@ -87,11 +87,11 @@ class Verification(Cog): log.trace(f"Sending on join message to new member: {member.id}") try: await safe_dm(member.send(ON_JOIN_MESSAGE)) - except disnake.HTTPException: + except discord.HTTPException: log.exception("DM dispatch failed on unexpected error code") @Cog.listener() - async def on_member_update(self, before: disnake.Member, after: disnake.Member) -> None: + async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: """Check if we need to send a verification DM to a gated user.""" if before.pending is True and after.pending is False: try: @@ -100,7 +100,7 @@ class Verification(Cog): # our alternate welcome DM which includes info such as our welcome # video. await safe_dm(after.send(VERIFIED_MESSAGE)) - except disnake.HTTPException: + except discord.HTTPException: log.exception("DM dispatch failed on unexpected error code") # endregion @@ -108,7 +108,7 @@ class Verification(Cog): @command(name='verify') @has_any_role(*constants.MODERATION_ROLES) - async def perform_manual_verification(self, ctx: Context, user: disnake.Member) -> None: + async def perform_manual_verification(self, ctx: Context, user: discord.Member) -> None: """Command for moderators to verify any user.""" log.trace(f'verify command called by {ctx.author} for {user.id}.') diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py index 24ae86bdd..fa66b00dd 100644 --- a/bot/exts/moderation/voice_gate.py +++ b/bot/exts/moderation/voice_gate.py @@ -3,10 +3,10 @@ from contextlib import suppress from datetime import timedelta import arrow -import disnake +import discord from async_rediscache import RedisCache -from disnake import Colour, Member, VoiceState -from disnake.ext.commands import Cog, Context, command +from discord import Colour, Member, VoiceState +from discord.ext.commands import Cog, Context, command from bot.api import ResponseCodeError from bot.bot import Bot @@ -51,7 +51,7 @@ VOICE_PING_DM = ( class VoiceGate(Cog): """Voice channels verification management.""" - # RedisCache[t.Union[disnake.User.id, disnake.Member.id], t.Union[disnake.Message.id, int]] + # RedisCache[t.Union[discord.User.id, discord.Member.id], t.Union[discord.Message.id, int]] # The cache's keys are the IDs of members who are verified or have joined a voice channel # The cache's values are either the message ID of the ping message or 0 (NO_MSG) if no message is present redis_cache = RedisCache() @@ -75,14 +75,14 @@ class VoiceGate(Cog): """ if message_id := await self.redis_cache.get(member_id): log.trace(f"Removing voice gate reminder message for user: {member_id}") - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await self.bot.http.delete_message(Channels.voice_gate, message_id) await self.redis_cache.set(member_id, NO_MSG) else: log.trace(f"Voice gate reminder message for user {member_id} was already removed") @redis_cache.atomic_transaction - async def _ping_newcomer(self, member: disnake.Member) -> tuple: + async def _ping_newcomer(self, member: discord.Member) -> tuple: """ See if `member` should be sent a voice verification notification, and send it if so. @@ -91,7 +91,7 @@ class VoiceGate(Cog): * The `member` is already voice-verified Otherwise, the notification message ID is stored in `redis_cache` and return (True, channel). - channel is either [disnake.TextChannel, disnake.DMChannel]. + channel is either [discord.TextChannel, discord.DMChannel]. """ if await self.redis_cache.contains(member.id): log.trace("User already in cache. Ignore.") @@ -111,7 +111,7 @@ class VoiceGate(Cog): try: message = await member.send(VOICE_PING_DM.format(channel_mention=voice_verification_channel.mention)) - except disnake.Forbidden: + except discord.Forbidden: log.trace("DM failed for Voice ping message. Sending in channel.") message = await voice_verification_channel.send(f"Hello, {member.mention}! {VOICE_PING}") @@ -137,7 +137,7 @@ class VoiceGate(Cog): data = await self.bot.api_client.get(f"bot/users/{ctx.author.id}/metricity_data") except ResponseCodeError as e: if e.status == 404: - embed = disnake.Embed( + embed = discord.Embed( title="Not found", description=( "We were unable to find user data for you. " @@ -148,7 +148,7 @@ class VoiceGate(Cog): ) log.info(f"Unable to find Metricity data about {ctx.author} ({ctx.author.id})") else: - embed = disnake.Embed( + embed = discord.Embed( title="Unexpected response", description=( "We encountered an error while attempting to find data for your user. " @@ -159,7 +159,7 @@ class VoiceGate(Cog): log.warning(f"Got response code {e.status} while trying to get {ctx.author.id} Metricity data.") try: await ctx.author.send(embed=embed) - except disnake.Forbidden: + except discord.Forbidden: log.info("Could not send user DM. Sending in voice-verify channel and scheduling delete.") await ctx.send(embed=embed) @@ -179,7 +179,7 @@ class VoiceGate(Cog): [self.bot.stats.incr(f"voice_gate.failed.{key}") for key, value in checks.items() if value is True] if failed: - embed = disnake.Embed( + embed = discord.Embed( title="Voice Gate failed", description=FAILED_MESSAGE.format(reasons="\n".join(f'• You {reason}.' for reason in failed_reasons)), color=Colour.red() @@ -187,12 +187,12 @@ class VoiceGate(Cog): try: await ctx.author.send(embed=embed) await ctx.send(f"{ctx.author}, please check your DMs.") - except disnake.Forbidden: + except discord.Forbidden: await ctx.channel.send(ctx.author.mention, embed=embed) return self.mod_log.ignore(Event.member_update, ctx.author.id) - embed = disnake.Embed( + embed = discord.Embed( title="Voice gate passed", description="You have been granted permission to use voice channels in Python Discord.", color=Colour.green() @@ -204,17 +204,17 @@ class VoiceGate(Cog): try: await ctx.author.send(embed=embed) await ctx.send(f"{ctx.author}, please check your DMs.") - except disnake.Forbidden: + except discord.Forbidden: await ctx.channel.send(ctx.author.mention, embed=embed) # wait a little bit so those who don't get DMs see the response in-channel before losing perms to see it. await asyncio.sleep(3) - await ctx.author.add_roles(disnake.Object(Roles.voice_verified), reason="Voice Gate passed") + await ctx.author.add_roles(discord.Object(Roles.voice_verified), reason="Voice Gate passed") self.bot.stats.incr("voice_gate.passed") @Cog.listener() - async def on_message(self, message: disnake.Message) -> None: + async def on_message(self, message: discord.Message) -> None: """Delete all non-staff messages from voice gate channel that don't invoke voice verify command.""" # Check is channel voice gate if message.channel.id != Channels.voice_gate: @@ -229,7 +229,7 @@ class VoiceGate(Cog): if message.content.endswith(VOICE_PING): log.trace("Message is the voice verification ping. Ignore.") return - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await message.delete(delay=GateConf.bot_message_delete_delay) return @@ -242,7 +242,7 @@ class VoiceGate(Cog): if ctx.command is not None and ctx.command.name == "voice_verify": self.mod_log.ignore(Event.message_delete, message.id) - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await message.delete() @Cog.listener() @@ -257,7 +257,7 @@ class VoiceGate(Cog): log.trace("User not in a voice channel. Ignore.") return - if isinstance(after.channel, disnake.StageChannel): + if isinstance(after.channel, discord.StageChannel): log.trace("User joined a stage channel. Ignore.") return @@ -267,7 +267,7 @@ class VoiceGate(Cog): # Schedule the channel ping notification to be deleted after the configured delay, which is # again delegated to an atomic helper - if notification_sent and isinstance(message_channel, disnake.TextChannel): + if notification_sent and isinstance(message_channel, discord.TextChannel): await asyncio.sleep(GateConf.voice_ping_delete_delay) await self._delete_ping(member.id) diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index 88669ccaa..ee9b6ba45 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -6,9 +6,9 @@ from collections import defaultdict, deque from dataclasses import dataclass from typing import Any, Dict, Optional -import disnake -from disnake import Color, DMChannel, Embed, HTTPException, Message, errors -from disnake.ext.commands import Cog, Context +import discord +from discord import Color, DMChannel, Embed, HTTPException, Message, errors +from discord.ext.commands import Cog, Context from bot.api import ResponseCodeError from bot.bot import Bot @@ -104,7 +104,7 @@ class WatchChannel(metaclass=CogABCMeta): try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except disnake.HTTPException: + except discord.HTTPException: self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") if self.channel is None or self.webhook is None: @@ -217,7 +217,7 @@ class WatchChannel(metaclass=CogABCMeta): username = messages.sub_clyde(username) try: await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except disnake.HTTPException as exc: + except discord.HTTPException as exc: self.log.exception( "Failed to send a message to the webhook", exc_info=exc @@ -265,7 +265,7 @@ class WatchChannel(metaclass=CogABCMeta): username=msg.author.display_name, avatar_url=msg.author.display_avatar.url ) - except disnake.HTTPException as exc: + except discord.HTTPException as exc: self.log.exception( "Failed to send an attachment to the webhook", exc_info=exc diff --git a/bot/exts/moderation/watchchannels/bigbrother.py b/bot/exts/moderation/watchchannels/bigbrother.py index f19e3d103..31b106a20 100644 --- a/bot/exts/moderation/watchchannels/bigbrother.py +++ b/bot/exts/moderation/watchchannels/bigbrother.py @@ -1,7 +1,7 @@ import textwrap from collections import ChainMap -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Webhooks @@ -94,7 +94,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"): await ctx.send(f":x: {user.mention} is already being watched.") return - # disnake.User instances don't have a roles attribute + # discord.User instances don't have a roles attribute if hasattr(user, "roles") and any(role.id in MODERATION_ROLES for role in user.roles): await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I must be kind to my masters.") return diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 3d784ef77..0554bf37a 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -3,10 +3,10 @@ from collections import ChainMap, defaultdict from io import StringIO from typing import Optional, Union -import disnake +import discord from async_rediscache import RedisCache -from disnake import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User -from disnake.ext.commands import BadArgument, Cog, Context, group, has_any_role +from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User +from discord.ext.commands import BadArgument, Cog, Context, group, has_any_role from bot.api import ResponseCodeError from bot.bot import Bot @@ -483,7 +483,7 @@ class TalentPool(Cog, name="Talentpool"): async def get_review(self, ctx: Context, user_id: int) -> None: """Get the user's review as a markdown file.""" review, _, _ = await self.reviewer.make_review(user_id) - file = disnake.File(StringIO(review), f"{user_id}_review.md") + file = discord.File(StringIO(review), f"{user_id}_review.md") await ctx.send(file=file) @nomination_group.command(aliases=('review',)) diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index d496d0eb2..b4d177622 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -10,8 +10,8 @@ from typing import List, Optional, Union import arrow from dateutil.parser import isoparse -from disnake import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel -from disnake.ext.commands import Context +from discord import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel +from discord.ext.commands import Context from bot.api import ResponseCodeError from bot.bot import Bot diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py index 7d18c0ed3..8f0094bc9 100644 --- a/bot/exts/utils/bot.py +++ b/bot/exts/utils/bot.py @@ -1,7 +1,7 @@ from typing import Optional -from disnake import Embed, TextChannel -from disnake.ext.commands import Cog, Context, command, group, has_any_role +from discord import Embed, TextChannel +from discord.ext.commands import Cog, Context, command, group, has_any_role from bot.bot import Bot from bot.constants import Guild, MODERATION_ROLES, URLs diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index 3d12ae848..fda1e49e2 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -2,9 +2,9 @@ import functools import typing as t from enum import Enum -from disnake import Colour, Embed -from disnake.ext import commands -from disnake.ext.commands import Context, group +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Context, group from bot import exts from bot.bot import Bot diff --git a/bot/exts/utils/internal.py b/bot/exts/utils/internal.py index 28c1867ad..e7113c09c 100644 --- a/bot/exts/utils/internal.py +++ b/bot/exts/utils/internal.py @@ -9,8 +9,8 @@ from io import StringIO from typing import Any, Optional, Tuple import arrow -import disnake -from disnake.ext.commands import Cog, Context, group, has_any_role, is_owner +import discord +from discord.ext.commands import Cog, Context, group, has_any_role, is_owner from bot.bot import Bot from bot.constants import DEBUG_MODE, Roles @@ -42,7 +42,7 @@ class Internal(Cog): self.socket_event_total += 1 self.socket_events[event_type] += 1 - def _format(self, inp: str, out: Any) -> Tuple[str, Optional[disnake.Embed]]: + def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: """Format the eval output into a string & attempt to format it into an Embed.""" self._ = out @@ -103,7 +103,7 @@ class Internal(Cog): res += f"Out[{self.ln}]: " - if isinstance(out, disnake.Embed): + if isinstance(out, discord.Embed): # We made an embed? Send that as embed res += "" res = (res, out) @@ -136,7 +136,7 @@ class Internal(Cog): return res # Return (text, embed) - async def _eval(self, ctx: Context, code: str) -> Optional[disnake.Message]: + async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: """Eval the input code string & send an embed to the invoking context.""" self.ln += 1 @@ -154,8 +154,7 @@ class Internal(Cog): "self": self, "bot": self.bot, "inspect": inspect, - "discord": disnake, - "disnake": disnake, + "discord": discord, "contextlib": contextlib } @@ -241,10 +240,10 @@ async def func(): # (None,) -> Any per_s = self.socket_event_total / running_s - stats_embed = disnake.Embed( + stats_embed = discord.Embed( title="WebSocket statistics", description=f"Receiving {per_s:0.2f} events per second.", - color=disnake.Color.og_blurple() + color=discord.Color.og_blurple() ) for event_type, count in self.socket_events.most_common(25): diff --git a/bot/exts/utils/ping.py b/bot/exts/utils/ping.py index eeb1d5ff5..9fb5b7b8f 100644 --- a/bot/exts/utils/ping.py +++ b/bot/exts/utils/ping.py @@ -1,7 +1,7 @@ import arrow from aiohttp import client_exceptions -from disnake import Embed -from disnake.ext import commands +from discord import Embed +from discord.ext import commands from bot.bot import Bot from bot.constants import Channels, STAFF_PARTNERS_COMMUNITY_ROLES, URLs diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index bf0e9d2ac..ad82d49c9 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -4,9 +4,9 @@ import typing as t from datetime import datetime, timezone from operator import itemgetter -import disnake +import discord from dateutil.parser import isoparse -from disnake.ext.commands import Cog, Context, Greedy, group +from discord.ext.commands import Cog, Context, Greedy, group from bot.bot import Bot from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES @@ -26,8 +26,8 @@ LOCK_NAMESPACE = "reminder" WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 -Mentionable = t.Union[disnake.Member, disnake.Role] -ReminderMention = t.Union[UnambiguousUser, disnake.Role] +Mentionable = t.Union[discord.Member, discord.Role] +ReminderMention = t.Union[UnambiguousUser, discord.Role] class Reminders(Cog): @@ -66,7 +66,7 @@ class Reminders(Cog): else: self.schedule_reminder(reminder) - def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, disnake.TextChannel]: + def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, discord.TextChannel]: """Ensure reminder channel can be fetched otherwise delete the reminder.""" channel = self.bot.get_channel(reminder['channel_id']) is_valid = True @@ -87,9 +87,9 @@ class Reminders(Cog): reminder_id: t.Union[str, int] ) -> None: """Send an embed confirming the reminder change was made successfully.""" - embed = disnake.Embed( + embed = discord.Embed( description=on_success, - colour=disnake.Colour.green(), + colour=discord.Colour.green(), title=random.choice(POSITIVE_REPLIES) ) @@ -113,7 +113,7 @@ class Reminders(Cog): if await has_no_roles_check(ctx, *STAFF_PARTNERS_COMMUNITY_ROLES): return False, "members/roles" elif await has_no_roles_check(ctx, *MODERATION_ROLES): - return all(isinstance(mention, (disnake.User, disnake.Member)) for mention in mentions), "roles" + return all(isinstance(mention, (discord.User, discord.Member)) for mention in mentions), "roles" else: return True, "" @@ -173,15 +173,15 @@ class Reminders(Cog): if not is_valid: # No need to cancel the task too; it'll simply be done once this coroutine returns. return - embed = disnake.Embed() + embed = discord.Embed() if expected_time: - embed.colour = disnake.Colour.red() + embed.colour = discord.Colour.red() embed.set_author( icon_url=Icons.remind_red, name="Sorry, your reminder should have arrived earlier!" ) else: - embed.colour = disnake.Colour.og_blurple() + embed.colour = discord.Colour.og_blurple() embed.set_author( icon_url=Icons.remind_blurple, name="It has arrived!" @@ -200,7 +200,7 @@ class Reminders(Cog): partial_message = channel.get_partial_message(int(jump_url.split("/")[-1])) try: await partial_message.reply(content=f"{additional_mentions}", embed=embed) - except disnake.HTTPException as e: + except discord.HTTPException as e: log.info( f"There was an error when trying to reply to a reminder invocation message, {e}, " "fall back to using jump_url" @@ -284,7 +284,7 @@ class Reminders(Cog): # If `content` isn't provided then we try to get message content of a replied message if not content: if reference := ctx.message.reference: - if isinstance((resolved_message := reference.resolved), disnake.Message): + if isinstance((resolved_message := reference.resolved), discord.Message): content = resolved_message.content # If we weren't able to get the content of a replied message if content is None: @@ -361,8 +361,8 @@ class Reminders(Cog): lines.append(text) - embed = disnake.Embed() - embed.colour = disnake.Colour.og_blurple() + embed = discord.Embed() + embed.colour = discord.Colour.og_blurple() embed.title = f"Reminders for {ctx.author}" # Remind the user that they have no reminders :^) @@ -372,7 +372,7 @@ class Reminders(Cog): return # Construct the embed and paginate it. - embed.colour = disnake.Colour.og_blurple() + embed.colour = discord.Colour.og_blurple() await LinePaginator.paginate( lines, diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 07d824f87..cc3a2e1d7 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -8,8 +8,8 @@ from signal import Signals from typing import Optional, Tuple from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX -from disnake import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from disnake.ext.commands import Cog, Context, command, guild_only +from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User +from discord.ext.commands import Cog, Context, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs diff --git a/bot/exts/utils/thread_bumper.py b/bot/exts/utils/thread_bumper.py index d37b3b51c..35057f1fe 100644 --- a/bot/exts/utils/thread_bumper.py +++ b/bot/exts/utils/thread_bumper.py @@ -1,8 +1,8 @@ import typing as t -import disnake +import discord from async_rediscache import RedisCache -from disnake.ext import commands +from discord.ext import commands from bot import constants from bot.bot import Bot @@ -16,14 +16,14 @@ log = get_logger(__name__) class ThreadBumper(commands.Cog): """Cog that allow users to add the current thread to a list that get reopened on archive.""" - # RedisCache[disnake.Thread.id, "sentinel"] + # RedisCache[discord.Thread.id, "sentinel"] threads_to_bump = RedisCache() def __init__(self, bot: Bot): self.bot = bot self.init_task = scheduling.create_task(self.ensure_bumped_threads_are_active(), event_loop=self.bot.loop) - async def unarchive_threads_not_manually_archived(self, threads: list[disnake.Thread]) -> None: + async def unarchive_threads_not_manually_archived(self, threads: list[discord.Thread]) -> None: """ Iterate through and unarchive any threads that weren't manually archived recently. @@ -35,7 +35,7 @@ class ThreadBumper(commands.Cog): guild = self.bot.get_guild(constants.Guild.id) recent_manually_archived_thread_ids = [] - async for thread_update in guild.audit_logs(limit=200, action=disnake.AuditLogAction.thread_update): + async for thread_update in guild.audit_logs(limit=200, action=discord.AuditLogAction.thread_update): if getattr(thread_update.after, "archived", False): recent_manually_archived_thread_ids.append(thread_update.target.id) @@ -58,7 +58,7 @@ class ThreadBumper(commands.Cog): for thread_id, _ in await self.threads_to_bump.items(): try: thread = await channel.get_or_fetch_channel(thread_id) - except disnake.NotFound: + except discord.NotFound: log.info("Thread %d has been deleted, removing from bumped threads.", thread_id) await self.threads_to_bump.delete(thread_id) continue @@ -75,12 +75,12 @@ class ThreadBumper(commands.Cog): await ctx.send_help(ctx.command) @thread_bump_group.command(name="add", aliases=("a",)) - async def add_thread_to_bump_list(self, ctx: commands.Context, thread: t.Optional[disnake.Thread]) -> None: + async def add_thread_to_bump_list(self, ctx: commands.Context, thread: t.Optional[discord.Thread]) -> None: """Add a thread to the bump list.""" await self.init_task if not thread: - if isinstance(ctx.channel, disnake.Thread): + if isinstance(ctx.channel, discord.Thread): thread = ctx.channel else: raise commands.BadArgument("You must provide a thread, or run this command within a thread.") @@ -92,12 +92,12 @@ class ThreadBumper(commands.Cog): await ctx.send(f":ok_hand:{thread.mention} has been added to the bump list.") @thread_bump_group.command(name="remove", aliases=("r", "rem", "d", "del", "delete")) - async def remove_thread_from_bump_list(self, ctx: commands.Context, thread: t.Optional[disnake.Thread]) -> None: + async def remove_thread_from_bump_list(self, ctx: commands.Context, thread: t.Optional[discord.Thread]) -> None: """Remove a thread from the bump list.""" await self.init_task if not thread: - if isinstance(ctx.channel, disnake.Thread): + if isinstance(ctx.channel, discord.Thread): thread = ctx.channel else: raise commands.BadArgument("You must provide a thread, or run this command within a thread.") @@ -114,14 +114,14 @@ class ThreadBumper(commands.Cog): await self.init_task lines = [f"<#{k}>" for k, _ in await self.threads_to_bump.items()] - embed = disnake.Embed( + embed = discord.Embed( title="Threads in the bump list", colour=constants.Colours.blue ) await LinePaginator.paginate(lines, ctx, embed) @commands.Cog.listener() - async def on_thread_update(self, _: disnake.Thread, after: disnake.Thread) -> None: + async def on_thread_update(self, _: discord.Thread, after: discord.Thread) -> None: """ Listen for thread updates and check if the thread has been archived. diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index 77be3315c..2a074788e 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -3,9 +3,9 @@ import re import unicodedata from typing import Tuple, Union -from disnake import Colour, Embed, utils -from disnake.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role -from disnake.utils import snowflake_time +from discord import Colour, Embed, utils +from discord.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role +from discord.utils import snowflake_time from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES diff --git a/bot/log.py b/bot/log.py index 0b1d1aca6..100cd06f6 100644 --- a/bot/log.py +++ b/bot/log.py @@ -74,7 +74,7 @@ def setup() -> None: coloredlogs.install(level=TRACE_LEVEL, logger=root_log, stream=sys.stdout) root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO) - get_logger("disnake").setLevel(logging.WARNING) + get_logger("discord").setLevel(logging.WARNING) get_logger("websockets").setLevel(logging.WARNING) get_logger("chardet").setLevel(logging.WARNING) get_logger("async_rediscache").setLevel(logging.WARNING) diff --git a/bot/monkey_patches.py b/bot/monkey_patches.py index 590be22a2..4840fa454 100644 --- a/bot/monkey_patches.py +++ b/bot/monkey_patches.py @@ -2,8 +2,8 @@ import re from datetime import timedelta import arrow -from disnake import Forbidden, http -from disnake.ext import commands +from discord import Forbidden, http +from discord.ext import commands from bot.log import get_logger @@ -13,7 +13,7 @@ MESSAGE_ID_RE = re.compile(r'(?P[0-9]{15,20})$') class Command(commands.Command): """ - A `disnake.ext.commands.Command` subclass which supports root aliases. + A `discord.ext.commands.Command` subclass which supports root aliases. A `root_aliases` keyword argument is added, which is a sequence of alias names that will act as top-level commands rather than being aliases of the command's group. It's stored as an attribute diff --git a/bot/pagination.py b/bot/pagination.py index 1a014daa1..8f4353eb1 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -3,9 +3,9 @@ import typing as t from contextlib import suppress from functools import partial -import disnake -from disnake.abc import User -from disnake.ext.commands import Context, Paginator +import discord +from discord.abc import User +from discord.ext.commands import Context, Paginator from bot import constants from bot.log import get_logger @@ -55,7 +55,7 @@ class LinePaginator(Paginator): linesep: str = "\n" ) -> None: """ - This function overrides the Paginator.__init__ from inside disnake.ext.commands. + This function overrides the Paginator.__init__ from inside discord.ext.commands. It overrides in order to allow us to configure the maximum number of lines per page. """ @@ -99,7 +99,7 @@ class LinePaginator(Paginator): effort to avoid breaking up single lines across pages, while keeping the total length of the page at a reasonable size. - This function overrides the `Paginator.add_line` from inside `disnake.ext.commands`. + This function overrides the `Paginator.add_line` from inside `discord.ext.commands`. It overrides in order to allow us to configure the maximum number of lines per page. """ @@ -192,7 +192,7 @@ class LinePaginator(Paginator): cls, lines: t.List[str], ctx: Context, - embed: disnake.Embed, + embed: discord.Embed, prefix: str = "", suffix: str = "", max_lines: t.Optional[int] = None, @@ -204,7 +204,7 @@ class LinePaginator(Paginator): footer_text: str = None, url: str = None, exception_on_empty_embed: bool = False, - ) -> t.Optional[disnake.Message]: + ) -> t.Optional[discord.Message]: """ Use a paginator and set of reactions to provide pagination over a set of lines. @@ -219,7 +219,7 @@ class LinePaginator(Paginator): to any user with a moderation role. Example: - >>> embed = disnake.Embed() + >>> embed = discord.Embed() >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) >>> await LinePaginator.paginate([line for line in lines], ctx, embed) """ @@ -367,5 +367,5 @@ class LinePaginator(Paginator): await message.edit(embed=embed) log.debug("Ending pagination and clearing reactions.") - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await message.clear_reactions() diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index 9c890e569..8903c385c 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/burst.py b/bot/rules/burst.py index a943cfdeb..25c5a2f33 100644 --- a/bot/rules/burst.py +++ b/bot/rules/burst.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/burst_shared.py b/bot/rules/burst_shared.py index dee857e18..bbe9271b3 100644 --- a/bot/rules/burst_shared.py +++ b/bot/rules/burst_shared.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/chars.py b/bot/rules/chars.py index 6d2f6eb83..1f587422c 100644 --- a/bot/rules/chars.py +++ b/bot/rules/chars.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/discord_emojis.py b/bot/rules/discord_emojis.py index 4fe4e88f9..d979ac5e7 100644 --- a/bot/rules/discord_emojis.py +++ b/bot/rules/discord_emojis.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message from emoji import demojize DISCORD_EMOJI_RE = re.compile(r"<:\w+:\d+>|:\w+:") diff --git a/bot/rules/duplicates.py b/bot/rules/duplicates.py index 77e393db0..8e4fbc12d 100644 --- a/bot/rules/duplicates.py +++ b/bot/rules/duplicates.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/links.py b/bot/rules/links.py index 92c13b3f4..c46b783c5 100644 --- a/bot/rules/links.py +++ b/bot/rules/links.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message LINK_RE = re.compile(r"(https?://[^\s]+)") diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 7ee66be31..6f5addad1 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/newlines.py b/bot/rules/newlines.py index 45266648e..4e66e1359 100644 --- a/bot/rules/newlines.py +++ b/bot/rules/newlines.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/role_mentions.py b/bot/rules/role_mentions.py index 1f7a6a74d..0649540b6 100644 --- a/bot/rules/role_mentions.py +++ b/bot/rules/role_mentions.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/utils/channel.py b/bot/utils/channel.py index ee0c87311..954a10e56 100644 --- a/bot/utils/channel.py +++ b/bot/utils/channel.py @@ -1,6 +1,6 @@ from typing import Union -import disnake +import discord import bot from bot import constants @@ -10,7 +10,7 @@ from bot.log import get_logger log = get_logger(__name__) -def is_help_channel(channel: disnake.TextChannel) -> bool: +def is_help_channel(channel: discord.TextChannel) -> bool: """Return True if `channel` is in one of the help categories (excluding dormant).""" log.trace(f"Checking if #{channel} is a help channel.") categories = (Categories.help_available, Categories.help_in_use) @@ -18,9 +18,9 @@ def is_help_channel(channel: disnake.TextChannel) -> bool: return any(is_in_category(channel, category) for category in categories) -def is_mod_channel(channel: Union[disnake.TextChannel, disnake.Thread]) -> bool: +def is_mod_channel(channel: Union[discord.TextChannel, discord.Thread]) -> bool: """True if channel, or channel.parent for threads, is considered a mod channel.""" - if isinstance(channel, disnake.Thread): + if isinstance(channel, discord.Thread): channel = channel.parent if channel.id in constants.MODERATION_CHANNELS: @@ -36,11 +36,11 @@ def is_mod_channel(channel: Union[disnake.TextChannel, disnake.Thread]) -> bool: return False -def is_staff_channel(channel: disnake.TextChannel) -> bool: +def is_staff_channel(channel: discord.TextChannel) -> bool: """True if `channel` is considered a staff channel.""" guild = bot.instance.get_guild(constants.Guild.id) - if channel.type is disnake.ChannelType.category: + if channel.type is discord.ChannelType.category: return False # Channel is staff-only if staff have explicit read allow perms @@ -52,12 +52,12 @@ def is_staff_channel(channel: disnake.TextChannel) -> bool: ) -def is_in_category(channel: disnake.TextChannel, category_id: int) -> bool: +def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: """Return True if `channel` is within a category with `category_id`.""" return getattr(channel, "category_id", None) == category_id -async def get_or_fetch_channel(channel_id: int) -> disnake.abc.GuildChannel: +async def get_or_fetch_channel(channel_id: int) -> discord.abc.GuildChannel: """Attempt to get or fetch a channel and return it.""" log.trace(f"Getting the channel {channel_id}.") diff --git a/bot/utils/checks.py b/bot/utils/checks.py index 9aa9bdc14..188285684 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -1,6 +1,6 @@ from typing import Callable, Container, Iterable, Optional, Union -from disnake.ext.commands import ( +from discord.ext.commands import ( BucketType, CheckFailure, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping, NoPrivateMessage, has_any_role ) @@ -135,7 +135,7 @@ def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketTy if any(role.id in bypass for role in ctx.author.roles): return - # cooldown logic, taken from disnake's internals + # cooldown logic, taken from discord.py internals current = ctx.message.created_at.timestamp() bucket = buckets.get_bucket(ctx.message) retry_after = bucket.update_rate_limit(current) diff --git a/bot/utils/function.py b/bot/utils/function.py index bb6d8afe3..55115d7d3 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -94,7 +94,7 @@ def update_wrapper_globals( """ Update globals of `wrapper` with the globals from `wrapped`. - For forwardrefs in command annotations disnake uses the __global__ attribute of the function + For forwardrefs in command annotations discordpy uses the __global__ attribute of the function to resolve their values, with decorators that replace the function this breaks because they have their own globals. @@ -103,7 +103,7 @@ def update_wrapper_globals( An exception will be raised in case `wrapper` and `wrapped` share a global name that is used by `wrapped`'s typehints and is not in `ignored_conflict_names`, - as this can cause incorrect objects being used by disnake's converters. + as this can cause incorrect objects being used by discordpy's converters. """ annotation_global_names = ( ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) @@ -136,7 +136,7 @@ def command_wraps( *, ignored_conflict_names: t.Set[str] = frozenset(), ) -> t.Callable[[types.FunctionType], types.FunctionType]: - """Update the decorated function to look like `wrapped` and update globals for disnake forwardref evaluation.""" + """Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation.""" def decorator(wrapper: types.FunctionType) -> types.FunctionType: return functools.update_wrapper( update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), diff --git a/bot/utils/helpers.py b/bot/utils/helpers.py index 859f53fdb..3501a3933 100644 --- a/bot/utils/helpers.py +++ b/bot/utils/helpers.py @@ -1,7 +1,7 @@ from abc import ABCMeta from typing import Optional -from disnake.ext.commands import CogMeta +from discord.ext.commands import CogMeta class CogABCMeta(CogMeta, ABCMeta): diff --git a/bot/utils/members.py b/bot/utils/members.py index d46baae5b..693286045 100644 --- a/bot/utils/members.py +++ b/bot/utils/members.py @@ -1,13 +1,13 @@ import typing as t -import disnake +import discord from bot.log import get_logger log = get_logger(__name__) -async def get_or_fetch_member(guild: disnake.Guild, member_id: int) -> t.Optional[disnake.Member]: +async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> t.Optional[discord.Member]: """ Attempt to get a member from cache; on failure fetch from the API. @@ -18,7 +18,7 @@ async def get_or_fetch_member(guild: disnake.Guild, member_id: int) -> t.Optiona else: try: member = await guild.fetch_member(member_id) - except disnake.errors.NotFound: + except discord.errors.NotFound: log.trace("Failed to fetch %d from API.", member_id) return None log.trace("%s fetched from API.", member) @@ -26,23 +26,23 @@ async def get_or_fetch_member(guild: disnake.Guild, member_id: int) -> t.Optiona async def handle_role_change( - member: disnake.Member, + member: discord.Member, coro: t.Callable[..., t.Coroutine], - role: disnake.Role + role: discord.Role ) -> None: """ Change `member`'s cooldown role via awaiting `coro` and handle errors. - `coro` is intended to be `disnake.Member.add_roles` or `disnake.Member.remove_roles`. + `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. """ try: await coro(role) - except disnake.NotFound: + except discord.NotFound: log.debug(f"Failed to change role for {member} ({member.id}): member not found") - except disnake.Forbidden: + except discord.Forbidden: log.debug( f"Forbidden to change role for {member} ({member.id}); " f"possibly due to role hierarchy" ) - except disnake.HTTPException as e: + except discord.HTTPException as e: log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py index edf2111e9..f68d280c9 100644 --- a/bot/utils/message_cache.py +++ b/bot/utils/message_cache.py @@ -1,7 +1,7 @@ import typing as t from math import ceil -from disnake import Message +from discord import Message class MessageCache: diff --git a/bot/utils/messages.py b/bot/utils/messages.py index 0bdb00a29..e55c07062 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -5,8 +5,8 @@ from functools import partial from io import BytesIO from typing import Callable, List, Optional, Sequence, Union -import disnake -from disnake.ext.commands import Context +import discord +from discord.ext.commands import Context import bot from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES @@ -17,8 +17,8 @@ log = get_logger(__name__) def reaction_check( - reaction: disnake.Reaction, - user: disnake.abc.User, + reaction: discord.Reaction, + user: discord.abc.User, *, message_id: int, allowed_emoji: Sequence[str], @@ -51,14 +51,14 @@ def reaction_check( log.trace(f"Removing reaction {reaction} by {user} on {reaction.message.id}: disallowed user.") scheduling.create_task( reaction.message.remove_reaction(reaction.emoji, user), - suppressed_exceptions=(disnake.HTTPException,), + suppressed_exceptions=(discord.HTTPException,), name=f"remove_reaction-{reaction}-{reaction.message.id}-{user}" ) return False async def wait_for_deletion( - message: disnake.Message, + message: discord.Message, user_ids: Sequence[int], deletion_emojis: Sequence[str] = (Emojis.trashcan,), timeout: float = 60 * 5, @@ -82,7 +82,7 @@ async def wait_for_deletion( for emoji in deletion_emojis: try: await message.add_reaction(emoji) - except disnake.NotFound: + except discord.NotFound: log.trace(f"Aborting wait_for_deletion: message {message.id} deleted prematurely.") return @@ -101,13 +101,13 @@ async def wait_for_deletion( await message.clear_reactions() else: await message.delete() - except disnake.NotFound: + except discord.NotFound: log.trace(f"wait_for_deletion: message {message.id} deleted prematurely.") async def send_attachments( - message: disnake.Message, - destination: Union[disnake.TextChannel, disnake.Webhook], + message: discord.Message, + destination: Union[discord.TextChannel, discord.Webhook], link_large: bool = True, use_cached: bool = False, **kwargs @@ -140,9 +140,9 @@ async def send_attachments( if attachment.size <= destination.guild.filesize_limit - 512: with BytesIO() as file: await attachment.save(file, use_cached=use_cached) - attachment_file = disnake.File(file, filename=attachment.filename) + attachment_file = discord.File(file, filename=attachment.filename) - if isinstance(destination, disnake.TextChannel): + if isinstance(destination, discord.TextChannel): msg = await destination.send(file=attachment_file, **kwargs) urls.append(msg.attachments[0].url) else: @@ -151,7 +151,7 @@ async def send_attachments( large.append(attachment) else: log.info(f"{failure_msg} because it's too large.") - except disnake.HTTPException as e: + except discord.HTTPException as e: if link_large and e.status == 413: large.append(attachment) else: @@ -159,10 +159,10 @@ async def send_attachments( if link_large and large: desc = "\n".join(f"[{attachment.filename}]({attachment.url})" for attachment in large) - embed = disnake.Embed(description=desc) + embed = discord.Embed(description=desc) embed.set_footer(text="Attachments exceed upload size limit.") - if isinstance(destination, disnake.TextChannel): + if isinstance(destination, discord.TextChannel): await destination.send(embed=embed, **kwargs) else: await destination.send(embed=embed, **webhook_send_kwargs) @@ -171,9 +171,9 @@ async def send_attachments( async def count_unique_users_reaction( - message: disnake.Message, - reaction_predicate: Callable[[disnake.Reaction], bool] = lambda _: True, - user_predicate: Callable[[disnake.User], bool] = lambda _: True, + message: discord.Message, + reaction_predicate: Callable[[discord.Reaction], bool] = lambda _: True, + user_predicate: Callable[[discord.User], bool] = lambda _: True, count_bots: bool = True ) -> int: """ @@ -193,7 +193,7 @@ async def count_unique_users_reaction( return len(unique_users) -async def pin_no_system_message(message: disnake.Message) -> bool: +async def pin_no_system_message(message: discord.Message) -> bool: """Pin the given message, wait a couple of seconds and try to delete the system message.""" await message.pin() @@ -201,7 +201,7 @@ async def pin_no_system_message(message: disnake.Message) -> bool: await asyncio.sleep(2) # Search for the system message in the last 10 messages async for historical_message in message.channel.history(limit=10): - if historical_message.type == disnake.MessageType.pins_add: + if historical_message.type == discord.MessageType.pins_add: await historical_message.delete() return True @@ -225,16 +225,16 @@ def sub_clyde(username: Optional[str]) -> Optional[str]: return username # Empty string or None -async def send_denial(ctx: Context, reason: str) -> disnake.Message: +async def send_denial(ctx: Context, reason: str) -> discord.Message: """Send an embed denying the user with the given reason.""" - embed = disnake.Embed() - embed.colour = disnake.Colour.red() + embed = discord.Embed() + embed.colour = discord.Colour.red() embed.title = random.choice(NEGATIVE_REPLIES) embed.description = reason return await ctx.send(embed=embed) -def format_user(user: disnake.abc.User) -> str: +def format_user(user: discord.abc.User) -> str: """Return a string for `user` which has their mention and ID.""" return f"{user.mention} (`{user.id}`)" diff --git a/bot/utils/webhooks.py b/bot/utils/webhooks.py index 8ef929b79..9c916b63a 100644 --- a/bot/utils/webhooks.py +++ b/bot/utils/webhooks.py @@ -1,7 +1,7 @@ from typing import Optional -import disnake -from disnake import Embed +import discord +from discord import Embed from bot.log import get_logger from bot.utils.messages import sub_clyde @@ -10,13 +10,13 @@ log = get_logger(__name__) async def send_webhook( - webhook: disnake.Webhook, + webhook: discord.Webhook, content: Optional[str] = None, username: Optional[str] = None, avatar_url: Optional[str] = None, embed: Optional[Embed] = None, wait: Optional[bool] = False -) -> disnake.Message: +) -> discord.Message: """ Send a message using the provided webhook. @@ -30,5 +30,5 @@ async def send_webhook( embed=embed, wait=wait, ) - except disnake.HTTPException: + except discord.HTTPException: log.exception("Failed to send a message to the webhook!") diff --git a/poetry.lock b/poetry.lock index 087fd739c..a8ee6ef5c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -143,7 +143,7 @@ lxml = ["lxml"] [[package]] name = "bot-core" -version = "2.1.0" +version = "1.2.0" description = "Bot-Core provides the core functionality and utilities for the bots of the Python Discord community." category = "main" optional = false @@ -154,7 +154,7 @@ python-versions = "3.9.*" [package.source] type = "url" -url = "https://github.com/python-discord/bot-core/archive/refs/tags/v2.1.0.zip" +url = "https://github.com/python-discord/bot-core/archive/511bcba1b0196cd498c707a525ea56921bd971db.zip" [[package]] name = "certifi" version = "2021.10.8" @@ -1074,7 +1074,7 @@ toml = ">=0.10.0,<0.11.0" [[package]] name = "testfixtures" -version = "6.18.5" +version = "6.18.3" description = "A collection of helpers and mock objects for unit tests and doc tests." category = "dev" optional = false @@ -1169,7 +1169,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "3.9.*" -content-hash = "b8b28311c13f7a66f028041bae889131d3916ca7f667c9a7539871d21bbcd077" +content-hash = "538a4809b9fc6fa93ee1baccf4016515ae311a886f1b7ec9b3d544bb87c830a3" [metadata.files] aio-pika = [ @@ -1990,8 +1990,8 @@ taskipy = [ {file = "taskipy-1.7.0.tar.gz", hash = "sha256:960e480b1004971e76454ecd1a0484e640744a30073a1069894a311467f85ed8"}, ] testfixtures = [ - {file = "testfixtures-6.18.5-py2.py3-none-any.whl", hash = "sha256:7de200e24f50a4a5d6da7019fb1197aaf5abd475efb2ec2422fdcf2f2eb98c1d"}, - {file = "testfixtures-6.18.5.tar.gz", hash = "sha256:02dae883f567f5b70fd3ad3c9eefb95912e78ac90be6c7444b5e2f46bf572c84"}, + {file = "testfixtures-6.18.3-py2.py3-none-any.whl", hash = "sha256:6ddb7f56a123e1a9339f130a200359092bd0a6455e31838d6c477e8729bb7763"}, + {file = "testfixtures-6.18.3.tar.gz", hash = "sha256:2600100ae96ffd082334b378e355550fef8b4a529a6fa4c34f47130905c7426d"}, ] tldextract = [ {file = "tldextract-3.1.2-py2.py3-none-any.whl", hash = "sha256:f55e05f6bf4cc952a87d13594386d32ad2dd265630a8bdfc3df03bd60425c6b0"}, diff --git a/pyproject.toml b/pyproject.toml index 1f02818ee..90b38ce66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" python = "3.9.*" disnake = "~=2.4" # See https://bot-core.pythondiscord.com/ for docs. -bot-core = {url = "https://github.com/python-discord/bot-core/archive/refs/tags/v2.1.0.zip"} +bot-core = {url = "https://github.com/python-discord/bot-core/archive/511bcba1b0196cd498c707a525ea56921bd971db.zip"} aio-pika = "~=6.1" aiodns = "~=2.0" aiohttp = "~=3.7" diff --git a/tests/README.md b/tests/README.md index fc03b3d43..b7fddfaa2 100644 --- a/tests/README.md +++ b/tests/README.md @@ -121,9 +121,9 @@ As we are trying to test our "units" of code independently, we want to make sure However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks". -To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `disnake` types (see the section on the below.). +To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.). -An example of mocking is when we provide a command with a mocked version of `disnake.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): +An example of mocking is when we provide a command with a mocked version of `discord.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): ```py import asyncio @@ -152,15 +152,15 @@ class BotCogTests(unittest.TestCase): By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. The [`AsyncMock`](https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock) that has been [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest) is an asynchronous version of `MagicMock` that can be used anywhere a coroutine is expected. -### Special mocks for some `disnake` types +### Special mocks for some `discord.py` types To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass. -In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual disnake types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `disnake` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. +In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR. -**Note:** These mock types only "know" the attributes that are set by default when these `disnake` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: +**Note:** These mock types only "know" the attributes that are set by default when these `discord.py` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: ```py import unittest.mock @@ -245,7 +245,7 @@ All in all, it's not only important to consider if all statements or branches we ### Unit Testing vs Integration Testing -Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `disnake` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. +Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `discord.py` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. The answer to this is that we also need to make sure that the individual parts come together into a working application. In addition, we will also need to make sure that the application communicates correctly with external applications. Since we currently have no automated integration tests or functional tests, that means **it's still very important to fire up the bot and test the code you've written manually** in addition to the unit tests you've written. diff --git a/tests/base.py b/tests/base.py index dea7dd678..5e304ea9d 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,8 +3,8 @@ import unittest from contextlib import contextmanager from typing import Dict -import disnake -from disnake.ext import commands +import discord +from discord.ext import commands from bot.log import get_logger from tests import helpers @@ -80,7 +80,7 @@ class LoggingTestsMixin: class CommandTestCase(unittest.IsolatedAsyncioTestCase): - """TestCase with additional assertions that are useful for testing disnake commands.""" + """TestCase with additional assertions that are useful for testing Discord commands.""" async def assertHasPermissionsCheck( # noqa: N802 self, @@ -98,7 +98,7 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): permissions = {k: not v for k, v in permissions.items()} ctx = helpers.MockContext() - ctx.channel.permissions_for.return_value = disnake.Permissions(**permissions) + ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) with self.assertRaises(commands.MissingPermissions) as cm: await cmd.can_run(ctx) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 4ed7de64d..fdd0ab74a 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -import disnake +import discord from bot import constants from bot.api import ResponseCodeError @@ -257,9 +257,9 @@ class SyncCogListenerTests(SyncCogTestCase): self.assertTrue(self.cog.on_member_update.__cog_listener__) subtests = ( - ("activities", disnake.Game("Pong"), disnake.Game("Frogger")), + ("activities", discord.Game("Pong"), discord.Game("Frogger")), ("nick", "old nick", "new nick"), - ("status", disnake.Status.online, disnake.Status.offline), + ("status", discord.Status.online, discord.Status.offline), ) for attribute, old_value, new_value in subtests: diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 9ecb8fae0..541074336 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -import disnake +import discord from bot.exts.backend.sync._syncers import RoleSyncer, _Diff, _Role from tests import helpers @@ -34,8 +34,8 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): for role in roles: mock_role = helpers.MockRole(**role) - mock_role.colour = disnake.Colour(role["colour"]) - mock_role.permissions = disnake.Permissions(role["permissions"]) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) guild.roles.append(mock_role) return guild diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index f55f5360f..2fc97af2d 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from disnake.errors import NotFound +from discord.errors import NotFound from bot.exts.backend.sync._syncers import UserSyncer, _Diff from tests import helpers diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 83b5f2749..35fa0ee59 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import AsyncMock, MagicMock, call, patch -from disnake.ext.commands import errors +from discord.ext.commands import errors from bot.api import ResponseCodeError from bot.errors import InvalidInfractedUserError, LockedResourceError diff --git a/tests/bot/exts/events/test_code_jams.py b/tests/bot/exts/events/test_code_jams.py index fdff36b61..0856546af 100644 --- a/tests/bot/exts/events/test_code_jams.py +++ b/tests/bot/exts/events/test_code_jams.py @@ -1,8 +1,8 @@ import unittest from unittest.mock import AsyncMock, MagicMock, create_autospec, patch -from disnake import CategoryChannel -from disnake.ext.commands import BadArgument +from discord import CategoryChannel +from discord.ext.commands import BadArgument from bot.constants import Roles from bot.exts.events import code_jams diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index 0cab405d0..06d78de9d 100644 --- a/tests/bot/exts/filters/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import AsyncMock, Mock -from disnake import NotFound +from discord import NotFound from bot.constants import Channels, STAFF_ROLES from bot.exts.filters import antimalware diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py index 46fa82fd7..c0c3baa42 100644 --- a/tests/bot/exts/filters/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock -from disnake.ext.commands import NoPrivateMessage +from discord.ext.commands import NoPrivateMessage from bot.exts.filters import security from tests.helpers import MockBot, MockContext diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py index dd56c10dd..4db27269a 100644 --- a/tests/bot/exts/filters/test_token_remover.py +++ b/tests/bot/exts/filters/test_token_remover.py @@ -3,7 +3,7 @@ from re import Match from unittest import mock from unittest.mock import MagicMock -from disnake import Colour, NotFound +from discord import Colour, NotFound from bot import constants from bot.exts.filters import token_remover diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 9a35de7a9..d896b7652 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -3,7 +3,7 @@ import unittest import unittest.mock from datetime import datetime -import disnake +import discord from bot import constants from bot.exts.info import information @@ -43,7 +43,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): embed = kwargs.pop('embed') self.assertEqual(embed.title, "Role information (Total 1 role)") - self.assertEqual(embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(embed.colour, discord.Colour.og_blurple()) self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") async def test_role_info_command(self): @@ -51,19 +51,19 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): dummy_role = helpers.MockRole( name="Dummy", id=112233445566778899, - colour=disnake.Colour.og_blurple(), + colour=discord.Colour.og_blurple(), position=10, members=[self.ctx.author], - permissions=disnake.Permissions(0) + permissions=discord.Permissions(0) ) admin_role = helpers.MockRole( name="Admins", id=998877665544332211, - colour=disnake.Colour.red(), + colour=discord.Colour.red(), position=3, members=[self.ctx.author], - permissions=disnake.Permissions(0), + permissions=discord.Permissions(0), ) self.ctx.guild.roles.extend([dummy_role, admin_role]) @@ -81,7 +81,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): admin_embed = admin_kwargs["embed"] self.assertEqual(dummy_embed.title, "Dummy info") - self.assertEqual(dummy_embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(dummy_embed.colour, discord.Colour.og_blurple()) self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") @@ -91,7 +91,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(dummy_embed.fields[5].value, "0") self.assertEqual(admin_embed.title, "Admins info") - self.assertEqual(admin_embed.colour, disnake.Colour.red()) + self.assertEqual(admin_embed.colour, discord.Colour.red()) class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase): @@ -449,7 +449,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) - self.assertEqual(embed.colour, disnake.Colour(100)) + self.assertEqual(embed.colour, discord.Colour(100)) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", @@ -463,11 +463,11 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): """The embed should be created with the og blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() - user = helpers.MockMember(id=217, colour=disnake.Colour.default()) + user = helpers.MockMember(id=217, colour=discord.Colour.default()) user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) - self.assertEqual(embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(embed.colour, discord.Colour.og_blurple()) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index b85d086c9..052048053 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -3,7 +3,7 @@ import textwrap import unittest from unittest.mock import ANY, AsyncMock, DEFAULT, MagicMock, Mock, patch -from disnake.errors import NotFound +from discord.errors import NotFound from bot.constants import Event from bot.exts.moderation.clean import Clean diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index eaa0e701e..ff81ddd65 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -3,7 +3,7 @@ from collections import namedtuple from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch -from disnake import Embed, Forbidden, HTTPException, NotFound +from discord import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError from bot.constants import Colours, Icons diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 725455bbe..cfe0c4b03 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -7,7 +7,7 @@ from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp -import disnake +import discord from async_rediscache import RedisSession from bot.constants import Colours @@ -24,7 +24,7 @@ class MockAsyncIterable: Helper for mocking asynchronous for loops. It does not appear that the `unittest` library currently provides anything that would - allow us to simply mock an async iterator, such as `disnake.TextChannel.history`. + allow us to simply mock an async iterator, such as `discord.TextChannel.history`. We therefore write our own helper to wrap a regular synchronous iterable, and feed its values via `__anext__` rather than `__next__`. @@ -60,7 +60,7 @@ class MockSignal(enum.Enum): B = "B" -mock_404 = disnake.NotFound( +mock_404 = discord.NotFound( response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response message="Not found", ) @@ -70,8 +70,8 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): """Collection of tests for the `download_file` helper function.""" async def test_download_file_success(self): - """If `to_file` succeeds, function returns the acquired `disnake.File`.""" - file = MagicMock(disnake.File, filename="bigbadlemon.jpg") + """If `to_file` succeeds, function returns the acquired `discord.File`.""" + file = MagicMock(discord.File, filename="bigbadlemon.jpg") attachment = MockAttachment(to_file=AsyncMock(return_value=file)) acquired_file = await incidents.download_file(attachment) @@ -86,7 +86,7 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): async def test_download_file_fail(self): """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" - arbitrary_error = disnake.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") + arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) with self.assertLogs(logger=incidents.log, level=logging.ERROR): @@ -121,7 +121,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_with_attachment_succeeds(self): """Incident's attachment is downloaded and displayed in the embed's image field.""" - file = MagicMock(disnake.File, filename="bigbadjoe.jpg") + file = MagicMock(discord.File, filename="bigbadjoe.jpg") attachment = MockAttachment(filename="bigbadjoe.jpg") incident = MockMessage(content="this is an incident", attachments=[attachment]) @@ -394,7 +394,7 @@ class TestArchive(TestIncidents): author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), id=123, ) - built_embed = MagicMock(disnake.Embed, id=123) # We patch `make_embed` to return this + built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this with patch("bot.exts.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) @@ -616,7 +616,7 @@ class TestResolveMessage(TestIncidents): """ self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - arbitrary_error = disnake.HTTPException( + arbitrary_error = discord.HTTPException( response=MagicMock(aiohttp.ClientResponse), message="Arbitrary error", ) @@ -649,7 +649,7 @@ class TestOnRawReactionAdd(TestIncidents): super().setUp() # Ensure `cog_instance` is assigned self.payload = MagicMock( - disnake.RawReactionActionEvent, + discord.RawReactionActionEvent, channel_id=123, # Patched at class level message_id=456, member=MockMember(bot=False), diff --git a/tests/bot/exts/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py index 6c9ebed95..79e04837d 100644 --- a/tests/bot/exts/moderation/test_modlog.py +++ b/tests/bot/exts/moderation/test_modlog.py @@ -1,6 +1,6 @@ import unittest -import disnake +import discord from bot.exts.moderation.modlog import ModLog from tests.helpers import MockBot, MockTextChannel @@ -19,7 +19,7 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase): self.bot.get_channel.return_value = self.channel await self.cog.send_log_message( icon_url="foo", - colour=disnake.Colour.blue(), + colour=discord.Colour.blue(), title="bar", text="foo bar" * 3000 ) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 539651d6c..92ce3418a 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -7,7 +7,7 @@ from unittest import mock from unittest.mock import AsyncMock, Mock from async_rediscache import RedisSession -from disnake import PermissionOverwrite +from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence @@ -152,7 +152,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. self.assertTrue(self.cog._init_task.cancelled()) - @autospec("disnake.ext.commands", "has_any_role") + @autospec("discord.ext.commands", "has_any_role") @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): """Role check was called with `MODERATION_ROLES`""" diff --git a/tests/bot/exts/test_cogs.py b/tests/bot/exts/test_cogs.py index 5cb071d58..f8e120262 100644 --- a/tests/bot/exts/test_cogs.py +++ b/tests/bot/exts/test_cogs.py @@ -8,7 +8,7 @@ from collections import defaultdict from types import ModuleType from unittest import mock -from disnake.ext import commands +from discord.ext import commands from bot import exts @@ -34,7 +34,7 @@ class CommandNameTests(unittest.TestCase): raise ImportError(name=name) # pragma: no cover # The mock prevents asyncio.get_event_loop() from being called. - with mock.patch("disnake.ext.tasks.loop"): + with mock.patch("discord.ext.tasks.loop"): prefix = f"{exts.__name__}." for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error): if not module.ispkg: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index bec7574fb..8bdeedd27 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -2,8 +2,8 @@ import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch -from disnake import AllowedMentions -from disnake.ext import commands +from discord import AllowedMentions +from discord.ext import commands from bot import constants from bot.exts.utils import snekbox diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index afb8a973d..1bb678db2 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -4,7 +4,7 @@ from datetime import MAXYEAR, datetime, timezone from unittest.mock import MagicMock, patch from dateutil.relativedelta import relativedelta -from disnake.ext.commands import BadArgument +from discord.ext.commands import BadArgument from bot.converters import Duration, HushDurationConverter, ISODateTime, PackageName diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 5675e10ec..4ae11d5d3 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock -from disnake import DMChannel +from discord import DMChannel from bot.utils import checks from bot.utils.checks import InWhitelistCheckFailure diff --git a/tests/helpers.py b/tests/helpers.py index bd1418ab9..9d4988d23 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,9 +7,9 @@ import unittest.mock from asyncio import AbstractEventLoop from typing import Iterable, Optional -import disnake +import discord from aiohttp import ClientSession -from disnake.ext.commands import Context +from discord.ext.commands import Context from bot.api import APIClient from bot.async_stats import AsyncStatsClient @@ -26,11 +26,11 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -class HashableMixin(disnake.mixins.EqualityComparable): +class HashableMixin(discord.mixins.EqualityComparable): """ - Mixin that provides similar hashing and equality functionality as disnake's `Hashable` mixin. + Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. - Note: disnake`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions + Note: discord.py`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions for the relative small `id` integers we generally use in tests, this bit-shift is omitted. """ @@ -39,22 +39,22 @@ class HashableMixin(disnake.mixins.EqualityComparable): class ColourMixin: - """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like disnake does.""" + """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like discord.py does.""" @property - def color(self) -> disnake.Colour: + def color(self) -> discord.Colour: return self.colour @color.setter - def color(self, color: disnake.Colour) -> None: + def color(self, color: discord.Colour) -> None: self.colour = color @property - def accent_color(self) -> disnake.Colour: + def accent_color(self) -> discord.Colour: return self.accent_colour @accent_color.setter - def accent_color(self, color: disnake.Colour) -> None: + def accent_color(self, color: discord.Colour) -> None: self.accent_colour = color @@ -63,7 +63,7 @@ class CustomMockMixin: Provides common functionality for our custom Mock types. The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock - object. As disnake also uses synchronous methods that nonetheless return coroutine objects, the + object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The class method `spec_set` can be overwritten with the object that should be uses as the specification @@ -119,7 +119,7 @@ class CustomMockMixin: return klass(**kw) -# Create a guild instance to get a realistic Mock of `disnake.Guild` +# Create a guild instance to get a realistic Mock of `discord.Guild` guild_data = { 'id': 1, 'name': 'guild', @@ -139,20 +139,20 @@ guild_data = { 'owner_id': 1, 'afk_channel_id': 464033278631084042, } -guild_instance = disnake.Guild(data=guild_data, state=unittest.mock.MagicMock()) +guild_instance = discord.Guild(data=guild_data, state=unittest.mock.MagicMock()) class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ - A `Mock` subclass to mock `disnake.Guild` objects. + A `Mock` subclass to mock `discord.Guild` objects. - A MockGuild instance will follow the specifications of a `disnake.Guild` instance. This means + A MockGuild instance will follow the specifications of a `discord.Guild` instance. This means that if the code you're testing tries to access an attribute or method that normally does not - exist for a `disnake.Guild` object this will raise an `AttributeError`. This is to make sure our - tests fail if the code we're testing uses a `disnake.Guild` object in the wrong way. + exist for a `discord.Guild` object this will raise an `AttributeError`. This is to make sure our + tests fail if the code we're testing uses a `discord.Guild` object in the wrong way. One restriction of that is that if the code tries to access an attribute that normally does not - exist for `disnake.Guild` instance but was added dynamically, this will raise an exception with + exist for `discord.Guild` instance but was added dynamically, this will raise an exception with the mocked object. To get around that, you can set the non-standard attribute explicitly for the instance of `MockGuild`: @@ -160,10 +160,10 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): >>> guild.attribute_that_normally_does_not_exist = unittest.mock.MagicMock() In addition to attribute simulation, mocked guild object will pass an `isinstance` check against - `disnake.Guild`: + `discord.Guild`: >>> guild = MockGuild() - >>> isinstance(guild, disnake.Guild) + >>> isinstance(guild, discord.Guild) True For more info, see the `Mocking` section in `tests/README.md`. @@ -179,16 +179,16 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): self.roles.extend(roles) -# Create a Role instance to get a realistic Mock of `disnake.Role` +# Create a Role instance to get a realistic Mock of `discord.Role` role_data = {'name': 'role', 'id': 1} -role_instance = disnake.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) +role_instance = discord.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ - A Mock subclass to mock `disnake.Role` objects. + A Mock subclass to mock `discord.Role` objects. - Instances of this class will follow the specifications of `disnake.Role` instances. For more + Instances of this class will follow the specifications of `discord.Role` instances. For more information, see the `MockGuild` docstring. """ spec_set = role_instance @@ -198,40 +198,40 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): 'id': next(self.discord_id), 'name': 'role', 'position': 1, - 'colour': disnake.Colour(0xdeadbf), - 'permissions': disnake.Permissions(), + 'colour': discord.Colour(0xdeadbf), + 'permissions': discord.Permissions(), } super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if isinstance(self.colour, int): - self.colour = disnake.Colour(self.colour) + self.colour = discord.Colour(self.colour) if isinstance(self.permissions, int): - self.permissions = disnake.Permissions(self.permissions) + self.permissions = discord.Permissions(self.permissions) if 'mention' not in kwargs: self.mention = f'&{self.name}' def __lt__(self, other): - """Simplified position-based comparisons similar to those of `disnake.Role`.""" + """Simplified position-based comparisons similar to those of `discord.Role`.""" return self.position < other.position def __ge__(self, other): - """Simplified position-based comparisons similar to those of `disnake.Role`.""" + """Simplified position-based comparisons similar to those of `discord.Role`.""" return self.position >= other.position -# Create a Member instance to get a realistic Mock of `disnake.Member` +# Create a Member instance to get a realistic Mock of `discord.Member` member_data = {'user': 'lemon', 'roles': [1]} state_mock = unittest.mock.MagicMock() -member_instance = disnake.Member(data=member_data, guild=guild_instance, state=state_mock) +member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock) class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ A Mock subclass to mock Member objects. - Instances of this class will follow the specifications of `disnake.Member` instances. For more + Instances of this class will follow the specifications of `discord.Member` instances. For more information, see the `MockGuild` docstring. """ spec_set = member_instance @@ -249,11 +249,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin self.mention = f"@{self.name}" -# Create a User instance to get a realistic Mock of `disnake.User` +# Create a User instance to get a realistic Mock of `discord.User` _user_data_mock = collections.defaultdict(unittest.mock.MagicMock, { "accent_color": 0 }) -user_instance = disnake.User( +user_instance = discord.User( data=unittest.mock.MagicMock(get=unittest.mock.Mock(side_effect=_user_data_mock.get)), state=unittest.mock.MagicMock() ) @@ -263,7 +263,7 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ A Mock subclass to mock User objects. - Instances of this class will follow the specifications of `disnake.User` instances. For more + Instances of this class will follow the specifications of `discord.User` instances. For more information, see the `MockGuild` docstring. """ spec_set = user_instance @@ -305,7 +305,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Bot objects. - Instances of this class will follow the specifications of `disnake.ext.commands.Bot` instances. + Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ spec_set = Bot( @@ -324,7 +324,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) -# Create a TextChannel instance to get a realistic MagicMock of `disnake.TextChannel` +# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` channel_data = { 'id': 1, 'type': 'TextChannel', @@ -337,17 +337,17 @@ channel_data = { } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -text_channel_instance = disnake.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) channel_data["type"] = "VoiceChannel" -voice_channel_instance = disnake.VoiceChannel(state=state, guild=guild, data=channel_data) +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data) class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock TextChannel objects. - Instances of this class will follow the specifications of `disnake.TextChannel` instances. For + Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = text_channel_instance @@ -364,7 +364,7 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock VoiceChannel objects. - Instances of this class will follow the specifications of `disnake.VoiceChannel` instances. For + Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = voice_channel_instance @@ -381,14 +381,14 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): state = unittest.mock.MagicMock() me = unittest.mock.MagicMock() dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]} -dm_channel_instance = disnake.DMChannel(me=me, state=state, data=dm_channel_data) +dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data) class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock TextChannel objects. - Instances of this class will follow the specifications of `disnake.TextChannel` instances. For + Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = dm_channel_instance @@ -398,17 +398,17 @@ class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): super().__init__(**collections.ChainMap(kwargs, default_kwargs)) -# Create CategoryChannel instance to get a realistic MagicMock of `disnake.CategoryChannel` +# Create CategoryChannel instance to get a realistic MagicMock of `discord.CategoryChannel` category_channel_data = { 'id': 1, - 'type': disnake.ChannelType.category, + 'type': discord.ChannelType.category, 'name': 'category', 'position': 1, } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -category_channel_instance = disnake.CategoryChannel( +category_channel_instance = discord.CategoryChannel( state=state, guild=guild, data=category_channel_data ) @@ -419,7 +419,7 @@ class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): super().__init__(**collections.ChainMap(default_kwargs, kwargs)) -# Create a Message instance to get a realistic MagicMock of `disnake.Message` +# Create a Message instance to get a realistic MagicMock of `discord.Message` message_data = { 'id': 1, 'webhook_id': 431341013479718912, @@ -438,10 +438,10 @@ message_data = { } state = unittest.mock.MagicMock() channel = unittest.mock.MagicMock() -message_instance = disnake.Message(state=state, channel=channel, data=message_data) +message_instance = discord.Message(state=state, channel=channel, data=message_data) -# Create a Context instance to get a realistic MagicMock of `disnake.ext.commands.Context` +# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context` context_instance = Context( message=unittest.mock.MagicMock(), prefix="$", @@ -455,7 +455,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Context objects. - Instances of this class will follow the specifications of `disnake.ext.commands.Context` + Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ spec_set = context_instance @@ -471,14 +471,14 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) -attachment_instance = disnake.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) +attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Attachment objects. - Instances of this class will follow the specifications of `disnake.Attachment` instances. For + Instances of this class will follow the specifications of `discord.Attachment` instances. For more information, see the `MockGuild` docstring. """ spec_set = attachment_instance @@ -488,7 +488,7 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. - Instances of this class will follow the specifications of `disnake.Message` instances. For more + Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ spec_set = message_instance @@ -501,14 +501,14 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'} -emoji_instance = disnake.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) +emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Emoji objects. - Instances of this class will follow the specifications of `disnake.Emoji` instances. For more + Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ spec_set = emoji_instance @@ -518,27 +518,27 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): self.guild = kwargs.get('guild', MockGuild()) -partial_emoji_instance = disnake.PartialEmoji(animated=False, name='guido') +partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido') class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock PartialEmoji objects. - Instances of this class will follow the specifications of `disnake.PartialEmoji` instances. For + Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ spec_set = partial_emoji_instance -reaction_instance = disnake.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) +reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) class MockReaction(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Reaction objects. - Instances of this class will follow the specifications of `disnake.Reaction` instances. For + Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ spec_set = reaction_instance @@ -556,14 +556,14 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): self.__str__.return_value = str(self.emoji) -webhook_instance = disnake.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock()) +webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock()) class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter. - Instances of this class will follow the specifications of `disnake.Webhook` instances. For + Instances of this class will follow the specifications of `discord.Webhook` instances. For more information, see the `MockGuild` docstring. """ spec_set = webhook_instance diff --git a/tests/test_helpers.py b/tests/test_helpers.py index c5e799a85..81285e009 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,20 +2,20 @@ import asyncio import unittest import unittest.mock -import disnake +import discord from tests import helpers class DiscordMocksTests(unittest.TestCase): - """Tests for our specialized disnake mocks.""" + """Tests for our specialized discord.py mocks.""" def test_mock_role_default_initialization(self): """Test if the default initialization of MockRole results in the correct object.""" role = helpers.MockRole() - # The `spec` argument makes sure `isistance` checks with `disnake.Role` pass - self.assertIsInstance(role, disnake.Role) + # The `spec` argument makes sure `isistance` checks with `discord.Role` pass + self.assertIsInstance(role, discord.Role) self.assertEqual(role.name, "role") self.assertEqual(role.position, 1) @@ -61,8 +61,8 @@ class DiscordMocksTests(unittest.TestCase): """Test if the default initialization of Mockmember results in the correct object.""" member = helpers.MockMember() - # The `spec` argument makes sure `isistance` checks with `disnake.Member` pass - self.assertIsInstance(member, disnake.Member) + # The `spec` argument makes sure `isistance` checks with `discord.Member` pass + self.assertIsInstance(member, discord.Member) self.assertEqual(member.name, "member") self.assertListEqual(member.roles, [helpers.MockRole(name="@everyone", position=1, id=0)]) @@ -86,18 +86,18 @@ class DiscordMocksTests(unittest.TestCase): """Test if MockMember accepts and sets abitrary keyword arguments.""" member = helpers.MockMember( nick="Dino Man", - colour=disnake.Colour.default(), + colour=discord.Colour.default(), ) self.assertEqual(member.nick, "Dino Man") - self.assertEqual(member.colour, disnake.Colour.default()) + self.assertEqual(member.colour, discord.Colour.default()) def test_mock_guild_default_initialization(self): """Test if the default initialization of Mockguild results in the correct object.""" guild = helpers.MockGuild() - # The `spec` argument makes sure `isistance` checks with `disnake.Guild` pass - self.assertIsInstance(guild, disnake.Guild) + # The `spec` argument makes sure `isistance` checks with `discord.Guild` pass + self.assertIsInstance(guild, discord.Guild) self.assertListEqual(guild.roles, [helpers.MockRole(name="@everyone", position=1, id=0)]) self.assertListEqual(guild.members, []) @@ -127,15 +127,15 @@ class DiscordMocksTests(unittest.TestCase): """Tests if MockBot initializes with the correct values.""" bot = helpers.MockBot() - # The `spec` argument makes sure `isistance` checks with `disnake.ext.commands.Bot` pass - self.assertIsInstance(bot, disnake.ext.commands.Bot) + # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Bot` pass + self.assertIsInstance(bot, discord.ext.commands.Bot) def test_mock_context_default_initialization(self): """Tests if MockContext initializes with the correct values.""" context = helpers.MockContext() - # The `spec` argument makes sure `isistance` checks with `disnake.ext.commands.Context` pass - self.assertIsInstance(context, disnake.ext.commands.Context) + # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Context` pass + self.assertIsInstance(context, discord.ext.commands.Context) self.assertIsInstance(context.bot, helpers.MockBot) self.assertIsInstance(context.guild, helpers.MockGuild) -- cgit v1.2.3 From 43b6fee9eba12a6836530029a642cba6e7e505f0 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 19 Mar 2022 16:34:01 +0000 Subject: Use bot-core scheduling and member util functions --- bot/__init__.py | 17 +- bot/async_stats.py | 3 +- bot/converters.py | 2 +- bot/decorators.py | 3 +- bot/exts/backend/logging.py | 2 +- bot/exts/filters/antispam.py | 3 +- bot/exts/filters/filtering.py | 4 +- bot/exts/filters/token_remover.py | 5 +- bot/exts/fun/off_topic_names.py | 2 +- bot/exts/help_channels/_cog.py | 3 +- bot/exts/info/codeblock/_cog.py | 3 +- bot/exts/info/doc/_batch_parser.py | 2 +- bot/exts/info/doc/_cog.py | 4 +- bot/exts/info/subscribe.py | 2 +- bot/exts/moderation/defcon.py | 5 +- bot/exts/moderation/incidents.py | 2 +- bot/exts/moderation/infraction/_scheduler.py | 3 +- bot/exts/moderation/metabase.py | 5 +- bot/exts/moderation/modpings.py | 5 +- bot/exts/moderation/silence.py | 4 +- bot/exts/moderation/stream.py | 3 +- bot/exts/moderation/watchchannels/_watchchannel.py | 3 +- bot/exts/recruitment/talentpool/_cog.py | 3 +- bot/exts/recruitment/talentpool/_review.py | 2 +- bot/exts/utils/reminders.py | 5 +- bot/exts/utils/snekbox.py | 5 +- bot/monkey_patches.py | 76 -------- bot/utils/messages.py | 2 +- bot/utils/scheduling.py | 194 --------------------- tests/bot/exts/backend/sync/test_cog.py | 2 +- tests/bot/exts/filters/test_filtering.py | 2 +- 31 files changed, 53 insertions(+), 323 deletions(-) delete mode 100644 bot/monkey_patches.py delete mode 100644 bot/utils/scheduling.py (limited to 'tests') diff --git a/bot/__init__.py b/bot/__init__.py index 17d99105a..c652897be 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -1,11 +1,10 @@ import asyncio import os -from functools import partial, partialmethod from typing import TYPE_CHECKING -from discord.ext import commands +from botcore.utils import apply_monkey_patches -from bot import log, monkey_patches +from bot import log if TYPE_CHECKING: from bot.bot import Bot @@ -16,16 +15,6 @@ log.setup() if os.name == "nt": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -monkey_patches.patch_typing() - -# This patches any convertors that use PartialMessage, but not the PartialMessageConverter itself -# as library objects are made by this mapping. -# https://github.com/Rapptz/discord.py/blob/1a4e73d59932cdbe7bf2c281f25e32529fc7ae1f/discord/ext/commands/converter.py#L984-L1004 -commands.converter.PartialMessageConverter = monkey_patches.FixedPartialMessageConverter - -# Monkey-patch discord.py decorators to use the Command subclass which supports root aliases. -# Must be patched before any cogs are added. -commands.command = partial(commands.command, cls=monkey_patches.Command) -commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=monkey_patches.Command) +apply_monkey_patches() instance: "Bot" = None # Global Bot instance. diff --git a/bot/async_stats.py b/bot/async_stats.py index 2af832e5b..0303de7a1 100644 --- a/bot/async_stats.py +++ b/bot/async_stats.py @@ -1,10 +1,9 @@ import asyncio import socket +from botcore.utils import scheduling from statsd.client.base import StatsClientBase -from bot.utils import scheduling - class AsyncStatsClient(StatsClientBase): """An async transport method for statsd communication.""" diff --git a/bot/converters.py b/bot/converters.py index 3522a32aa..e819e4713 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -8,7 +8,7 @@ from ssl import CertificateError import dateutil.parser import discord from aiohttp import ClientConnectorError -from botcore.regex import DISCORD_INVITE +from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter from discord.utils import escape_markdown, snowflake_time diff --git a/bot/decorators.py b/bot/decorators.py index 8971898b3..466770c3a 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -5,13 +5,14 @@ import typing as t from contextlib import suppress import arrow +from botcore.utils import scheduling from discord import Member, NotFound from discord.ext import commands from discord.ext.commands import Cog, Context from bot.constants import Channels, DEBUG_MODE, RedirectOutput from bot.log import get_logger -from bot.utils import function, scheduling +from bot.utils import function from bot.utils.checks import ContextCheckFailure, in_whitelist_check from bot.utils.function import command_wraps diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py index 2d03cd580..469331ae5 100644 --- a/bot/exts/backend/logging.py +++ b/bot/exts/backend/logging.py @@ -1,10 +1,10 @@ +from botcore.utils import scheduling from discord import Embed from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE from bot.log import get_logger -from bot.utils import scheduling log = get_logger(__name__) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index bcd845a43..d9e23b25e 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -8,6 +8,7 @@ from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set import arrow +from botcore.utils import scheduling from discord import Colour, Member, Message, NotFound, Object, TextChannel from discord.ext.commands import Cog @@ -20,7 +21,7 @@ from bot.converters import Duration from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import lock, scheduling +from bot.utils import lock from bot.utils.message_cache import MessageCache from bot.utils.messages import format_user, send_attachments diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index b9f2a0e51..32efcc307 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -9,7 +9,8 @@ import dateutil.parser import regex import tldextract from async_rediscache import RedisCache -from botcore.regex import DISCORD_INVITE +from botcore.utils import scheduling +from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord import ChannelType, Colour, Embed, Forbidden, HTTPException, Member, Message, NotFound, TextChannel from discord.ext.commands import Cog @@ -21,7 +22,6 @@ from bot.constants import Channels, Colours, Filter, Guild, Icons, URLs from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import scheduling from bot.utils.messages import format_user log = get_logger(__name__) diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filters/token_remover.py index 520283ba3..436e6dc19 100644 --- a/bot/exts/filters/token_remover.py +++ b/bot/exts/filters/token_remover.py @@ -1,5 +1,4 @@ import base64 -import binascii import re import typing as t @@ -182,7 +181,7 @@ class TokenRemover(Cog): # that means it's not a valid user id. return None return int(string) - except (binascii.Error, ValueError): + except ValueError: return None @staticmethod @@ -198,7 +197,7 @@ class TokenRemover(Cog): try: decoded_bytes = base64.urlsafe_b64decode(b64_content) timestamp = int.from_bytes(decoded_bytes, byteorder="big") - except (binascii.Error, ValueError) as e: + except ValueError as e: log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") return False diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index 7df1d172d..33f43f2a8 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -2,6 +2,7 @@ import difflib from datetime import timedelta import arrow +from botcore.utils import scheduling from discord import Colour, Embed from discord.ext.commands import Cog, Context, group, has_any_role from discord.utils import sleep_until @@ -12,7 +13,6 @@ from bot.constants import Channels, MODERATION_ROLES from bot.converters import OffTopicName from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) log = get_logger(__name__) diff --git a/bot/exts/help_channels/_cog.py b/bot/exts/help_channels/_cog.py index a93acffb6..fc80c968c 100644 --- a/bot/exts/help_channels/_cog.py +++ b/bot/exts/help_channels/_cog.py @@ -7,6 +7,7 @@ from operator import attrgetter import arrow import discord import discord.abc +from botcore.utils import members, scheduling from discord.ext import commands from bot import constants @@ -14,7 +15,7 @@ from bot.bot import Bot from bot.constants import Channels, RedirectOutput from bot.exts.help_channels import _caches, _channel, _message, _name, _stats from bot.log import get_logger -from bot.utils import channel as channel_utils, lock, members, scheduling +from bot.utils import channel as channel_utils, lock log = get_logger(__name__) diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index a859d8cef..9027105d9 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -2,6 +2,7 @@ import time from typing import Optional import discord +from botcore.utils import scheduling from discord import Message, RawMessageUpdateEvent from discord.ext.commands import Cog @@ -11,7 +12,7 @@ from bot.exts.filters.token_remover import TokenRemover from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE from bot.exts.info.codeblock._instructions import get_instructions from bot.log import get_logger -from bot.utils import has_lines, scheduling +from bot.utils import has_lines from bot.utils.channel import is_help_channel from bot.utils.messages import wait_for_deletion diff --git a/bot/exts/info/doc/_batch_parser.py b/bot/exts/info/doc/_batch_parser.py index c27f28eac..41a15fb6e 100644 --- a/bot/exts/info/doc/_batch_parser.py +++ b/bot/exts/info/doc/_batch_parser.py @@ -8,12 +8,12 @@ from operator import attrgetter from typing import Deque, Dict, List, NamedTuple, Optional, Union import discord +from botcore.utils import scheduling from bs4 import BeautifulSoup import bot from bot.constants import Channels from bot.log import get_logger -from bot.utils import scheduling from . import _cog, doc_cache from ._parsing import get_symbol_markdown diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 4dc5276d9..3789fdbe3 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -10,6 +10,8 @@ from typing import Dict, NamedTuple, Optional, Tuple, Union import aiohttp import discord +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord.ext import commands from bot.api import ResponseCodeError @@ -18,10 +20,8 @@ from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import Inventory, PackageName, ValidURL, allowed_strings from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling from bot.utils.lock import SharedEvent, lock from bot.utils.messages import send_denial, wait_for_deletion -from bot.utils.scheduling import Scheduler from . import NAMESPACE, PRIORITY_PACKAGES, _batch_parser, doc_cache from ._inventory_parser import InvalidHeaderError, InventoryDict, fetch_inventory diff --git a/bot/exts/info/subscribe.py b/bot/exts/info/subscribe.py index eff0c13b8..ed134ff78 100644 --- a/bot/exts/info/subscribe.py +++ b/bot/exts/info/subscribe.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import arrow import discord +from botcore.utils import members, scheduling from discord.ext import commands from discord.interactions import Interaction @@ -12,7 +13,6 @@ from bot import constants from bot.bot import Bot from bot.decorators import redirect_output from bot.log import get_logger -from bot.utils import members, scheduling @dataclass(frozen=True) diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py index 178be734d..a8640cb1b 100644 --- a/bot/exts/moderation/defcon.py +++ b/bot/exts/moderation/defcon.py @@ -7,6 +7,8 @@ from typing import Optional, Union import arrow from aioredis import RedisError from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.relativedelta import relativedelta from discord import Colour, Embed, Forbidden, Member, TextChannel, User from discord.ext import tasks @@ -17,9 +19,8 @@ from bot.constants import Channels, Colours, Emojis, Event, Icons, MODERATION_RO from bot.converters import DurationDelta, Expiry from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.messages import format_user -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index b579416a6..d34c1c7fa 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -6,12 +6,12 @@ from typing import Optional import discord from async_rediscache import RedisCache +from botcore.utils import scheduling from discord.ext.commands import Cog, Context, MessageConverter, MessageNotFound from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Guild, Roles, Webhooks from bot.log import get_logger -from bot.utils import scheduling from bot.utils.messages import format_user, sub_clyde log = get_logger(__name__) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 2fc54856f..9f5800e2a 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -6,6 +6,7 @@ from gettext import ngettext import arrow import dateutil.parser import discord +from botcore.utils import scheduling from discord.ext.commands import Context from bot import constants @@ -16,7 +17,7 @@ from bot.converters import MemberOrUser from bot.exts.moderation.infraction import _utils from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import messages, scheduling, time +from bot.utils import messages, time from bot.utils.channel import is_mod_channel log = get_logger(__name__) diff --git a/bot/exts/moderation/metabase.py b/bot/exts/moderation/metabase.py index ce9c220b3..d68726faf 100644 --- a/bot/exts/moderation/metabase.py +++ b/bot/exts/moderation/metabase.py @@ -8,15 +8,16 @@ import arrow from aiohttp.client_exceptions import ClientResponseError from arrow import Arrow from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Metabase as MetabaseConfig, Roles from bot.converters import allowed_strings from bot.log import get_logger -from bot.utils import scheduling, send_to_paste_service +from bot.utils import send_to_paste_service from bot.utils.channel import is_mod_channel -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index b5cd29b12..cb1e4fd05 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -3,6 +3,8 @@ import datetime import arrow from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse, parse as dateutil_parse from discord import Embed, Member from discord.ext.commands import Cog, Context, group, has_any_role @@ -11,8 +13,7 @@ from bot.bot import Bot from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles from bot.converters import Expiry from bot.log import get_logger -from bot.utils import scheduling, time -from bot.utils.scheduling import Scheduler +from bot.utils import time log = get_logger(__name__) diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 511520252..307729181 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -5,6 +5,8 @@ from datetime import datetime, timedelta, timezone from typing import Optional, OrderedDict, Union from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel from discord.ext import commands, tasks from discord.ext.commands import Context @@ -14,9 +16,7 @@ from bot import constants from bot.bot import Bot from bot.converters import HushDurationConverter from bot.log import get_logger -from bot.utils import scheduling from bot.utils.lock import LockedResourceError, lock, lock_arg -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 985cc6eb1..17d24eb89 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -5,6 +5,7 @@ import arrow import discord from arrow import Arrow from async_rediscache import RedisCache +from botcore.utils import scheduling from discord.ext import commands from bot.bot import Bot @@ -14,7 +15,7 @@ from bot.constants import ( from bot.converters import Expiry from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.members import get_or_fetch_member log = get_logger(__name__) diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index ee9b6ba45..bae7ecd02 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional import discord +from botcore.utils import scheduling from discord import Color, DMChannel, Embed, HTTPException, Message, errors from discord.ext.commands import Cog, Context @@ -18,7 +19,7 @@ from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE from bot.exts.moderation.modlog import ModLog from bot.log import CustomLogger, get_logger from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages, scheduling, time +from bot.utils import CogABCMeta, messages, time from bot.utils.members import get_or_fetch_member log = get_logger(__name__) diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 0554bf37a..0d51af2ca 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -5,6 +5,7 @@ from typing import Optional, Union import discord from async_rediscache import RedisCache +from botcore.utils import scheduling from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User from discord.ext.commands import BadArgument, Cog, Context, group, has_any_role @@ -15,7 +16,7 @@ from bot.converters import MemberOrUser, UnambiguousMemberOrUser from bot.exts.recruitment.talentpool._review import Reviewer from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.members import get_or_fetch_member AUTOREVIEW_ENABLED_KEY = "autoreview_enabled" diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index b4d177622..214d85851 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta from typing import List, Optional, Union import arrow +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse from discord import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel from discord.ext.commands import Context @@ -20,7 +21,6 @@ from bot.log import get_logger from bot.utils import time from bot.utils.members import get_or_fetch_member from bot.utils.messages import count_unique_users_reaction, pin_no_system_message -from bot.utils.scheduling import Scheduler if typing.TYPE_CHECKING: from bot.exts.recruitment.talentpool._cog import TalentPool diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index ad82d49c9..62603697c 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -5,6 +5,8 @@ from datetime import datetime, timezone from operator import itemgetter import discord +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse from discord.ext.commands import Cog, Context, Greedy, group @@ -13,12 +15,11 @@ from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Role from bot.converters import Duration, UnambiguousUser from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.checks import has_any_role_check, has_no_roles_check from bot.utils.lock import lock_arg from bot.utils.members import get_or_fetch_member from bot.utils.messages import send_denial -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 3c1009d2a..2b073ed72 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -7,7 +7,8 @@ from signal import Signals from textwrap import dedent from typing import Optional, Tuple -from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX +from botcore.utils import scheduling +from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only @@ -15,7 +16,7 @@ from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs from bot.decorators import redirect_output from bot.log import get_logger -from bot.utils import scheduling, send_to_paste_service +from bot.utils import send_to_paste_service from bot.utils.messages import wait_for_deletion log = get_logger(__name__) diff --git a/bot/monkey_patches.py b/bot/monkey_patches.py deleted file mode 100644 index 4840fa454..000000000 --- a/bot/monkey_patches.py +++ /dev/null @@ -1,76 +0,0 @@ -import re -from datetime import timedelta - -import arrow -from discord import Forbidden, http -from discord.ext import commands - -from bot.log import get_logger - -log = get_logger(__name__) -MESSAGE_ID_RE = re.compile(r'(?P[0-9]{15,20})$') - - -class Command(commands.Command): - """ - A `discord.ext.commands.Command` subclass which supports root aliases. - - A `root_aliases` keyword argument is added, which is a sequence of alias names that will act as - top-level commands rather than being aliases of the command's group. It's stored as an attribute - also named `root_aliases`. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.root_aliases = kwargs.get("root_aliases", []) - - if not isinstance(self.root_aliases, (list, tuple)): - raise TypeError("Root aliases of a command must be a list or a tuple of strings.") - - -def patch_typing() -> None: - """ - Sometimes discord turns off typing events by throwing 403's. - - Handle those issues by patching the trigger_typing method so it ignores 403's in general. - """ - log.debug("Patching send_typing, which should fix things breaking when discord disables typing events. Stay safe!") - - original = http.HTTPClient.send_typing - last_403 = None - - async def honeybadger_type(self, channel_id: int) -> None: # noqa: ANN001 - nonlocal last_403 - if last_403 and (arrow.utcnow() - last_403) < timedelta(minutes=5): - log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.") - return - try: - await original(self, channel_id) - except Forbidden: - last_403 = arrow.utcnow() - log.warning("Got a 403 from typing event!") - pass - - http.HTTPClient.send_typing = honeybadger_type - - -class FixedPartialMessageConverter(commands.PartialMessageConverter): - """ - Make the Message converter infer channelID from the given context if only a messageID is given. - - Discord.py's Message converter is supposed to infer channelID based - on ctx.channel if only a messageID is given. A refactor commit, linked below, - a few weeks before d.py's archival broke this defined behaviour of the converter. - Currently, if only a messageID is given to the converter, it will only find that message - if it's in the bot's cache. - - https://github.com/Rapptz/discord.py/commit/1a4e73d59932cdbe7bf2c281f25e32529fc7ae1f - """ - - @staticmethod - def _get_id_matches(ctx: commands.Context, argument: str) -> tuple[int, int, int]: - """Inserts ctx.channel.id before calling super method if argument is just a messageID.""" - match = MESSAGE_ID_RE.match(argument) - if match: - argument = f"{ctx.channel.id}-{match.group('message_id')}" - return commands.PartialMessageConverter._get_id_matches(ctx, argument) diff --git a/bot/utils/messages.py b/bot/utils/messages.py index e55c07062..a5ed84351 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -6,12 +6,12 @@ from io import BytesIO from typing import Callable, List, Optional, Sequence, Union import discord +from botcore.utils import scheduling from discord.ext.commands import Context import bot from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES from bot.log import get_logger -from bot.utils import scheduling log = get_logger(__name__) diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py deleted file mode 100644 index 23acacf74..000000000 --- a/bot/utils/scheduling.py +++ /dev/null @@ -1,194 +0,0 @@ -import asyncio -import contextlib -import inspect -import typing as t -from datetime import datetime -from functools import partial - -from arrow import Arrow - -from bot.log import get_logger - - -class Scheduler: - """ - Schedule the execution of coroutines and keep track of them. - - When instantiating a Scheduler, a name must be provided. This name is used to distinguish the - instance's log messages from other instances. Using the name of the class or module containing - the instance is suggested. - - Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at` - or `schedule_later`. A unique ID is required to be given in order to keep track of the - resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing - the same ID used to schedule it. The `in` operator is supported for checking if a task with a - given ID is currently scheduled. - - Any exception raised in a scheduled task is logged when the task is done. - """ - - def __init__(self, name: str): - self.name = name - - self._log = get_logger(f"{__name__}.{name}") - self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} - - def __contains__(self, task_id: t.Hashable) -> bool: - """Return True if a task with the given `task_id` is currently scheduled.""" - return task_id in self._scheduled_tasks - - def schedule(self, task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule the execution of a `coroutine`. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - self._log.trace(f"Scheduling task #{task_id}...") - - msg = f"Cannot schedule an already started coroutine for #{task_id}" - assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg - - if task_id in self._scheduled_tasks: - self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") - coroutine.close() - return - - task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") - task.add_done_callback(partial(self._task_done_callback, task_id)) - - self._scheduled_tasks[task_id] = task - self._log.debug(f"Scheduled task #{task_id} {id(task)}.") - - def schedule_at(self, time: t.Union[datetime, Arrow], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule `coroutine` to be executed at the given `time`. - - If `time` is timezone aware, then use that timezone to calculate now() when subtracting. - If `time` is naïve, then use UTC. - - If `time` is in the past, schedule `coroutine` immediately. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() - delay = (time - now_datetime).total_seconds() - if delay > 0: - coroutine = self._await_later(delay, task_id, coroutine) - - self.schedule(task_id, coroutine) - - def schedule_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule `coroutine` to be executed after the given `delay` number of seconds. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - self.schedule(task_id, self._await_later(delay, task_id, coroutine)) - - def cancel(self, task_id: t.Hashable) -> None: - """Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist.""" - self._log.trace(f"Cancelling task #{task_id}...") - - try: - task = self._scheduled_tasks.pop(task_id) - except KeyError: - self._log.warning(f"Failed to unschedule {task_id} (no task found).") - else: - task.cancel() - - self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") - - def cancel_all(self) -> None: - """Unschedule all known tasks.""" - self._log.debug("Unscheduling all tasks") - - for task_id in self._scheduled_tasks.copy(): - self.cancel(task_id) - - async def _await_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """Await `coroutine` after the given `delay` number of seconds.""" - try: - self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") - await asyncio.sleep(delay) - - # Use asyncio.shield to prevent the coroutine from cancelling itself. - self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") - await asyncio.shield(coroutine) - finally: - # Close it to prevent unawaited coroutine warnings, - # which would happen if the task was cancelled during the sleep. - # Only close it if it's not been awaited yet. This check is important because the - # coroutine may cancel this task, which would also trigger the finally block. - state = inspect.getcoroutinestate(coroutine) - if state == "CORO_CREATED": - self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") - coroutine.close() - else: - self._log.debug(f"Finally block reached for #{task_id}; {state=}") - - def _task_done_callback(self, task_id: t.Hashable, done_task: asyncio.Task) -> None: - """ - Delete the task and raise its exception if one exists. - - If `done_task` and the task associated with `task_id` are different, then the latter - will not be deleted. In this case, a new task was likely rescheduled with the same ID. - """ - self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") - - scheduled_task = self._scheduled_tasks.get(task_id) - - if scheduled_task and done_task is scheduled_task: - # A task for the ID exists and is the same as the done task. - # Since this is the done callback, the task is already done so no need to cancel it. - self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") - del self._scheduled_tasks[task_id] - elif scheduled_task: - # A new task was likely rescheduled with the same ID. - self._log.debug( - f"The scheduled task #{task_id} {id(scheduled_task)} " - f"and the done task {id(done_task)} differ." - ) - elif not done_task.cancelled(): - self._log.warning( - f"Task #{task_id} not found while handling task {id(done_task)}! " - f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." - ) - - with contextlib.suppress(asyncio.CancelledError): - exception = done_task.exception() - # Log the exception if one exists. - if exception: - self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) - - -def create_task( - coro: t.Awaitable, - *, - suppressed_exceptions: tuple[t.Type[Exception]] = (), - event_loop: t.Optional[asyncio.AbstractEventLoop] = None, - **kwargs, -) -> asyncio.Task: - """ - Wrapper for creating asyncio `Task`s which logs exceptions raised in the task. - - If the loop kwarg is provided, the task is created from that event loop, otherwise the running loop is used. - """ - if event_loop is not None: - task = event_loop.create_task(coro, **kwargs) - else: - task = asyncio.create_task(coro, **kwargs) - task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) - return task - - -def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: t.Tuple[t.Type[Exception]]) -> None: - """Retrieve and log the exception raised in `task` if one exists.""" - with contextlib.suppress(asyncio.CancelledError): - exception = task.exception() - # Log the exception if one exists. - if exception and not isinstance(exception, suppressed_exceptions): - log = get_logger(__name__) - log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index fdd0ab74a..7dff38f96 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -60,7 +60,7 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch("bot.utils.scheduling.create_task") + @mock.patch("botcore.utils.scheduling.create_task") @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) def test_sync_cog_init(self, sync_guild, create_task): """Should instantiate syncers and run a sync for the guild.""" diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py index 8ae59c1f1..bd26532f1 100644 --- a/tests/bot/exts/filters/test_filtering.py +++ b/tests/bot/exts/filters/test_filtering.py @@ -11,7 +11,7 @@ class FilteringCogTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Instantiate the bot and cog.""" self.bot = MockBot() - with patch("bot.utils.scheduling.create_task", new=lambda task, **_: task.close()): + with patch("botcore.utils.scheduling.create_task", new=lambda task, **_: task.close()): self.cog = filtering.Filtering(self.bot) @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) -- cgit v1.2.3 From 047705ac91c2997ccb509ea4e1fb3fad38840412 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Thu, 31 Mar 2022 20:46:45 +0100 Subject: Remove async stats and site api wrapper We now source them from bot-core, so no need to have them here too. --- bot/api.py | 102 --------------------- bot/async_stats.py | 40 -------- bot/converters.py | 2 +- bot/exts/backend/error_handler.py | 2 +- bot/exts/backend/sync/_cog.py | 2 +- bot/exts/backend/sync/_syncers.py | 2 +- bot/exts/filters/filter_lists.py | 2 +- bot/exts/filters/filtering.py | 2 +- bot/exts/fun/off_topic_names.py | 2 +- bot/exts/info/doc/_cog.py | 2 +- bot/exts/info/information.py | 2 +- bot/exts/moderation/infraction/_scheduler.py | 2 +- bot/exts/moderation/infraction/_utils.py | 2 +- bot/exts/moderation/voice_gate.py | 2 +- bot/exts/moderation/watchchannels/_watchchannel.py | 2 +- bot/exts/recruitment/talentpool/_cog.py | 2 +- bot/exts/recruitment/talentpool/_review.py | 2 +- tests/bot/exts/backend/sync/test_base.py | 3 +- tests/bot/exts/backend/sync/test_cog.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 2 +- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- tests/bot/test_api.py | 66 ------------- tests/helpers.py | 4 +- 23 files changed, 22 insertions(+), 229 deletions(-) delete mode 100644 bot/api.py delete mode 100644 bot/async_stats.py delete mode 100644 tests/bot/test_api.py (limited to 'tests') diff --git a/bot/api.py b/bot/api.py deleted file mode 100644 index 856f7c865..000000000 --- a/bot/api.py +++ /dev/null @@ -1,102 +0,0 @@ -import asyncio -from typing import Optional -from urllib.parse import quote as quote_url - -import aiohttp - -from bot.log import get_logger - -from .constants import Keys, URLs - -log = get_logger(__name__) - - -class ResponseCodeError(ValueError): - """Raised when a non-OK HTTP response is received.""" - - def __init__( - self, - response: aiohttp.ClientResponse, - response_json: Optional[dict] = None, - response_text: str = "" - ): - self.status = response.status - self.response_json = response_json or {} - self.response_text = response_text - self.response = response - - def __str__(self): - response = self.response_json if self.response_json else self.response_text - return f"Status: {self.status} Response: {response}" - - -class APIClient: - """Django Site API wrapper.""" - - # These are class attributes so they can be seen when being mocked for tests. - # See commit 22a55534ef13990815a6f69d361e2a12693075d5 for details. - session: Optional[aiohttp.ClientSession] = None - loop: asyncio.AbstractEventLoop = None - - def __init__(self, **session_kwargs): - auth_headers = { - 'Authorization': f"Token {Keys.site_api}" - } - - if 'headers' in session_kwargs: - session_kwargs['headers'].update(auth_headers) - else: - session_kwargs['headers'] = auth_headers - - # aiohttp will complain if APIClient gets instantiated outside a coroutine. Thankfully, we - # don't and shouldn't need to do that, so we can avoid scheduling a task to create it. - self.session = aiohttp.ClientSession(**session_kwargs) - - @staticmethod - def _url_for(endpoint: str) -> str: - return f"{URLs.site_api_schema}{URLs.site_api}/{quote_url(endpoint)}" - - async def close(self) -> None: - """Close the aiohttp session.""" - await self.session.close() - - async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: - """Raise ResponseCodeError for non-OK response if an exception should be raised.""" - if should_raise and response.status >= 400: - try: - response_json = await response.json() - raise ResponseCodeError(response=response, response_json=response_json) - except aiohttp.ContentTypeError: - response_text = await response.text() - raise ResponseCodeError(response=response, response_text=response_text) - - async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Send an HTTP request to the site API and return the JSON response.""" - async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() - - async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Site API GET.""" - return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Site API PATCH.""" - return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Site API POST.""" - return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Site API PUT.""" - return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]: - """Site API DELETE.""" - async with self.session.delete(self._url_for(endpoint), **kwargs) as resp: - if resp.status == 204: - return None - - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() diff --git a/bot/async_stats.py b/bot/async_stats.py deleted file mode 100644 index 0303de7a1..000000000 --- a/bot/async_stats.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -import socket - -from botcore.utils import scheduling -from statsd.client.base import StatsClientBase - - -class AsyncStatsClient(StatsClientBase): - """An async transport method for statsd communication.""" - - def __init__( - self, - loop: asyncio.AbstractEventLoop, - host: str = 'localhost', - port: int = 8125, - prefix: str = None - ): - """Create a new client.""" - family, _, _, _, addr = socket.getaddrinfo( - host, port, socket.AF_INET, socket.SOCK_DGRAM)[0] - self._addr = addr - self._prefix = prefix - self._loop = loop - self._transport = None - - async def create_socket(self) -> None: - """Use the loop.create_datagram_endpoint method to create a socket.""" - self._transport, _ = await self._loop.create_datagram_endpoint( - asyncio.DatagramProtocol, - family=socket.AF_INET, - remote_addr=self._addr - ) - - def _send(self, data: str) -> None: - """Start an async task to send data to statsd.""" - scheduling.create_task(self._async_send(data), event_loop=self._loop) - - async def _async_send(self, data: str) -> None: - """Send data to the statsd server using the async transport.""" - self._transport.sendto(data.encode('ascii'), self._addr) diff --git a/bot/converters.py b/bot/converters.py index e819e4713..a3f4630a0 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -8,13 +8,13 @@ from ssl import CertificateError import dateutil.parser import discord from aiohttp import ClientConnectorError +from botcore.site_api import ResponseCodeError from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter from discord.utils import escape_markdown, snowflake_time from bot import exts -from bot.api import ResponseCodeError from bot.constants import URLs from bot.errors import InvalidInfraction from bot.exts.info.doc import _inventory_parser diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index fabb2dbb5..5391a7f15 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,10 +1,10 @@ import difflib +from botcore.site_api import ResponseCodeError from discord import Embed from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from sentry_sdk import push_scope -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Colours, Icons, MODERATION_ROLES from bot.errors import InvalidInfractedUserError, LockedResourceError diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index 58aabc141..a5bf82397 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -1,11 +1,11 @@ from typing import Any, Dict +from botcore.site_api import ResponseCodeError from discord import Member, Role, User from discord.ext import commands from discord.ext.commands import Cog, Context from bot import constants -from bot.api import ResponseCodeError from bot.bot import Bot from bot.exts.backend.sync import _syncers from bot.log import get_logger diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 45301b098..e1c4541ef 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,12 +2,12 @@ import abc import typing as t from collections import namedtuple +from botcore.site_api import ResponseCodeError from discord import Guild from discord.ext.commands import Context from more_itertools import chunked import bot -from bot.api import ResponseCodeError from bot.log import get_logger from bot.utils.members import get_or_fetch_member diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py index 3e3f5c562..fc9cfbeca 100644 --- a/bot/exts/filters/filter_lists.py +++ b/bot/exts/filters/filter_lists.py @@ -1,11 +1,11 @@ import re from typing import Optional +from botcore.site_api import ResponseCodeError from discord import Colour, Embed from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role from bot import constants -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels from bot.converters import ValidDiscordServerInvite, ValidFilterListType diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index cabb7f0b6..6982f5948 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -9,6 +9,7 @@ import dateutil.parser import regex import tldextract from async_rediscache import RedisCache +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta @@ -16,7 +17,6 @@ from discord import ChannelType, Colour, Embed, Forbidden, HTTPException, Member from discord.ext.commands import Cog from discord.utils import escape_markdown -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, Colours, Filter, Guild, Icons, URLs from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index ac172f2a8..d8111bdf5 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -2,12 +2,12 @@ import difflib from datetime import timedelta import arrow +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from discord import Colour, Embed from discord.ext.commands import Cog, Context, group, has_any_role from discord.utils import sleep_until -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES from bot.converters import OffTopicName diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 8c3038c5b..bbdc4e82a 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -10,11 +10,11 @@ from typing import Dict, NamedTuple, Optional, Tuple, Union import aiohttp import discord +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from botcore.utils.scheduling import Scheduler from discord.ext import commands -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import Inventory, PackageName, ValidURL, allowed_strings diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index b56fd171a..e7d17c971 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -6,12 +6,12 @@ from textwrap import shorten from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union import rapidfuzz +from botcore.site_api import ResponseCodeError from discord import AllowedMentions, Colour, Embed, Guild, Message, Role from discord.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role from discord.utils import escape_markdown from bot import constants -from bot.api import ResponseCodeError from bot.bot import Bot from bot.converters import MemberOrUser from bot.decorators import in_whitelist diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 137358ec3..9c73bde5f 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -6,11 +6,11 @@ from gettext import ngettext import arrow import dateutil.parser import discord +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from discord.ext.commands import Context from bot import constants -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Colours from bot.converters import MemberOrUser diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index c1be18362..3a2485ec2 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -3,10 +3,10 @@ from datetime import datetime import arrow import discord +from botcore.site_api import ResponseCodeError from discord.ext.commands import Context import bot -from bot.api import ResponseCodeError from bot.constants import Colours, Icons from bot.converters import MemberOrUser from bot.errors import InvalidInfractedUserError diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py index 33096e7e0..9b1621c01 100644 --- a/bot/exts/moderation/voice_gate.py +++ b/bot/exts/moderation/voice_gate.py @@ -5,10 +5,10 @@ from datetime import timedelta import arrow import discord from async_rediscache import RedisCache +from botcore.site_api import ResponseCodeError from discord import Colour, Member, VoiceState from discord.ext.commands import Cog, Context, command -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Roles, VoiceGate as GateConf from bot.decorators import has_no_roles, in_whitelist diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index ab5ce62f9..bc78b3934 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -7,11 +7,11 @@ from dataclasses import dataclass from typing import Any, Dict, Optional import discord +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from discord import Color, DMChannel, Embed, HTTPException, Message, errors from discord.ext.commands import Cog, Context -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons from bot.exts.filters.token_remover import TokenRemover diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 8aa124536..24496af54 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -5,11 +5,11 @@ from typing import Optional, Union import discord from async_rediscache import RedisCache +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User from discord.ext.commands import BadArgument, Cog, Context, group, has_any_role -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles, STAFF_ROLES from bot.converters import MemberOrUser, UnambiguousMemberOrUser diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index d0edf5388..be181d005 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -9,12 +9,12 @@ from datetime import datetime, timedelta from typing import List, Optional, Union import arrow +from botcore.site_api import ResponseCodeError from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse from discord import Embed, Emoji, Member, Message, NotFound, PartialMessage, TextChannel from discord.ext.commands import Context -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Guild, Roles from bot.log import get_logger diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 9dc46005b..a17c1fa10 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -1,7 +1,8 @@ import unittest from unittest import mock -from bot.api import ResponseCodeError +from botcore.site_api import ResponseCodeError + from bot.exts.backend.sync._syncers import Syncer from tests import helpers diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 7dff38f96..4ec36e39f 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -2,9 +2,9 @@ import unittest from unittest import mock import discord +from botcore.site_api import ResponseCodeError from bot import constants -from bot.api import ResponseCodeError from bot.exts.backend import sync from bot.exts.backend.sync._cog import Sync from bot.exts.backend.sync._syncers import Syncer diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 35fa0ee59..04a018289 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -1,9 +1,9 @@ import unittest from unittest.mock import AsyncMock, MagicMock, call, patch +from botcore.site_api import ResponseCodeError from discord.ext.commands import errors -from bot.api import ResponseCodeError from bot.errors import InvalidInfractedUserError, LockedResourceError from bot.exts.backend.error_handler import ErrorHandler, setup from bot.exts.info.tags import Tags diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index ff81ddd65..5cf02033d 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -3,9 +3,9 @@ from collections import namedtuple from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch +from botcore.site_api import ResponseCodeError from discord import Embed, Forbidden, HTTPException, NotFound -from bot.api import ResponseCodeError from bot.constants import Colours, Icons from bot.exts.moderation.infraction import _utils as utils from tests.helpers import MockBot, MockContext, MockMember, MockUser diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py deleted file mode 100644 index 76bcb481d..000000000 --- a/tests/bot/test_api.py +++ /dev/null @@ -1,66 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from bot import api - - -class APIClientTests(unittest.IsolatedAsyncioTestCase): - """Tests for the bot's API client.""" - - @classmethod - def setUpClass(cls): - """Sets up the shared fixtures for the tests.""" - cls.error_api_response = MagicMock() - cls.error_api_response.status = 999 - - def test_response_code_error_default_initialization(self): - """Test the default initialization of `ResponseCodeError` without `text` or `json`""" - error = api.ResponseCodeError(response=self.error_api_response) - - self.assertIs(error.status, self.error_api_response.status) - self.assertEqual(error.response_json, {}) - self.assertEqual(error.response_text, "") - self.assertIs(error.response, self.error_api_response) - - def test_response_code_error_string_representation_default_initialization(self): - """Test the string representation of `ResponseCodeError` initialized without text or json.""" - error = api.ResponseCodeError(response=self.error_api_response) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") - - def test_response_code_error_initialization_with_json(self): - """Test the initialization of `ResponseCodeError` with json.""" - json_data = {'hello': 'world'} - error = api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data, - ) - self.assertEqual(error.response_json, json_data) - self.assertEqual(error.response_text, "") - - def test_response_code_error_string_representation_with_nonempty_response_json(self): - """Test the string representation of `ResponseCodeError` initialized with json.""" - json_data = {'hello': 'world'} - error = api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}") - - def test_response_code_error_initialization_with_text(self): - """Test the initialization of `ResponseCodeError` with text.""" - text_data = 'Lemon will eat your soul' - error = api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data, - ) - self.assertEqual(error.response_text, text_data) - self.assertEqual(error.response_json, {}) - - def test_response_code_error_string_representation_with_nonempty_response_text(self): - """Test the string representation of `ResponseCodeError` initialized with text.""" - text_data = 'Lemon will eat your soul' - error = api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") diff --git a/tests/helpers.py b/tests/helpers.py index 9d4988d23..3e6290e58 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -9,10 +9,10 @@ from typing import Iterable, Optional import discord from aiohttp import ClientSession +from botcore.async_stats import AsyncStatsClient +from botcore.site_api import APIClient from discord.ext.commands import Context -from bot.api import APIClient -from bot.async_stats import AsyncStatsClient from bot.bot import Bot from tests._autospec import autospec # noqa: F401 other modules import it via this module -- cgit v1.2.3 From 277bb011d8ae6b1c58c9de80d10b61791ee1fc49 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:42:58 +0100 Subject: Adding missing kwargs required by BotBase in test helper --- tests/helpers.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 3e6290e58..e6e95c20c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -312,6 +312,9 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop(), redis_session=unittest.mock.MagicMock(), + http_session=unittest.mock.MagicMock(), + allowed_roles=[1], + guild_id=1, ) additional_spec_asyncs = ("wait_for", "redis_ready") -- cgit v1.2.3 From 56e38eeb38de10611611f0f81f5cbc429d7f2bc8 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:44:53 +0100 Subject: Update test helpers with breaking d.py changes region was removed from the guild object, so this has been replaced with features add_cog is now async, so it is now an async_mock during tests Two new required voice_channel attrs were added channel.type is required to be set to ChannelType due to a new isinstance check in d.py --- tests/helpers.py | 4 ++++ tests/test_helpers.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index e6e95c20c..2f0c9b4ad 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -325,6 +325,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.api_client = MockAPIClient(loop=self.loop) self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) + self.add_cog = unittest.mock.AsyncMock() # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` @@ -337,6 +338,8 @@ channel_data = { 'position': 1, 'nsfw': False, 'last_message_id': 1, + 'bitrate': 1337, + 'user_limit': 25, } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() @@ -441,6 +444,7 @@ message_data = { } state = unittest.mock.MagicMock() channel = unittest.mock.MagicMock() +channel.type = discord.ChannelType.text message_instance = discord.Message(state=state, channel=channel, data=message_data) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 81285e009..f3040b305 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -327,7 +327,7 @@ class MockObjectTests(unittest.TestCase): def test_spec_propagation_of_mock_subclasses(self): """Test if the `spec` does not propagate to attributes of the mock object.""" test_values = ( - (helpers.MockGuild, "region"), + (helpers.MockGuild, "features"), (helpers.MockRole, "mentionable"), (helpers.MockMember, "display_name"), (helpers.MockBot, "owner_id"), -- cgit v1.2.3 From 450c9ce5b9bb711681ce87508d5b33a0ad6aed52 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:47:13 +0100 Subject: Update tests to use new async cog setup function --- tests/bot/exts/backend/sync/test_cog.py | 6 +++--- tests/bot/exts/backend/test_error_handler.py | 8 ++++---- tests/bot/exts/events/test_code_jams.py | 8 ++++---- tests/bot/exts/filters/test_antimalware.py | 8 ++++---- tests/bot/exts/filters/test_security.py | 11 +++++------ tests/bot/exts/filters/test_token_remover.py | 8 ++++---- tests/bot/exts/utils/test_snekbox.py | 8 ++++---- 7 files changed, 28 insertions(+), 29 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 4ec36e39f..ce620aa8d 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -16,11 +16,11 @@ class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the sync extension.""" @staticmethod - def test_extension_setup(): + async def test_extension_setup(): """The Sync cog should be added.""" bot = helpers.MockBot() - sync.setup(bot) - bot.add_cog.assert_called_once() + await sync.setup(bot) + bot.add_cog.assert_awaited_once() class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 04a018289..193f1d822 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -544,11 +544,11 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): push_scope_mock.set_extra.has_calls(set_extra_calls) -class ErrorHandlerSetupTests(unittest.TestCase): +class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): """Tests for `ErrorHandler` `setup` function.""" - def test_setup(self): + async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - setup(bot) - bot.add_cog.assert_called_once() + await setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/events/test_code_jams.py b/tests/bot/exts/events/test_code_jams.py index 0856546af..684f7abcd 100644 --- a/tests/bot/exts/events/test_code_jams.py +++ b/tests/bot/exts/events/test_code_jams.py @@ -160,11 +160,11 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase): member.add_roles.assert_not_awaited() -class CodeJamSetup(unittest.TestCase): +class CodeJamSetup(unittest.IsolatedAsyncioTestCase): """Test for `setup` function of `CodeJam` cog.""" - def test_setup(self): + async def test_setup(self): """Should call `bot.add_cog`.""" bot = MockBot() - code_jams.setup(bot) - bot.add_cog.assert_called_once() + await code_jams.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index 06d78de9d..7282334e2 100644 --- a/tests/bot/exts/filters/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -192,11 +192,11 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) -class AntiMalwareSetupTests(unittest.TestCase): +class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase): """Tests setup of the `AntiMalware` cog.""" - def test_setup(self): + async def test_setup(self): """Setup of the extension should call add_cog.""" bot = MockBot() - antimalware.setup(bot) - bot.add_cog.assert_called_once() + await antimalware.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py index c0c3baa42..007b7b1eb 100644 --- a/tests/bot/exts/filters/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -1,5 +1,4 @@ import unittest -from unittest.mock import MagicMock from discord.ext.commands import NoPrivateMessage @@ -44,11 +43,11 @@ class SecurityCogTests(unittest.TestCase): self.assertTrue(self.cog.check_on_guild(self.ctx)) -class SecurityCogLoadTests(unittest.TestCase): +class SecurityCogLoadTests(unittest.IsolatedAsyncioTestCase): """Tests loading the `Security` cog.""" - def test_security_cog_load(self): + async def test_security_cog_load(self): """Setup of the extension should call add_cog.""" - bot = MagicMock() - security.setup(bot) - bot.add_cog.assert_called_once() + bot = MockBot() + await security.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py index 4db27269a..c1f3762ac 100644 --- a/tests/bot/exts/filters/test_token_remover.py +++ b/tests/bot/exts/filters/test_token_remover.py @@ -395,15 +395,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.channel.send.assert_not_awaited() -class TokenRemoverExtensionTests(unittest.TestCase): +class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the token_remover extension.""" @autospec("bot.exts.filters.token_remover", "TokenRemover") - def test_extension_setup(self, cog): + async def test_extension_setup(self, cog): """The TokenRemover cog should be added.""" bot = MockBot() - token_remover.setup(bot) + await token_remover.setup(bot) cog.assert_called_once_with(bot) - bot.add_cog.assert_called_once() + bot.add_cog.assert_awaited_once() self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index f68a20089..3c555c051 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -403,11 +403,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(actual, expected) -class SnekboxSetupTests(unittest.TestCase): +class SnekboxSetupTests(unittest.IsolatedAsyncioTestCase): """Tests setup of the `Snekbox` cog.""" - def test_setup(self): + async def test_setup(self): """Setup of the extension should call add_cog.""" bot = MockBot() - snekbox.setup(bot) - bot.add_cog.assert_called_once() + await snekbox.setup(bot) + bot.add_cog.assert_awaited_once() -- cgit v1.2.3 From 4db508e7abdc9d5835940f15d5faecdd2f045d0a Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:49:36 +0100 Subject: Update tests to use new async cog_load function --- tests/bot/exts/backend/sync/test_cog.py | 2 +- tests/bot/exts/moderation/test_silence.py | 40 +++++++++++++------------------ 2 files changed, 17 insertions(+), 25 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index ce620aa8d..4e9941d88 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -87,7 +87,7 @@ class SyncCogTests(SyncCogTestCase): self.bot.get_guild = mock.MagicMock(return_value=guild) - await self.cog.sync_guild() + await self.cog.cog_load() self.bot.wait_until_guild_available.assert_called_once() self.bot.get_guild.assert_called_once_with(constants.Guild.id) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 92ce3418a..2ebb16978 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -114,44 +114,36 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.cog = silence.Silence(self.bot) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_got_guild(self): + async def testcog_load_got_guild(self): """Bot got guild after it became available.""" - await self.cog._async_init() + await self.cog.cog_load() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_got_channels(self): + async def testcog_load_got_channels(self): """Got channels from bot.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog._async_init() + await self.cog.cog_load() self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") - async def test_async_init_got_notifier(self, notifier): + async def testcog_load_got_notifier(self, notifier): """Notifier was started with channel.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog._async_init() + await self.cog.cog_load() notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) self.assertEqual(self.cog.notifier, notifier.return_value) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_rescheduled(self): + async def testcog_load_rescheduled(self): """`_reschedule_` coroutine was awaited.""" self.cog._reschedule = mock.create_autospec(self.cog._reschedule) - await self.cog._async_init() + await self.cog.cog_load() self.cog._reschedule.assert_awaited_once_with() - def test_cog_unload_cancelled_tasks(self): - """The init task was cancelled.""" - self.cog._init_task = asyncio.Future() - self.cog.cog_unload() - - # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. - self.assertTrue(self.cog._init_task.cancelled()) - @autospec("discord.ext.commands", "has_any_role") @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): @@ -165,7 +157,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_force_voice_sync(self): """Tests the _force_voice_sync helper function.""" - await self.cog._async_init() + await self.cog.cog_load() # Create a regular member, and one member for each of the moderation roles moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] @@ -187,7 +179,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_force_voice_sync_no_channel(self): """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" - await self.cog._async_init() + await self.cog.cog_load() channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) new_channel = MockVoiceChannel(delete=AsyncMock()) @@ -206,7 +198,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_voice_kick(self): """Test to ensure kick function can remove all members from a voice channel.""" - await self.cog._async_init() + await self.cog.cog_load() # Create a regular member, and one member for each of the moderation roles moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] @@ -236,7 +228,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_kick_move_to_error(self): """Test to ensure move_to gets called on all members during kick, even if some fail.""" - await self.cog._async_init() + await self.cog.cog_load() _, members = self.create_erroneous_members() await self.cog._kick_voice_members(MockVoiceChannel(members=members)) @@ -245,7 +237,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_sync_move_to_error(self): """Test to ensure move_to gets called on all members during sync, even if some fail.""" - await self.cog._async_init() + await self.cog.cog_load() failing_member, members = self.create_erroneous_members() await self.cog._force_voice_sync(MockVoiceChannel(members=members)) @@ -339,7 +331,7 @@ class RescheduleTests(unittest.IsolatedAsyncioTestCase): self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) with mock.patch.object(self.cog, "_reschedule", autospec=True): - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. async def test_skipped_missing_channel(self): """Did nothing because the channel couldn't be retrieved.""" @@ -428,7 +420,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite( @@ -701,7 +693,7 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) self.cog.previous_overwrites = overwrites_cache - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. self.cog.scheduler.__contains__.return_value = True overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' -- cgit v1.2.3 From 0afe07d0734e50854d7abd8685086e79858fcadf Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:50:33 +0100 Subject: Remove sync cog init test Discord.py now implicitly calls the new async cog_load function from within it's internals on load. There is no longer a need to test that this happens. --- tests/bot/exts/backend/sync/test_cog.py | 17 ----------------- 1 file changed, 17 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 4e9941d88..28afeebeb 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -60,23 +60,6 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch("botcore.utils.scheduling.create_task") - @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) - def test_sync_cog_init(self, sync_guild, create_task): - """Should instantiate syncers and run a sync for the guild.""" - # Reset because a Sync cog was already instantiated in setUp. - self.RoleSyncer.reset_mock() - self.UserSyncer.reset_mock() - - mock_sync_guild_coro = mock.MagicMock() - sync_guild.return_value = mock_sync_guild_coro - - Sync(self.bot) - - sync_guild.assert_called_once_with() - create_task.assert_called_once() - self.assertEqual(create_task.call_args.args[0], mock_sync_guild_coro) - async def test_sync_cog_sync_guild(self): """Roles and users should be synced only if a guild is successfully retrieved.""" for guild in (helpers.MockGuild(), None): -- cgit v1.2.3 From 4a9e2819929908182f8c6a148502671f281357ca Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:51:01 +0100 Subject: Don't try to overwrite a read-only attr in help command test --- tests/bot/exts/info/test_help.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/info/test_help.py b/tests/bot/exts/info/test_help.py index 604c69671..21d124f3a 100644 --- a/tests/bot/exts/info/test_help.py +++ b/tests/bot/exts/info/test_help.py @@ -1,4 +1,5 @@ import unittest +import unittest.mock import rapidfuzz @@ -12,7 +13,6 @@ class HelpCogTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = help.Help(self.bot) self.ctx = MockContext(bot=self.bot) - self.bot.help_command.context = self.ctx @autospec(help.CustomHelpCommand, "get_all_help_choices", return_value={"help"}, pass_mocks=False) async def test_help_fuzzy_matching(self): -- cgit v1.2.3 From 87edeff7347f011c2317cd4b3681bea5bbe07185 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 18 Apr 2022 17:39:20 +0100 Subject: Test that sync cog syncers run when sync cog is loaded --- tests/bot/exts/backend/sync/test_cog.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'tests') diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 28afeebeb..87b76c6b4 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -60,6 +60,19 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" + async def test_sync_cog_sync_on_load(self): + """Roles and users should be synced on cog load.""" + guild = helpers.MockGuild() + self.bot.get_guild = mock.MagicMock(return_value=guild) + + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + + await self.cog.cog_load() + + self.RoleSyncer.sync.assert_called_once_with(guild) + self.UserSyncer.sync.assert_called_once_with(guild) + async def test_sync_cog_sync_guild(self): """Roles and users should be synced only if a guild is successfully retrieved.""" for guild in (helpers.MockGuild(), None): -- cgit v1.2.3 From 267f6d94cb10a45f873ea1d0e22f812267be5e69 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 18 Apr 2022 17:40:41 +0100 Subject: Add missing underscores to test function names --- tests/bot/exts/moderation/test_silence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 2ebb16978..65aecad28 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -114,14 +114,14 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.cog = silence.Silence(self.bot) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def testcog_load_got_guild(self): + async def test_cog_load_got_guild(self): """Bot got guild after it became available.""" await self.cog.cog_load() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def testcog_load_got_channels(self): + async def test_cog_load_got_channels(self): """Got channels from bot.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) @@ -129,7 +129,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") - async def testcog_load_got_notifier(self, notifier): + async def test_cog_load_got_notifier(self, notifier): """Notifier was started with channel.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) -- cgit v1.2.3 From 4c1a076dd0b7c021ac1b352589bb353139f86a6f Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Tue, 19 Apr 2022 17:23:28 +0100 Subject: Pass the now required intents kwarg when creating MockBot --- tests/helpers.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 2f0c9b4ad..a6e4bdd66 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -315,6 +315,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): http_session=unittest.mock.MagicMock(), allowed_roles=[1], guild_id=1, + intents=discord.Intents.all(), ) additional_spec_asyncs = ("wait_for", "redis_ready") -- cgit v1.2.3 From 03a8c8138c53b6582a0921b3dcf7a1b7d55877de Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Wed, 20 Apr 2022 21:06:13 +0100 Subject: remove unneeded import in tests --- tests/bot/exts/info/test_help.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/info/test_help.py b/tests/bot/exts/info/test_help.py index 21d124f3a..2644ae40d 100644 --- a/tests/bot/exts/info/test_help.py +++ b/tests/bot/exts/info/test_help.py @@ -1,5 +1,4 @@ import unittest -import unittest.mock import rapidfuzz -- cgit v1.2.3 From 8f9f25a796f6cc07f01f8f7f56e825cb5ebf56c8 Mon Sep 17 00:00:00 2001 From: Hassan Abouelela Date: Sat, 23 Apr 2022 14:24:20 +0400 Subject: Speed Up Sync Cog Loading The user syncer was blocking the startup of the sync cog due to having to perform thousands of pointless member fetch requests. This speeds up that process by increasing the probability that the cache is up-to-date using `Guild.chunked`, and limiting the fetches to members who were in the guild during the previous sync only. Co-authored-by: ChrisJL Co-authored-by: wookie184 Signed-off-by: Hassan Abouelela --- bot/exts/backend/sync/_cog.py | 6 ++++++ bot/exts/backend/sync/_syncers.py | 13 +++++++++++-- tests/helpers.py | 2 +- 3 files changed, 18 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index a5bf82397..4ec822d3f 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict from botcore.site_api import ResponseCodeError @@ -27,6 +28,11 @@ class Sync(Cog): if guild is None: return + log.info("Waiting for guild to be chunked to start syncers.") + while not guild.chunked: + await asyncio.sleep(10) + log.info("Starting syncers.") + for syncer in (_syncers.RoleSyncer, _syncers.UserSyncer): await syncer.sync(guild) diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index e1c4541ef..799137cb9 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,6 +2,7 @@ import abc import typing as t from collections import namedtuple +import discord.errors from botcore.site_api import ResponseCodeError from discord import Guild from discord.ext.commands import Context @@ -9,7 +10,6 @@ from more_itertools import chunked import bot from bot.log import get_logger -from bot.utils.members import get_or_fetch_member log = get_logger(__name__) @@ -157,7 +157,16 @@ class UserSyncer(Syncer): if db_user[db_field] != guild_value: updated_fields[db_field] = guild_value - if guild_user := await get_or_fetch_member(guild, db_user["id"]): + guild_user = guild.get_member(db_user["id"]) + if not guild_user and db_user["in_guild"]: + # The member was in the guild during the last sync. + # We try to fetch them to verify cache integrity. + try: + guild_user = await guild.fetch_member(db_user["id"]) + except discord.errors.NotFound: + guild_user = None + + if guild_user: seen_guild_users.add(guild_user.id) maybe_update("name", guild_user.name) diff --git a/tests/helpers.py b/tests/helpers.py index a6e4bdd66..5f3111616 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -171,7 +171,7 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): spec_set = guild_instance def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'members': []} + default_kwargs = {'id': next(self.discord_id), 'members': [], "chunked": True} super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] -- cgit v1.2.3 From e4ece6e598803b26770eb46e39024ccb34367705 Mon Sep 17 00:00:00 2001 From: wookie184 Date: Mon, 2 May 2022 16:16:05 +0100 Subject: Fix tests --- tests/bot/exts/utils/test_snekbox.py | 6 +++--- tests/bot/utils/test_services.py | 21 ++++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 3c555c051..b870a9945 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -35,15 +35,15 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): resp.json.assert_awaited_once() async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LEN.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + """Reject output longer than MAX_PASTE_LENGTH.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LENGTH + 1)) self.assertEqual(result, "too long to upload") @patch("bot.exts.utils.snekbox.send_to_paste_service") async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with("Test output.", extension="txt") + mock_paste_util.assert_called_once_with("Test output.", extension="txt", max_length=snekbox.MAX_PASTE_LENGTH) async def test_codeblock_converter(self): ctx = MockContext() diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index 3b71022db..de166c813 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -4,7 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch from aiohttp import ClientConnectorError -from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service +from bot.utils.services import ( + FAILED_REQUEST_ATTEMPTS, MAX_PASTE_LENGTH, PasteTooLongError, PasteUploadError, send_to_paste_service +) from tests.helpers import MockBot @@ -55,23 +57,28 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): for error_json in test_cases: with self.subTest(error_json=error_json): response.json = AsyncMock(return_value=error_json) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) self.bot.http_session.post.reset_mock() async def test_request_repeated_on_connection_errors(self): """Requests are repeated in the case of connection errors.""" self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) async def test_general_error_handled_and_request_repeated(self): """All `Exception`s are handled, logged and request repeated.""" self.bot.http_session.post = MagicMock(side_effect=Exception) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertLogs("bot.utils", logging.ERROR) - self.assertIsNone(result) + + async def test_raises_error_on_too_long_input(self): + contents = "a" * (MAX_PASTE_LENGTH+1) + with self.assertRaises(PasteTooLongError): + await send_to_paste_service(contents) -- cgit v1.2.3 From 890d2578d3d3d65814edfb350e981ce2fbe2957e Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sat, 21 May 2022 17:40:01 +0100 Subject: Bump malformed API response from debug to error log (#2175) --- bot/exts/backend/error_handler.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 5391a7f15..35dddd8dc 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -285,7 +285,7 @@ class ErrorHandler(Cog): ctx.bot.stats.incr("errors.api_error_404") elif e.status == 400: content = await e.response.json() - log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + log.error(f"API responded with 400 for command {ctx.command}: %r.", content) await ctx.send("According to the API, your request is malformed.") ctx.bot.stats.incr("errors.api_error_400") elif 500 <= e.status < 600: diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 193f1d822..d02bd7c34 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -477,11 +477,11 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.backend.error_handler.log") async def test_handle_api_error(self, log_mock): - """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code.""" + """Should `ctx.send` on HTTP error codes, and log at correct level.""" test_cases = ( { "error": ResponseCodeError(AsyncMock(status=400)), - "log_level": "debug" + "log_level": "error" }, { "error": ResponseCodeError(AsyncMock(status=404)), @@ -505,6 +505,8 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): self.ctx.send.assert_awaited_once() if case["log_level"] == "warning": log_mock.warning.assert_called_once() + elif case["log_level"] == "error": + log_mock.error.assert_called_once() else: log_mock.debug.assert_called_once() -- cgit v1.2.3 From 96c7deab22f5018a17b97cd68e3914c37f926be5 Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sat, 28 May 2022 16:46:42 +0100 Subject: Fix tests --- tests/bot/exts/backend/test_error_handler.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'tests') diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index d02bd7c34..0a58126e7 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -48,6 +48,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): cog = ErrorHandler(self.bot) cog.try_silence = AsyncMock() cog.try_get_tag = AsyncMock() + cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): @@ -76,6 +77,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): cog = ErrorHandler(self.bot) cog.try_silence = AsyncMock() cog.try_get_tag = AsyncMock() + cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() @@ -83,6 +85,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): cog.try_silence.assert_not_awaited() cog.try_get_tag.assert_not_awaited() + cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): -- cgit v1.2.3 From 381dfc1e7fd1849c3381971e9332f2fec20c7a7e Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sun, 29 May 2022 15:56:46 +0100 Subject: Raise ValueError if max_length greater than allowed by paste service --- bot/utils/services.py | 9 ++++++--- tests/bot/utils/test_services.py | 5 +++++ 2 files changed, 11 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/utils/services.py b/bot/utils/services.py index 3a6833e72..82c7d284c 100644 --- a/bot/utils/services.py +++ b/bot/utils/services.py @@ -25,17 +25,20 @@ async def send_to_paste_service(contents: str, *, extension: str = "", max_lengt `extension` is added to the output URL. `max_length` can be used to limit the allowed contents length to lower than the maximum allowed by the paste service. + Raises `ValueError` if `max_length` is greater than the maximum allowed by the paste service. Raises `PasteTooLongError` if contents is too long to upload, and `PasteUploadError` if uploading fails. Returns the generated URL with the extension. """ + if max_length > MAX_PASTE_LENGTH: + raise ValueError(f"`max_length` must not be greater than {MAX_PASTE_LENGTH}") + extension = extension and f".{extension}" - max_size = min(max_length, MAX_PASTE_LENGTH) contents_size = len(contents.encode()) - if contents_size > max_size: + if contents_size > max_length: log.info("Contents too large to send to paste service.") - raise PasteTooLongError(f"Contents of size {contents_size} greater than maximum size {max_size}") + raise PasteTooLongError(f"Contents of size {contents_size} greater than maximum size {max_length}") log.debug(f"Sending contents of size {contents_size} bytes to paste service.") paste_url = URLs.paste_service.format(key="documents") diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index de166c813..e6b95be8e 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -82,3 +82,8 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): contents = "a" * (MAX_PASTE_LENGTH+1) with self.assertRaises(PasteTooLongError): await send_to_paste_service(contents) + + async def test_raises_on_too_large_max_length(self): + """Ensure ValueError is raised if `max_length` passed is greater than `MAX_PASTE_LENGTH`.""" + with self.assertRaises(ValueError): + await send_to_paste_service("Hello World!", max_length=MAX_PASTE_LENGTH+1) -- cgit v1.2.3 From 3fe10aa1e07fcdd8a2efb4a00118b7ec0fcdedce Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sun, 29 May 2022 16:08:45 +0100 Subject: Make small wording and style changes --- bot/utils/services.py | 10 +++++----- tests/bot/utils/test_services.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/bot/utils/services.py b/bot/utils/services.py index 82c7d284c..a752ac0ec 100644 --- a/bot/utils/services.py +++ b/bot/utils/services.py @@ -22,13 +22,13 @@ async def send_to_paste_service(contents: str, *, extension: str = "", max_lengt """ Upload `contents` to the paste service. - `extension` is added to the output URL. `max_length` can be used to limit the allowed contents length + Add `extension` to the output URL. Use `max_length` to limit the allowed contents length to lower than the maximum allowed by the paste service. - Raises `ValueError` if `max_length` is greater than the maximum allowed by the paste service. - Raises `PasteTooLongError` if contents is too long to upload, and `PasteUploadError` if uploading fails. + Raise `ValueError` if `max_length` is greater than the maximum allowed by the paste service. + Raise `PasteTooLongError` if `contents` is too long to upload, and `PasteUploadError` if uploading fails. - Returns the generated URL with the extension. + Return the generated URL with the extension. """ if max_length > MAX_PASTE_LENGTH: raise ValueError(f"`max_length` must not be greater than {MAX_PASTE_LENGTH}") @@ -80,4 +80,4 @@ async def send_to_paste_service(contents: str, *, extension: str = "", max_lengt f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS})." ) - raise PasteUploadError("Failed to upload contents to pastebin") + raise PasteUploadError("Failed to upload contents to paste service") diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index e6b95be8e..d0e801299 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -79,11 +79,12 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): self.assertLogs("bot.utils", logging.ERROR) async def test_raises_error_on_too_long_input(self): - contents = "a" * (MAX_PASTE_LENGTH+1) + """Ensure PasteTooLongError is raised if `contents` is longer than `MAX_PASTE_LENGTH`.""" + contents = "a" * (MAX_PASTE_LENGTH + 1) with self.assertRaises(PasteTooLongError): await send_to_paste_service(contents) async def test_raises_on_too_large_max_length(self): """Ensure ValueError is raised if `max_length` passed is greater than `MAX_PASTE_LENGTH`.""" with self.assertRaises(ValueError): - await send_to_paste_service("Hello World!", max_length=MAX_PASTE_LENGTH+1) + await send_to_paste_service("Hello World!", max_length=MAX_PASTE_LENGTH + 1) -- cgit v1.2.3 From ac5bb147f5dc6b80c20c4c9abd5cc1fb7ac4abfa Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Tue, 28 Jun 2022 13:17:39 +0100 Subject: Use new application format for message data in test helper --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 5f3111616..17214553c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -432,7 +432,7 @@ message_data = { 'webhook_id': 431341013479718912, 'attachments': [], 'embeds': [], - 'application': 'Python Discord', + 'application': {"id": 4, "description": "A Python Bot", "name": "Python Discord", "icon": None}, 'activity': 'mocking', 'channel': unittest.mock.MagicMock(), 'edited_timestamp': '2019-10-14T15:33:48+00:00', -- cgit v1.2.3 From b7e03616ac3fc0b5e8a5a77a352df593983d187a Mon Sep 17 00:00:00 2001 From: Izan Date: Thu, 14 Jul 2022 22:21:34 +0100 Subject: Address Reviews - Use the more concise DATETIME timestamp instead of both a DATE and a TIME timestamp. - Remove underline from the "Reported ..." section at the bottom of the embed. - Re-add time of action/rejection timestamp to footer of embed. --- bot/exts/moderation/incidents.py | 7 ++++--- tests/bot/exts/moderation/test_incidents.py | 5 ++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index bd9e5b88e..f29cfcdd6 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -1,5 +1,6 @@ import asyncio import re +from datetime import datetime, timezone from enum import Enum from typing import Optional @@ -97,10 +98,9 @@ async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: di colour = Colours.soft_red footer = f"Rejected by {actioned_by}" - day_timestamp = discord_timestamp(incident.created_at, TimestampFormats.DATE) - time_timestamp = discord_timestamp(incident.created_at, TimestampFormats.TIME) + reported_timestamp = discord_timestamp(incident.created_at) relative_timestamp = discord_timestamp(incident.created_at, TimestampFormats.RELATIVE) - reported_on_msg = f"__*Reported {day_timestamp} at {time_timestamp} ({relative_timestamp}).*__" + reported_on_msg = f"*Reported {reported_timestamp} ({relative_timestamp}).*" # If the description will be too long (>4096 total characters), truncate the incident content if len(incident.content) > (allowed_content_chars := 4096-len(reported_on_msg)-2): # -2 for the newlines @@ -111,6 +111,7 @@ async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: di embed = discord.Embed( description=description, colour=colour, + timestamp=datetime.now(timezone.utc) ) embed.set_footer(text=footer, icon_url=actioned_by.display_avatar.url) diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index ef33aa62b..da0a79ce8 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -119,14 +119,13 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): current_time = datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) incident = MockMessage(content="this is an incident", created_at=current_time) - day_timestamp = discord_timestamp(current_time, TimestampFormats.DATE) - time_timestamp = discord_timestamp(current_time, TimestampFormats.TIME) + reported_timestamp = discord_timestamp(current_time) relative_timestamp = discord_timestamp(current_time, TimestampFormats.RELATIVE) embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) self.assertEqual( - f"{incident.content}\n\n__*Reported {day_timestamp} at {time_timestamp} ({relative_timestamp}).*__", + f"{incident.content}\n\n*Reported {reported_timestamp} ({relative_timestamp}).*", embed.description ) -- cgit v1.2.3 From 32d44d10635d9d721c5bbce22400449052853aa5 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Wed, 13 Jul 2022 22:44:26 +0100 Subject: Update snekbox tests to reflect current behaviour --- tests/bot/exts/utils/test_snekbox.py | 60 ++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 23 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index b870a9945..2fff20fd9 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -6,9 +6,10 @@ from discord import AllowedMentions from discord.ext import commands from bot import constants +from bot.errors import LockedResourceError from bot.exts.utils import snekbox from bot.exts.utils.snekbox import Snekbox -from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser +from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser class SnekboxTests(unittest.IsolatedAsyncioTestCase): @@ -26,7 +27,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_job("import random"), "return") + self.assertEqual(await self.cog.post_job("import random", "3.10"), "return") self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, json={"input": "import random"}, @@ -179,9 +180,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.send_job = AsyncMock(return_value=response) self.cog.continue_job = AsyncMock(return_value=(None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) - self.cog.send_job.assert_called_once_with(ctx, 'MyAwesomeCode', args=None, job_name='eval') - self.cog.continue_job.assert_called_once_with(ctx, response, ctx.command) + await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) + self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval') + self.cog.continue_job.assert_called_once_with(ctx, response, 'eval') async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" @@ -192,23 +193,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.continue_job = AsyncMock() self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) + await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) self.cog.send_job.assert_called_with( - ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' + ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval' ) - self.cog.continue_job.assert_called_with(ctx, response, ctx.command) + self.cog.continue_job.assert_called_with(ctx, response, 'eval') async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" ctx = MockContext() ctx.author.id = 42 - ctx.author.mention = '@LemonLemonishBeard#0042' - ctx.send = AsyncMock() - self.cog.jobs = (42,) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - ctx.send.assert_called_once_with( - "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" - ) + + async def delay_with_side_effect(*args, **kwargs) -> dict: + """Delay the post_job call to ensure the job runs long enough to conflict.""" + await asyncio.sleep(1) + return {'stdout': '', 'returncode': 0} + + self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) + with self.assertRaises(LockedResourceError): + await asyncio.gather( + self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + ) async def test_send_job(self): """Test the send_job function.""" @@ -226,7 +232,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -237,7 +243,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('') @@ -258,7 +264,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -267,7 +273,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('Way too long beard') @@ -287,7 +293,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -295,7 +301,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') self.cog.format_output.assert_not_called() @@ -303,9 +309,17 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.utils.snekbox.partial") async def test_continue_job_does_continue(self, partial_mock): """Test that the continue_job function does continue if required conditions are met.""" - ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) - response = MockMessage(delete=AsyncMock()) + ctx = MockContext( + message=MockMessage( + id=4, + add_reaction=AsyncMock(), + clear_reactions=AsyncMock() + ), + author=MockMember(id=14) + ) + response = MockMessage(id=42, delete=AsyncMock()) new_msg = MockMessage() + self.cog.jobs = {4: 42} self.bot.wait_for.side_effect = ((None, new_msg), None) expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) -- cgit v1.2.3 From 28d91d4ab8f560593b2e5fad728e5a74c62e9e4c Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sun, 17 Jul 2022 13:16:36 +0100 Subject: Update snekbox tests to expect new output --- tests/bot/exts/utils/test_snekbox.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 2fff20fd9..b1f32c210 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -85,28 +85,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_get_results_message(self): """Return error and message according to the eval result.""" cases = ( - ('ERROR', None, ('Your eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), - ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval') + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11') self.assertEqual(actual, expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), - ('Your eval job has completed with return code 127', '') + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), + ('Your 3.11 eval job has completed with return code 127', '') ) @patch('bot.exts.utils.snekbox.Signals') def test_get_results_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), - ('Your eval job has completed with return code 127 (SIGTEST)', '') + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), + ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '') ) def test_get_status_emoji(self): @@ -245,7 +245,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11') self.cog.format_output.assert_called_once_with('') async def test_send_job_with_paste_link(self): @@ -275,7 +275,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') + self.cog.get_results_message.assert_called_once_with( + {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11' + ) self.cog.format_output.assert_called_once_with('Way too long beard') async def test_send_job_with_non_zero_eval(self): @@ -303,7 +305,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") -- cgit v1.2.3 From f599c7bb945a4d0e26ff3e9f5f234f3f34f5ff16 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:44:58 +0100 Subject: Remove call to get_event_loop in tests get_event_loop is deprecated as of 3.10 if there is no running loop. --- tests/bot/exts/moderation/test_silence.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 65aecad28..82ec138db 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -16,20 +16,19 @@ from tests.helpers import ( ) redis_session = None -redis_loop = asyncio.get_event_loop() def setUpModule(): # noqa: N802 """Create and connect to the fakeredis session.""" global redis_session redis_session = RedisSession(use_fakeredis=True) - redis_loop.run_until_complete(redis_session.connect()) + asyncio.run(redis_session.connect()) def tearDownModule(): # noqa: N802 """Close the fakeredis session.""" if redis_session: - redis_loop.run_until_complete(redis_session.close()) + asyncio.run(redis_session.client.close()) # Have to subclass it because builtins can't be patched. -- cgit v1.2.3 From c906daa2250558962f00be1c423a9a0cff98f905 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:46:18 +0100 Subject: Remove warnings in error handler tests These warnings were caused by the setup coro from error_handler.py being imported directly, causing a warning about an un-awaited coro whenever the Cog was accessed from the same file. --- tests/bot/exts/backend/test_error_handler.py | 103 ++++++++++++--------------- 1 file changed, 47 insertions(+), 56 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 0a58126e7..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError from discord.ext.commands import errors from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_error_handler_already_handled(self): """Should not do anything when error is already handled by local error handler.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandError() error.handled = "foo" - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_not_awaited() async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): "called_try_get_tag": True } ) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock(return_value=False) + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): self.ctx.reset_mock() - cog.try_silence.reset_mock(return_value=True) - cog.try_get_tag.reset_mock() + self.cog.try_silence.reset_mock(return_value=True) + self.cog.try_get_tag.reset_mock() - cog.try_silence.return_value = case["try_silence_return"] + self.cog.try_silence.return_value = case["try_silence_return"] self.ctx.channel.id = 1234 - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) if case["try_silence_return"]: - cog.try_get_tag.assert_not_awaited() - cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_awaited_once() else: - cog.try_silence.assert_awaited_once() - cog.try_get_tag.assert_awaited_once() + self.cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_awaited_once() self.ctx.send.assert_not_awaited() @@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() - self.assertIsNone(await cog.on_command_error(ctx, error)) + self.assertIsNone(await self.cog.on_command_error(ctx, error)) - cog.try_silence.assert_not_awaited() - cog.try_get_tag.assert_not_awaited() - cog.try_run_eval.assert_not_awaited() + self.cog.try_silence.assert_not_awaited() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_user_input_error = AsyncMock() + self.cog.handle_user_input_error = AsyncMock() error = errors.UserInputError() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) async def test_error_handler_check_failure(self): """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_check_failure = AsyncMock() + self.cog.handle_check_failure = AsyncMock() error = errors.CheckFailure() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) async def test_error_handler_command_on_cooldown(self): """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandOnCooldown(10, 9, type=None) - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_awaited_once_with(error) async def test_error_handler_command_invoke_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() test_cases = ( { "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), - "expect_mock_call": cog.handle_api_error + "expect_mock_call": self.cog.handle_api_error }, { "args": (self.ctx, errors.CommandInvokeError(TypeError)), - "expect_mock_call": cog.handle_unexpected_error + "expect_mock_call": self.cog.handle_unexpected_error }, { "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for case in test_cases: with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): self.ctx.send.reset_mock() - self.assertIsNone(await cog.on_command_error(*case["args"])) + self.assertIsNone(await self.cog.on_command_error(*case["args"])) if case["expect_mock_call"] == "send": self.ctx.send.assert_awaited_once() else: @@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): async def test_error_handler_conversion_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() cases = ( { "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), - "mock_function_to_call": cog.handle_api_error + "mock_function_to_call": self.cog.handle_api_error }, { "error": errors.ConversionError(AsyncMock(), TypeError), - "mock_function_to_call": cog.handle_unexpected_error + "mock_function_to_call": self.cog.handle_unexpected_error } ) for case in cases: with self.subTest(**case): - self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) + self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"])) case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) async def test_error_handler_two_other_errors(self): """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" - cog = ErrorHandler(self.bot) - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() errs = ( errors.MaxConcurrencyReached(1, MagicMock()), errors.ExtensionError(name="foo") @@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for err in errs: with self.subTest(error=err): - cog.handle_unexpected_error.reset_mock() - self.assertIsNone(await cog.on_command_error(self.ctx, err)) - cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + self.cog.handle_unexpected_error.reset_mock() + self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) + self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) @patch("bot.exts.backend.error_handler.log") async def test_error_handler_other_errors(self, log_mock): """Should `log.debug` other errors.""" - cog = ErrorHandler(self.bot) error = errors.DisabledCommand() # Use this just as a other error - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) log_mock.debug.assert_called_once() @@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.silence = Silence(self.bot) self.bot.get_command.return_value = self.silence.silence self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_try_silence_context_invoked_from_error_handler(self): """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.ctx = MockContext() self.tag = Tags(self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command async def test_try_get_tag_get_command(self): @@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_handle_input_error_handler_errors(self): """Should handle each error probably.""" @@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - await setup(bot) + await error_handler.setup(bot) bot.add_cog.assert_awaited_once() -- cgit v1.2.3 From 7782c196830098f81f39d235354636cd0d4a481d Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:52:52 +0100 Subject: No longer use the removed RedisSession connection object This has been abstracted away, the correct way to do this now is to directly access the client. --- bot/exts/info/doc/_redis_cache.py | 92 ++++++++++++++--------------- tests/bot/exts/moderation/test_incidents.py | 5 +- 2 files changed, 46 insertions(+), 51 deletions(-) (limited to 'tests') diff --git a/bot/exts/info/doc/_redis_cache.py b/bot/exts/info/doc/_redis_cache.py index 8e08e7ae4..0f4d663d1 100644 --- a/bot/exts/info/doc/_redis_cache.py +++ b/bot/exts/info/doc/_redis_cache.py @@ -34,55 +34,52 @@ class DocRedisCache(RedisObject): redis_key = f"{self.namespace}:{item_key(item)}" needs_expire = False - with await self._get_pool_connection() as connection: - set_expire = self._set_expires.get(redis_key) - if set_expire is None: - # An expire is only set if the key didn't exist before. - ttl = await connection.ttl(redis_key) - log.debug(f"Checked TTL for `{redis_key}`.") - - if ttl == -1: - log.warning(f"Key `{redis_key}` had no expire set.") - if ttl < 0: # not set or didn't exist - needs_expire = True - else: - log.debug(f"Key `{redis_key}` has a {ttl} TTL.") - self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis - - elif time.monotonic() > set_expire: - # If we got here the key expired in redis and we can be sure it doesn't exist. + set_expire = self._set_expires.get(redis_key) + if set_expire is None: + # An expire is only set if the key didn't exist before. + ttl = await self.redis_session.client.ttl(redis_key) + log.debug(f"Checked TTL for `{redis_key}`.") + + if ttl == -1: + log.warning(f"Key `{redis_key}` had no expire set.") + if ttl < 0: # not set or didn't exist needs_expire = True - log.debug(f"Key `{redis_key}` expired in internal key cache.") + else: + log.debug(f"Key `{redis_key}` has a {ttl} TTL.") + self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis - await connection.hset(redis_key, item.symbol_id, value) - if needs_expire: - self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS - await connection.expire(redis_key, WEEK_SECONDS) - log.info(f"Set {redis_key} to expire in a week.") + elif time.monotonic() > set_expire: + # If we got here the key expired in redis and we can be sure it doesn't exist. + needs_expire = True + log.debug(f"Key `{redis_key}` expired in internal key cache.") + + await self.redis_session.client.hset(redis_key, item.symbol_id, value) + if needs_expire: + self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS + await self.redis_session.client.expire(redis_key, WEEK_SECONDS) + log.info(f"Set {redis_key} to expire in a week.") @namespace_lock async def get(self, item: DocItem) -> Optional[str]: """Return the Markdown content of the symbol `item` if it exists.""" - with await self._get_pool_connection() as connection: - return await connection.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id, encoding="utf8") + return await self.redis_session.client.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id) @namespace_lock async def delete(self, package: str) -> bool: """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" pattern = f"{self.namespace}:{package}:*" - with await self._get_pool_connection() as connection: - package_keys = [ - package_key async for package_key in connection.iscan(match=pattern) - ] - if package_keys: - await connection.delete(*package_keys) - log.info(f"Deleted keys from redis: {package_keys}.") - self._set_expires = { - key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) - } - return True - return False + package_keys = [ + package_key async for package_key in self.redis_session.client.iscan(match=pattern) + ] + if package_keys: + await self.redis_session.client.delete(*package_keys) + log.info(f"Deleted keys from redis: {package_keys}.") + self._set_expires = { + key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) + } + return True + return False class StaleItemCounter(RedisObject): @@ -96,21 +93,20 @@ class StaleItemCounter(RedisObject): If the counter didn't exist, initialize it with 1. """ key = f"{self.namespace}:{item_key(item)}:{item.symbol_id}" - with await self._get_pool_connection() as connection: - await connection.expire(key, WEEK_SECONDS * 3) - return int(await connection.incr(key)) + await self.redis_session.client.expire(key, WEEK_SECONDS * 3) + return int(await self.redis_session.client.incr(key)) @namespace_lock async def delete(self, package: str) -> bool: """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" - with await self._get_pool_connection() as connection: - package_keys = [ - package_key async for package_key in connection.iscan(match=f"{self.namespace}:{package}:*") - ] - if package_keys: - await connection.delete(*package_keys) - return True - return False + package_keys = [ + package_key + async for package_key in self.redis_session.client.iscan(match=f"{self.namespace}:{package}:*") + ] + if package_keys: + await self.redis_session.client.delete(*package_keys) + return True + return False def item_key(item: DocItem) -> str: diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..f60c177c5 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -283,8 +283,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): async def flush(self): """Flush everything from the database to prevent carry-overs between tests.""" - with await self.session.pool as connection: - await connection.flushall() + await self.session.client.flushall() async def asyncSetUp(self): # noqa: N802 self.session = RedisSession(use_fakeredis=True) @@ -293,7 +292,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): async def asyncTearDown(self): # noqa: N802 if self.session: - await self.session.close() + await self.session.client.close() def setUp(self): """ -- cgit v1.2.3 From 46da1ecf621a64e6d8f0a37572378ae363ba76a2 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 25 Jul 2022 23:24:25 +0100 Subject: Stop creating futures in tests with no event loop running --- tests/bot/exts/moderation/test_silence.py | 6 ------ 1 file changed, 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 82ec138db..03b7b2fdb 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -250,8 +250,6 @@ class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @@ -413,8 +411,6 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() @@ -686,8 +682,6 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) self.cog.previous_overwrites = overwrites_cache -- cgit v1.2.3 From 9cf3de3e9bf6725b2baa2e7adb77e058c216b332 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Tue, 26 Jul 2022 00:56:12 +0100 Subject: Remove unneeded N802 noqas pep-naming now supports these functions being in camel case. --- tests/bot/exts/moderation/test_incidents.py | 6 +++--- tests/bot/exts/moderation/test_silence.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index f60c177c5..211eb1bf8 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -285,12 +285,12 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): """Flush everything from the database to prevent carry-overs between tests.""" await self.session.client.flushall() - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): self.session = RedisSession(use_fakeredis=True) await self.session.connect() await self.flush() - async def asyncTearDown(self): # noqa: N802 + async def asyncTearDown(self): if self.session: await self.session.client.close() @@ -655,7 +655,7 @@ class TestOnRawReactionAdd(TestIncidents): emoji="reaction", ) - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): """ Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 03b7b2fdb..f5caefdca 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -18,14 +18,14 @@ from tests.helpers import ( redis_session = None -def setUpModule(): # noqa: N802 +def setUpModule(): """Create and connect to the fakeredis session.""" global redis_session redis_session = RedisSession(use_fakeredis=True) asyncio.run(redis_session.connect()) -def tearDownModule(): # noqa: N802 +def tearDownModule(): """Close the fakeredis session.""" if redis_session: asyncio.run(redis_session.client.close()) -- cgit v1.2.3 From 4a47c816641332fbb49f8c88c8a7720849cabf06 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 30 Jul 2022 19:28:36 +0100 Subject: Add a new test helper for managing redis sessions This helper ensures that a fresh RedisSession is given to each test case that inherits from it. --- tests/base.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from typing import Dict import discord +from async_rediscache import RedisSession from discord.ext import commands from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): await cmd.can_run(ctx) self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): + """ + Use this as a base class for any test cases that require a redis session. + + This will prepare a fresh redis instance for each test function, and will + not make any assertions on its own. Tests can mutate the instance as they wish. + """ + + session = None + + async def flush(self): + """Flush everything from the redis database to prevent carry-overs between tests.""" + await self.session.client.flushall() + + async def asyncSetUp(self): + self.session = await RedisSession(use_fakeredis=True).connect() + await self.flush() + + async def asyncTearDown(self): + if self.session: + await self.session.client.close() -- cgit v1.2.3 From f044f36833e9dc003e89dd81868ea3f48a9da002 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 30 Jul 2022 19:29:27 +0100 Subject: Use RedisTestCase helper class for both Incidents and Silence test cases. --- tests/bot/exts/moderation/test_incidents.py | 19 ++----------------- tests/bot/exts/moderation/test_silence.py | 23 ++++------------------- 2 files changed, 6 insertions(+), 36 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 211eb1bf8..97682163f 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp import discord -from async_rediscache import RedisSession from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user +from tests.base import RedisTestCase from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser @@ -270,7 +270,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase): """ Tests for bound methods of the `Incidents` cog. @@ -279,21 +279,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): the instance as they wish. """ - session = None - - async def flush(self): - """Flush everything from the database to prevent carry-overs between tests.""" - await self.session.client.flushall() - - async def asyncSetUp(self): - self.session = RedisSession(use_fakeredis=True) - await self.session.connect() - await self.flush() - - async def asyncTearDown(self): - if self.session: - await self.session.client.close() - def setUp(self): """ Prepare a fresh `Incidents` instance for each test. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index f5caefdca..98547e2bc 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -6,30 +6,15 @@ from typing import List, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence +from tests.base import RedisTestCase from tests.helpers import ( MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec ) -redis_session = None - - -def setUpModule(): - """Create and connect to the fakeredis session.""" - global redis_session - redis_session = RedisSession(use_fakeredis=True) - asyncio.run(redis_session.connect()) - - -def tearDownModule(): - """Close the fakeredis session.""" - if redis_session: - asyncio.run(redis_session.client.close()) - # Have to subclass it because builtins can't be patched. class PatchedDatetime(datetime): @@ -104,7 +89,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(RedisTestCase): """Tests for the general functionality of the Silence cog.""" @autospec(silence, "Scheduler", pass_mocks=False) @@ -244,7 +229,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(RedisTestCase): """Tests for the silence argument parser utility function.""" def setUp(self): @@ -403,7 +388,7 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(RedisTestCase): """Tests for the silence command and its related helper methods.""" @autospec(silence.Silence, "_reschedule", pass_mocks=False) -- cgit v1.2.3 From 1df6034a1723fd3ff1bd88047ca6a62f920767e6 Mon Sep 17 00:00:00 2001 From: Izan Date: Mon, 15 Aug 2022 12:04:33 +0100 Subject: Fix incident tests. --- tests/bot/exts/moderation/test_incidents.py | 38 +++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 11fe565fc..53d98360c 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -20,6 +20,8 @@ from tests.helpers import ( MockUser ) +CURRENT_TIME = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc) + class MockAsyncIterable: """ @@ -102,25 +104,32 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_actioned(self): """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + embed, file = await incidents.make_embed( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=incidents.Signal.ACTIONED, + actioned_by=MockMember() + ) self.assertEqual(embed.colour.value, Colours.soft_green) self.assertIn("Actioned", embed.footer.text) async def test_make_embed_not_actioned(self): """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + embed, file = await incidents.make_embed( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=incidents.Signal.NOT_ACTIONED, + actioned_by=MockMember() + ) self.assertEqual(embed.colour.value, Colours.soft_red) self.assertIn("Rejected", embed.footer.text) async def test_make_embed_content(self): """Incident content appears as embed description.""" - current_time = datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) - incident = MockMessage(content="this is an incident", created_at=current_time) + incident = MockMessage(content="this is an incident", created_at=CURRENT_TIME) - reported_timestamp = discord_timestamp(current_time) - relative_timestamp = discord_timestamp(current_time, TimestampFormats.RELATIVE) + reported_timestamp = discord_timestamp(CURRENT_TIME) + relative_timestamp = discord_timestamp(CURRENT_TIME, TimestampFormats.RELATIVE) embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) @@ -133,7 +142,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): """Incident's attachment is downloaded and displayed in the embed's image field.""" file = MagicMock(discord.File, filename="bigbadjoe.jpg") attachment = MockAttachment(filename="bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) + incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME) # Patch `download_file` to return our `file` with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): @@ -145,7 +154,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_with_attachment_fails(self): """Incident's attachment fails to download, proxy url is linked instead.""" attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) + incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME) # Patch `download_file` to return None as if the download failed with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): @@ -359,7 +368,6 @@ class TestCrawlIncidents(TestIncidents): class TestArchive(TestIncidents): """Tests for the `Incidents.archive` coroutine.""" - async def test_archive_webhook_not_found(self): """ Method recovers and returns False when the webhook is not found. @@ -369,7 +377,11 @@ class TestArchive(TestIncidents): """ self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) self.assertFalse( - await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) + await self.cog_instance.archive( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=MagicMock(), + actioned_by=MockMember() + ) ) async def test_archive_relays_incident(self): @@ -416,7 +428,7 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great")) + message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great"), created_at=CURRENT_TIME) await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) @@ -520,7 +532,7 @@ class TestProcessEvent(TestIncidents): with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(author=mock_member, id=123), + incident=MockMessage(author=mock_member, id=123, created_at=CURRENT_TIME), member=mock_member ) @@ -540,7 +552,7 @@ class TestProcessEvent(TestIncidents): with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(id=123), + incident=MockMessage(id=123, created_at=CURRENT_TIME), member=MockMember(roles=[MockRole(id=1)]) ) except asyncio.TimeoutError: -- cgit v1.2.3 From e0b593318eba77d6fe93f2145b43838d6eb09278 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 15 Aug 2022 22:14:32 +0100 Subject: Correctly initialise redis tests Calling the cog_load from within the setUp function resulted in interaction with a RedisSession before it was initialised. This wasn't noticed in CI as it only error under certain concurrency timings due to xdist. To resolve this, we moved the setup and async setup logic to a base class. Co-authored-by: Hassan Abouelela --- tests/bot/exts/moderation/test_silence.py | 79 +++++++++++++++---------------- 1 file changed, 37 insertions(+), 42 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 98547e2bc..2622f46a7 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,4 +1,3 @@ -import asyncio import itertools import unittest from datetime import datetime, timezone @@ -23,8 +22,24 @@ class PatchedDatetime(datetime): now = mock.create_autospec(datetime, "now") -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): +class SilenceTest(RedisTestCase): + """A base class for Silence tests that correctly sets up the cog and redis.""" + + @autospec(silence, "Scheduler", pass_mocks=False) + @autospec(silence.Silence, "_reschedule", pass_mocks=False) + def setUp(self) -> None: + self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id)) + self.cog = silence.Silence(self.bot) + + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + await self.cog.cog_load() # Populate instance attributes. + + +class SilenceNotifierTests(SilenceTest): def setUp(self) -> None: + super().setUp() self.alert_channel = MockTextChannel() self.notifier = silence.SilenceNotifier(self.alert_channel) self.notifier.stop = self.notifier_stop_mock = Mock() @@ -89,34 +104,24 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(RedisTestCase): +class SilenceCogTests(SilenceTest): """Tests for the general functionality of the Silence cog.""" - @autospec(silence, "Scheduler", pass_mocks=False) - def setUp(self) -> None: - self.bot = MockBot() - self.cog = silence.Silence(self.bot) - @autospec(silence, "SilenceNotifier", pass_mocks=False) async def test_cog_load_got_guild(self): """Bot got guild after it became available.""" - await self.cog.cog_load() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) async def test_cog_load_got_channels(self): """Got channels from bot.""" - self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog.cog_load() self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") async def test_cog_load_got_notifier(self, notifier): """Notifier was started with channel.""" - self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog.cog_load() notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) self.assertEqual(self.cog.notifier, notifier.return_value) @@ -229,13 +234,9 @@ class SilenceCogTests(RedisTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(RedisTestCase): +class SilenceArgumentParserTests(SilenceTest): """Tests for the silence argument parser utility function.""" - def setUp(self): - self.bot = MockBot() - self.cog = silence.Silence(self.bot) - @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @autospec(silence.Silence, "parse_silence_args") @@ -303,17 +304,19 @@ class SilenceArgumentParserTests(RedisTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class RescheduleTests(unittest.IsolatedAsyncioTestCase): +class RescheduleTests(RedisTestCase): """Tests for the rescheduling of cached unsilences.""" - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) - def setUp(self): + @autospec(silence, "Scheduler", pass_mocks=False) + def setUp(self) -> None: self.bot = MockBot() self.cog = silence.Silence(self.bot) self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) - with mock.patch.object(self.cog, "_reschedule", autospec=True): - asyncio.run(self.cog.cog_load()) # Populate instance attributes. + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + await self.cog.cog_load() # Populate instance attributes. async def test_skipped_missing_channel(self): """Did nothing because the channel couldn't be retrieved.""" @@ -388,20 +391,14 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(RedisTestCase): +class SilenceTests(SilenceTest): """Tests for the silence command and its related helper methods.""" - @autospec(silence.Silence, "_reschedule", pass_mocks=False) - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot(get_channel=lambda _: MockTextChannel()) - self.cog = silence.Silence(self.bot) + super().setUp() # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - - asyncio.run(self.cog.cog_load()) # Populate instance attributes. - self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite( send_messages=True, @@ -659,22 +656,13 @@ class SilenceTests(RedisTestCase): @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) -class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +class UnsilenceTests(SilenceTest): """Tests for the unsilence command and its related helper methods.""" - @autospec(silence.Silence, "_reschedule", pass_mocks=False) - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot(get_channel=lambda _: MockTextChannel()) - self.cog = silence.Silence(self.bot) - - overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) - self.cog.previous_overwrites = overwrites_cache - - asyncio.run(self.cog.cog_load()) # Populate instance attributes. + super().setUp() self.cog.scheduler.__contains__.return_value = True - overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False) self.text_channel.overwrites_for.return_value = self.text_overwrite @@ -683,6 +671,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) self.voice_channel.overwrites_for.return_value = self.voice_overwrite + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) + self.cog.previous_overwrites = overwrites_cache + + overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' + async def test_sent_correct_message(self): """Appropriate failure/success message was sent by the command.""" unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) -- cgit v1.2.3 From dada405211eac996196cdfb0496f4ff22f9a656a Mon Sep 17 00:00:00 2001 From: arl Date: Thu, 18 Aug 2022 19:01:22 -0400 Subject: fix: don't include replied mentions in mention filter (#2017) Co-authored-by: Izan Co-authored-by: TizzySaurus <47674925+TizzySaurus@users.noreply.github.com> Co-authored-by: Xithrius <15021300+Xithrius@users.noreply.github.com> --- bot/rules/mentions.py | 56 +++++++++++++++++++++++++++++++++----- tests/bot/rules/test_mentions.py | 58 ++++++++++++++++++++++++++++++++++++---- tests/helpers.py | 22 +++++++++++++++ 3 files changed, 124 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 6f5addad1..ca1d0c01c 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -1,23 +1,65 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from discord import DeletedReferencedMessage, Member, Message, MessageType, NotFound + +import bot +from bot.log import get_logger + +log = get_logger(__name__) async def apply( last_message: Message, recent_messages: List[Message], config: Dict[str, int] ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - """Detects total mentions exceeding the limit sent by a single user.""" + """ + Detects total mentions exceeding the limit sent by a single user. + + Excludes mentions that are bots, themselves, or replied users. + + In very rare cases, may not be able to determine a + mention was to a reply, in which case it is not ignored. + """ relevant_messages = tuple( msg for msg in recent_messages if msg.author == last_message.author ) + # We use `msg.mentions` here as that is supplied by the api itself, to determine who was mentioned. + # Additionally, `msg.mentions` includes the user replied to, even if the mention doesn't occur in the body. + # In order to exclude users who are mentioned as a reply, we check if the msg has a reference + # + # While we could use regex to parse the message content, and get a list of + # the mentions, that solution is very prone to breaking. + # We would need to deal with codeblocks, escaping markdown, and any discrepancies between + # our implementation and discord's markdown parser which would cause false positives or false negatives. + total_recent_mentions = 0 + for msg in relevant_messages: + # We check if the message is a reply, and if it is try to get the author + # since we ignore mentions of a user that we're replying to + reply_author = None - total_recent_mentions = sum( - not user.bot - for msg in relevant_messages - for user in msg.mentions - ) + if msg.type == MessageType.reply: + ref = msg.reference + + if not (resolved := ref.resolved): + # It is possible, in a very unusual situation, for a message to have a reference + # that is both not in the cache, and deleted while running this function. + # In such a situation, this will throw an error which we catch. + try: + resolved = await bot.instance.get_partial_messageable(resolved.channel_id).fetch_message( + resolved.message_id + ) + except NotFound: + log.info('Could not fetch the reference message as it has been deleted.') + + if resolved and not isinstance(resolved, DeletedReferencedMessage): + reply_author = resolved.author + + for user in msg.mentions: + # Don't count bot or self mentions, or the user being replied to (if applicable) + if user.bot or user in {msg.author, reply_author}: + continue + total_recent_mentions += 1 if total_recent_mentions > config['max']: return ( diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index f8805ac48..e1f904917 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,15 +1,32 @@ -from typing import Iterable +from typing import Iterable, Optional + +import discord from bot.rules import mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage +from tests.helpers import MockMember, MockMessage, MockMessageReference -def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage: - """Makes a message with `total_mentions` mentions.""" +def make_msg( + author: str, + total_user_mentions: int, + total_bot_mentions: int = 0, + *, + reference: Optional[MockMessageReference] = None +) -> MockMessage: + """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions.""" user_mentions = [MockMember() for _ in range(total_user_mentions)] bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] - return MockMessage(author=author, mentions=user_mentions+bot_mentions) + + mentions = user_mentions + bot_mentions + if reference is not None: + # For the sake of these tests we assume that all references are mentions. + mentions.append(reference.resolved.author) + msg_type = discord.MessageType.reply + else: + msg_type = discord.MessageType.default + + return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type) class TestMentions(RuleTest): @@ -56,6 +73,16 @@ class TestMentions(RuleTest): ("bob",), 3, ), + DisallowedCase( + [make_msg("bob", 3, reference=MockMessageReference())], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], + ("bob",), + 3 + ) ) await self.run_disallowed(cases) @@ -71,6 +98,27 @@ class TestMentions(RuleTest): await self.run_allowed(cases) + async def test_ignore_reply_mentions(self): + """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" + cases = ( + [ + make_msg("bob", 2, reference=MockMessageReference()) + ], + [ + make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) + ], + [ + make_msg("bob", 2, reference=MockMessageReference()), + make_msg("bob", 0, reference=MockMessageReference()) + ], + [ + make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), + make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) + ] + ) + + await self.run_allowed(cases) + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: last_message = case.recent_messages[0] return tuple( diff --git a/tests/helpers.py b/tests/helpers.py index 17214553c..687e15b96 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -492,6 +492,28 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): spec_set = attachment_instance +message_reference_instance = discord.MessageReference( + message_id=unittest.mock.MagicMock(id=1), + channel_id=unittest.mock.MagicMock(id=2), + guild_id=unittest.mock.MagicMock(id=3) +) + + +class MockMessageReference(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock MessageReference objects. + + Instances of this class will follow the specification of `discord.MessageReference` instances. + For more information, see the `MockGuild` docstring. + """ + spec_set = message_reference_instance + + def __init__(self, *, reference_author_is_bot: bool = False, **kwargs): + super().__init__(**kwargs) + referenced_msg_author = MockMember(name="bob", bot=reference_author_is_bot) + self.resolved = MockMessage(author=referenced_msg_author) + + class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. -- cgit v1.2.3