From 3f04611ddfc2e6d750d4c4e0a19d3cf154e7c5a9 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Thu, 20 May 2021 16:09:39 -0400 Subject: chore: Update tests to correspond with the timeit command --- tests/bot/exts/utils/test_snekbox.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 321a92445..1b3d61094 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -161,7 +161,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') + self.cog.send_eval.assert_called_once_with( + ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output + ) self.cog.continue_eval.assert_called_once_with(ctx, response) async def test_eval_command_evaluate_twice(self): @@ -171,11 +173,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) + self.cog.continue_eval.side_effect = ('MyAwesomeFormattedCode', None) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') + self.cog.send_eval.assert_called_with( + ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output + ) self.cog.continue_eval.assert_called_with(ctx, response) async def test_eval_command_reject_two_eval_at_the_same_time(self): @@ -190,12 +194,6 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) - async def test_eval_command_call_help(self): - """Test if the eval command call the help command if no code is provided.""" - ctx = MockContext(command="sentinel") - await self.cog.eval_command(self.cog, ctx=ctx, code='') - ctx.send_help.assert_called_once_with(ctx.command) - async def test_send_eval(self): """Test the send_eval function.""" ctx = MockContext() @@ -212,11 +210,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.format_output.assert_called_once_with('') @@ -237,12 +235,12 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :yay!: Return code 0.' '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.format_output.assert_called_once_with('Way too long beard') @@ -262,11 +260,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once_with( '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.format_output.assert_not_called() @@ -282,7 +280,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) actual = await self.cog.continue_eval(ctx, response) - self.cog.get_code.assert_awaited_once_with(new_msg) + self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) self.assertEqual(actual, expected) self.bot.wait_for.assert_has_awaits( ( @@ -327,7 +325,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.return_value = MockContext(command=command) message = MockMessage(content=content) - actual_code = await self.cog.get_code(message) + actual_code = await self.cog.get_code(message, self.cog.eval_command) self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) -- cgit v1.2.3 From af3c1459ba491e748339545687a8939b4dd70e43 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 15 Jul 2021 15:25:12 -0700 Subject: Add util function to send an infraction using an Infraction dict There was some redundant pre-processing of arguments happening before calling `notify_infraction`. --- bot/exts/moderation/infraction/_scheduler.py | 18 +++------- bot/exts/moderation/infraction/_utils.py | 38 +++++++++++++++++++++- bot/exts/moderation/infraction/superstarify.py | 4 +-- tests/bot/exts/moderation/infraction/test_utils.py | 4 +-- 4 files changed, 45 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 8286d3635..19402d01d 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -162,20 +162,12 @@ class InfractionScheduler: # apply kick/ban infractions first, this would mean that we'd make it # impossible for us to deliver a DM. See python-discord/bot#982. if not infraction["hidden"]: - dm_result = f"{constants.Emojis.failmail} " - dm_log_text = "\nDM: **Failed**" - - # Sometimes user is a discord.Object; make it a proper user. - try: - if not isinstance(user, (discord.Member, discord.User)): - user = await self.bot.fetch_user(user.id) - except discord.HTTPException as e: - log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + if await _utils.notify_infraction(infraction, user, user_reason): + dm_result = ":incoming_envelope: " + dm_log_text = "\nDM: Sent" else: - # Accordingly display whether the user was successfully notified via DM. - if await _utils.notify_infraction(user, infr_type.replace("_", " ").title(), expiry, user_reason, icon): - dm_result = ":incoming_envelope: " - dm_log_text = "\nDM: Sent" + dm_result = f"{constants.Emojis.failmail} " + dm_log_text = "\nDM: **Failed**" end_msg = "" if infraction["actor"] == self.bot.user.id: diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index adbc641fa..a6f180c8c 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -5,9 +5,11 @@ from datetime import datetime import discord from discord.ext.commands import Context +import bot from bot.api import ResponseCodeError from bot.constants import Colours, Icons from bot.errors import InvalidInfractedUserError +from bot.utils import time log = logging.getLogger(__name__) @@ -152,7 +154,7 @@ async def get_active_infraction( log.trace(f"{user} does not have active infractions of type {infr_type}.") -async def notify_infraction( +async def send_infraction_embed( user: UserObject, infr_type: str, expires_at: t.Optional[str] = None, @@ -188,6 +190,40 @@ async def notify_infraction( return await send_private_embed(user, embed) +async def notify_infraction( + infraction: Infraction, + user: t.Optional[UserSnowflake] = None, + reason: t.Optional[str] = None +) -> bool: + """ + DM a user about their new infraction and return True if the DM is successful. + + `user` and `reason` can be used to override what is in `infraction`. Otherwise, this data will + be retrieved from `infraction`. + + Also return False if the user needs to be fetched but fails to be fetched. + """ + if user is None: + user = discord.Object(infraction["user"]) + + # Sometimes user is a discord.Object; make it a proper user. + try: + if not isinstance(user, (discord.Member, discord.User)): + user = await bot.instance.fetch_user(user.id) + except discord.HTTPException as e: + log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + return False + + type_ = infraction["type"].replace("_", " ").title() + icon = INFRACTION_ICONS[infraction["type"]][0] + expiry = time.format_infraction_with_duration(infraction["expires_at"]) + + if reason is None: + reason = infraction["reason"] + + return await send_infraction_embed(user, type_, expiry, reason, icon) + + async def notify_pardon( user: UserObject, title: str, diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 07e79b9fe..6dd9924ad 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -70,15 +70,13 @@ class Superstarify(InfractionScheduler, Cog): ) notified = await _utils.notify_infraction( + infraction=infraction, user=after, - infr_type="Superstarify", - expires_at=format_infraction(infraction["expires_at"]), reason=( "You have tried to change your nickname on the **Python Discord** server " f"from **{before.display_name}** to **{after.display_name}**, but as you " "are currently in superstar-prison, you do not have permission to do so." ), - icon_url=_utils.INFRACTION_ICONS["superstar"][0] ) if not notified: diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 50a717bb5..d35120992 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -124,7 +124,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.ctx.send.assert_not_awaited() @patch("bot.exts.moderation.infraction._utils.send_private_embed") - async def test_notify_infraction(self, send_private_embed_mock): + async def test_send_infraction_embed(self, send_private_embed_mock): """ Should send an embed of a certain format as a DM and return `True` if DM successful. @@ -230,7 +230,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): send_private_embed_mock.reset_mock() send_private_embed_mock.return_value = case["send_result"] - result = await utils.notify_infraction(*case["args"]) + result = await utils.send_infraction_embed(*case["args"]) self.assertEqual(case["send_result"], result) -- cgit v1.2.3 From 6b280b19ed5c564e824e55a1ec9bb13120c0193d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 5 Aug 2021 16:13:49 -0700 Subject: Time: remove RFC1123 support It's not used anywhere and hasn't been for a very long time. --- bot/utils/time.py | 6 ------ tests/bot/utils/test_time.py | 7 ------- 2 files changed, 13 deletions(-) (limited to 'tests') diff --git a/bot/utils/time.py b/bot/utils/time.py index eaa9b72e9..545e50859 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -7,7 +7,6 @@ import arrow import dateutil.parser from dateutil.relativedelta import relativedelta -RFC1123_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" DISCORD_TIMESTAMP_REGEX = re.compile(r"") _DURATION_REGEX = re.compile( @@ -167,11 +166,6 @@ def time_since(past_datetime: datetime.datetime) -> str: return discord_timestamp(past_datetime, TimestampFormats.RELATIVE) -def parse_rfc1123(stamp: str) -> datetime.datetime: - """Parse RFC1123 time string into datetime.""" - return datetime.datetime.strptime(stamp, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc) - - def format_infraction(timestamp: str) -> str: """Format an infraction timestamp to a discord timestamp.""" return discord_timestamp(dateutil.parser.isoparse(timestamp)) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index a3dcbfc0a..9c52fed27 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -43,13 +43,6 @@ class TimeTests(unittest.TestCase): time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) self.assertEqual(str(error.exception), 'max_units must be positive') - def test_parse_rfc1123(self): - """Testing parse_rfc1123.""" - self.assertEqual( - time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), - datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) - ) - def test_format_infraction(self): """Testing format_infraction.""" self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '') -- cgit v1.2.3 From 93742d718dcb4aee72ef5d20ca570b6200f07d2d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 5 Aug 2021 19:14:08 -0700 Subject: Time: rename format_infraction_with_duration It's not necessarily tied to infractions anymore. --- bot/exts/moderation/infraction/_scheduler.py | 4 ++-- bot/exts/moderation/infraction/management.py | 2 +- bot/exts/moderation/stream.py | 2 +- bot/utils/time.py | 18 +++++++++--------- tests/bot/utils/test_time.py | 18 +++++++++--------- 5 files changed, 22 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 57aa2d9b6..9d4d58e2e 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -136,7 +136,7 @@ class InfractionScheduler: infr_type = infraction["type"] icon = _utils.INFRACTION_ICONS[infr_type][0] reason = infraction["reason"] - expiry = time.format_infraction_with_duration(infraction["expires_at"]) + expiry = time.format_with_duration(infraction["expires_at"]) id_ = infraction['id'] if user_reason is None: @@ -387,7 +387,7 @@ class InfractionScheduler: log.info(f"Marking infraction #{id_} as inactive (expired).") expiry = dateutil.parser.isoparse(expiry) if expiry else None - created = time.format_infraction_with_duration(inserted_at, expiry) + created = time.format_with_duration(inserted_at, expiry) log_content = None log_text = { diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index fb5af9eaa..dd994a2d2 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -150,7 +150,7 @@ class ModManagement(commands.Cog): confirm_messages.append("marked as permanent") elif duration is not None: request_data['expires_at'] = duration.isoformat() - expiry = time.format_infraction_with_duration(request_data['expires_at']) + expiry = time.format_with_duration(request_data['expires_at']) confirm_messages.append(f"set to expire on {expiry}") else: confirm_messages.append("expiry unchanged") diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 5a7b12295..bc9d35714 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -133,7 +133,7 @@ class Stream(commands.Cog): await ctx.send(f"{Emojis.check_mark} {member.mention} can now stream until {time.discord_timestamp(duration)}.") # Convert here for nicer logging - revoke_time = time.format_infraction_with_duration(str(duration)) + revoke_time = time.format_with_duration(str(duration)) log.debug(f"Successfully gave {member} ({member.id}) permission to stream until {revoke_time}.") @commands.command(aliases=("pstream",)) diff --git a/bot/utils/time.py b/bot/utils/time.py index 60720031a..13dfc6fb7 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -169,26 +169,26 @@ def format_infraction(timestamp: str) -> str: return discord_timestamp(dateutil.parser.isoparse(timestamp)) -def format_infraction_with_duration( - date_to: Optional[str], - date_from: Optional[datetime.datetime] = None, +def format_with_duration( + timestamp: Optional[str], + other_timestamp: Optional[datetime.datetime] = None, max_units: int = 2, ) -> Optional[str]: """ - Return `date_to` formatted as a discord timestamp with the timestamp duration since `date_from`. + Return `timestamp` formatted as a discord timestamp with the timestamp duration since `other_timestamp`. `max_units` specifies the maximum number of units of time to include in the duration. For example, a value of 1 may include days but not hours. """ - if not date_to: + if not timestamp: return None - date_to_formatted = format_infraction(date_to) + date_to_formatted = format_infraction(timestamp) - date_from = date_from or datetime.datetime.now(datetime.timezone.utc) - date_to = dateutil.parser.isoparse(date_to).replace(microsecond=0) + other_timestamp = other_timestamp or datetime.datetime.now(datetime.timezone.utc) + timestamp = dateutil.parser.isoparse(timestamp).replace(microsecond=0) - delta = abs(relativedelta(date_to, date_from)) + delta = abs(relativedelta(timestamp, other_timestamp)) duration = humanize_delta(delta, max_units=max_units) duration_formatted = f" ({duration})" if duration else "" diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 9c52fed27..02b5f8c17 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -47,8 +47,8 @@ class TimeTests(unittest.TestCase): """Testing format_infraction.""" self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '') - def test_format_infraction_with_duration_none_expiry(self): - """format_infraction_with_duration should work for None expiry.""" + def test_format_with_duration_none_expiry(self): + """format_with_duration should work for None expiry.""" test_cases = ( (None, None, None, None), @@ -60,10 +60,10 @@ class TimeTests(unittest.TestCase): for expiry, date_from, max_units, expected in test_cases: with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): - self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected) - def test_format_infraction_with_duration_custom_units(self): - """format_infraction_with_duration should work for custom max_units.""" + def test_format_with_duration_custom_units(self): + """format_with_duration should work for custom max_units.""" test_cases = ( ('3000-12-12T00:01:00Z', datetime(3000, 12, 11, 12, 5, 5, tzinfo=timezone.utc), 6, ' (11 hours, 55 minutes and 55 seconds)'), @@ -73,10 +73,10 @@ class TimeTests(unittest.TestCase): for expiry, date_from, max_units, expected in test_cases: with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): - self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected) - def test_format_infraction_with_duration_normal_usage(self): - """format_infraction_with_duration should work for normal usage, across various durations.""" + def test_format_with_duration_normal_usage(self): + """format_with_duration should work for normal usage, across various durations.""" utc = timezone.utc test_cases = ( ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5, tzinfo=utc), 2, @@ -98,7 +98,7 @@ class TimeTests(unittest.TestCase): for expiry, date_from, max_units, expected in test_cases: with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): - self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected) def test_until_expiration_with_duration_none_expiry(self): """until_expiration should work for None expiry.""" -- cgit v1.2.3 From ea7fc62ddc8d08b6acdb00ac2d9a024fee8ad634 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 5 Aug 2021 19:26:26 -0700 Subject: Time: support more timestamp formats as arguments Remove the burden of conversion from the caller to clean up and simplify the call sites. Handle timestamp conversions internally with arrow.get. Remove format_infraction and get_time_delta because they're now obsolete. Replace the former with discord_timestamp and the latter with format_relative. --- bot/exts/moderation/infraction/_scheduler.py | 7 +- bot/exts/moderation/infraction/management.py | 4 +- bot/exts/moderation/infraction/superstarify.py | 4 +- bot/exts/moderation/stream.py | 2 +- bot/exts/moderation/watchchannels/_watchchannel.py | 4 +- bot/exts/recruitment/talentpool/_cog.py | 8 +- bot/exts/recruitment/talentpool/_review.py | 4 +- bot/exts/utils/reminders.py | 5 +- bot/utils/time.py | 94 +++++++++++----------- tests/bot/utils/test_time.py | 4 - 10 files changed, 63 insertions(+), 73 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 9d4d58e2e..47b639421 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -381,20 +381,15 @@ class InfractionScheduler: actor = infraction["actor"] type_ = infraction["type"] id_ = infraction["id"] - inserted_at = infraction["inserted_at"] - expiry = infraction["expires_at"] log.info(f"Marking infraction #{id_} as inactive (expired).") - expiry = dateutil.parser.isoparse(expiry) if expiry else None - created = time.format_with_duration(inserted_at, expiry) - log_content = None log_text = { "Member": f"<@{user_id}>", "Actor": f"<@{actor}>", "Reason": infraction["reason"], - "Created": created, + "Created": time.format_with_duration(infraction["inserted_at"], infraction["expires_at"]), } try: diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index dd994a2d2..23c6e8b92 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -150,7 +150,7 @@ class ModManagement(commands.Cog): confirm_messages.append("marked as permanent") elif duration is not None: request_data['expires_at'] = duration.isoformat() - expiry = time.format_with_duration(request_data['expires_at']) + expiry = time.format_with_duration(duration) confirm_messages.append(f"set to expire on {expiry}") else: confirm_messages.append("expiry unchanged") @@ -351,7 +351,7 @@ class ModManagement(commands.Cog): active = infraction["active"] user = infraction["user"] expires_at = infraction["expires_at"] - created = time.format_infraction(infraction["inserted_at"]) + created = time.discord_timestamp(infraction["inserted_at"]) dm_sent = infraction["dm_sent"] # Format the user string. diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 2e272dbb0..a037ca1be 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -73,7 +73,7 @@ class Superstarify(InfractionScheduler, Cog): notified = await _utils.notify_infraction( user=after, infr_type="Superstarify", - expires_at=time.format_infraction(infraction["expires_at"]), + expires_at=time.discord_timestamp(infraction["expires_at"]), reason=( "You have tried to change your nickname on the **Python Discord** server " f"from **{before.display_name}** to **{after.display_name}**, but as you " @@ -150,7 +150,7 @@ class Superstarify(InfractionScheduler, Cog): id_ = infraction["id"] forced_nick = self.get_nick(id_, member.id) - expiry_str = time.format_infraction(infraction["expires_at"]) + expiry_str = time.discord_timestamp(infraction["expires_at"]) # Apply the infraction async def action() -> None: diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index bc9d35714..4dccc8a7e 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -133,7 +133,7 @@ class Stream(commands.Cog): await ctx.send(f"{Emojis.check_mark} {member.mention} can now stream until {time.discord_timestamp(duration)}.") # Convert here for nicer logging - revoke_time = time.format_with_duration(str(duration)) + revoke_time = time.format_with_duration(duration) log.debug(f"Successfully gave {member} ({member.id}) permission to stream until {revoke_time}.") @commands.command(aliases=("pstream",)) diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index 106483527..ee9b6ba45 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -285,7 +285,7 @@ class WatchChannel(metaclass=CogABCMeta): actor = actor.display_name if actor else self.watched_users[user_id]['actor'] inserted_at = self.watched_users[user_id]['inserted_at'] - time_delta = time.get_time_delta(inserted_at) + time_delta = time.format_relative(inserted_at) reason = self.watched_users[user_id]['reason'] @@ -359,7 +359,7 @@ class WatchChannel(metaclass=CogABCMeta): if member: line += f" ({member.name}#{member.discriminator})" inserted_at = user_data['inserted_at'] - line += f", added {time.get_time_delta(inserted_at)}" + line += f", added {time.format_relative(inserted_at)}" if not member: # Cross off users who left the server. line = f"~~{line}~~" list_data["info"][user_id] = line diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 80274eaea..bbc135454 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -180,7 +180,7 @@ class TalentPool(Cog, name="Talentpool"): if member: line += f" ({member.name}#{member.discriminator})" inserted_at = user_data['inserted_at'] - line += f", added {time.get_time_delta(inserted_at)}" + line += f", added {time.format_relative(inserted_at)}" if not member: # Cross off users who left the server. line = f"~~{line}~~" if user_data['reviewed']: @@ -561,7 +561,7 @@ class TalentPool(Cog, name="Talentpool"): actor = await get_or_fetch_member(guild, actor_id) reason = site_entry["reason"] or "*None*" - created = time.format_infraction(site_entry["inserted_at"]) + created = time.discord_timestamp(site_entry["inserted_at"]) entries.append( f"Actor: {actor.mention if actor else actor_id}\nCreated: {created}\nReason: {reason}" ) @@ -570,7 +570,7 @@ class TalentPool(Cog, name="Talentpool"): active = nomination_object["active"] - start_date = time.format_infraction(nomination_object["inserted_at"]) + start_date = time.discord_timestamp(nomination_object["inserted_at"]) if active: lines = textwrap.dedent( f""" @@ -584,7 +584,7 @@ class TalentPool(Cog, name="Talentpool"): """ ) else: - end_date = time.format_infraction(nomination_object["ended_at"]) + end_date = time.discord_timestamp(nomination_object["ended_at"]) lines = textwrap.dedent( f""" =============== diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index 474f669c6..b4d177622 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -321,7 +321,7 @@ class Reviewer: infractions += ", with the last infraction issued " # Infractions were ordered by time since insertion descending. - infractions += time.get_time_delta(infraction_list[0]['inserted_at']) + infractions += time.format_relative(infraction_list[0]['inserted_at']) return f"They have {infractions}." @@ -365,7 +365,7 @@ class Reviewer: nomination_times = f"{num_entries} times" if num_entries > 1 else "once" rejection_times = f"{len(history)} times" if len(history) > 1 else "once" - end_time = time.format_relative(isoparse(history[0]['ended_at'])) + end_time = time.format_relative(history[0]['ended_at']) review = ( f"They were nominated **{nomination_times}** before" diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index bfa294809..289d00356 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -168,7 +168,7 @@ class Reminders(Cog): self.schedule_reminder(reminder) @lock_arg(LOCK_NAMESPACE, "reminder", itemgetter("id"), raise_error=True) - async def send_reminder(self, reminder: dict, expected_time: datetime = None) -> None: + async def send_reminder(self, reminder: dict, expected_time: t.Optional[time.Timestamp] = None) -> None: """Send the reminder.""" is_valid, user, channel = self.ensure_valid_reminder(reminder) if not is_valid: @@ -347,8 +347,7 @@ class Reminders(Cog): for content, remind_at, id_, mentions in reminders: # Parse and humanize the time, make it pretty :D - remind_datetime = isoparse(remind_at) - expiry = time.format_relative(remind_datetime) + expiry = time.format_relative(remind_at) mentions = ", ".join([ # Both Role and User objects have the `name` attribute diff --git a/bot/utils/time.py b/bot/utils/time.py index 13dfc6fb7..e927a5e63 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,10 +1,10 @@ import datetime import re from enum import Enum +from time import struct_time from typing import Optional, Union import arrow -import dateutil.parser from dateutil.relativedelta import relativedelta DISCORD_TIMESTAMP_REGEX = re.compile(r"") @@ -19,8 +19,18 @@ _DURATION_REGEX = re.compile( r"((?P\d+?) ?(seconds|second|S|s))?" ) - -ValidTimestamp = Union[int, datetime.datetime, datetime.date] +# All supported types for the single-argument overload of arrow.get(). tzinfo is excluded because +# it's too implicit of a way for the caller to specify that they want the current time. +Timestamp = Union[ + arrow.Arrow, + datetime.datetime, + datetime.date, + struct_time, + int, # POSIX timestamp + float, # POSIX timestamp + str, # ISO 8601-formatted string + tuple[int, int, int], # ISO calendar tuple +] class TimestampFormats(Enum): @@ -60,15 +70,14 @@ def _stringify_time_unit(value: int, unit: str) -> str: return f"{value} {unit}" -def discord_timestamp(timestamp: ValidTimestamp, format: TimestampFormats = TimestampFormats.DATE_TIME) -> str: - """Create and format a Discord flavored markdown timestamp.""" - # Convert each possible timestamp class to an integer. - if isinstance(timestamp, datetime.datetime): - timestamp = (timestamp - arrow.get(0)).total_seconds() - elif isinstance(timestamp, datetime.date): - timestamp = (timestamp - arrow.get(0)).total_seconds() +def discord_timestamp(timestamp: Timestamp, format: TimestampFormats = TimestampFormats.DATE_TIME) -> str: + """ + Format a timestamp as a Discord-flavored Markdown timestamp. - return f"" + `timestamp` can be any type supported by the single-arg `arrow.get()`, except for a `tzinfo`. + """ + timestamp = int(arrow.get(timestamp).timestamp()) + return f"" def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: int = 6) -> str: @@ -115,14 +124,6 @@ def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: return humanized -def get_time_delta(time_string: str) -> str: - """Returns the time in human-readable time delta format.""" - date_time = dateutil.parser.isoparse(time_string) - time_delta = format_relative(date_time) - - return time_delta - - def parse_duration_string(duration: str) -> Optional[relativedelta]: """ Converts a `duration` string to a relativedelta object. @@ -154,64 +155,63 @@ def relativedelta_to_timedelta(delta: relativedelta) -> datetime.timedelta: return utcnow + delta - utcnow -def format_relative(timestamp: ValidTimestamp) -> str: +def format_relative(timestamp: Timestamp) -> str: """ Format `timestamp` as a relative Discord timestamp. A relative timestamp describes how much time has elapsed since `timestamp` or how much time - remains until `timestamp` is reached. See `time.discord_timestamp`. + remains until `timestamp` is reached. + + `timestamp` can be any type supported by the single-arg `arrow.get()`, except for a `tzinfo`. """ return discord_timestamp(timestamp, TimestampFormats.RELATIVE) -def format_infraction(timestamp: str) -> str: - """Format an infraction timestamp to a discord timestamp.""" - return discord_timestamp(dateutil.parser.isoparse(timestamp)) - - def format_with_duration( - timestamp: Optional[str], - other_timestamp: Optional[datetime.datetime] = None, + timestamp: Optional[Timestamp], + other_timestamp: Optional[Timestamp] = None, max_units: int = 2, ) -> Optional[str]: """ Return `timestamp` formatted as a discord timestamp with the timestamp duration since `other_timestamp`. + `timestamp` and `other_timestamp` can be any type supported by the single-arg `arrow.get()`, + except for a `tzinfo`. Use the current time if `other_timestamp` is falsy or unspecified. + `max_units` specifies the maximum number of units of time to include in the duration. For example, a value of 1 may include days but not hours. + + Return None if `timestamp` is falsy. """ if not timestamp: return None - date_to_formatted = format_infraction(timestamp) - - other_timestamp = other_timestamp or datetime.datetime.now(datetime.timezone.utc) - timestamp = dateutil.parser.isoparse(timestamp).replace(microsecond=0) + timestamp = arrow.get(timestamp) + if not other_timestamp: + other_timestamp = arrow.utcnow() + else: + other_timestamp = arrow.get(other_timestamp) - delta = abs(relativedelta(timestamp, other_timestamp)) + formatted_timestamp = discord_timestamp(timestamp) + delta = abs(relativedelta(timestamp.datetime, other_timestamp.datetime)) duration = humanize_delta(delta, max_units=max_units) - duration_formatted = f" ({duration})" if duration else "" - return f"{date_to_formatted}{duration_formatted}" + return f"{formatted_timestamp} ({duration})" -def until_expiration( - expiry: Optional[str] -) -> Optional[str]: +def until_expiration(expiry: Optional[Timestamp]) -> Optional[str]: """ - Get the remaining time until infraction's expiration, in a discord timestamp. + Get the remaining time until an infraction's expiration as a Discord timestamp. - Returns a human-readable version of the remaining duration between arrow.utcnow() and an expiry. - Similar to format_relative, except that this function doesn't error on a null input - and return null if the expiry is in the paste + `expiry` can be any type supported by the single-arg `arrow.get()`, except for a `tzinfo`. + + Return None if `expiry` is falsy or is in the past. """ if not expiry: return None - now = arrow.utcnow() - since = dateutil.parser.isoparse(expiry).replace(microsecond=0) - - if since < now: + expiry = arrow.get(expiry) + if expiry < arrow.utcnow(): return None - return format_relative(since) + return format_relative(expiry) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 02b5f8c17..027e2052e 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -43,10 +43,6 @@ class TimeTests(unittest.TestCase): time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) self.assertEqual(str(error.exception), 'max_units must be positive') - def test_format_infraction(self): - """Testing format_infraction.""" - self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '') - def test_format_with_duration_none_expiry(self): """format_with_duration should work for None expiry.""" test_cases = ( -- cgit v1.2.3 From 2004477e12c72e4739ea1b1f192fb2c12eac69d0 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 6 Aug 2021 14:20:52 -0700 Subject: Time: add overload to pass 2 timestamps to humanize_delta Remove the need for the caller to create a `relativedelta` from 2 timestamps before calling `humanize_delta`. This is especially convenient for cases where the original inputs aren't `datetime`s since `relativedelta` only accepts those. --- bot/exts/moderation/defcon.py | 4 +- bot/exts/moderation/infraction/management.py | 6 +- bot/exts/moderation/modlog.py | 2 +- bot/utils/time.py | 110 ++++++++++++++++++++++----- tests/bot/utils/test_time.py | 13 ++-- 5 files changed, 105 insertions(+), 30 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py index 263e8136e..178be734d 100644 --- a/bot/exts/moderation/defcon.py +++ b/bot/exts/moderation/defcon.py @@ -254,8 +254,8 @@ class Defcon(Cog): expiry_message = "" if expiry: - activity_duration = relativedelta(expiry, arrow.utcnow().datetime) - expiry_message = f" for the next {time.humanize_delta(activity_duration, max_units=2)}" + formatted_expiry = time.humanize_delta(expiry, max_units=2) + expiry_message = f" for the next {formatted_expiry}" if self.threshold: channel_message = ( diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index fa1ebdadc..0dfd2d759 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -1,9 +1,7 @@ import textwrap import typing as t -import arrow import discord -from dateutil.relativedelta import relativedelta from discord.ext import commands from discord.ext.commands import Context from discord.utils import escape_markdown @@ -371,9 +369,7 @@ class ModManagement(commands.Cog): if expires_at is None: duration = "*Permanent*" else: - start = arrow.get(inserted_at).datetime - end = arrow.get(expires_at).datetime - duration = time.humanize_delta(relativedelta(start, end)) + duration = time.humanize_delta(inserted_at, expires_at) # Format `dm_sent` if dm_sent is None: diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index d5e209d81..2c01a4a21 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -713,7 +713,7 @@ class ModLog(Cog, name="ModLog"): # datetime as the baseline and create a human-readable delta between this edit event # and the last time the message was edited timestamp = msg_before.edited_at - delta = time.humanize_delta(relativedelta(msg_after.edited_at, msg_before.edited_at)) + delta = time.humanize_delta(msg_after.edited_at, msg_before.edited_at) footer = f"Last edited {delta} ago" else: # Message was not previously edited, use the created_at datetime as the baseline, no diff --git a/bot/utils/time.py b/bot/utils/time.py index 21d26db7d..7e314a870 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -2,7 +2,7 @@ import datetime import re from enum import Enum from time import struct_time -from typing import Optional, Union +from typing import Optional, Union, overload import arrow from dateutil.relativedelta import relativedelta @@ -78,15 +78,99 @@ def discord_timestamp(timestamp: Timestamp, format: TimestampFormats = Timestamp return f"" -def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: int = 6) -> str: +@overload +def humanize_delta( + arg1: Union[relativedelta, Timestamp], + /, + *, + precision: str = "seconds", + max_units: int = 6, + absolute: bool = True, +) -> str: + ... + + +@overload +def humanize_delta( + end: Timestamp, + start: Timestamp, + /, + *, + precision: str = "seconds", + max_units: int = 6, + absolute: bool = True, +) -> str: + ... + + +def humanize_delta( + *args, + precision: str = "seconds", + max_units: int = 6, + absolute: bool = True, +) -> str: """ - Returns a human-readable version of the relativedelta. + Return a human-readable version of a time duration. - precision specifies the smallest unit of time to include (e.g. "seconds", "minutes"). - max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + `precision` is the smallest unit of time to include (e.g. "seconds", "minutes"). + + `max_units` is the maximum number of units of time to include. + Count units from largest to smallest (e.g. count days before months). + + Use the absolute value of the duration if `absolute` is True. + + Usage: + + **One** `relativedelta` object, to humanize the duration represented by it: + + >>> humanize_delta(relativedelta(years=12, months=6)) + '12 years and 6 months' + + Note that `leapdays` and absolute info (singular names) will be ignored during humanization. + + **One** timestamp of a type supported by the single-arg `arrow.get()`, except for `tzinfo`, + to humanize the duration between it and the current time: + + >>> humanize_delta('2021-08-06T12:43:01Z', absolute=True) # now = 2021-08-06T12:33:33Z + '9 minutes and 28 seconds' + + >>> humanize_delta('2021-08-06T12:43:01Z', absolute=False) # now = 2021-08-06T12:33:33Z + '-9 minutes and -28 seconds' + + **Two** timestamps, each of a type supported by the single-arg `arrow.get()`, except for + `tzinfo`, to humanize the duration between them: + + >>> humanize_delta(datetime.datetime(2020, 1, 1), '2021-01-01T12:00:00Z', absolute=False) + '1 year and 12 hours' + + >>> humanize_delta('2021-01-01T12:00:00Z', datetime.datetime(2020, 1, 1), absolute=False) + '-1 years and -12 hours' + + Note that order of the arguments can result in a different output even if `absolute` is True: + + >>> x = datetime.datetime(3000, 11, 1) + >>> y = datetime.datetime(3000, 9, 2) + >>> humanize_delta(y, x, absolute=True), humanize_delta(x, y, absolute=True) + ('1 month and 30 days', '1 month and 29 days') + + This is due to the nature of `relativedelta`; it does not represent a fixed period of time. + Instead, it's relative to the `datetime` to which it's added to get the other `datetime`. + In the example, the difference arises because all months don't have the same number of days. """ + if len(args) == 1 and isinstance(args[0], relativedelta): + delta = args[0] + elif 1 <= len(args) <= 2: + end = arrow.get(args[0]) + start = arrow.get(args[1]) if len(args) == 2 else arrow.utcnow() + + delta = relativedelta(end.datetime, start.datetime) + if absolute: + delta = abs(delta) + else: + raise ValueError(f"Received {len(args)} positional arguments, but expected 1 or 2.") + if max_units <= 0: - raise ValueError("max_units must be positive") + raise ValueError("max_units must be positive.") units = ( ("years", delta.years), @@ -174,25 +258,17 @@ def format_with_duration( Return `timestamp` formatted as a discord timestamp with the timestamp duration since `other_timestamp`. `timestamp` and `other_timestamp` can be any type supported by the single-arg `arrow.get()`, - except for a `tzinfo`. Use the current time if `other_timestamp` is falsy or unspecified. + except for a `tzinfo`. Use the current time if `other_timestamp` is None or unspecified. - `max_units` specifies the maximum number of units of time to include in the duration. For - example, a value of 1 may include days but not hours. + `max_units` is forwarded to `time.humanize_delta`. See its documentation for more information. Return None if `timestamp` is falsy. """ if not timestamp: return None - timestamp = arrow.get(timestamp) - if not other_timestamp: - other_timestamp = arrow.utcnow() - else: - other_timestamp = arrow.get(other_timestamp) - formatted_timestamp = discord_timestamp(timestamp) - delta = abs(relativedelta(timestamp.datetime, other_timestamp.datetime)) - duration = humanize_delta(delta, max_units=max_units) + duration = humanize_delta(timestamp, other_timestamp, max_units=max_units) return f"{formatted_timestamp} ({duration})" diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 027e2052e..e235f9b70 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -13,13 +13,15 @@ class TimeTests(unittest.TestCase): """humanize_delta should be able to handle unknown units, and will not abort.""" # Does not abort for unknown units, as the unit name is checked # against the attribute of the relativedelta instance. - self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') + actual = time.humanize_delta(relativedelta(days=2, hours=2), precision='elephants', max_units=2) + self.assertEqual(actual, '2 days and 2 hours') def test_humanize_delta_handle_high_units(self): """humanize_delta should be able to handle very high units.""" # Very high maximum units, but it only ever iterates over # each value the relativedelta might have. - self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') + actual = time.humanize_delta(relativedelta(days=2, hours=2), precision='hours', max_units=20) + self.assertEqual(actual, '2 days and 2 hours') def test_humanize_delta_should_normal_usage(self): """Testing humanize delta.""" @@ -32,7 +34,8 @@ class TimeTests(unittest.TestCase): for delta, precision, max_units, expected in test_cases: with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): - self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + actual = time.humanize_delta(delta, precision=precision, max_units=max_units) + self.assertEqual(actual, expected) def test_humanize_delta_raises_for_invalid_max_units(self): """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" @@ -40,8 +43,8 @@ class TimeTests(unittest.TestCase): for max_units in test_cases: with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: - time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) - self.assertEqual(str(error.exception), 'max_units must be positive') + time.humanize_delta(relativedelta(days=2, hours=2), precision='hours', max_units=max_units) + self.assertEqual(str(error.exception), 'max_units must be positive.') def test_format_with_duration_none_expiry(self): """format_with_duration should work for None expiry.""" -- cgit v1.2.3 From 2209b9f1b95cbe2366ea2b316046ddd35ff6d3a9 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 7 Aug 2021 10:48:57 -0700 Subject: Fix create_user_embed tests Mock User.created_at and User.joined_at because `arrow.get()` doesn't work with Mock objects. The old implementation of `time.discord_timestamp` accepted mocks because it just did `int()` on any type it didn't explicitly check for. --- tests/bot/exts/info/test_information.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tests') diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 30e5258fb..d896b7652 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -1,6 +1,7 @@ import textwrap import unittest import unittest.mock +from datetime import datetime import discord @@ -288,6 +289,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.nick = None user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") user.colour = 0 + user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) @@ -309,6 +311,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.nick = "Cat lover" user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") user.colour = 0 + user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) @@ -329,6 +332,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): # A `MockMember` has the @Everyone role by default; we add the Admins to that. user = helpers.MockMember(roles=[admins_role], colour=100) + user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) @@ -355,6 +359,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): nomination_counts.return_value = ("Nominations", "nomination info") user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) + user.created_at = user.joined_at = datetime.utcfromtimestamp(1) embed = await self.cog.create_user_embed(ctx, user, False) infraction_counts.assert_called_once_with(user) @@ -394,6 +399,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user_messages.return_value = ("Messages", "user message counts") user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) + user.created_at = user.joined_at = datetime.utcfromtimestamp(1) embed = await self.cog.create_user_embed(ctx, user, False) infraction_counts.assert_called_once_with(user) @@ -440,6 +446,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): moderators_role = helpers.MockRole(name='Moderators') user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) + user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.colour, discord.Colour(100)) @@ -457,6 +464,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): ctx = helpers.MockContext() user = helpers.MockMember(id=217, colour=discord.Colour.default()) + user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.colour, discord.Colour.og_blurple()) @@ -474,6 +482,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): ctx = helpers.MockContext() user = helpers.MockMember(id=217, colour=0) + user.created_at = user.joined_at = datetime.utcnow() user.display_avatar.url = "avatar url" embed = await self.cog.create_user_embed(ctx, user, False) -- cgit v1.2.3 From 1af466753975b70effd5e600d0afc8b21f272dd0 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 7 Aug 2021 10:57:24 -0700 Subject: Time: return strings from until_expiration instead of ambiguous None None was returned for two separate cases: permanent infractions and expired infractions. This resulted in an ambiguity. --- bot/exts/moderation/infraction/management.py | 6 +++--- bot/utils/time.py | 8 ++++---- tests/bot/utils/test_time.py | 5 ++--- 3 files changed, 9 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index 0dfd2d759..dda3fadae 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -179,8 +179,8 @@ class ModManagement(commands.Cog): self.infractions_cog.schedule_expiration(new_infraction) log_text += f""" - Previous expiry: {time.until_expiration(infraction['expires_at']) or "Permanent"} - New expiry: {time.until_expiration(new_infraction['expires_at']) or "Permanent"} + Previous expiry: {time.until_expiration(infraction['expires_at'])} + New expiry: {time.until_expiration(new_infraction['expires_at'])} """.rstrip() changes = ' & '.join(confirm_messages) @@ -362,7 +362,7 @@ class ModManagement(commands.Cog): user_str = f"<@{user['id']}> ({name}#{user['discriminator']:04})" if active: - remaining = time.until_expiration(expires_at) or "Expired" + remaining = time.until_expiration(expires_at) else: remaining = "Inactive" diff --git a/bot/utils/time.py b/bot/utils/time.py index 8ba49a455..4b2fbae2c 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -303,19 +303,19 @@ def format_with_duration( return f"{formatted_timestamp} ({duration})" -def until_expiration(expiry: Optional[Timestamp]) -> Optional[str]: +def until_expiration(expiry: Optional[Timestamp]) -> str: """ Get the remaining time until an infraction's expiration as a Discord timestamp. `expiry` can be any type supported by the single-arg `arrow.get()`, except for a `tzinfo`. - Return None if `expiry` is falsy or is in the past. + Return "Permanent" if `expiry` is falsy. Return "Expired" if `expiry` is in the past. """ if not expiry: - return None + return "Permanent" expiry = arrow.get(expiry) if expiry < arrow.utcnow(): - return None + return "Expired" return format_relative(expiry) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index e235f9b70..120d65176 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -100,8 +100,8 @@ class TimeTests(unittest.TestCase): self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected) def test_until_expiration_with_duration_none_expiry(self): - """until_expiration should work for None expiry.""" - self.assertEqual(time.until_expiration(None), None) + """until_expiration should return "Permanent" is expiry is None.""" + self.assertEqual(time.until_expiration(None), "Permanent") def test_until_expiration_with_duration_custom_units(self): """until_expiration should work for custom max_units.""" @@ -122,7 +122,6 @@ class TimeTests(unittest.TestCase): ('3000-12-12T00:00:00Z', ''), ('3000-11-23T20:09:00Z', ''), ('3000-11-23T20:09:00Z', ''), - (None, None), ) for expiry, expected in test_cases: -- cgit v1.2.3 From 1f327a54640a781026dc223597f8e2a306751460 Mon Sep 17 00:00:00 2001 From: Izan Date: Tue, 16 Nov 2021 09:40:31 +0000 Subject: Fix tests --- tests/bot/exts/moderation/infraction/test_utils.py | 27 +++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 72eebb254..999dbd1c6 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -19,6 +19,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.member = MockMember(id=1234) self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) + self.maxDiff = None async def test_post_user(self): """Should POST a new user and return the response if successful or otherwise send an error message.""" @@ -132,7 +133,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """ test_cases = [ { - "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), + "args": (self.bot, self.user, 0, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -150,7 +151,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "send_result": True }, { - "args": (self.user, "warning", None, "Test reason."), + "args": (self.bot, self.user, 0, "warning", None, "Test reason."), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -170,7 +171,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): # Note that this test case asserts that the DM that *would* get sent to the user is formatted # correctly, even though that message is deliberately never sent. { - "args": (self.user, "note", None, None, Icons.defcon_denied), + "args": (self.bot, self.user, 0, "note", None, None, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -188,7 +189,15 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "send_result": False }, { - "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), + "args": ( + self.bot, + self.user, + 0, + "mute", + "2020-02-26 09:20 (23 hours and 59 minutes)", + "Test", + Icons.defcon_denied + ), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -206,7 +215,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): "send_result": False }, { - "args": (self.user, "mute", None, "foo bar" * 4000, Icons.defcon_denied), + "args": (self.bot, self.user, 0, "mute", None, "foo bar" * 4000, Icons.defcon_denied), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -238,7 +247,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.to_dict(), case["expected_output"].to_dict()) - send_private_embed_mock.assert_awaited_once_with(case["args"][0], embed) + send_private_embed_mock.assert_awaited_once_with(case["args"][1], embed) @patch("bot.exts.moderation.infraction._utils.send_private_embed") async def test_notify_pardon(self, send_private_embed_mock): @@ -313,7 +322,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "type": "ban", "user": self.member.id, "active": False, - "expires_at": now.isoformat() + "expires_at": now.isoformat(), + "dm_sent": False } self.ctx.bot.api_client.post.return_value = "foo" @@ -350,7 +360,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "reason": "Test reason", "type": "mute", "user": self.user.id, - "active": True + "active": True, + "dm_sent": False } self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] -- cgit v1.2.3 From a6be95385edc1caccd84dc83a8d11ece86847c8b Mon Sep 17 00:00:00 2001 From: Izan Date: Thu, 25 Nov 2021 19:55:33 +0000 Subject: Remove debug `maxDiff` assignment. --- tests/bot/exts/moderation/infraction/test_utils.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 999dbd1c6..350274ecd 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -19,7 +19,6 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): self.member = MockMember(id=1234) self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) - self.maxDiff = None async def test_post_user(self): """Should POST a new user and return the response if successful or otherwise send an error message.""" -- cgit v1.2.3 From 94f5c99c1ff5815341862431d02129e80ceb6850 Mon Sep 17 00:00:00 2001 From: ChrisJL Date: Tue, 28 Dec 2021 18:11:52 +0000 Subject: Include message counts in all channels (#2016) Co-authored-by: Xithrius <15021300+Xithrius@users.noreply.github.com> --- bot/exts/info/information.py | 11 +++------ bot/exts/moderation/voice_gate.py | 5 +--- tests/bot/exts/info/test_information.py | 43 +++++++++++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 73357211e..d0e1eae74 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -298,11 +298,11 @@ class Information(Cog): "Member information", membership ), + await self.user_messages(user), ] # Show more verbose output in moderation channels for infractions and nominations if is_mod_channel(ctx.channel): - fields.append(await self.user_messages(user)) fields.append(await self.expanded_user_infraction_counts(user)) fields.append(await self.user_nomination_counts(user)) else: @@ -420,13 +420,8 @@ class Information(Cog): if e.status == 404: activity_output = "No activity" else: - activity_output.append(user_activity["total_messages"] or "No messages") - - if (activity_blocks := user_activity.get("activity_blocks")) is not None: - # activity_blocks is not included in the response if the user has a lot of messages - activity_output.append(activity_blocks or "No activity") # Special case when activity_blocks is 0. - else: - activity_output.append("Too many to count!") + activity_output.append(f"{user_activity['total_messages']:,}" or "No messages") + activity_output.append(f"{user_activity['activity_blocks']:,}" or "No activity") activity_output = "\n".join( f"{name}: {metric}" for name, metric in zip(["Messages", "Activity blocks"], activity_output) diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py index ae55a03a0..a382b13d1 100644 --- a/bot/exts/moderation/voice_gate.py +++ b/bot/exts/moderation/voice_gate.py @@ -171,11 +171,8 @@ class VoiceGate(Cog): ), "total_messages": data["total_messages"] < GateConf.minimum_messages, "voice_banned": data["voice_banned"], + "activity_blocks": data["activity_blocks"] < GateConf.minimum_activity_blocks, } - if activity_blocks := data.get("activity_blocks"): - # activity_blocks is not included in the response if the user has a lot of messages. - # Only check if the user has enough activity blocks if it is included. - checks["activity_blocks"] = activity_blocks < GateConf.minimum_activity_blocks failed = any(checks.values()) failed_reasons = [MESSAGE_FIELD_MAP[key] for key, value in checks.items() if value is True] diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 632287322..724456b04 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -276,6 +276,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) + @unittest.mock.patch( + f"{COG_PATH}.user_messages", + new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count")) + ) async def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): """The embed should use the string representation of the user if they don't have a nick.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -293,6 +297,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) + @unittest.mock.patch( + f"{COG_PATH}.user_messages", + new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count")) + ) async def test_create_user_embed_uses_nick_in_title_if_available(self): """The embed should use the nick if it's available.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -310,6 +318,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) + @unittest.mock.patch( + f"{COG_PATH}.user_messages", + new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count")) + ) async def test_create_user_embed_ignores_everyone_role(self): """Created `!user` embeds should not contain mention of the @everyone-role.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -325,6 +337,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch( + f"{COG_PATH}.user_messages", + new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count")) + ) async def test_create_user_embed_expanded_information_in_moderation_channels( self, nomination_counts, @@ -363,13 +379,19 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): ) @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - async def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): + @unittest.mock.patch(f"{COG_PATH}.user_messages", new_callable=unittest.mock.AsyncMock) + async def test_create_user_embed_basic_information_outside_of_moderation_channels( + self, + user_messages, + infraction_counts, + ): """The embed should contain only basic infraction data outside of mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) moderators_role = helpers.MockRole(name='Moderators') infraction_counts.return_value = ("Infractions", "basic infractions info") + user_messages.return_value = ("Messages", "user message counts") user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) embed = await self.cog.create_user_embed(ctx, user) @@ -394,14 +416,23 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): ) self.assertEqual( - "basic infractions info", + "user message counts", embed.fields[2].value ) + self.assertEqual( + "basic infractions info", + embed.fields[3].value + ) + @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) + @unittest.mock.patch( + f"{COG_PATH}.user_messages", + new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count")) + ) async def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): """The embed should be created with the colour of the top role, if a top role is available.""" ctx = helpers.MockContext() @@ -417,6 +448,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) + @unittest.mock.patch( + f"{COG_PATH}.user_messages", + new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count")) + ) async def test_create_user_embed_uses_og_blurple_colour_when_user_has_no_roles(self): """The embed should be created with the og blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() @@ -430,6 +465,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) + @unittest.mock.patch( + f"{COG_PATH}.user_messages", + new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count")) + ) async def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): """The embed thumbnail should be set to the user's avatar in `png` format.""" ctx = helpers.MockContext() -- cgit v1.2.3 From 681771b945ad9c3968323c083c3ed45a32ba37bf Mon Sep 17 00:00:00 2001 From: TizzySaurus <47674925+TizzySaurus@users.noreply.github.com> Date: Wed, 29 Dec 2021 20:38:05 +0000 Subject: Add text indicating when user fetched by message (#2013) Co-authored-by: Xithrius <15021300+Xithrius@users.noreply.github.com> --- bot/exts/info/information.py | 10 ++++++---- tests/bot/exts/info/test_information.py | 24 ++++++++++++------------ 2 files changed, 18 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index d0e1eae74..1f95c460f 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -227,7 +227,7 @@ class Information(Cog): @command(name="user", aliases=["user_info", "member", "member_info", "u"]) async def user_info(self, ctx: Context, user_or_message: Union[MemberOrUser, Message] = None) -> None: """Returns info about a user.""" - if isinstance(user_or_message, Message): + if passed_as_message := isinstance(user_or_message, Message): user = user_or_message.author else: user = user_or_message @@ -242,10 +242,10 @@ class Information(Cog): # Will redirect to #bot-commands if it fails. if in_whitelist_check(ctx, roles=constants.STAFF_PARTNERS_COMMUNITY_ROLES): - embed = await self.create_user_embed(ctx, user) + embed = await self.create_user_embed(ctx, user, passed_as_message) await ctx.send(embed=embed) - async def create_user_embed(self, ctx: Context, user: MemberOrUser) -> Embed: + async def create_user_embed(self, ctx: Context, user: MemberOrUser, passed_as_message: bool) -> Embed: """Creates an embed containing information on the `user`.""" on_server = bool(await get_or_fetch_member(ctx.guild, user.id)) @@ -256,6 +256,9 @@ class Information(Cog): name = f"{user.nick} ({name})" name = escape_markdown(name) + if passed_as_message: + name += " - From Message" + if user.public_flags.verified_bot: name += f" {constants.Emojis.verified_bot}" elif user.bot: @@ -282,7 +285,6 @@ class Information(Cog): membership = textwrap.dedent("\n".join([f"{key}: {value}" for key, value in membership.items()])) else: - roles = None membership = "The user is not a member of the server" fields = [ diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 724456b04..30e5258fb 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -289,7 +289,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") user.colour = 0 - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.title, "Mr. Hemlock") @@ -310,7 +310,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") user.colour = 0 - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") @@ -330,7 +330,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): # A `MockMember` has the @Everyone role by default; we add the Admins to that. user = helpers.MockMember(roles=[admins_role], colour=100) - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertIn("&Admins", embed.fields[1].value) self.assertNotIn("&Everyone", embed.fields[1].value) @@ -355,7 +355,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): nomination_counts.return_value = ("Nominations", "nomination info") user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) infraction_counts.assert_called_once_with(user) nomination_counts.assert_called_once_with(user) @@ -394,7 +394,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user_messages.return_value = ("Messages", "user message counts") user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) infraction_counts.assert_called_once_with(user) @@ -440,7 +440,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): moderators_role = helpers.MockRole(name='Moderators') user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.colour, discord.Colour(100)) @@ -457,7 +457,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): ctx = helpers.MockContext() user = helpers.MockMember(id=217, colour=discord.Colour.default()) - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.colour, discord.Colour.og_blurple()) @@ -475,7 +475,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user = helpers.MockMember(id=217, colour=0) user.display_avatar.url = "avatar url" - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.thumbnail.url, "avatar url") @@ -528,7 +528,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): await self.cog.user_info(self.cog, ctx) - create_embed.assert_called_once_with(ctx, self.author) + create_embed.assert_called_once_with(ctx, self.author, False) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") @@ -539,7 +539,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): await self.cog.user_info(self.cog, ctx, self.author) - create_embed.assert_called_once_with(ctx, self.author) + create_embed.assert_called_once_with(ctx, self.author, False) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") @@ -550,7 +550,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): await self.cog.user_info(self.cog, ctx) - create_embed.assert_called_once_with(ctx, self.moderator) + create_embed.assert_called_once_with(ctx, self.moderator, False) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") @@ -562,5 +562,5 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): await self.cog.user_info(self.cog, ctx, self.target) - create_embed.assert_called_once_with(ctx, self.target) + create_embed.assert_called_once_with(ctx, self.target, False) ctx.send.assert_called_once() -- cgit v1.2.3 From f6b50c17b59f6ec02c9f7e8a7cf6f7ef1a426b7a Mon Sep 17 00:00:00 2001 From: Ben Soyka Date: Sat, 8 Jan 2022 14:39:15 -0700 Subject: Fix snekbox tests with new allowed_mentions --- tests/bot/exts/utils/test_snekbox.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 321a92445..8bdeedd27 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -2,6 +2,7 @@ import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch +from discord import AllowedMentions from discord.ext import commands from bot import constants @@ -201,7 +202,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() - ctx.author.mention = '@LemonLemonishBeard#0042' + ctx.author = MockUser(mention='@LemonLemonishBeard#0042') self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) @@ -213,9 +214,16 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode') - ctx.send.assert_called_once_with( + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```' ) + allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions'] + expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) + self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) @@ -238,10 +246,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode') - ctx.send.assert_called_once_with( + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], '@LemonLemonishBeard#0042 :yay!: Return code 0.' '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) @@ -263,9 +275,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode') - ctx.send.assert_called_once_with( + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) -- cgit v1.2.3 From b7e49a5fb1adb541db2cf5632a460a37ddda6d0a Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Thu, 13 Jan 2022 21:26:58 -0500 Subject: chore: Suppress output in the setup code, not the code that gets timed. If multiple formatted codeblocks are passed to the command, the first one will be used as the setup code that does not get timed. --- bot/exts/utils/snekbox.py | 70 +++++++++++++++++++++++++++--------- tests/bot/exts/utils/test_snekbox.py | 6 ++-- 2 files changed, 56 insertions(+), 20 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index bd521a4ee..0d8da5e56 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -2,9 +2,9 @@ import asyncio import contextlib import datetime import re -import textwrap from functools import partial from signal import Signals +from textwrap import dedent from typing import Awaitable, Callable, Optional, Tuple from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User @@ -36,13 +36,35 @@ RAW_CODE_REGEX = re.compile( re.DOTALL # "." also matches newlines ) -TIMEIT_EVAL_WRAPPER = """ -from contextlib import redirect_stdout -from io import StringIO +TIMEIT_SETUP_WRAPPER = """ +import atexit +import sys +from collections import deque -with redirect_stdout(StringIO()): - del redirect_stdout, StringIO -{code} +if not hasattr(sys, "_setup_finished"): + class Writer(deque): + def __init__(self): + super().__init__(maxlen=1) + + def write(self, string): + if string.strip(): + self.append(string) + + def read(self): + return self.pop() + + def flush(self): + pass + + sys.stdout = Writer() + + def print_last_line(): + if sys.stdout: + print(sys.stdout.read(), file=sys.__stdout__) + + atexit.register(print_last_line) + sys._setup_finished = None +{setup} """ TIMEIT_OUTPUT_REGEX = re.compile(r"\d+ loops, best of \d+: \d(?:\.\d\d?)? [mnu]?sec per loop") @@ -90,34 +112,37 @@ class Snekbox(Cog): return await send_to_paste_service(output, extension="txt") @staticmethod - def prepare_input(code: str) -> str: + def prepare_input(code: str) -> list[str]: """ Extract code from the Markdown, format it, and insert it into the code template. If there is any code block, ignore text outside the code block. Use the first code block, but prefer a fenced code block. If there are several fenced code blocks, concatenate only the fenced code blocks. + + Retrun a list of code blocks if any, otherwise return a list with a single string of code. """ if match := list(FORMATTED_CODE_REGEX.finditer(code)): blocks = [block for block in match if block.group("block")] if len(blocks) > 1: - code = '\n'.join(block.group("code") for block in blocks) + codeblocks = [block.group("code") for block in blocks] info = "several code blocks" else: match = match[0] if len(blocks) == 0 else blocks[0] code, block, lang, delim = match.group("code", "block", "lang", "delim") + codeblocks = [dedent(code)] if block: info = (f"'{lang}' highlighted" if lang else "plain") + " code block" else: info = f"{delim}-enclosed inline code" else: - code = RAW_CODE_REGEX.fullmatch(code).group("code") + codeblocks = [dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))] info = "unformatted or badly formatted code" - code = textwrap.dedent(code) + code = "\n".join(codeblocks) log.trace(f"Extracted {info} for evaluation:\n{code}") - return code + return codeblocks @staticmethod def get_results_message(results: dict) -> Tuple[str, str]: @@ -248,7 +273,7 @@ class Snekbox(Cog): log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") return response - async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + async def continue_eval(self, ctx: Context, response: Message) -> Optional[list[str]]: """ Check if the eval session should continue. @@ -380,7 +405,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = self.prepare_input(code) + code = "\n".join(self.prepare_input(code)) await self.run_eval(ctx, code, format_func=self.format_output) @command(name="timeit", aliases=("ti",)) @@ -400,13 +425,24 @@ class Snekbox(Cog): block. Code can be re-evaluated by editing the original message within 10 seconds and clicking the reaction that subsequently appears. + If multiple formatted codeblocks are provided, the first one will be the setup code, which will + not be timed. The remaining codeblocks will be joined together and timed. + We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = self.prepare_input(code) + args = ["-m", "timeit"] + setup = "" + codeblocks = self.prepare_input(code) + + if len(codeblocks) > 1: + setup = codeblocks.pop(0) + + code = "\n".join(codeblocks) + args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) + await self.run_eval( - ctx, TIMEIT_EVAL_WRAPPER.format(code=textwrap.indent(code, " ")), - format_func=self.format_timeit_output, args=["-m", "timeit"] + ctx, code=code, format_func=self.format_timeit_output, args=args ) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index cbffaa6b0..ebab71e71 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -61,7 +61,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) + self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected) def test_get_results_message(self): """Return error and message according to the eval result.""" @@ -156,7 +156,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test the eval command procedure.""" ctx = MockContext() response = MockMessage() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') + self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock(return_value=None) @@ -297,7 +297,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): actual = await self.cog.continue_eval(ctx, response) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) - self.assertEqual(actual, expected) + self.assertEqual(actual, [expected]) self.bot.wait_for.assert_has_awaits( ( call( -- cgit v1.2.3 From 8594a3535413e662ae519f31989459a953e8a726 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Mon, 17 Jan 2022 14:43:58 -0500 Subject: fix: Modify tests to correspond with Snekbox.continue_eval --- tests/bot/exts/utils/test_snekbox.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index ebab71e71..4245de8a3 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -156,32 +156,35 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test the eval command procedure.""" ctx = MockContext() response = MockMessage() + ctx.command = MagicMock() + self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock(return_value=None) + self.cog.continue_eval = AsyncMock(return_value=(None, None)) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') self.cog.send_eval.assert_called_once_with( ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output ) - self.cog.continue_eval.assert_called_once_with(ctx, response) + self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command) async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() response = MockMessage() + ctx.command = MagicMock() self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = ('MyAwesomeFormattedCode', None) + self.cog.continue_eval.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) self.cog.send_eval.assert_called_with( ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output ) - self.cog.continue_eval.assert_called_with(ctx, response) + self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -295,9 +298,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - actual = await self.cog.continue_eval(ctx, response) + actual = await self.cog.continue_eval(ctx, response, self.cog.eval_command) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) - self.assertEqual(actual, [expected]) + self.assertEqual(actual, (expected, None)) self.bot.wait_for.assert_has_awaits( ( call( @@ -316,8 +319,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError - actual = await self.cog.continue_eval(ctx, MockMessage()) - self.assertEqual(actual, None) + actual = await self.cog.continue_eval(ctx, MockMessage(), self.cog.eval_command) + self.assertEqual(actual, (None, None)) ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) async def test_get_code(self): -- cgit v1.2.3 From 54e4f3777372ef526667885f4392030bab1b5b07 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Mon, 17 Jan 2022 17:42:37 -0500 Subject: chore: Apply suggestions and adjust tests --- bot/exts/utils/snekbox.py | 52 +++++++++++++----------------------- tests/bot/exts/utils/test_snekbox.py | 24 ++++++++--------- 2 files changed, 29 insertions(+), 47 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 49f1be17b..1d9646113 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -36,6 +36,8 @@ RAW_CODE_REGEX = re.compile( re.DOTALL # "." also matches newlines ) +# The timeit command should only output the very last line, so all other output should be suppressed. +# This will be used as the setup code along with any setup code provided. TIMEIT_SETUP_WRAPPER = """ import atexit import sys @@ -163,19 +165,19 @@ class Snekbox(Cog): return code, args @staticmethod - def get_results_message(results: dict) -> Tuple[str, str]: + def get_results_message(results: dict, job_name: str) -> Tuple[str, str]: """Return a user-friendly message and error corresponding to the process's return code.""" stdout, returncode = results["stdout"], results["returncode"] - msg = f"Your eval job has completed with return code {returncode}" + msg = f"Your {job_name} job has completed with return code {returncode}" error = "" if returncode is None: - msg = "Your eval job has failed" + msg = f"Your {job_name} job has failed" error = stdout.strip() elif returncode == 128 + SIGKILL: - msg = "Your eval job timed out or ran out of memory" + msg = f"Your {job_name} job timed out or ran out of memory" elif returncode == 255: - msg = "Your eval job has failed" + msg = f"Your {job_name} job has failed" error = "A fatal NsJail error occurred" else: # Try to append signal's name if one exists @@ -249,7 +251,7 @@ class Snekbox(Cog): code: str, *, args: Optional[list[str]] = None, - format_func: FormatFunc + job_name: str ) -> Message: """ Evaluate code, format it, and send the output to the corresponding channel. @@ -258,13 +260,13 @@ class Snekbox(Cog): """ async with ctx.typing(): results = await self.post_eval(code, args=args) - msg, error = self.get_results_message(results) + msg, error = self.get_results_message(results, job_name) if error: output, paste_link = error, None else: log.trace("Formatting output...") - output, paste_link = await format_func(results["stdout"]) + output, paste_link = await self.format_output(results["stdout"]) icon = self.get_status_emoji(results) msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" @@ -288,12 +290,12 @@ class Snekbox(Cog): response = await ctx.send(msg, allowed_mentions=allowed_mentions) scheduling.create_task(wait_for_deletion(response, (ctx.author.id,)), event_loop=self.bot.loop) - log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") + log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") return response async def continue_eval( self, ctx: Context, response: Message, command: Command - ) -> Optional[tuple[str, Optional[list[str]]]]: + ) -> tuple[Optional[str], Optional[list[str]]]: """ Check if the eval session should continue. @@ -355,19 +357,15 @@ class Snekbox(Cog): return code - async def run_eval( + async def run_job( self, + job_name: str, ctx: Context, code: str, - format_func: FormatFunc, *, args: Optional[list[str]] = None, ) -> None: - """ - Handles checks, stats and re-evaluation of an eval. - - `format_func` is an async callable that takes a string (the output) and formats it to show to the user. - """ + """Handles checks, stats and re-evaluation of a snekbox job.""" if ctx.author.id in self.jobs: await ctx.send( f"{ctx.author.mention} You've already got a job running - " @@ -392,7 +390,7 @@ class Snekbox(Cog): while True: self.jobs[ctx.author.id] = datetime.datetime.now() try: - response = await self.send_eval(ctx, code, args=args, format_func=format_func) + response = await self.send_eval(ctx, code, args=args, job_name=job_name) finally: del self.jobs[ctx.author.id] @@ -401,18 +399,6 @@ class Snekbox(Cog): break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") - async def format_timeit_output(self, output: str) -> tuple[str, str]: - """ - Parses the time from the end of the output given by timeit. - - If an error happened, then it won't contain the time and instead proceed with regular formatting. - """ - split_output = output.rstrip("\n").rsplit("\n", 1) - if len(split_output) == 2 and TIMEIT_OUTPUT_REGEX.fullmatch(split_output[1]): - return split_output[1], None - - return await self.format_output(output) - @command(name="eval", aliases=("e",)) @guild_only() @redirect_output( @@ -434,7 +420,7 @@ class Snekbox(Cog): issue with it! """ code = "\n".join(self.prepare_input(code)) - await self.run_eval(ctx, code, format_func=self.format_output) + await self.run_job("eval", ctx, code) @command(name="timeit", aliases=("ti",)) @guild_only() @@ -462,9 +448,7 @@ class Snekbox(Cog): codeblocks = self.prepare_input(code) code, args = self.prepare_timeit_input(codeblocks) - await self.run_eval( - ctx, code=code, format_func=self.format_timeit_output, args=args - ) + await self.run_job("timeit", ctx, code=code, args=args) def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 4245de8a3..339cdaaa4 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -72,13 +72,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval') self.assertEqual(actual, expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), ('Your eval job has completed with return code 127', '') ) @@ -86,7 +86,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_get_results_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), ('Your eval job has completed with return code 127 (SIGTEST)', '') ) @@ -164,9 +164,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with( - ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output - ) + self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval') self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command) async def test_eval_command_evaluate_twice(self): @@ -182,7 +180,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) self.cog.send_eval.assert_called_with( - ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output + ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' ) self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) @@ -214,7 +212,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -227,7 +225,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('') async def test_send_eval_with_paste_link(self): @@ -246,7 +244,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -257,7 +255,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('Way too long beard') async def test_send_eval_with_non_zero_eval(self): @@ -275,7 +273,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -285,7 +283,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") -- cgit v1.2.3 From 61d652a32ce23373e67bb0e1cf985dd4ffc99a18 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Fri, 21 Jan 2022 21:30:41 +0000 Subject: Rename voice_ban type to voice_mute This commit changes all of the back-end so that it is in line with the new site API (see this PR https://github.com/python-discord/site/pull/608). This comes with no changes to commands, or functions definitions. --- bot/exts/moderation/infraction/_utils.py | 2 +- bot/exts/moderation/infraction/infractions.py | 20 ++++++++++---------- bot/exts/moderation/voice_gate.py | 4 ++-- .../exts/moderation/infraction/test_infractions.py | 8 ++++---- 4 files changed, 17 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index e683c9db4..4df833ffb 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -21,7 +21,7 @@ INFRACTION_ICONS = { "note": (Icons.user_warn, None), "superstar": (Icons.superstarify, Icons.unsuperstarify), "warning": (Icons.user_warn, None), - "voice_ban": (Icons.voice_state_red, Icons.voice_state_green), + "voice_mute": (Icons.voice_state_red, Icons.voice_state_green), } RULES_URL = "https://pythondiscord.com/pages/rules" diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index e495a94b3..72e09cbf4 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -27,7 +27,7 @@ class Infractions(InfractionScheduler, commands.Cog): category_description = "Server moderation tools." def __init__(self, bot: Bot): - super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning", "voice_ban"}) + super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning", "voice_mute"}) self.category = "Moderation" self._muted_role = discord.Object(constants.Roles.muted) @@ -273,7 +273,7 @@ class Infractions(InfractionScheduler, commands.Cog): @command(aliases=("uvban",)) async def unvoiceban(self, ctx: Context, user: UnambiguousMemberOrUser) -> None: """Prematurely end the active voice ban infraction for the user.""" - await self.pardon_infraction(ctx, "voice_ban", user) + await self.pardon_infraction(ctx, "voice_mute", user) # endregion # region: Base apply functions @@ -397,10 +397,10 @@ class Infractions(InfractionScheduler, commands.Cog): @respect_role_hierarchy(member_arg=2) async def apply_voice_ban(self, ctx: Context, user: MemberOrUser, reason: t.Optional[str], **kwargs) -> None: """Apply a voice ban infraction with kwargs passed to `post_infraction`.""" - if await _utils.get_active_infraction(ctx, user, "voice_ban"): + if await _utils.get_active_infraction(ctx, user, "voice_mute"): return - infraction = await _utils.post_infraction(ctx, user, "voice_ban", reason, active=True, **kwargs) + infraction = await _utils.post_infraction(ctx, user, "voice_mute", reason, active=True, **kwargs) if infraction is None: return @@ -414,7 +414,7 @@ class Infractions(InfractionScheduler, commands.Cog): if not isinstance(user, Member): return - await user.move_to(None, reason="Disconnected from voice to apply voiceban.") + await user.move_to(None, reason="Disconnected from voice to apply voice mute.") await user.remove_roles(self._voice_verified_role, reason=reason) await self.apply_infraction(ctx, infraction, user, action()) @@ -487,9 +487,9 @@ class Infractions(InfractionScheduler, commands.Cog): # DM user about infraction expiration notified = await _utils.notify_pardon( user=user, - title="Voice ban ended", - content="You have been unbanned and can verify yourself again in the server.", - icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] + title="Voice mute ended", + content="You have been unmuted and can verify yourself again in the server.", + icon_url=_utils.INFRACTION_ICONS["voice_mute"][1] ) log_text["DM"] = "Sent" if notified else "**Failed**" @@ -514,8 +514,8 @@ class Infractions(InfractionScheduler, commands.Cog): return await self.pardon_mute(user_id, guild, reason, notify=notify) elif infraction["type"] == "ban": return await self.pardon_ban(user_id, guild, reason) - elif infraction["type"] == "voice_ban": - return await self.pardon_voice_ban(user_id, guild, notify=notify) + elif infraction["type"] == "voice_mute": + return await self.pardon_voice_mute(user_id, guild, notify=notify) # endregion diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py index a382b13d1..42505b8e7 100644 --- a/bot/exts/moderation/voice_gate.py +++ b/bot/exts/moderation/voice_gate.py @@ -30,7 +30,7 @@ FAILED_MESSAGE = ( MESSAGE_FIELD_MAP = { "joined_at": f"have been on the server for less than {GateConf.minimum_days_member} days", - "voice_banned": "have an active voice ban infraction", + "voice_muted": "have an active voice mute infraction", "total_messages": f"have sent less than {GateConf.minimum_messages} messages", "activity_blocks": f"have been active for fewer than {GateConf.minimum_activity_blocks} ten-minute blocks", } @@ -170,7 +170,7 @@ class VoiceGate(Cog): ctx.author.joined_at > arrow.utcnow() - timedelta(days=GateConf.minimum_days_member) ), "total_messages": data["total_messages"] < GateConf.minimum_messages, - "voice_banned": data["voice_banned"], + "voice_muted": data["voice_muted"], "activity_blocks": data["activity_blocks"] < GateConf.minimum_activity_blocks, } diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 4d01e18a5..a796fd049 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -89,7 +89,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): """Should call infraction pardoning function.""" self.cog.pardon_infraction = AsyncMock() self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user)) - self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user) + self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") @@ -97,7 +97,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): """Should return early when user already have Voice Ban infraction.""" get_active_infraction.return_value = {"foo": "bar"} self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) - get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban") + get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_mute") post_infraction_mock.assert_not_awaited() @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @@ -120,7 +120,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): post_infraction_mock.return_value = None self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23)) post_infraction_mock.assert_awaited_once_with( - self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23 + self.ctx, self.user, "voice_mute", "foobar", active=True, my_kwarg=23 ) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @@ -187,7 +187,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): user = MockUser() await self.cog.voiceban(self.cog, self.ctx, user, reason=None) - post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_ban", None, active=True, expires_at=None) + post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None) apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY) # Test action -- cgit v1.2.3 From 32d77fa9839eb9d373106700dcc4927851d94635 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Fri, 21 Jan 2022 21:34:41 +0000 Subject: Refactor voice_ban function definitions to voice_mute This changes all functions that reference voice_ban to voice_mute instead, which comes with breaking front-end changes. These front end changes are desirable, so that moderators get used to use voice_mute now, rather than voice_ban, in preparation for when we roll out real voice_bans. --- bot/exts/moderation/infraction/infractions.py | 42 ++++++------ .../exts/moderation/infraction/test_infractions.py | 78 +++++++++++----------- 2 files changed, 60 insertions(+), 60 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 72e09cbf4..d6580bc14 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -107,8 +107,8 @@ class Infractions(InfractionScheduler, commands.Cog): """ await self.apply_ban(ctx, user, reason, 1, expires_at=duration) - @command(aliases=('vban',)) - async def voiceban( + @command(aliases=("vmute",)) + async def voicemute( self, ctx: Context, user: UnambiguousMemberOrUser, @@ -117,11 +117,11 @@ class Infractions(InfractionScheduler, commands.Cog): reason: t.Optional[str] ) -> None: """ - Permanently ban user from using voice channels. + Permanently mute user in voice channels. - If duration is specified, it temporarily voice bans that user for the given duration. + If duration is specified, it temporarily voice mutes that user for the given duration. """ - await self.apply_voice_ban(ctx, user, reason, expires_at=duration) + await self.apply_voice_mute(ctx, user, reason, expires_at=duration) # endregion # region: Temporary infractions @@ -185,17 +185,17 @@ class Infractions(InfractionScheduler, commands.Cog): """ await self.apply_ban(ctx, user, reason, expires_at=duration) - @command(aliases=("tempvban", "tvban")) - async def tempvoiceban( - self, - ctx: Context, - user: UnambiguousMemberOrUser, - duration: Expiry, - *, - reason: t.Optional[str] + @command(aliases=("tempvmute", "tvmute")) + async def tempvoicemute( + self, + ctx: Context, + user: UnambiguousMemberOrUser, + duration: Expiry, + *, + reason: t.Optional[str] ) -> None: """ - Temporarily voice ban a user for the given reason and duration. + Temporarily voice mute a user for the given reason and duration. A unit of time should be appended to the duration. Units (∗case-sensitive): @@ -209,7 +209,7 @@ class Infractions(InfractionScheduler, commands.Cog): Alternatively, an ISO 8601 timestamp can be provided for the duration. """ - await self.apply_voice_ban(ctx, user, reason, expires_at=duration) + await self.apply_voice_mute(ctx, user, reason, expires_at=duration) # endregion # region: Permanent shadow infractions @@ -270,9 +270,9 @@ class Infractions(InfractionScheduler, commands.Cog): """Prematurely end the active ban infraction for the user.""" await self.pardon_infraction(ctx, "ban", user) - @command(aliases=("uvban",)) - async def unvoiceban(self, ctx: Context, user: UnambiguousMemberOrUser) -> None: - """Prematurely end the active voice ban infraction for the user.""" + @command(aliases=("uvmute",)) + async def unvoicemute(self, ctx: Context, user: UnambiguousMemberOrUser) -> None: + """Prematurely end the active voice mute infraction for the user.""" await self.pardon_infraction(ctx, "voice_mute", user) # endregion @@ -395,8 +395,8 @@ class Infractions(InfractionScheduler, commands.Cog): await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) @respect_role_hierarchy(member_arg=2) - async def apply_voice_ban(self, ctx: Context, user: MemberOrUser, reason: t.Optional[str], **kwargs) -> None: - """Apply a voice ban infraction with kwargs passed to `post_infraction`.""" + async def apply_voice_mute(self, ctx: Context, user: MemberOrUser, reason: t.Optional[str], **kwargs) -> None: + """Apply a voice mute infraction with kwargs passed to `post_infraction`.""" if await _utils.get_active_infraction(ctx, user, "voice_mute"): return @@ -471,7 +471,7 @@ class Infractions(InfractionScheduler, commands.Cog): return log_text - async def pardon_voice_ban( + async def pardon_voice_mute( self, user_id: int, guild: discord.Guild, diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index a796fd049..f89465f84 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -62,8 +62,8 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) -class VoiceBanTests(unittest.IsolatedAsyncioTestCase): - """Tests for voice ban related functions and commands.""" +class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): + """Tests for voice mute related functions and commands.""" def setUp(self): self.bot = MockBot() @@ -73,59 +73,59 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): self.ctx = MockContext(bot=self.bot, author=self.mod) self.cog = Infractions(self.bot) - async def test_permanent_voice_ban(self): - """Should call voice ban applying function without expiry.""" - self.cog.apply_voice_ban = AsyncMock() - self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar")) - self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) + async def test_permanent_voice_mute(self): + """Should call voice mute applying function without expiry.""" + self.cog.apply_voice_mute = AsyncMock() + self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar")) + self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) - async def test_temporary_voice_ban(self): - """Should call voice ban applying function with expiry.""" - self.cog.apply_voice_ban = AsyncMock() - self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar")) - self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + async def test_temporary_voice_mute(self): + """Should call voice mute applying function with expiry.""" + self.cog.apply_voice_mute = AsyncMock() + self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar")) + self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") - async def test_voice_unban(self): + async def test_voice_unmute(self): """Should call infraction pardoning function.""" self.cog.pardon_infraction = AsyncMock() - self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user)) + self.assertIsNone(await self.cog.unvoicemute(self.cog, self.ctx, self.user)) self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): - """Should return early when user already have Voice Ban infraction.""" + async def test_voice_mute_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): + """Should return early when user already have Voice Mute infraction.""" get_active_infraction.return_value = {"foo": "bar"} - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar")) get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_mute") post_infraction_mock.assert_not_awaited() @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock): + async def test_voice_mute_infraction_post_failed(self, get_active_infraction, post_infraction_mock): """Should return early when posting infraction fails.""" self.cog.mod_log.ignore = MagicMock() get_active_infraction.return_value = None post_infraction_mock.return_value = None - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar")) post_infraction_mock.assert_awaited_once() self.cog.mod_log.ignore.assert_not_called() @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): - """Should pass all kwargs passed to apply_voice_ban to post_infraction.""" + async def test_voice_mute_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): + """Should pass all kwargs passed to apply_voice_mute to post_infraction.""" get_active_infraction.return_value = None # We don't want that this continue yet post_infraction_mock.return_value = None - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23)) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar", my_kwarg=23)) post_infraction_mock.assert_awaited_once_with( self.ctx, self.user, "voice_mute", "foobar", active=True, my_kwarg=23 ) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock): + async def test_voice_mute_mod_log_ignore(self, get_active_infraction, post_infraction_mock): """Should ignore Voice Verified role removing.""" self.cog.mod_log.ignore = MagicMock() self.cog.apply_infraction = AsyncMock() @@ -134,11 +134,11 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): get_active_infraction.return_value = None post_infraction_mock.return_value = {"foo": "bar"} - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar")) self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id) async def action_tester(self, action, reason: str) -> None: - """Helper method to test voice ban action.""" + """Helper method to test voice mute action.""" self.assertTrue(inspect.iscoroutine(action)) await action @@ -147,7 +147,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock): + async def test_voice_mute_apply_infraction(self, get_active_infraction, post_infraction_mock): """Should ignore Voice Verified role removing.""" self.cog.mod_log.ignore = MagicMock() self.cog.apply_infraction = AsyncMock() @@ -156,22 +156,22 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): post_infraction_mock.return_value = {"foo": "bar"} reason = "foobar" - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, reason)) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, reason)) self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, ANY) await self.action_tester(self.cog.apply_infraction.call_args[0][-1], reason) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock): - """Should truncate reason for voice ban.""" + async def test_voice_mute_truncate_reason(self, get_active_infraction, post_infraction_mock): + """Should truncate reason for voice mute.""" self.cog.mod_log.ignore = MagicMock() self.cog.apply_infraction = AsyncMock() get_active_infraction.return_value = None post_infraction_mock.return_value = {"foo": "bar"} - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000)) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar" * 3000)) self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, ANY) # Test action @@ -180,13 +180,13 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): @autospec(_utils, "post_infraction", "get_active_infraction", return_value=None) @autospec(Infractions, "apply_infraction") - async def test_voice_ban_user_left_guild(self, apply_infraction_mock, post_infraction_mock, _): - """Should voice ban user that left the guild without throwing an error.""" + async def test_voice_mute_user_left_guild(self, apply_infraction_mock, post_infraction_mock, _): + """Should voice mute user that left the guild without throwing an error.""" infraction = {"foo": "bar"} post_infraction_mock.return_value = {"foo": "bar"} user = MockUser() - await self.cog.voiceban(self.cog, self.ctx, user, reason=None) + await self.cog.voicemute(self.cog, self.ctx, user, reason=None) post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None) apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY) @@ -195,22 +195,22 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(inspect.iscoroutine(action)) await action - async def test_voice_unban_user_not_found(self): + async def test_voice_unmute_user_not_found(self): """Should include info to return dict when user was not found from guild.""" self.guild.get_member.return_value = None self.guild.fetch_member.side_effect = NotFound(Mock(status=404), "Not found") - result = await self.cog.pardon_voice_ban(self.user.id, self.guild) + result = await self.cog.pardon_voice_mute(self.user.id, self.guild) self.assertEqual(result, {"Info": "User was not found in the guild."}) @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @patch("bot.exts.moderation.infraction.infractions.format_user") - async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock): + async def test_voice_unmute_user_found(self, format_user_mock, notify_pardon_mock): """Should add role back with ignoring, notify user and return log dictionary..""" self.guild.get_member.return_value = self.user notify_pardon_mock.return_value = True format_user_mock.return_value = "my-user" - result = await self.cog.pardon_voice_ban(self.user.id, self.guild) + result = await self.cog.pardon_voice_mute(self.user.id, self.guild) self.assertEqual(result, { "Member": "my-user", "DM": "Sent" @@ -219,13 +219,13 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @patch("bot.exts.moderation.infraction.infractions.format_user") - async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock): + async def test_voice_unmute_dm_fail(self, format_user_mock, notify_pardon_mock): """Should add role back with ignoring, notify user and return log dictionary..""" self.guild.get_member.return_value = self.user notify_pardon_mock.return_value = False format_user_mock.return_value = "my-user" - result = await self.cog.pardon_voice_ban(self.user.id, self.guild) + result = await self.cog.pardon_voice_mute(self.user.id, self.guild) self.assertEqual(result, { "Member": "my-user", "DM": "**Failed**" -- cgit v1.2.3 From d0cf7f2f1573e883cf1f6aaf8c54b3701496722b Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Wed, 26 Jan 2022 16:12:15 -0500 Subject: chore: Remove the naming of 'eval' in certain places Since the !eval command is no longer the only snekbox command, make the naming more generic. --- bot/exts/filters/filtering.py | 4 +- bot/exts/utils/snekbox.py | 72 ++++++++++++++++++------------------ tests/bot/exts/utils/test_snekbox.py | 6 +-- 3 files changed, 41 insertions(+), 41 deletions(-) (limited to 'tests') diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index ad904d147..e49cf4f82 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -267,9 +267,9 @@ class Filtering(Cog): # Update time when alert sent await self.name_alerts.set(member.id, arrow.utcnow().timestamp()) - async def filter_eval(self, result: str, msg: Message) -> bool: + async def filter_snekbox_job(self, result: str, msg: Message) -> bool: """ - Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. + Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. Also requires the original message, to check whether to filter and for mod logs. Returns whether a filter was triggered or not. diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index e7712eee5..86993b7f1 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -71,15 +71,15 @@ if not hasattr(sys, "_setup_finished"): MAX_PASTE_LEN = 10000 -# `!eval` command whitelists and blacklists. -NO_EVAL_CHANNELS = (Channels.python_general,) -NO_EVAL_CATEGORIES = () -EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) +# The Snekbox commands' whitelists and blacklists. +NO_SNEKBOX_CHANNELS = (Channels.python_general,) +NO_SNEKBOX_CATEGORIES = () +SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) SIGKILL = 9 -REEVAL_EMOJI = '\U0001f501' # :repeat: -REEVAL_TIMEOUT = 30 +REDO_EMOJI = '\U0001f501' # :repeat: +REDO_TIMEOUT = 30 class Snekbox(Cog): @@ -89,7 +89,7 @@ class Snekbox(Cog): self.bot = bot self.jobs = {} - async def post_eval(self, code: str, *, args: Optional[list[str]] = None) -> dict: + async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: """Send a POST request to the Snekbox API to evaluate code and return the results.""" url = URLs.snekbox_eval_api data = {"input": code} @@ -101,7 +101,7 @@ class Snekbox(Cog): return await resp.json() async def upload_output(self, output: str) -> Optional[str]: - """Upload the eval output to a paste service and return a URL to it if successful.""" + """Upload the job's output to a paste service and return a URL to it if successful.""" log.trace("Uploading full output to paste service...") if len(output) > MAX_PASTE_LEN: @@ -241,7 +241,7 @@ class Snekbox(Cog): return output, paste_link - async def send_eval( + async def send_job( self, ctx: Context, code: str, @@ -255,7 +255,7 @@ class Snekbox(Cog): Return the bot response. """ async with ctx.typing(): - results = await self.post_eval(code, args=args) + results = await self.post_job(code, args=args) msg, error = self.get_results_message(results, job_name) if error: @@ -269,7 +269,7 @@ class Snekbox(Cog): if paste_link: msg = f"{msg}\nFull output: {paste_link}" - # Collect stats of eval fails + successes + # Collect stats of job fails + successes if icon == ":x:": self.bot.stats.incr("snekbox.python.fail") else: @@ -278,7 +278,7 @@ class Snekbox(Cog): filter_cog = self.bot.get_cog("Filtering") filter_triggered = False if filter_cog: - filter_triggered = await filter_cog.filter_eval(msg, ctx.message) + filter_triggered = await filter_cog.filter_snekbox_job(msg, ctx.message) if filter_triggered: response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") else: @@ -289,26 +289,26 @@ class Snekbox(Cog): log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") return response - async def continue_eval( + async def continue_job( self, ctx: Context, response: Message, command: Command ) -> tuple[Optional[str], Optional[list[str]]]: """ - Check if the eval session should continue. + Check if the job's session should continue. If the code is to be re-evaluated, return the new code, and the args if the command is the timeit command. - Otherwise return (None, None) if the eval session should be terminated. + Otherwise return (None, None) if the job's session should be terminated. """ - _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) - _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + _predicate_message_edit = partial(predicate_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx) with contextlib.suppress(NotFound): try: _, new_message = await self.bot.wait_for( 'message_edit', - check=_predicate_eval_message_edit, - timeout=REEVAL_TIMEOUT + check=_predicate_message_edit, + timeout=REDO_TIMEOUT ) - await ctx.message.add_reaction(REEVAL_EMOJI) + await ctx.message.add_reaction(REDO_EMOJI) await self.bot.wait_for( 'reaction_add', check=_predicate_emoji_reaction, @@ -316,7 +316,7 @@ class Snekbox(Cog): ) code = await self.get_code(new_message, ctx.command) - await ctx.message.clear_reaction(REEVAL_EMOJI) + await ctx.message.clear_reaction(REDO_EMOJI) with contextlib.suppress(HTTPException): await response.delete() @@ -324,7 +324,7 @@ class Snekbox(Cog): return None, None except asyncio.TimeoutError: - await ctx.message.clear_reaction(REEVAL_EMOJI) + await ctx.message.clear_reaction(REDO_EMOJI) return None, None codeblocks = self.prepare_input(code) @@ -347,11 +347,11 @@ class Snekbox(Cog): new_ctx = await self.bot.get_context(message) if new_ctx.command is command: - log.trace(f"Message {message.id} invokes eval command.") + log.trace(f"Message {message.id} invokes {command} command.") split = message.content.split(maxsplit=1) code = split[1] if len(split) > 1 else None else: - log.trace(f"Message {message.id} does not invoke eval command.") + log.trace(f"Message {message.id} does not invoke {command} command.") code = message.content return code @@ -389,11 +389,11 @@ class Snekbox(Cog): while True: self.jobs[ctx.author.id] = datetime.datetime.now() try: - response = await self.send_eval(ctx, code, args=args, job_name=job_name) + response = await self.send_job(ctx, code, args=args, job_name=job_name) finally: del self.jobs[ctx.author.id] - code, args = await self.continue_eval(ctx, response, ctx.command) + code, args = await self.continue_job(ctx, response, ctx.command) if not code: break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") @@ -402,9 +402,9 @@ class Snekbox(Cog): @guild_only() @redirect_output( destination_channel=Channels.bot_commands, - bypass_roles=EVAL_ROLES, - categories=NO_EVAL_CATEGORIES, - channels=NO_EVAL_CHANNELS, + bypass_roles=SNEKBOX_ROLES, + categories=NO_SNEKBOX_CATEGORIES, + channels=NO_SNEKBOX_CHANNELS, ping_user=False ) async def eval_command(self, ctx: Context, *, code: str) -> None: @@ -425,9 +425,9 @@ class Snekbox(Cog): @guild_only() @redirect_output( destination_channel=Channels.bot_commands, - bypass_roles=EVAL_ROLES, - categories=NO_EVAL_CATEGORIES, - channels=NO_EVAL_CHANNELS, + bypass_roles=SNEKBOX_ROLES, + categories=NO_SNEKBOX_CATEGORIES, + channels=NO_SNEKBOX_CHANNELS, ping_user=False ) async def timeit_command(self, ctx: Context, *, code: str) -> str: @@ -450,14 +450,14 @@ class Snekbox(Cog): await self.run_job("timeit", ctx, code=code, args=args) -def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: +def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: """Return True if the edited message is the context message and the content was indeed modified.""" return new_msg.id == ctx.message.id and old_msg.content != new_msg.content -def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: - """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI +def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI def setup(bot: Bot) -> None: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 339cdaaa4..5d213a883 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -209,7 +209,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('[No output]', None)) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') @@ -241,7 +241,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') @@ -270,7 +270,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') -- cgit v1.2.3 From 85a6f430aa68f59ce6958ecb6450eca0736628e4 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Thu, 27 Jan 2022 10:31:12 -0500 Subject: chore: Switch Snekbox.prepare_input with a CodeblockConverter As per @Numerlor's suggestion --- bot/exts/filters/filtering.py | 2 +- bot/exts/utils/snekbox.py | 74 ++++++++++++------------ tests/bot/exts/utils/test_snekbox.py | 105 +++++++++++++++++------------------ 3 files changed, 91 insertions(+), 90 deletions(-) (limited to 'tests') diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index 599302576..375e9dca8 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -267,7 +267,7 @@ class Filtering(Cog): # Update time when alert sent await self.name_alerts.set(member.id, arrow.utcnow().timestamp()) - async def filter_snekbox_job(self, result: str, msg: Message) -> bool: + async def filter_snekbox_output(self, result: str, msg: Message) -> bool: """ Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 15599208f..41f6bf8ad 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -9,7 +9,7 @@ from typing import Optional, Tuple from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Command, Context, command, guild_only +from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs @@ -68,35 +68,11 @@ REDO_EMOJI = '\U0001f501' # :repeat: REDO_TIMEOUT = 30 -class Snekbox(Cog): - """Safe evaluation of Python code using Snekbox.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.jobs = {} - - async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: - """Send a POST request to the Snekbox API to evaluate code and return the results.""" - url = URLs.snekbox_eval_api - data = {"input": code} - - if args is not None: - data["args"] = args +class CodeblockConverter(Converter): + """Attempts to extract code from a codeblock, if provided.""" - async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return await resp.json() - - async def upload_output(self, output: str) -> Optional[str]: - """Upload the job's output to a paste service and return a URL to it if successful.""" - log.trace("Uploading full output to paste service...") - - if len(output) > MAX_PASTE_LEN: - log.info("Full output is too long to upload") - return "too long to upload" - return await send_to_paste_service(output, extension="txt") - - @staticmethod - def prepare_input(code: str) -> list[str]: + @classmethod + async def convert(cls, ctx: Context, code: str) -> list[str]: """ Extract code from the Markdown, format it, and insert it into the code template. @@ -128,6 +104,34 @@ class Snekbox(Cog): log.trace(f"Extracted {info} for evaluation:\n{code}") return codeblocks + +class Snekbox(Cog): + """Safe evaluation of Python code using Snekbox.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.jobs = {} + + async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + url = URLs.snekbox_eval_api + data = {"input": code} + + if args is not None: + data["args"] = args + + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return await resp.json() + + async def upload_output(self, output: str) -> Optional[str]: + """Upload the job's output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") + + if len(output) > MAX_PASTE_LEN: + log.info("Full output is too long to upload") + return "too long to upload" + return await send_to_paste_service(output, extension="txt") + @staticmethod def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: """ @@ -313,7 +317,7 @@ class Snekbox(Cog): await ctx.message.clear_reaction(REDO_EMOJI) return None, None - codeblocks = self.prepare_input(code) + codeblocks = await CodeblockConverter.convert(ctx, code) if command is self.timeit_command: return self.prepare_timeit_input(codeblocks) @@ -393,7 +397,7 @@ class Snekbox(Cog): channels=NO_SNEKBOX_CHANNELS, ping_user=False ) - async def eval_command(self, ctx: Context, *, code: str) -> None: + async def eval_command(self, ctx: Context, *, code: CodeblockConverter) -> None: """ Run Python code and get the results. @@ -404,8 +408,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = "\n".join(self.prepare_input(code)) - await self.run_job("eval", ctx, code) + await self.run_job("eval", ctx, "\n".join(code)) @command(name="timeit", aliases=("ti",)) @guild_only() @@ -416,7 +419,7 @@ class Snekbox(Cog): channels=NO_SNEKBOX_CHANNELS, ping_user=False ) - async def timeit_command(self, ctx: Context, *, code: str) -> str: + async def timeit_command(self, ctx: Context, *, code: CodeblockConverter) -> str: """ Profile Python Code to find execution time. @@ -430,8 +433,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - codeblocks = self.prepare_input(code) - code, args = self.prepare_timeit_input(codeblocks) + code, args = self.prepare_timeit_input(code) await self.run_job("timeit", ctx, code=code, args=args) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 5d213a883..75da0c860 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -17,7 +17,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = Snekbox(bot=self.bot) - async def test_post_eval(self): + async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() resp.json = AsyncMock(return_value="return") @@ -26,7 +26,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_eval("import random"), "return") + self.assertEqual(await self.cog.post_job("import random"), "return") self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, json={"input": "import random"}, @@ -45,7 +45,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.upload_output("Test output.") mock_paste_util.assert_called_once_with("Test output.", extension="txt") - def test_prepare_input(self): + async def test_codeblock_converter(self): + ctx = MockContext() cases = ( ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), @@ -61,7 +62,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected) + self.assertEqual( + '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected + ) def test_get_results_message(self): """Return error and message according to the eval result.""" @@ -158,31 +161,27 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): response = MockMessage() ctx.command = MagicMock() - self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock(return_value=(None, None)) + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock(return_value=(None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval') - self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command) + await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) + self.cog.send_job.assert_called_once_with(ctx, 'MyAwesomeCode', args=None, job_name='eval') + self.cog.continue_job.assert_called_once_with(ctx, response, ctx.command) async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() response = MockMessage() ctx.command = MagicMock() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock() + self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with( + await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) + self.cog.send_job.assert_called_with( ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' ) - self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) + self.cog.continue_job.assert_called_with(ctx, response, ctx.command) async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -196,14 +195,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) - async def test_send_eval(self): - """Test the send_eval function.""" + async def test_send_job(self): + """Test the send_job function.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author = MockUser(mention='@LemonLemonishBeard#0042') - self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('[No output]', None)) @@ -212,7 +211,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -223,19 +222,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('') - async def test_send_eval_with_paste_link(self): - """Test the send_eval function with a too long output that generate a paste link.""" + async def test_send_job_with_paste_link(self): + """Test the send_job function with a too long output that generate a paste link.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) @@ -244,7 +243,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -253,18 +252,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('Way too long beard') - async def test_send_eval_with_non_zero_eval(self): - """Test the send_eval function with a code returning a non-zero code.""" + async def test_send_job_with_non_zero_eval(self): + """Test the send_job function with a code returning a non-zero code.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) self.cog.get_status_emoji = MagicMock(return_value=':nope!:') self.cog.format_output = AsyncMock() # This function isn't called @@ -273,7 +272,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -281,14 +280,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") - async def test_continue_eval_does_continue(self, partial_mock): - """Test that the continue_eval function does continue if required conditions are met.""" + async def test_continue_job_does_continue(self, partial_mock): + """Test that the continue_job function does continue if required conditions are met.""" ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) response = MockMessage(delete=AsyncMock()) new_msg = MockMessage() @@ -296,30 +295,30 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - actual = await self.cog.continue_eval(ctx, response, self.cog.eval_command) + actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) self.assertEqual(actual, (expected, None)) self.bot.wait_for.assert_has_awaits( ( call( 'message_edit', - check=partial_mock(snekbox.predicate_eval_message_edit, ctx), - timeout=snekbox.REEVAL_TIMEOUT, + check=partial_mock(snekbox.predicate_message_edit, ctx), + timeout=snekbox.REDO_TIMEOUT, ), - call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10) ) ) - ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) response.delete.assert_called_once() - async def test_continue_eval_does_not_continue(self): + async def test_continue_job_does_not_continue(self): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError - actual = await self.cog.continue_eval(ctx, MockMessage(), self.cog.eval_command) + actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) self.assertEqual(actual, (None, None)) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) async def test_get_code(self): """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -347,8 +346,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) - def test_predicate_eval_message_edit(self): - """Test the predicate_eval_message_edit function.""" + def test_predicate_message_edit(self): + """Test the predicate_message_edit function.""" msg0 = MockMessage(id=1, content='abc') msg1 = MockMessage(id=2, content='abcdef') msg2 = MockMessage(id=1, content='abcdef') @@ -361,18 +360,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): for ctx_msg, new_msg, expected, testname in cases: with self.subTest(msg=f'Messages with {testname} return {expected}'): ctx = MockContext(message=ctx_msg) - actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) + actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg) self.assertEqual(actual, expected) - def test_predicate_eval_emoji_reaction(self): - """Test the predicate_eval_emoji_reaction function.""" + def test_predicate_emoji_reaction(self): + """Test the predicate_emoji_reaction function.""" valid_reaction = MockReaction(message=MockMessage(id=1)) - valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI + valid_reaction.__str__.return_value = snekbox.REDO_EMOJI valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) valid_user = MockUser(id=2) invalid_reaction_id = MockReaction(message=MockMessage(id=42)) - invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI + invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI invalid_user_id = MockUser(id=42) invalid_reaction_str = MockReaction(message=MockMessage(id=1)) invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -385,7 +384,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for reaction, user, expected, testname in cases: with self.subTest(msg=f'Test with {testname} and expected return {expected}'): - actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) + actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user) self.assertEqual(actual, expected) -- cgit v1.2.3 From b689f059e45e68f75e3a97277ac3bf58263fa769 Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Fri, 4 Feb 2022 20:54:49 -0500 Subject: fix: Use filter_snekbox_output rather than job --- bot/exts/utils/snekbox.py | 2 +- tests/bot/exts/utils/test_snekbox.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index ce3dd7c24..a932b96ff 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -268,7 +268,7 @@ class Snekbox(Cog): filter_cog = self.bot.get_cog("Filtering") filter_triggered = False if filter_cog: - filter_triggered = await filter_cog.filter_snekbox_job(msg, ctx.message) + filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message) if filter_triggered: response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") else: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 75da0c860..2eaed0446 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -208,7 +208,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('[No output]', None)) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') @@ -240,7 +240,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') @@ -269,7 +269,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') -- cgit v1.2.3 From 7ef9b2163e7f0da5a7abd17decac2b0b2d53defb Mon Sep 17 00:00:00 2001 From: ToxicKidz Date: Sat, 5 Feb 2022 23:21:21 -0500 Subject: tests: Add a test for timeit command codeblock preparation --- tests/bot/exts/utils/test_snekbox.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 2eaed0446..f68a20089 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -66,6 +66,21 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected ) + def test_prepare_timeit_input(self): + """Test the prepare_timeit_input codeblock detection.""" + base_args = ('-m', 'timeit', '-s') + cases = ( + (['print("Hello World")'], '', 'single block of code'), + (['x = 1', 'print(x)'], 'x = 1', 'two blocks of code'), + (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') + ) + + for case, setup_code, testname in cases: + setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) + expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup]) + with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + self.assertEqual(self.cog.prepare_timeit_input(case), expected) + def test_get_results_message(self): """Return error and message according to the eval result.""" cases = ( -- cgit v1.2.3 From 993529aa945a1f9ec8d769c770399dbe2cd8bd25 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sun, 2 Jan 2022 20:13:53 +0000 Subject: Add tests for new CleanBan and Clean functionality --- .../exts/moderation/infraction/test_infractions.py | 90 +++++++++++++++++- tests/bot/exts/moderation/test_clean.py | 104 +++++++++++++++++++++ 2 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 tests/bot/exts/moderation/test_clean.py (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index f89465f84..57235ec6d 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -1,13 +1,15 @@ import inspect import textwrap import unittest -from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch +from unittest.mock import ANY, AsyncMock, DEFAULT, MagicMock, Mock, patch from discord.errors import NotFound from bot.constants import Event +from bot.exts.moderation.clean import Clean from bot.exts.moderation.infraction import _utils from bot.exts.moderation.infraction.infractions import Infractions +from bot.exts.moderation.infraction.management import ModManagement from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockUser, autospec @@ -231,3 +233,89 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): "DM": "**Failed**" }) notify_pardon_mock.assert_awaited_once() + + +class CleanBanTests(unittest.IsolatedAsyncioTestCase): + """Tests for cleanban functionality.""" + + def setUp(self): + self.bot = MockBot() + self.mod = MockMember(roles=[MockRole(id=7890123, position=10)]) + self.user = MockMember(roles=[MockRole(id=123456, position=1)]) + self.guild = MockGuild() + self.ctx = MockContext(bot=self.bot, author=self.mod) + self.cog = Infractions(self.bot) + self.clean_cog = Clean(self.bot) + self.management_cog = ModManagement(self.bot) + + self.cog.apply_ban = AsyncMock(return_value={"id": 42}) + self.log_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" + self.clean_cog._clean_messages = AsyncMock(return_value=self.log_url) + + def mock_get_cog(self, enable_clean, enable_manage): + def inner(name): + if name == "ModManagement": + return self.management_cog if enable_manage else None + elif name == "Clean": + return self.clean_cog if enable_clean else None + else: + return DEFAULT + return inner + + async def test_cleanban_falls_back_to_native_purge_without_clean_cog(self): + """Should fallback to native purge if the Clean cog is not available.""" + self.bot.get_cog.side_effect = self.mock_get_cog(False, False) + + self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) + self.cog.apply_ban.assert_awaited_once_with( + self.ctx, + self.user, + "FooBar", + 1, + expires_at=None, + ) + + async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self): + """Cleanban command should use the native purge messages if the clean cog is available.""" + self.bot.get_cog.side_effect = self.mock_get_cog(True, False) + + self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) + self.cog.apply_ban.assert_awaited_once_with( + self.ctx, + self.user, + "FooBar", + expires_at=None, + ) + + @patch("bot.exts.moderation.infraction.infractions.Age") + async def test_cleanban_uses_clean_cog_when_available(self, mocked_age_converter): + """Test cleanban uses the clean cog to clean messages if it's available.""" + self.bot.api_client.patch = AsyncMock() + self.bot.get_cog.side_effect = self.mock_get_cog(True, False) + + mocked_age_converter.return_value.convert = AsyncMock(return_value="81M") + self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) + + self.clean_cog._clean_messages.assert_awaited_once_with( + self.ctx, + users=[self.user], + channels="*", + first_limit="81M", + attempt_delete_invocation=False, + ) + + @patch("bot.exts.moderation.infraction.infractions.Infraction") + async def test_cleanban_edits_infraction_reason(self, mocked_infraction_converter): + """Ensure cleanban edits the ban reason with a link to the clean log.""" + self.bot.get_cog.side_effect = self.mock_get_cog(True, True) + + self.management_cog.infraction_append = AsyncMock() + mocked_infraction_converter.return_value.convert = AsyncMock(return_value=42) + self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) + + self.management_cog.infraction_append.assert_awaited_once_with( + self.ctx, + 42, + None, + reason=f"[Clean log]({self.log_url})" + ) diff --git a/tests/bot/exts/moderation/test_clean.py b/tests/bot/exts/moderation/test_clean.py new file mode 100644 index 000000000..83489ea00 --- /dev/null +++ b/tests/bot/exts/moderation/test_clean.py @@ -0,0 +1,104 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from bot.exts.moderation.clean import Clean +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockMessage, MockRole, MockTextChannel + + +class CleanTests(unittest.IsolatedAsyncioTestCase): + """Tests for clean cog functionality.""" + + def setUp(self): + self.bot = MockBot() + self.mod = MockMember(roles=[MockRole(id=7890123, position=10)]) + self.user = MockMember(roles=[MockRole(id=123456, position=1)]) + self.guild = MockGuild() + self.ctx = MockContext(bot=self.bot, author=self.mod) + self.cog = Clean(self.bot) + + self.log_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" + self.cog._modlog_cleaned_messages = AsyncMock(return_value=self.log_url) + + self.cog._use_cache = MagicMock(return_value=True) + self.cog._delete_found = AsyncMock(return_value=[42, 84]) + + @patch("bot.exts.moderation.clean.is_mod_channel") + async def test_clean_deletes_invocation_in_non_mod_channel(self, mod_channel_check): + """Clean command should delete the invocation message if ran in a non mod channel.""" + mod_channel_check.return_value = False + self.ctx.message.delete = AsyncMock() + + self.assertIsNone(await self.cog._delete_invocation(self.ctx)) + + self.ctx.message.delete.assert_awaited_once() + + @patch("bot.exts.moderation.clean.is_mod_channel") + async def test_clean_doesnt_delete_invocation_in_mod_channel(self, mod_channel_check): + """Clean command should not delete the invocation message if ran in a mod channel.""" + mod_channel_check.return_value = True + self.ctx.message.delete = AsyncMock() + + self.assertIsNone(await self.cog._delete_invocation(self.ctx)) + + self.ctx.message.delete.assert_not_awaited() + + async def test_clean_doesnt_attempt_deletion_when_attempt_delete_invocation_is_false(self): + """Clean command should not attempt to delete the invocation message if attempt_delete_invocation is false.""" + self.cog._delete_invocation = AsyncMock() + self.bot.get_channel = MagicMock(return_value=False) + + self.assertEqual( + await self.cog._clean_messages( + self.ctx, + None, + first_limit=MockMessage(), + attempt_delete_invocation=False, + ), + self.log_url, + ) + + self.cog._delete_invocation.assert_not_awaited() + + @patch("bot.exts.moderation.clean.is_mod_channel") + async def test_clean_replies_with_success_message_when_ran_in_mod_channel(self, mod_channel_check): + """Clean command should reply to the message with a confirmation message if invoked in a mod channel.""" + mod_channel_check.return_value = True + self.ctx.reply = AsyncMock() + + self.assertEqual( + await self.cog._clean_messages( + self.ctx, + None, + first_limit=MockMessage(), + attempt_delete_invocation=False, + ), + self.log_url, + ) + + self.ctx.reply.assert_awaited_once() + sent_message = self.ctx.reply.await_args[0][0] + self.assertIn(self.log_url, sent_message) + self.assertIn("2 messages", sent_message) + + @patch("bot.exts.moderation.clean.is_mod_channel") + async def test_clean_send_success_message__to_mods_when_ran_in_non_mod_channel(self, mod_channel_check): + """Clean command should send a confirmation message to #mods if invoked in a non-mod channel.""" + mod_channel_check.return_value = False + mocked_mods = MockTextChannel(id=1234567) + mocked_mods.send = AsyncMock() + self.bot.get_channel = MagicMock(return_value=mocked_mods) + + self.assertEqual( + await self.cog._clean_messages( + self.ctx, + None, + first_limit=MockMessage(), + attempt_delete_invocation=False, + ), + self.log_url, + ) + + mocked_mods.send.assert_awaited_once() + sent_message = mocked_mods.send.await_args[0][0] + self.assertIn(self.log_url, sent_message) + self.assertIn("2 messages", sent_message) -- cgit v1.2.3 From 6c139905cca53f7810a100435955ec0c5fbc30e1 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 14 Feb 2022 01:51:23 +0000 Subject: Send error when cleanban fails to ban Co-authored-by: GDWR --- bot/exts/moderation/infraction/infractions.py | 4 +++- tests/bot/exts/moderation/infraction/test_infractions.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 32ff376cf..09ee1a7b4 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -113,12 +113,14 @@ class Infractions(InfractionScheduler, commands.Cog): clean_cog: t.Optional[Clean] = self.bot.get_cog("Clean") if clean_cog is None: # If we can't get the clean cog, fall back to native purgeban. - await self.apply_ban(ctx, user, reason, 1, expires_at=duration) + await self.apply_ban(ctx, user, reason, purge_days=1, expires_at=duration) return infraction = await self.apply_ban(ctx, user, reason, expires_at=duration) if not infraction or not infraction.get("id"): # Ban was unsuccessful, quit early. + await ctx.send(":x: Failed to apply ban.") + log.error("Failed to apply ban to user %d", user.id) return # Calling commands directly skips Discord.py's convertors, so we need to convert args manually. diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 57235ec6d..8845fb382 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -271,7 +271,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase): self.ctx, self.user, "FooBar", - 1, + purge_days=1, expires_at=None, ) -- cgit v1.2.3 From 762b107056145d44b5219a929302455c9e6ed1d0 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 14 Feb 2022 01:53:33 +0000 Subject: Typo and docstrings in clean ban tests Co-authored-by: GDWR --- tests/bot/exts/moderation/infraction/test_infractions.py | 1 + tests/bot/exts/moderation/test_clean.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 8845fb382..8bed1e386 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -253,6 +253,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase): self.clean_cog._clean_messages = AsyncMock(return_value=self.log_url) def mock_get_cog(self, enable_clean, enable_manage): + """Mock get cog factory that allows the user to specify whether clean and manage cogs are enabled.""" def inner(name): if name == "ModManagement": return self.management_cog if enable_manage else None diff --git a/tests/bot/exts/moderation/test_clean.py b/tests/bot/exts/moderation/test_clean.py index 83489ea00..d7647fa48 100644 --- a/tests/bot/exts/moderation/test_clean.py +++ b/tests/bot/exts/moderation/test_clean.py @@ -81,7 +81,7 @@ class CleanTests(unittest.IsolatedAsyncioTestCase): self.assertIn("2 messages", sent_message) @patch("bot.exts.moderation.clean.is_mod_channel") - async def test_clean_send_success_message__to_mods_when_ran_in_non_mod_channel(self, mod_channel_check): + async def test_clean_send_success_message_to_mods_when_ran_in_non_mod_channel(self, mod_channel_check): """Clean command should send a confirmation message to #mods if invoked in a non-mod channel.""" mod_channel_check.return_value = False mocked_mods = MockTextChannel(id=1234567) -- cgit v1.2.3 From 7e8e95f07e343f1d7d9a8069b6cdb1a9fcbb00d7 Mon Sep 17 00:00:00 2001 From: ChrisJL Date: Wed, 16 Feb 2022 22:48:09 +0000 Subject: Remove unnecessary Infraction conversion in clean ban (#2092) --- bot/exts/moderation/infraction/infractions.py | 3 +-- tests/bot/exts/moderation/infraction/test_infractions.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 09ee1a7b4..af42ab1b8 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -9,7 +9,7 @@ from discord.ext.commands import Context, command from bot import constants from bot.bot import Bot from bot.constants import Event -from bot.converters import Age, Duration, Expiry, Infraction, MemberOrUser, UnambiguousMemberOrUser +from bot.converters import Age, Duration, Expiry, MemberOrUser, UnambiguousMemberOrUser from bot.decorators import respect_role_hierarchy from bot.exts.moderation.infraction import _utils from bot.exts.moderation.infraction._scheduler import InfractionScheduler @@ -125,7 +125,6 @@ class Infractions(InfractionScheduler, commands.Cog): # Calling commands directly skips Discord.py's convertors, so we need to convert args manually. clean_time = await Age().convert(ctx, "1h") - infraction = await Infraction().convert(ctx, infraction["id"]) log_url = await clean_cog._clean_messages( ctx, diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 8bed1e386..052048053 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -305,18 +305,16 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase): attempt_delete_invocation=False, ) - @patch("bot.exts.moderation.infraction.infractions.Infraction") - async def test_cleanban_edits_infraction_reason(self, mocked_infraction_converter): + async def test_cleanban_edits_infraction_reason(self): """Ensure cleanban edits the ban reason with a link to the clean log.""" self.bot.get_cog.side_effect = self.mock_get_cog(True, True) self.management_cog.infraction_append = AsyncMock() - mocked_infraction_converter.return_value.convert = AsyncMock(return_value=42) self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) self.management_cog.infraction_append.assert_awaited_once_with( self.ctx, - 42, + {"id": 42}, None, reason=f"[Clean log]({self.log_url})" ) -- cgit v1.2.3 From 960619c23300c56c8aaa454edc7241e2badf80ad Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 21 Feb 2022 02:14:07 +0000 Subject: Update all references of discord.py to disnake All of the tag content is out of scope for this PR. --- bot/__init__.py | 2 +- bot/bot.py | 20 ++-- bot/converters.py | 24 ++-- bot/decorators.py | 10 +- bot/errors.py | 2 +- bot/exts/backend/branding/_cog.py | 18 +-- bot/exts/backend/config_verifier.py | 2 +- bot/exts/backend/error_handler.py | 4 +- bot/exts/backend/logging.py | 4 +- bot/exts/backend/sync/_cog.py | 6 +- bot/exts/backend/sync/_syncers.py | 4 +- bot/exts/events/code_jams/_channels.py | 42 +++---- bot/exts/events/code_jams/_cog.py | 24 ++-- bot/exts/filters/antimalware.py | 4 +- bot/exts/filters/antispam.py | 8 +- bot/exts/filters/filter_lists.py | 4 +- bot/exts/filters/filtering.py | 26 ++--- bot/exts/filters/security.py | 2 +- bot/exts/filters/token_remover.py | 6 +- bot/exts/filters/webhook_remover.py | 4 +- bot/exts/fun/duck_pond.py | 22 ++-- bot/exts/fun/off_topic_names.py | 6 +- bot/exts/help_channels/_caches.py | 14 +-- bot/exts/help_channels/_channel.py | 18 +-- bot/exts/help_channels/_cog.py | 64 +++++------ bot/exts/help_channels/_message.py | 36 +++--- bot/exts/help_channels/_name.py | 6 +- bot/exts/info/code_snippets.py | 8 +- bot/exts/info/codeblock/_cog.py | 22 ++-- bot/exts/info/doc/_batch_parser.py | 4 +- bot/exts/info/doc/_cog.py | 18 +-- bot/exts/info/help.py | 4 +- bot/exts/info/information.py | 8 +- bot/exts/info/pep.py | 4 +- bot/exts/info/pypi.py | 6 +- bot/exts/info/python_news.py | 12 +- bot/exts/info/source.py | 4 +- bot/exts/info/stats.py | 6 +- bot/exts/info/subscribe.py | 26 ++--- bot/exts/info/tags.py | 16 +-- bot/exts/moderation/clean.py | 10 +- bot/exts/moderation/defcon.py | 6 +- bot/exts/moderation/dm_relay.py | 6 +- bot/exts/moderation/incidents.py | 100 ++++++++--------- bot/exts/moderation/infraction/_scheduler.py | 14 +-- bot/exts/moderation/infraction/_utils.py | 14 +-- bot/exts/moderation/infraction/infractions.py | 26 ++--- bot/exts/moderation/infraction/management.py | 36 +++--- bot/exts/moderation/infraction/superstarify.py | 6 +- bot/exts/moderation/metabase.py | 2 +- bot/exts/moderation/modlog.py | 72 ++++++------ bot/exts/moderation/modpings.py | 8 +- bot/exts/moderation/silence.py | 8 +- bot/exts/moderation/slowmode.py | 4 +- bot/exts/moderation/stream.py | 24 ++-- bot/exts/moderation/verification.py | 16 +-- bot/exts/moderation/voice_gate.py | 42 +++---- bot/exts/moderation/watchchannels/_watchchannel.py | 12 +- bot/exts/moderation/watchchannels/bigbrother.py | 4 +- bot/exts/recruitment/talentpool/_cog.py | 8 +- bot/exts/recruitment/talentpool/_review.py | 4 +- bot/exts/utils/bot.py | 4 +- bot/exts/utils/extensions.py | 6 +- bot/exts/utils/internal.py | 17 +-- bot/exts/utils/ping.py | 4 +- bot/exts/utils/reminders.py | 32 +++--- bot/exts/utils/snekbox.py | 4 +- bot/exts/utils/thread_bumper.py | 24 ++-- bot/exts/utils/utils.py | 6 +- bot/log.py | 2 +- bot/monkey_patches.py | 6 +- bot/pagination.py | 18 +-- bot/rules/attachments.py | 2 +- bot/rules/burst.py | 2 +- bot/rules/burst_shared.py | 2 +- bot/rules/chars.py | 2 +- bot/rules/discord_emojis.py | 2 +- bot/rules/duplicates.py | 2 +- bot/rules/links.py | 2 +- bot/rules/mentions.py | 2 +- bot/rules/newlines.py | 2 +- bot/rules/role_mentions.py | 2 +- bot/utils/channel.py | 16 +-- bot/utils/checks.py | 4 +- bot/utils/function.py | 6 +- bot/utils/helpers.py | 2 +- bot/utils/members.py | 18 +-- bot/utils/message_cache.py | 2 +- bot/utils/messages.py | 48 ++++---- bot/utils/webhooks.py | 10 +- poetry.lock | 12 +- pyproject.toml | 2 +- tests/README.md | 12 +- tests/base.py | 8 +- tests/bot/exts/backend/sync/test_cog.py | 6 +- tests/bot/exts/backend/sync/test_roles.py | 6 +- tests/bot/exts/backend/sync/test_users.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 2 +- tests/bot/exts/events/test_code_jams.py | 4 +- tests/bot/exts/filters/test_antimalware.py | 2 +- tests/bot/exts/filters/test_security.py | 2 +- tests/bot/exts/filters/test_token_remover.py | 2 +- tests/bot/exts/info/test_information.py | 22 ++-- .../exts/moderation/infraction/test_infractions.py | 2 +- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- tests/bot/exts/moderation/test_incidents.py | 20 ++-- tests/bot/exts/moderation/test_modlog.py | 4 +- tests/bot/exts/moderation/test_silence.py | 4 +- tests/bot/exts/test_cogs.py | 4 +- tests/bot/exts/utils/test_snekbox.py | 4 +- tests/bot/test_converters.py | 2 +- tests/bot/utils/test_checks.py | 2 +- tests/helpers.py | 124 ++++++++++----------- tests/test_helpers.py | 28 ++--- 114 files changed, 735 insertions(+), 734 deletions(-) (limited to 'tests') diff --git a/bot/__init__.py b/bot/__init__.py index 17d99105a..b28513bff 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -3,7 +3,7 @@ import os from functools import partial, partialmethod from typing import TYPE_CHECKING -from discord.ext import commands +from disnake.ext import commands from bot import log, monkey_patches diff --git a/bot/bot.py b/bot/bot.py index 94783a466..2769b7dda 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -6,9 +6,9 @@ from contextlib import suppress from typing import Dict, List, Optional import aiohttp -import discord +import disnake from async_rediscache import RedisSession -from discord.ext import commands +from disnake.ext import commands from sentry_sdk import push_scope from bot import api, constants @@ -28,7 +28,7 @@ class StartupError(Exception): class Bot(commands.Bot): - """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" + """A subclass of `disnake.ext.commands.Bot` with an aiohttp session and an API client.""" def __init__(self, *args, redis_session: RedisSession, **kwargs): if "connector" in kwargs: @@ -109,9 +109,9 @@ class Bot(commands.Bot): def create(cls) -> "Bot": """Create and return an instance of a Bot.""" loop = asyncio.get_event_loop() - allowed_roles = list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}) + allowed_roles = list({disnake.Object(id_) for id_ in constants.MODERATION_ROLES}) - intents = discord.Intents.all() + intents = disnake.Intents.all() intents.presences = False intents.dm_typing = False intents.dm_reactions = False @@ -123,10 +123,10 @@ class Bot(commands.Bot): redis_session=_create_redis_session(loop), loop=loop, command_prefix=commands.when_mentioned_or(constants.Bot.prefix), - activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), + activity=disnake.Game(name=f"Commands: {constants.Bot.prefix}help"), case_insensitive=True, max_messages=10_000, - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), + allowed_mentions=disnake.AllowedMentions(everyone=False, roles=allowed_roles), intents=intents, ) @@ -258,7 +258,7 @@ class Bot(commands.Bot): await self.stats.create_socket() await super().login(*args, **kwargs) - async def on_guild_available(self, guild: discord.Guild) -> None: + async def on_guild_available(self, guild: disnake.Guild) -> None: """ Set the internal guild available event when constants.Guild.id becomes available. @@ -274,7 +274,7 @@ class Bot(commands.Bot): try: webhook = await self.fetch_webhook(constants.Webhooks.dev_log) - except discord.HTTPException as e: + except disnake.HTTPException as e: log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}") else: await webhook.send(f"<@&{constants.Roles.admin}> {msg}") @@ -283,7 +283,7 @@ class Bot(commands.Bot): self._guild_available.set() - async def on_guild_unavailable(self, guild: discord.Guild) -> None: + async def on_guild_unavailable(self, guild: disnake.Guild) -> None: """Clear the internal guild available event when constants.Guild.id becomes unavailable.""" if guild.id != constants.Guild.id: return diff --git a/bot/converters.py b/bot/converters.py index 3522a32aa..9d93428ca 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -6,12 +6,12 @@ from datetime import datetime, timezone from ssl import CertificateError import dateutil.parser -import discord +import disnake from aiohttp import ClientConnectorError from botcore.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter -from discord.utils import escape_markdown, snowflake_time +from disnake.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from disnake.utils import escape_markdown, snowflake_time from bot import exts from bot.api import ResponseCodeError @@ -505,14 +505,14 @@ AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Us class UnambiguousUser(UserConverter): """ - Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. + Converts to a `disnake.User`, but only if a mention, userID or a username (name#discrim) is provided. Unlike the default `UserConverter`, it doesn't allow conversion from a name. This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - async def convert(self, ctx: Context, argument: str) -> discord.User: - """Convert the `argument` to a `discord.User`.""" + async def convert(self, ctx: Context, argument: str) -> disnake.User: + """Convert the `argument` to a `disnake.User`.""" if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: @@ -521,14 +521,14 @@ class UnambiguousUser(UserConverter): class UnambiguousMember(MemberConverter): """ - Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. + Converts to a `disnake.Member`, but only if a mention, userID or a username (name#discrim) is provided. Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - async def convert(self, ctx: Context, argument: str) -> discord.Member: - """Convert the `argument` to a `discord.Member`.""" + async def convert(self, ctx: Context, argument: str) -> disnake.Member: + """Convert the `argument` to a `disnake.Member`.""" if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: @@ -588,10 +588,10 @@ if t.TYPE_CHECKING: OffTopicName = str # noqa: F811 ISODateTime = datetime # noqa: F811 HushDurationConverter = int # noqa: F811 - UnambiguousUser = discord.User # noqa: F811 - UnambiguousMember = discord.Member # noqa: F811 + UnambiguousUser = disnake.User # noqa: F811 + UnambiguousMember = disnake.Member # noqa: F811 Infraction = t.Optional[dict] # noqa: F811 Expiry = t.Union[Duration, ISODateTime] -MemberOrUser = t.Union[discord.Member, discord.User] +MemberOrUser = t.Union[disnake.Member, disnake.User] UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] diff --git a/bot/decorators.py b/bot/decorators.py index f4331264f..9ae98442c 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -4,9 +4,9 @@ import types import typing as t from contextlib import suppress -from discord import Member, NotFound -from discord.ext import commands -from discord.ext.commands import Cog, Context +from disnake import Member, NotFound +from disnake.ext import commands +from disnake.ext.commands import Cog, Context from bot.constants import Channels, DEBUG_MODE, RedirectOutput from bot.log import get_logger @@ -179,7 +179,7 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: Ensure the highest role of the invoking member is greater than that of the target member. If the condition fails, a warning is sent to the invoking context. A target which is not an - instance of discord.Member will always pass. + instance of disnake.Member will always pass. `member_arg` is the keyword name or position index of the parameter of the decorated command whose value is the target member. @@ -195,7 +195,7 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: target = function.get_arg_value(member_arg, bound_args) if not isinstance(target, Member): - log.trace("The target is not a discord.Member; skipping role hierarchy check.") + log.trace("The target is not a disnake.Member; skipping role hierarchy check.") return await func(*args, **kwargs) ctx = function.get_arg_value(1, bound_args) diff --git a/bot/errors.py b/bot/errors.py index 078b645f1..298e7ac2d 100644 --- a/bot/errors.py +++ b/bot/errors.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Hashable, TYPE_CHECKING, Union -from discord.ext.commands import ConversionError, Converter +from disnake.ext.commands import ConversionError, Converter if TYPE_CHECKING: from bot.converters import MemberOrUser diff --git a/bot/exts/backend/branding/_cog.py b/bot/exts/backend/branding/_cog.py index 0c5839a7a..a07e70d58 100644 --- a/bot/exts/backend/branding/_cog.py +++ b/bot/exts/backend/branding/_cog.py @@ -7,10 +7,10 @@ from enum import Enum from operator import attrgetter import async_timeout -import discord +import disnake from arrow import Arrow from async_rediscache import RedisCache -from discord.ext import commands, tasks +from disnake.ext import commands, tasks from bot.bot import Bot from bot.constants import Branding as BrandingConfig, Channels, Colours, Guild, MODERATION_ROLES @@ -42,7 +42,7 @@ def compound_hash(objects: t.Iterable[RemoteObject]) -> str: return "-".join(item.sha for item in objects) -def make_embed(title: str, description: str, *, success: bool) -> discord.Embed: +def make_embed(title: str, description: str, *, success: bool) -> disnake.Embed: """ Construct simple response embed. @@ -51,7 +51,7 @@ def make_embed(title: str, description: str, *, success: bool) -> discord.Embed: For both `title` and `description`, empty string are valid values ~ fields will be empty. """ colour = Colours.soft_green if success else Colours.soft_red - return discord.Embed(title=title[:256], description=description[:4096], colour=colour) + return disnake.Embed(title=title[:256], description=description[:4096], colour=colour) def extract_event_duration(event: Event) -> str: @@ -147,13 +147,13 @@ class Branding(commands.Cog): return False await self.bot.wait_until_guild_available() - pydis: discord.Guild = self.bot.get_guild(Guild.id) + pydis: disnake.Guild = self.bot.get_guild(Guild.id) timeout = 10 # Seconds. try: with async_timeout.timeout(timeout): # Raise after `timeout` seconds. await pydis.edit(**{asset_type.value: file}) - except discord.HTTPException: + except disnake.HTTPException: log.exception("Asset upload to Discord failed.") return False except asyncio.TimeoutError: @@ -277,7 +277,7 @@ class Branding(commands.Cog): log.debug(f"Sending event information event to channel: {channel_id} ({is_notification=}).") await self.bot.wait_until_guild_available() - channel: t.Optional[discord.TextChannel] = self.bot.get_channel(channel_id) + channel: t.Optional[disnake.TextChannel] = self.bot.get_channel(channel_id) if channel is None: log.warning(f"Cannot send event information: channel {channel_id} not found!") @@ -294,7 +294,7 @@ class Branding(commands.Cog): else: content = "Python Discord is entering a new event!" if is_notification else None - embed = discord.Embed(description=description[:4096], colour=discord.Colour.og_blurple()) + embed = disnake.Embed(description=description[:4096], colour=disnake.Colour.og_blurple()) embed.set_footer(text=duration[:4096]) await channel.send(content=content, embed=embed) @@ -573,7 +573,7 @@ class Branding(commands.Cog): await ctx.send(embed=resp) return - embed = discord.Embed(title="Current event calendar", colour=discord.Colour.og_blurple()) + embed = disnake.Embed(title="Current event calendar", colour=disnake.Colour.og_blurple()) # Because Discord embeds can only contain up to 25 fields, we only show the first 25. first_25 = list(available_events.items())[:25] diff --git a/bot/exts/backend/config_verifier.py b/bot/exts/backend/config_verifier.py index dc85a65a2..1ade2bce7 100644 --- a/bot/exts/backend/config_verifier.py +++ b/bot/exts/backend/config_verifier.py @@ -1,4 +1,4 @@ -from discord.ext.commands import Cog +from disnake.ext.commands import Cog from bot import constants from bot.bot import Bot diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index c79c7b2a7..953843a77 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,7 +1,7 @@ import difflib -from discord import Embed -from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors +from disnake import Embed +from disnake.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from sentry_sdk import push_scope from bot.api import ResponseCodeError diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py index 2d03cd580..040fb5d37 100644 --- a/bot/exts/backend/logging.py +++ b/bot/exts/backend/logging.py @@ -1,5 +1,5 @@ -from discord import Embed -from discord.ext.commands import Cog +from disnake import Embed +from disnake.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index 80f5750bc..d08e56077 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -1,8 +1,8 @@ from typing import Any, Dict -from discord import Member, Role, User -from discord.ext import commands -from discord.ext.commands import Cog, Context +from disnake import Member, Role, User +from disnake.ext import commands +from disnake.ext.commands import Cog, Context from bot import constants from bot.api import ResponseCodeError diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 45301b098..48ee3c842 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,8 +2,8 @@ import abc import typing as t from collections import namedtuple -from discord import Guild -from discord.ext.commands import Context +from disnake import Guild +from disnake.ext.commands import Context from more_itertools import chunked import bot diff --git a/bot/exts/events/code_jams/_channels.py b/bot/exts/events/code_jams/_channels.py index e8cf5f7bf..fc4693bd4 100644 --- a/bot/exts/events/code_jams/_channels.py +++ b/bot/exts/events/code_jams/_channels.py @@ -1,6 +1,6 @@ import typing as t -import discord +import disnake from bot.constants import Categories, Channels, Roles from bot.log import get_logger @@ -11,7 +11,7 @@ MAX_CHANNELS = 50 CATEGORY_NAME = "Code Jam" -async def _get_category(guild: discord.Guild) -> discord.CategoryChannel: +async def _get_category(guild: disnake.Guild) -> disnake.CategoryChannel: """ Return a code jam category. @@ -24,13 +24,13 @@ async def _get_category(guild: discord.Guild) -> discord.CategoryChannel: return await _create_category(guild) -async def _create_category(guild: discord.Guild) -> discord.CategoryChannel: +async def _create_category(guild: disnake.Guild) -> disnake.CategoryChannel: """Create a new code jam category and return it.""" log.info("Creating a new code jam category.") category_overwrites = { - guild.default_role: discord.PermissionOverwrite(read_messages=False), - guild.me: discord.PermissionOverwrite(read_messages=True) + guild.default_role: disnake.PermissionOverwrite(read_messages=False), + guild.me: disnake.PermissionOverwrite(read_messages=True) } category = await guild.create_category_channel( @@ -47,17 +47,17 @@ async def _create_category(guild: discord.Guild) -> discord.CategoryChannel: def _get_overwrites( - members: list[tuple[discord.Member, bool]], - guild: discord.Guild, -) -> dict[t.Union[discord.Member, discord.Role], discord.PermissionOverwrite]: + members: list[tuple[disnake.Member, bool]], + guild: disnake.Guild, +) -> dict[t.Union[disnake.Member, disnake.Role], disnake.PermissionOverwrite]: """Get code jam team channels permission overwrites.""" team_channel_overwrites = { - guild.default_role: discord.PermissionOverwrite(read_messages=False), - guild.get_role(Roles.code_jam_event_team): discord.PermissionOverwrite(read_messages=True) + guild.default_role: disnake.PermissionOverwrite(read_messages=False), + guild.get_role(Roles.code_jam_event_team): disnake.PermissionOverwrite(read_messages=True) } for member, _ in members: - team_channel_overwrites[member] = discord.PermissionOverwrite( + team_channel_overwrites[member] = disnake.PermissionOverwrite( read_messages=True ) @@ -65,10 +65,10 @@ def _get_overwrites( async def create_team_channel( - guild: discord.Guild, + guild: disnake.Guild, team_name: str, - members: list[tuple[discord.Member, bool]], - team_leaders: discord.Role + members: list[tuple[disnake.Member, bool]], + team_leaders: disnake.Role ) -> None: """Create the team's text channel.""" await _add_team_leader_roles(members, team_leaders) @@ -84,29 +84,29 @@ async def create_team_channel( ) -async def create_team_leader_channel(guild: discord.Guild, team_leaders: discord.Role) -> None: +async def create_team_leader_channel(guild: disnake.Guild, team_leaders: disnake.Role) -> None: """Create the Team Leader Chat channel for the Code Jam team leaders.""" - category: discord.CategoryChannel = guild.get_channel(Categories.summer_code_jam) + category: disnake.CategoryChannel = guild.get_channel(Categories.summer_code_jam) team_leaders_chat = await category.create_text_channel( name="team-leaders-chat", overwrites={ - guild.default_role: discord.PermissionOverwrite(read_messages=False), - team_leaders: discord.PermissionOverwrite(read_messages=True) + guild.default_role: disnake.PermissionOverwrite(read_messages=False), + team_leaders: disnake.PermissionOverwrite(read_messages=True) } ) await _send_status_update(guild, f"Created {team_leaders_chat.mention} in the {category} category.") -async def _send_status_update(guild: discord.Guild, message: str) -> None: +async def _send_status_update(guild: disnake.Guild, message: str) -> None: """Inform the events lead with a status update when the command is ran.""" - channel: discord.TextChannel = guild.get_channel(Channels.code_jam_planning) + channel: disnake.TextChannel = guild.get_channel(Channels.code_jam_planning) await channel.send(f"<@&{Roles.events_lead}>\n\n{message}") -async def _add_team_leader_roles(members: list[tuple[discord.Member, bool]], team_leaders: discord.Role) -> None: +async def _add_team_leader_roles(members: list[tuple[disnake.Member, bool]], team_leaders: disnake.Role) -> None: """Assign the team leader role to the team leaders.""" for member, is_leader in members: if is_leader: diff --git a/bot/exts/events/code_jams/_cog.py b/bot/exts/events/code_jams/_cog.py index 452199f5f..5cb11826d 100644 --- a/bot/exts/events/code_jams/_cog.py +++ b/bot/exts/events/code_jams/_cog.py @@ -3,9 +3,9 @@ import csv import typing as t from collections import defaultdict -import discord -from discord import Colour, Embed, Guild, Member -from discord.ext import commands +import disnake +from disnake import Colour, Embed, Guild, Member +from disnake.ext import commands from bot.bot import Bot from bot.constants import Emojis, Roles @@ -85,7 +85,7 @@ class CodeJams(commands.Cog): A confirmation message is displayed with the categories and channels to be deleted.. Pressing the added reaction deletes those channels. """ - def predicate_deletion_emoji_reaction(reaction: discord.Reaction, user: discord.User) -> bool: + def predicate_deletion_emoji_reaction(reaction: disnake.Reaction, user: disnake.User) -> bool: """Return True if the reaction :boom: was added by the context message author on this message.""" return ( reaction.message.id == message.id @@ -124,14 +124,14 @@ class CodeJams(commands.Cog): @staticmethod async def _build_confirmation_message( - categories: dict[discord.CategoryChannel, list[discord.abc.GuildChannel]] + categories: dict[disnake.CategoryChannel, list[disnake.abc.GuildChannel]] ) -> str: """Sends details of the channels to be deleted to the pasting service, and formats the confirmation message.""" - def channel_repr(channel: discord.abc.GuildChannel) -> str: + def channel_repr(channel: disnake.abc.GuildChannel) -> str: """Formats the channel name and ID and a readable format.""" return f"{channel.name} ({channel.id})" - def format_category_info(category: discord.CategoryChannel, channels: list[discord.abc.GuildChannel]) -> str: + def format_category_info(category: disnake.CategoryChannel, channels: list[disnake.abc.GuildChannel]) -> str: """Displays the category and the channels within it in a readable format.""" return f"{channel_repr(category)}:\n" + "\n".join(" - " + channel_repr(channel) for channel in channels) @@ -187,7 +187,7 @@ class CodeJams(commands.Cog): await old_team_channel.set_permissions(member, overwrite=None, reason=f"Participant moved to {new_team_name}") await new_team_channel.set_permissions( member, - overwrite=discord.PermissionOverwrite(read_messages=True), + overwrite=disnake.PermissionOverwrite(read_messages=True), reason=f"Participant moved from {old_team_channel.name}" ) @@ -212,16 +212,16 @@ class CodeJams(commands.Cog): await ctx.send(f"Removed the participant from `{self.team_name(channel)}`.") @staticmethod - def jam_categories(guild: Guild) -> list[discord.CategoryChannel]: + def jam_categories(guild: Guild) -> list[disnake.CategoryChannel]: """Get all the code jam team categories.""" return [category for category in guild.categories if category.name == _channels.CATEGORY_NAME] @staticmethod - def team_channel(guild: Guild, criterion: t.Union[str, Member]) -> t.Optional[discord.TextChannel]: + def team_channel(guild: Guild, criterion: t.Union[str, Member]) -> t.Optional[disnake.TextChannel]: """Get a team channel through either a participant or the team name.""" for category in CodeJams.jam_categories(guild): for channel in category.channels: - if isinstance(channel, discord.TextChannel): + if isinstance(channel, disnake.TextChannel): if ( # If it's a string. criterion == channel.name or criterion == CodeJams.team_name(channel) @@ -231,6 +231,6 @@ class CodeJams(commands.Cog): return channel @staticmethod - def team_name(channel: discord.TextChannel) -> str: + def team_name(channel: disnake.TextChannel) -> str: """Retrieves the team name from the given channel.""" return channel.name.replace("-", " ").title() diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py index 6cccf3680..e55ece910 100644 --- a/bot/exts/filters/antimalware.py +++ b/bot/exts/filters/antimalware.py @@ -1,8 +1,8 @@ import typing as t from os.path import splitext -from discord import Embed, Message, NotFound -from discord.ext.commands import Cog +from disnake import Embed, Message, NotFound +from disnake.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, Filter, URLs diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index bcd845a43..c887cf5fc 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -8,8 +8,8 @@ from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set import arrow -from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Cog +from disnake import Colour, Member, Message, NotFound, Object, TextChannel +from disnake.ext.commands import Cog from bot import rules from bot.bot import Bot @@ -195,7 +195,7 @@ class AntiSpam(Cog): result = await rule_function(message, messages_for_rule, rule_config) # If the rule returns `None`, that means the message didn't violate it. - # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` + # If it doesn't, it returns a tuple in the form `(str, Iterable[disnake.Member])` # which contains the reason for why the message violated the rule and # an iterable of all members that violated the rule. if result is not None: @@ -265,7 +265,7 @@ class AntiSpam(Cog): # In the rare case where we found messages matching the # spam filter across multiple channels, it is possible # that a single channel will only contain a single message - # to delete. If that should be the case, discord.py will + # to delete. If that should be the case, disnake will # use the "delete single message" endpoint instead of the # bulk delete endpoint, and the single message deletion # endpoint will complain if you give it that does not exist. diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py index a883ddf54..05910973a 100644 --- a/bot/exts/filters/filter_lists.py +++ b/bot/exts/filters/filter_lists.py @@ -1,8 +1,8 @@ import re from typing import Optional -from discord import Colour, Embed -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role +from disnake import Colour, Embed +from disnake.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role from bot import constants from bot.api import ResponseCodeError diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index f44b28125..e8c9bab62 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -6,15 +6,15 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union import arrow import dateutil.parser -import discord.errors +import disnake.errors import regex import tldextract from async_rediscache import RedisCache from botcore.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta -from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel -from discord.ext.commands import Cog -from discord.utils import escape_markdown +from disnake import Colour, HTTPException, Member, Message, NotFound, TextChannel +from disnake.ext.commands import Cog +from disnake.utils import escape_markdown from bot.api import ResponseCodeError from bot.bot import Bot @@ -63,14 +63,14 @@ AUTO_BAN_REASON = ( ) AUTO_BAN_DURATION = timedelta(days=4) -FilterMatch = Union[re.Match, dict, bool, List[discord.Embed]] +FilterMatch = Union[re.Match, dict, bool, List[disnake.Embed]] class Stats(NamedTuple): """Additional stats on a triggered filter to append to a mod log.""" message_content: str - additional_embeds: Optional[List[discord.Embed]] + additional_embeds: Optional[List[disnake.Embed]] class Filtering(Cog): @@ -339,7 +339,7 @@ class Filtering(Cog): match = result if match: - is_private = msg.channel.type is discord.ChannelType.private + is_private = msg.channel.type is disnake.ChannelType.private # If this is a filter (not a watchlist) and not in a DM, delete the message. if _filter["type"] == "filter" and not is_private: @@ -354,7 +354,7 @@ class Filtering(Cog): # In addition, to avoid sending two notifications to the user, the # logs, and mod_alert, we return if the message no longer exists. await msg.delete() - except discord.errors.NotFound: + except disnake.errors.NotFound: return # Notify the user if the filter specifies @@ -409,14 +409,14 @@ class Filtering(Cog): self, filter_name: str, _filter: Dict[str, Any], - msg: discord.Message, + msg: disnake.Message, stats: Stats, reason: Optional[str] = None, *, is_eval: bool = False, ) -> None: """Send a mod log for a triggered filter.""" - if msg.channel.type is discord.ChannelType.private: + if msg.channel.type is disnake.ChannelType.private: channel_str = "via DM" ping_everyone = False else: @@ -478,7 +478,7 @@ class Filtering(Cog): additional_embeds = [] for _, data in match.items(): reason = f"Reason: {data['reason']} | " if data.get('reason') else "" - embed = discord.Embed(description=( + embed = disnake.Embed(description=( f"**Members:**\n{data['members']}\n" f"**Active:**\n{data['active']}" )) @@ -626,7 +626,7 @@ class Filtering(Cog): return invite_data if invite_data else False @staticmethod - async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: + async def _has_rich_embed(msg: Message) -> Union[bool, List[disnake.Embed]]: """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" if msg.embeds: for embed in msg.embeds: @@ -662,7 +662,7 @@ class Filtering(Cog): """ try: await filtered_member.send(reason) - except discord.errors.Forbidden: + except disnake.errors.Forbidden: await channel.send(f"{filtered_member.mention} {reason}") def schedule_msg_delete(self, msg: dict) -> None: diff --git a/bot/exts/filters/security.py b/bot/exts/filters/security.py index fe3918423..bbb15542f 100644 --- a/bot/exts/filters/security.py +++ b/bot/exts/filters/security.py @@ -1,4 +1,4 @@ -from discord.ext.commands import Cog, Context, NoPrivateMessage +from disnake.ext.commands import Cog, Context, NoPrivateMessage from bot.bot import Bot from bot.log import get_logger diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filters/token_remover.py index 520283ba3..da42bb0aa 100644 --- a/bot/exts/filters/token_remover.py +++ b/bot/exts/filters/token_remover.py @@ -3,8 +3,8 @@ import binascii import re import typing as t -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog +from disnake import Colour, Message, NotFound +from disnake.ext.commands import Cog from bot import utils from bot.bot import Bot @@ -53,7 +53,7 @@ class Token(t.NamedTuple): class TokenRemover(Cog): - """Scans messages for potential discord.py bot tokens and removes them.""" + """Scans messages for potential Discord bot tokens and removes them.""" def __init__(self, bot: Bot): self.bot = bot diff --git a/bot/exts/filters/webhook_remover.py b/bot/exts/filters/webhook_remover.py index 96334317c..a5d51700c 100644 --- a/bot/exts/filters/webhook_remover.py +++ b/bot/exts/filters/webhook_remover.py @@ -1,7 +1,7 @@ import re -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog +from disnake import Colour, Message, NotFound +from disnake.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, Colours, Event, Icons diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py index c51656343..55196cd65 100644 --- a/bot/exts/fun/duck_pond.py +++ b/bot/exts/fun/duck_pond.py @@ -1,9 +1,9 @@ import asyncio from typing import Union -import discord -from discord import Color, Embed, Message, RawReactionActionEvent, TextChannel, errors -from discord.ext.commands import Cog, Context, command +import disnake +from disnake import Color, Embed, Message, RawReactionActionEvent, TextChannel, errors +from disnake.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot @@ -34,7 +34,7 @@ class DuckPond(Cog): try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: + except disnake.HTTPException: log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") @staticmethod @@ -67,7 +67,7 @@ class DuckPond(Cog): return False @staticmethod - def _is_duck_emoji(emoji: Union[str, discord.PartialEmoji, discord.Emoji]) -> bool: + def _is_duck_emoji(emoji: Union[str, disnake.PartialEmoji, disnake.Emoji]) -> bool: """Check if the emoji is a valid duck emoji.""" if isinstance(emoji, str): return emoji == "🦆" @@ -111,7 +111,7 @@ class DuckPond(Cog): username=message.author.display_name, avatar_url=message.author.display_avatar.url ) - except discord.HTTPException: + except disnake.HTTPException: log.exception("Failed to send an attachment to the webhook") async def locked_relay(self, message: Message) -> bool: @@ -133,7 +133,7 @@ class DuckPond(Cog): await message.add_reaction("✅") return True - def _payload_has_duckpond_emoji(self, emoji: discord.PartialEmoji) -> bool: + def _payload_has_duckpond_emoji(self, emoji: disnake.PartialEmoji) -> bool: """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" if emoji.is_unicode_emoji(): # For unicode PartialEmojis, the `name` attribute is just the string @@ -165,7 +165,7 @@ class DuckPond(Cog): if not self._payload_has_duckpond_emoji(payload.emoji): return - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + channel = disnake.utils.get(self.bot.get_all_channels(), id=payload.channel_id) if channel is None: return @@ -175,10 +175,10 @@ class DuckPond(Cog): try: message = await channel.fetch_message(payload.message_id) - except discord.NotFound: + except disnake.NotFound: return # Message was deleted. - member = discord.utils.get(message.guild.members, id=payload.user_id) + member = disnake.utils.get(message.guild.members, id=payload.user_id) if not member: return # Member left or wasn't in the cache. @@ -205,7 +205,7 @@ class DuckPond(Cog): if payload.guild_id != constants.Guild.id: return - channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + channel = disnake.utils.get(self.bot.get_all_channels(), id=payload.channel_id) if channel is None: return diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index 7df1d172d..d49f71320 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -2,9 +2,9 @@ import difflib from datetime import timedelta import arrow -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group, has_any_role -from discord.utils import sleep_until +from disnake import Colour, Embed +from disnake.ext.commands import Cog, Context, group, has_any_role +from disnake.utils import sleep_until from bot.api import ResponseCodeError from bot.bot import Bot diff --git a/bot/exts/help_channels/_caches.py b/bot/exts/help_channels/_caches.py index 8d45c2466..f4eaf3291 100644 --- a/bot/exts/help_channels/_caches.py +++ b/bot/exts/help_channels/_caches.py @@ -1,24 +1,24 @@ from async_rediscache import RedisCache # This dictionary maps a help channel to the time it was claimed -# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] +# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] claim_times = RedisCache(namespace="HelpChannels.claim_times") # This cache tracks which channels are claimed by which members. -# RedisCache[discord.TextChannel.id, t.Union[discord.User.id, discord.Member.id]] +# RedisCache[disnake.TextChannel.id, t.Union[disnake.User.id, disnake.Member.id]] claimants = RedisCache(namespace="HelpChannels.help_channel_claimants") # Stores the timestamp of the last message from the claimant of a help channel -# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] +# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] claimant_last_message_times = RedisCache(namespace="HelpChannels.claimant_last_message_times") # This cache maps a help channel to the timestamp of the last non-claimant message. # This cache being empty for a given help channel indicates the question is unanswered. -# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] +# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] non_claimant_last_message_times = RedisCache(namespace="HelpChannels.non_claimant_last_message_times") # This cache maps a help channel to original question message in same channel. -# RedisCache[discord.TextChannel.id, discord.Message.id] +# RedisCache[disnake.TextChannel.id, disnake.Message.id] question_messages = RedisCache(namespace="HelpChannels.question_messages") # This cache keeps track of the dynamic message ID for @@ -26,10 +26,10 @@ question_messages = RedisCache(namespace="HelpChannels.question_messages") dynamic_message = RedisCache(namespace="HelpChannels.dynamic_message") # This cache keeps track of who has help-dms on. -# RedisCache[discord.User.id, bool] +# RedisCache[disnake.User.id, bool] help_dm = RedisCache(namespace="HelpChannels.help_dm") # This cache tracks member who are participating and opted in to help channel dms. # serialise the set as a comma separated string to allow usage with redis -# RedisCache[discord.TextChannel.id, str[set[discord.User.id]]] +# RedisCache[disnake.TextChannel.id, str[set[disnake.User.id]]] session_participants = RedisCache(namespace="HelpChannels.session_participants") diff --git a/bot/exts/help_channels/_channel.py b/bot/exts/help_channels/_channel.py index d9cebf215..3c4eaa2b2 100644 --- a/bot/exts/help_channels/_channel.py +++ b/bot/exts/help_channels/_channel.py @@ -4,7 +4,7 @@ from datetime import timedelta from enum import Enum import arrow -import discord +import disnake from arrow import Arrow import bot @@ -31,7 +31,7 @@ class ClosingReason(Enum): CLEANUP = "auto.cleanup" -def get_category_channels(category: discord.CategoryChannel) -> t.Iterable[discord.TextChannel]: +def get_category_channels(category: disnake.CategoryChannel) -> t.Iterable[disnake.TextChannel]: """Yield the text channels of the `category` in an unsorted manner.""" log.trace(f"Getting text channels in the category '{category}' ({category.id}).") @@ -41,7 +41,7 @@ def get_category_channels(category: discord.CategoryChannel) -> t.Iterable[disco yield channel -async def get_closing_time(channel: discord.TextChannel, init_done: bool) -> t.Tuple[Arrow, ClosingReason]: +async def get_closing_time(channel: disnake.TextChannel, init_done: bool) -> t.Tuple[Arrow, ClosingReason]: """ Return the time at which the given help `channel` should be closed along with the reason. @@ -116,12 +116,12 @@ async def get_in_use_time(channel_id: int) -> t.Optional[timedelta]: return arrow.utcnow() - claimed -def is_excluded_channel(channel: discord.abc.GuildChannel) -> bool: +def is_excluded_channel(channel: disnake.abc.GuildChannel) -> bool: """Check if a channel should be excluded from the help channel system.""" - return not isinstance(channel, discord.TextChannel) or channel.id in EXCLUDED_CHANNELS + return not isinstance(channel, disnake.TextChannel) or channel.id in EXCLUDED_CHANNELS -async def move_to_bottom(channel: discord.TextChannel, category_id: int, **options) -> None: +async def move_to_bottom(channel: disnake.TextChannel, category_id: int, **options) -> None: """ Move the `channel` to the bottom position of `category` and edit channel attributes. @@ -130,8 +130,8 @@ async def move_to_bottom(channel: discord.TextChannel, category_id: int, **optio really ends up at the bottom of the category. If `options` are provided, the channel will be edited after the move is completed. This is the - same order of operations that `discord.TextChannel.edit` uses. For information on available - options, see the documentation on `discord.TextChannel.edit`. While possible, position-related + same order of operations that `disnake.TextChannel.edit` uses. For information on available + options, see the documentation on `disnake.TextChannel.edit`. While possible, position-related options should be avoided, as it may interfere with the category move we perform. """ # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. @@ -161,7 +161,7 @@ async def move_to_bottom(channel: discord.TextChannel, category_id: int, **optio await channel.edit(**options) -async def ensure_cached_claimant(channel: discord.TextChannel) -> None: +async def ensure_cached_claimant(channel: disnake.TextChannel) -> None: """ Ensure there is a claimant cached for each help channel. diff --git a/bot/exts/help_channels/_cog.py b/bot/exts/help_channels/_cog.py index d3d70e252..fc55fa1df 100644 --- a/bot/exts/help_channels/_cog.py +++ b/bot/exts/help_channels/_cog.py @@ -5,9 +5,9 @@ from datetime import timedelta from operator import attrgetter import arrow -import discord -import discord.abc -from discord.ext import commands +import disnake +import disnake.abc +from disnake.ext import commands from bot import constants from bot.bot import Bot @@ -66,16 +66,16 @@ class HelpChannels(commands.Cog): self.bot = bot self.scheduler = scheduling.Scheduler(self.__class__.__name__) - self.guild: discord.Guild = None - self.cooldown_role: discord.Role = None + self.guild: disnake.Guild = None + self.cooldown_role: disnake.Role = None # Categories - self.available_category: discord.CategoryChannel = None - self.in_use_category: discord.CategoryChannel = None - self.dormant_category: discord.CategoryChannel = None + self.available_category: disnake.CategoryChannel = None + self.in_use_category: disnake.CategoryChannel = None + self.dormant_category: disnake.CategoryChannel = None # Queues - self.channel_queue: asyncio.Queue[discord.TextChannel] = None + self.channel_queue: asyncio.Queue[disnake.TextChannel] = None self.name_queue: t.Deque[str] = None # Notifications @@ -84,7 +84,7 @@ class HelpChannels(commands.Cog): self.last_running_low_notification = arrow.get('1815-12-10T18:00:00.00000+00:00') self.dynamic_message: t.Optional[int] = None - self.available_help_channels: t.Set[discord.TextChannel] = set() + self.available_help_channels: t.Set[disnake.TextChannel] = set() # Asyncio stuff self.queue_tasks: t.List[asyncio.Task] = [] @@ -104,7 +104,7 @@ class HelpChannels(commands.Cog): @lock.lock_arg(NAMESPACE, "message", attrgetter("channel.id")) @lock.lock_arg(NAMESPACE, "message", attrgetter("author.id")) @lock.lock_arg(f"{NAMESPACE}.unclaim", "message", attrgetter("author.id"), wait=True) - async def claim_channel(self, message: discord.Message) -> None: + async def claim_channel(self, message: disnake.Message) -> None: """ Claim the channel in which the question `message` was sent. @@ -116,7 +116,7 @@ class HelpChannels(commands.Cog): try: await self.move_to_in_use(message.channel) - except discord.DiscordServerError: + except disnake.DiscordServerError: try: await message.channel.send( "The bot encountered a Discord API error while trying to move this channel, please try again later." @@ -133,14 +133,14 @@ class HelpChannels(commands.Cog): self.bot.stats.incr("help.failed_claims.500_on_move") return - embed = discord.Embed( + embed = disnake.Embed( description=f"Channel claimed by {message.author.mention}.", color=constants.Colours.bright_green, ) await message.channel.send(embed=embed) - # Handle odd edge case of `message.author` not being a `discord.Member` (see bot#1839) - if not isinstance(message.author, discord.Member): + # Handle odd edge case of `message.author` not being a `disnake.Member` (see bot#1839) + if not isinstance(message.author, disnake.Member): log.debug(f"{message.author} ({message.author.id}) isn't a member. Not giving cooldown role or sending DM.") else: await members.handle_role_change(message.author, message.author.add_roles, self.cooldown_role) @@ -189,7 +189,7 @@ class HelpChannels(commands.Cog): return queue - async def create_dormant(self) -> t.Optional[discord.TextChannel]: + async def create_dormant(self) -> t.Optional[disnake.TextChannel]: """ Create and return a new channel in the Dormant category. @@ -234,12 +234,12 @@ class HelpChannels(commands.Cog): May only be invoked by the channel's claimant or by staff. """ - # Don't use a discord.py check because the check needs to fail silently. + # Don't use a disnake check because the check needs to fail silently. if await self.close_check(ctx): log.info(f"Close command invoked by {ctx.author} in #{ctx.channel}.") await self.unclaim_channel(ctx.channel, closed_on=_channel.ClosingReason.COMMAND) - async def get_available_candidate(self) -> discord.TextChannel: + async def get_available_candidate(self) -> disnake.TextChannel: """ Return a dormant channel to turn into an available channel. @@ -313,7 +313,7 @@ class HelpChannels(commands.Cog): self.dormant_category = await channel_utils.get_or_fetch_channel( constants.Categories.help_dormant ) - except discord.HTTPException: + except disnake.HTTPException: log.exception("Failed to get a category; cog will be removed") self.bot.remove_cog(self.qualified_name) @@ -355,7 +355,7 @@ class HelpChannels(commands.Cog): log.info("Cog is ready!") - async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None: + async def move_idle_channel(self, channel: disnake.TextChannel, has_task: bool = True) -> None: """ Make the `channel` dormant if idle or schedule the move if still active. @@ -416,7 +416,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() - async def move_to_dormant(self, channel: discord.TextChannel) -> None: + async def move_to_dormant(self, channel: disnake.TextChannel) -> None: """Make the `channel` dormant.""" log.info(f"Moving #{channel} ({channel.id}) to the Dormant category.") await _channel.move_to_bottom( @@ -425,7 +425,7 @@ class HelpChannels(commands.Cog): ) log.trace(f"Sending dormant message for #{channel} ({channel.id}).") - embed = discord.Embed( + embed = disnake.Embed( description=_message.DORMANT_MSG.format( dormant=self.dormant_category.name, available=self.available_category.name, @@ -439,7 +439,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() @lock.lock_arg(f"{NAMESPACE}.unclaim", "channel") - async def unclaim_channel(self, channel: discord.TextChannel, *, closed_on: _channel.ClosingReason) -> None: + async def unclaim_channel(self, channel: disnake.TextChannel, *, closed_on: _channel.ClosingReason) -> None: """ Unclaim an in-use help `channel` to make it dormant. @@ -462,7 +462,7 @@ class HelpChannels(commands.Cog): async def _unclaim_channel( self, - channel: discord.TextChannel, + channel: disnake.TextChannel, claimant_id: t.Optional[int], closed_on: _channel.ClosingReason ) -> None: @@ -488,7 +488,7 @@ class HelpChannels(commands.Cog): if closed_on == _channel.ClosingReason.COMMAND: self.scheduler.cancel(channel.id) - async def move_to_in_use(self, channel: discord.TextChannel) -> None: + async def move_to_in_use(self, channel: disnake.TextChannel) -> None: """Make a channel in-use and schedule it to be made dormant.""" log.info(f"Moving #{channel} ({channel.id}) to the In Use category.") @@ -504,7 +504,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() @commands.Cog.listener() - async def on_message(self, message: discord.Message) -> None: + async def on_message(self, message: disnake.Message) -> None: """Move an available channel to the In Use category and replace it with a dormant one.""" if message.author.bot: return # Ignore messages sent by bots. @@ -520,7 +520,7 @@ class HelpChannels(commands.Cog): await _message.update_message_caches(message) @commands.Cog.listener() - async def on_message_delete(self, msg: discord.Message) -> None: + async def on_message_delete(self, msg: disnake.Message) -> None: """ Reschedule an in-use channel to become dormant sooner if the channel is empty. @@ -542,7 +542,7 @@ class HelpChannels(commands.Cog): delay = constants.HelpChannels.deleted_idle_minutes * 60 self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel)) - async def wait_for_dormant_channel(self) -> discord.TextChannel: + async def wait_for_dormant_channel(self) -> disnake.TextChannel: """Wait for a dormant channel to become available in the queue and return it.""" log.trace("Waiting for a dormant channel.") @@ -569,7 +569,7 @@ class HelpChannels(commands.Cog): await self.bot.http.edit_message( constants.Channels.how_to_get_help, self.dynamic_message, content=available_channels, files=None ) - except discord.NotFound: + except disnake.NotFound: pass else: return @@ -593,7 +593,7 @@ class HelpChannels(commands.Cog): @lock.lock_arg(NAMESPACE, "message", attrgetter("channel.id")) @lock.lock_arg(NAMESPACE, "message", attrgetter("author.id")) - async def notify_session_participants(self, message: discord.Message) -> None: + async def notify_session_participants(self, message: disnake.Message) -> None: """ Check if the message author meets the requirements to be notified. @@ -615,7 +615,7 @@ class HelpChannels(commands.Cog): if message.author.id not in session_participants: session_participants.add(message.author.id) - embed = discord.Embed( + embed = disnake.Embed( title="Currently Helping", description=f"You're currently helping in {message.channel.mention}", color=constants.Colours.bright_green, @@ -625,7 +625,7 @@ class HelpChannels(commands.Cog): try: await message.author.send(embed=embed) - except discord.Forbidden: + except disnake.Forbidden: log.trace( f"Failed to send helpdm message to {message.author.id}. DMs Closed/Blocked. " "Removing user from helpdm." diff --git a/bot/exts/help_channels/_message.py b/bot/exts/help_channels/_message.py index 7ceed9b4d..e08043694 100644 --- a/bot/exts/help_channels/_message.py +++ b/bot/exts/help_channels/_message.py @@ -2,7 +2,7 @@ import textwrap import typing as t import arrow -import discord +import disnake from arrow import Arrow import bot @@ -41,7 +41,7 @@ through our guide for **[asking a good question]({ASKING_GUIDE_URL})**. """ -async def update_message_caches(message: discord.Message) -> None: +async def update_message_caches(message: disnake.Message) -> None: """Checks the source of new content in a help channel and updates the appropriate cache.""" channel = message.channel @@ -62,18 +62,18 @@ async def update_message_caches(message: discord.Message) -> None: await _caches.non_claimant_last_message_times.set(channel.id, timestamp) -async def get_last_message(channel: discord.TextChannel) -> t.Optional[discord.Message]: +async def get_last_message(channel: disnake.TextChannel) -> t.Optional[disnake.Message]: """Return the last message sent in the channel or None if no messages exist.""" log.trace(f"Getting the last message in #{channel} ({channel.id}).") try: return await channel.history(limit=1).next() # noqa: B305 - except discord.NoMoreItems: + except disnake.NoMoreItems: log.debug(f"No last message available; #{channel} ({channel.id}) has no messages.") return None -async def is_empty(channel: discord.TextChannel) -> bool: +async def is_empty(channel: disnake.TextChannel) -> bool: """Return True if there's an AVAILABLE_MSG and the messages leading up are bot messages.""" log.trace(f"Checking if #{channel} ({channel.id}) is empty.") @@ -92,13 +92,13 @@ async def is_empty(channel: discord.TextChannel) -> bool: return False -async def dm_on_open(message: discord.Message) -> None: +async def dm_on_open(message: disnake.Message) -> None: """ DM claimant with a link to the claimed channel's first message, with a 100 letter preview of the message. Does nothing if the user has DMs disabled. """ - embed = discord.Embed( + embed = disnake.Embed( title="Help channel opened", description=f"You claimed {message.channel.mention}.", colour=bot.constants.Colours.bright_green, @@ -118,7 +118,7 @@ async def dm_on_open(message: discord.Message) -> None: try: await message.author.send(embed=embed) log.trace(f"Sent DM to {message.author.id} after claiming help channel.") - except discord.errors.Forbidden: + except disnake.errors.Forbidden: log.trace( f"Ignoring to send DM to {message.author.id} after claiming help channel: DMs disabled." ) @@ -146,7 +146,7 @@ async def notify_none_remaining(last_notification: Arrow) -> t.Optional[Arrow]: log.trace("Notifying about lack of channels.") mentions = " ".join(f"<@&{role}>" for role in constants.HelpChannels.notify_none_remaining_roles) - allowed_roles = [discord.Object(id_) for id_ in constants.HelpChannels.notify_none_remaining_roles] + allowed_roles = [disnake.Object(id_) for id_ in constants.HelpChannels.notify_none_remaining_roles] channel = bot.instance.get_channel(constants.HelpChannels.notify_channel) if channel is None: @@ -157,7 +157,7 @@ async def notify_none_remaining(last_notification: Arrow) -> t.Optional[Arrow]: f"{mentions} A new available help channel is needed but there " "are no more dormant ones. Consider freeing up some in-use channels manually by " f"using the `{constants.Bot.prefix}dormant` command within the channels.", - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + allowed_mentions=disnake.AllowedMentions(everyone=False, roles=allowed_roles) ) except Exception: # Handle it here cause this feature isn't critical for the functionality of the system. @@ -213,18 +213,18 @@ async def notify_running_low(number_of_channels_left: int, last_notification: Ar return arrow.utcnow() -async def pin(message: discord.Message) -> None: +async def pin(message: disnake.Message) -> None: """Pin an initial question `message` and store it in a cache.""" if await pin_wrapper(message.id, message.channel, pin=True): await _caches.question_messages.set(message.channel.id, message.id) -async def send_available_message(channel: discord.TextChannel) -> None: +async def send_available_message(channel: disnake.TextChannel) -> None: """Send the available message by editing a dormant message or sending a new message.""" channel_info = f"#{channel} ({channel.id})" log.trace(f"Sending available message in {channel_info}.") - embed = discord.Embed( + embed = disnake.Embed( color=constants.Colours.bright_green, description=AVAILABLE_MSG, ) @@ -240,7 +240,7 @@ async def send_available_message(channel: discord.TextChannel) -> None: await channel.send(embed=embed) -async def unpin(channel: discord.TextChannel) -> None: +async def unpin(channel: disnake.TextChannel) -> None: """Unpin the initial question message sent in `channel`.""" msg_id = await _caches.question_messages.pop(channel.id) if msg_id is None: @@ -249,19 +249,19 @@ async def unpin(channel: discord.TextChannel) -> None: await pin_wrapper(msg_id, channel, pin=False) -def _match_bot_embed(message: t.Optional[discord.Message], description: str) -> bool: +def _match_bot_embed(message: t.Optional[disnake.Message], description: str) -> bool: """Return `True` if the bot's `message`'s embed description matches `description`.""" if not message or not message.embeds: return False bot_msg_desc = message.embeds[0].description - if bot_msg_desc is discord.Embed.Empty: + if bot_msg_desc is disnake.Embed.Empty: log.trace("Last message was a bot embed but it was empty.") return False return message.author == bot.instance.user and bot_msg_desc.strip() == description.strip() -async def pin_wrapper(msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool: +async def pin_wrapper(msg_id: int, channel: disnake.TextChannel, *, pin: bool) -> bool: """ Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. @@ -277,7 +277,7 @@ async def pin_wrapper(msg_id: int, channel: discord.TextChannel, *, pin: bool) - try: await func(channel.id, msg_id) - except discord.HTTPException as e: + except disnake.HTTPException as e: if e.code == 10008: log.debug(f"Message {msg_id} in {channel_str} doesn't exist; can't {verb}.") else: diff --git a/bot/exts/help_channels/_name.py b/bot/exts/help_channels/_name.py index a9d9b2df1..50b250cb5 100644 --- a/bot/exts/help_channels/_name.py +++ b/bot/exts/help_channels/_name.py @@ -3,7 +3,7 @@ import typing as t from collections import deque from pathlib import Path -import discord +import disnake from bot import constants from bot.exts.help_channels._channel import MAX_CHANNELS_PER_CATEGORY, get_category_channels @@ -12,7 +12,7 @@ from bot.log import get_logger log = get_logger(__name__) -def create_name_queue(*categories: discord.CategoryChannel) -> deque: +def create_name_queue(*categories: disnake.CategoryChannel) -> deque: """ Return a queue of food names to use for creating new channels. @@ -50,7 +50,7 @@ def _get_names() -> t.List[str]: return all_names[:count] -def _get_used_names(*categories: discord.CategoryChannel) -> t.Set[str]: +def _get_used_names(*categories: disnake.CategoryChannel) -> t.Set[str]: """Return names which are already being used by channels in `categories`.""" log.trace("Getting channel names which are already being used.") diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py index f2f29020f..68eb52a59 100644 --- a/bot/exts/info/code_snippets.py +++ b/bot/exts/info/code_snippets.py @@ -4,9 +4,9 @@ import textwrap from typing import Any from urllib.parse import quote_plus -import discord +import disnake from aiohttp import ClientResponseError -from discord.ext.commands import Cog +from disnake.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels @@ -241,7 +241,7 @@ class CodeSnippets(Cog): return '\n'.join(map(lambda x: x[1], sorted(all_snippets))) @Cog.listener() - async def on_message(self, message: discord.Message) -> None: + async def on_message(self, message: disnake.Message) -> None: """Checks if the message has a snippet link, removes the embed, then sends the snippet contents.""" if message.author.bot: return @@ -255,7 +255,7 @@ class CodeSnippets(Cog): if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: try: await message.edit(suppress=True) - except discord.NotFound: + except disnake.NotFound: # Don't send snippets if the original message was deleted. return diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index a859d8cef..cf8c7d0be 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -1,9 +1,9 @@ import time from typing import Optional -import discord -from discord import Message, RawMessageUpdateEvent -from discord.ext.commands import Cog +import disnake +from disnake import Message, RawMessageUpdateEvent +from disnake.ext.commands import Cog from bot import constants from bot.bot import Bot @@ -62,9 +62,9 @@ class CodeBlockCog(Cog, name="Code Block"): self.codeblock_message_ids = {} @staticmethod - def create_embed(instructions: str) -> discord.Embed: + def create_embed(instructions: str) -> disnake.Embed: """Return an embed which displays code block formatting `instructions`.""" - return discord.Embed(description=instructions) + return disnake.Embed(description=instructions) async def get_sent_instructions(self, payload: RawMessageUpdateEvent) -> Optional[Message]: """ @@ -78,11 +78,11 @@ class CodeBlockCog(Cog, name="Code Block"): try: return await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) - except discord.NotFound: + except disnake.NotFound: log.debug("Could not find instructions message; it was probably deleted.") return None - def is_on_cooldown(self, channel: discord.TextChannel) -> bool: + def is_on_cooldown(self, channel: disnake.TextChannel) -> bool: """ Return True if an embed was sent too recently for `channel`. @@ -93,7 +93,7 @@ class CodeBlockCog(Cog, name="Code Block"): cooldown = constants.CodeBlock.cooldown_seconds return (time.time() - self.channel_cooldowns.get(channel.id, 0)) < cooldown - def is_valid_channel(self, channel: discord.TextChannel) -> bool: + def is_valid_channel(self, channel: disnake.TextChannel) -> bool: """Return True if `channel` is a help channel, may be on a cooldown, or is whitelisted.""" log.trace(f"Checking if #{channel} qualifies for code block detection.") return ( @@ -102,7 +102,7 @@ class CodeBlockCog(Cog, name="Code Block"): or channel.id in constants.CodeBlock.channel_whitelist ) - async def send_instructions(self, message: discord.Message, instructions: str) -> None: + async def send_instructions(self, message: disnake.Message, instructions: str) -> None: """ Send an embed with `instructions` on fixing an incorrect code block in a `message`. @@ -119,7 +119,7 @@ class CodeBlockCog(Cog, name="Code Block"): # Increase amount of codeblock correction in stats self.bot.stats.incr("codeblock_corrections") - def should_parse(self, message: discord.Message) -> bool: + def should_parse(self, message: disnake.Message) -> bool: """ Return True if `message` should be parsed. @@ -185,5 +185,5 @@ class CodeBlockCog(Cog, name="Code Block"): else: log.info("Message edited but still has invalid code blocks; editing instructions.") await bot_message.edit(embed=self.create_embed(instructions)) - except discord.NotFound: + except disnake.NotFound: log.debug("Could not find instructions message; it was probably deleted.") diff --git a/bot/exts/info/doc/_batch_parser.py b/bot/exts/info/doc/_batch_parser.py index c27f28eac..487a0fd21 100644 --- a/bot/exts/info/doc/_batch_parser.py +++ b/bot/exts/info/doc/_batch_parser.py @@ -7,7 +7,7 @@ from contextlib import suppress from operator import attrgetter from typing import Deque, Dict, List, NamedTuple, Optional, Union -import discord +import disnake from bs4 import BeautifulSoup import bot @@ -48,7 +48,7 @@ class StaleInventoryNotifier: if await self.symbol_counter.increment_for(doc_item) < 3: self._warned_urls.add(doc_item.url) await self._init_task - embed = discord.Embed( + embed = disnake.Embed( description=f"Doc item `{doc_item.symbol_id=}` present in loaded documentation inventories " f"not found on [site]({doc_item.url}), inventories may need to be refreshed." ) diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 4dc5276d9..77fc61389 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -9,8 +9,8 @@ from types import SimpleNamespace from typing import Dict, NamedTuple, Optional, Tuple, Union import aiohttp -import discord -from discord.ext import commands +import disnake +from disnake.ext import commands from bot.api import ResponseCodeError from bot.bot import Bot @@ -275,7 +275,7 @@ class DocCog(commands.Cog): return "Unable to parse the requested symbol." return markdown - async def create_symbol_embed(self, symbol_name: str) -> Optional[discord.Embed]: + async def create_symbol_embed(self, symbol_name: str) -> Optional[disnake.Embed]: """ Attempt to scrape and fetch the data for the given `symbol_name`, and build an embed from its contents. @@ -304,8 +304,8 @@ class DocCog(commands.Cog): else: footer_text = "" - embed = discord.Embed( - title=discord.utils.escape_markdown(symbol_name), + embed = disnake.Embed( + title=disnake.utils.escape_markdown(symbol_name), url=f"{doc_item.url}#{doc_item.symbol_id}", description=await self.get_symbol_markdown(doc_item) ) @@ -331,9 +331,9 @@ class DocCog(commands.Cog): !docs getdoc aiohttp.ClientSession """ if not symbol_name: - inventory_embed = discord.Embed( + inventory_embed = disnake.Embed( title=f"All inventories (`{len(self.base_urls)}` total)", - colour=discord.Colour.blue() + colour=disnake.Colour.blue() ) lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) @@ -355,7 +355,7 @@ class DocCog(commands.Cog): # Make sure that we won't cause a ghost-ping by deleting the message if not (ctx.message.mentions or ctx.message.role_mentions): - with suppress(discord.NotFound): + with suppress(disnake.NotFound): await ctx.message.delete() await error_message.delete() @@ -449,7 +449,7 @@ class DocCog(commands.Cog): if removed := ", ".join(old_inventories - new_inventories): removed = "- " + removed - embed = discord.Embed( + embed = disnake.Embed( title="Inventories refreshed", description=f"```diff\n{added}\n{removed}```" if added or removed else "" ) diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py index 864e7edd2..29d73c564 100644 --- a/bot/exts/info/help.py +++ b/bot/exts/info/help.py @@ -6,8 +6,8 @@ from collections import namedtuple from contextlib import suppress from typing import List, Optional, Union -from discord import ButtonStyle, Colour, Embed, Emoji, Interaction, PartialEmoji, ui -from discord.ext.commands import Bot, Cog, Command, CommandError, Context, DisabledCommand, Group, HelpCommand +from disnake import ButtonStyle, Colour, Embed, Emoji, Interaction, PartialEmoji, ui +from disnake.ext.commands import Bot, Cog, Command, CommandError, Context, DisabledCommand, Group, HelpCommand from rapidfuzz import fuzz, process from rapidfuzz.utils import default_process diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index e616b9208..44a9b8f1a 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -6,9 +6,9 @@ from textwrap import shorten from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union import rapidfuzz -from discord import AllowedMentions, Colour, Embed, Guild, Message, Role -from discord.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role -from discord.utils import escape_markdown +from disnake import AllowedMentions, Colour, Embed, Guild, Message, Role +from disnake.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role +from disnake.utils import escape_markdown from bot import constants from bot.api import ResponseCodeError @@ -466,7 +466,7 @@ class Information(Cog): async def send_raw_content(self, ctx: Context, message: Message, json: bool = False) -> None: """ - Send information about the raw API response for a `discord.Message`. + Send information about the raw API response for a `disnake.Message`. If `json` is True, send the information in a copy-pasteable Python format. """ diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 67866620b..08c693581 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -3,8 +3,8 @@ from email.parser import HeaderParser from io import StringIO from typing import Dict, Optional, Tuple -from discord import Colour, Embed -from discord.ext.commands import Cog, Context, command +from disnake import Colour, Embed +from disnake.ext.commands import Cog, Context, command from bot.bot import Bot from bot.constants import Keys diff --git a/bot/exts/info/pypi.py b/bot/exts/info/pypi.py index dacf7bc12..0a7705eb0 100644 --- a/bot/exts/info/pypi.py +++ b/bot/exts/info/pypi.py @@ -3,9 +3,9 @@ import random import re from contextlib import suppress -from discord import Embed, NotFound -from discord.ext.commands import Cog, Context, command -from discord.utils import escape_markdown +from disnake import Embed, NotFound +from disnake.ext.commands import Cog, Context, command +from disnake.utils import escape_markdown from bot.bot import Bot from bot.constants import Colours, NEGATIVE_REPLIES, RedirectOutput diff --git a/bot/exts/info/python_news.py b/bot/exts/info/python_news.py index 2fad9d2ab..7603b402b 100644 --- a/bot/exts/info/python_news.py +++ b/bot/exts/info/python_news.py @@ -2,11 +2,11 @@ import re import typing as t from datetime import date, datetime -import discord +import disnake import feedparser from bs4 import BeautifulSoup -from discord.ext.commands import Cog -from discord.ext.tasks import loop +from disnake.ext.commands import Cog +from disnake.ext.tasks import loop from bot import constants from bot.bot import Bot @@ -40,7 +40,7 @@ class PythonNews(Cog): def __init__(self, bot: Bot): self.bot = bot self.webhook_names = {} - self.webhook: t.Optional[discord.Webhook] = None + self.webhook: t.Optional[disnake.Webhook] = None scheduling.create_task(self.get_webhook_names(), event_loop=self.bot.loop) scheduling.create_task(self.get_webhook_and_channel(), event_loop=self.bot.loop) @@ -119,7 +119,7 @@ class PythonNews(Cog): continue # Build an embed and send a webhook - embed = discord.Embed( + embed = disnake.Embed( title=self.escape_markdown(new["title"]), description=self.escape_markdown(new["summary"]), timestamp=new_datetime, @@ -189,7 +189,7 @@ class PythonNews(Cog): link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) # Build an embed and send a message to the webhook - embed = discord.Embed( + embed = disnake.Embed( title=self.escape_markdown(thread_information["subject"]), description=content[:1000] + f"... [continue reading]({link})" if len(content) > 1000 else content, timestamp=new_date, diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index e3e7029ca..6305a9842 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -2,8 +2,8 @@ import inspect from pathlib import Path from typing import Optional, Tuple, Union -from discord import Embed -from discord.ext import commands +from disnake import Embed +from disnake.ext import commands from bot.bot import Bot from bot.constants import URLs diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py index 4d8bb645e..08422b38e 100644 --- a/bot/exts/info/stats.py +++ b/bot/exts/info/stats.py @@ -1,8 +1,8 @@ import string -from discord import Member, Message -from discord.ext.commands import Cog, Context -from discord.ext.tasks import loop +from disnake import Member, Message +from disnake.ext.commands import Cog, Context +from disnake.ext.tasks import loop from bot.bot import Bot from bot.constants import Categories, Channels, Guild diff --git a/bot/exts/info/subscribe.py b/bot/exts/info/subscribe.py index eff0c13b8..0f285e0cb 100644 --- a/bot/exts/info/subscribe.py +++ b/bot/exts/info/subscribe.py @@ -4,9 +4,9 @@ import typing as t from dataclasses import dataclass import arrow -import discord -from discord.ext import commands -from discord.interactions import Interaction +import disnake +from disnake.ext import commands +from disnake.interactions import Interaction from bot import constants from bot.bot import Bot @@ -58,10 +58,10 @@ DELETE_MESSAGE_AFTER = 300 # Seconds log = get_logger(__name__) -class RoleButtonView(discord.ui.View): +class RoleButtonView(disnake.ui.View): """A list of SingleRoleButtons to show to the member.""" - def __init__(self, member: discord.Member): + def __init__(self, member: disnake.Member): super().__init__() self.interaction_owner = member @@ -76,12 +76,12 @@ class RoleButtonView(discord.ui.View): return True -class SingleRoleButton(discord.ui.Button): +class SingleRoleButton(disnake.ui.Button): """A button that adds or removes a role from the member depending on it's current state.""" - ADD_STYLE = discord.ButtonStyle.success - REMOVE_STYLE = discord.ButtonStyle.red - UNAVAILABLE_STYLE = discord.ButtonStyle.secondary + ADD_STYLE = disnake.ButtonStyle.success + REMOVE_STYLE = disnake.ButtonStyle.red + UNAVAILABLE_STYLE = disnake.ButtonStyle.secondary LABEL_FORMAT = "{action} role {role_name}." CUSTOM_ID_FORMAT = "subscribe-{role_id}" @@ -104,7 +104,7 @@ class SingleRoleButton(discord.ui.Button): async def callback(self, interaction: Interaction) -> None: """Update the member's role and change button text to reflect current text.""" - if isinstance(interaction.user, discord.User): + if isinstance(interaction.user, disnake.User): log.trace("User %s is not a member", interaction.user) await interaction.message.delete() self.view.stop() @@ -117,7 +117,7 @@ class SingleRoleButton(discord.ui.Button): await members.handle_role_change( interaction.user, interaction.user.remove_roles if self.assigned else interaction.user.add_roles, - discord.Object(self.role.role_id), + disnake.Object(self.role.role_id), ) self.assigned = not self.assigned @@ -133,7 +133,7 @@ class SingleRoleButton(discord.ui.Button): self.label = self.LABEL_FORMAT.format(action="Remove" if self.assigned else "Add", role_name=self.role.name) try: await interaction.message.edit(view=self.view) - except discord.NotFound: + except disnake.NotFound: log.debug("Subscribe message for %s removed before buttons could be updated", interaction.user) self.view.stop() @@ -145,7 +145,7 @@ class Subscribe(commands.Cog): self.bot = bot self.init_task = scheduling.create_task(self.init_cog(), event_loop=self.bot.loop) self.assignable_roles: list[AssignableRole] = [] - self.guild: discord.Guild = None + self.guild: disnake.Guild = None async def init_cog(self) -> None: """Initialise the cog by resolving the role IDs in ASSIGNABLE_ROLES to role names.""" diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index f66237c8e..baeb21adb 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -6,10 +6,10 @@ import time from pathlib import Path from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union -import discord +import disnake import frontmatter -from discord import Embed, Member -from discord.ext.commands import Cog, Context, group +from disnake import Embed, Member +from disnake.ext.commands import Cog, Context, group from bot import constants from bot.bot import Bot @@ -81,7 +81,7 @@ class Tag: self.content = post.content self.metadata = post.metadata self._restricted_to: set[int] = set(self.metadata.get("restricted_to", ())) - self._cooldowns: dict[discord.TextChannel, float] = {} + self._cooldowns: dict[disnake.TextChannel, float] = {} @property def embed(self) -> Embed: @@ -90,18 +90,18 @@ class Tag: embed.description = self.content return embed - def accessible_by(self, member: discord.Member) -> bool: + def accessible_by(self, member: disnake.Member) -> bool: """Check whether `member` can access the tag.""" return bool( not self._restricted_to or self._restricted_to & {role.id for role in member.roles} ) - def on_cooldown_in(self, channel: discord.TextChannel) -> bool: + def on_cooldown_in(self, channel: disnake.TextChannel) -> bool: """Check whether the tag is on cooldown in `channel`.""" return self._cooldowns.get(channel, float("-inf")) > time.time() - def set_cooldown_for(self, channel: discord.TextChannel) -> None: + def set_cooldown_for(self, channel: disnake.TextChannel) -> None: """Set the tag to be on cooldown in `channel` for `constants.Cooldowns.tags` seconds.""" self._cooldowns[channel] = time.time() + constants.Cooldowns.tags @@ -344,7 +344,7 @@ class Tags(Cog): return result_lines - def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]: + def accessible_tags_in_group(self, group: str, user: disnake.Member) -> list[str]: """Return a formatted list of tags in `group`, that are accessible by `user`.""" return sorted( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py index cb6836258..2e274b23b 100644 --- a/bot/exts/moderation/clean.py +++ b/bot/exts/moderation/clean.py @@ -7,10 +7,10 @@ from datetime import datetime from itertools import takewhile from typing import Callable, Iterable, Literal, Optional, TYPE_CHECKING, Union -from discord import Colour, Message, NotFound, TextChannel, User, errors -from discord.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role -from discord.ext.commands.converter import TextChannelConverter -from discord.ext.commands.errors import BadArgument +from disnake import Colour, Message, NotFound, TextChannel, User, errors +from disnake.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role +from disnake.ext.commands.converter import TextChannelConverter +from disnake.ext.commands.errors import BadArgument from bot.bot import Bot from bot.constants import Channels, CleanMessages, Colours, Emojis, Event, Icons, MODERATION_ROLES @@ -459,7 +459,7 @@ class Clean(Cog): regex: Optional[Regex] = None, bots_only: Optional[bool] = False, *, - channels: CleanChannels = None # "Optional" with discord.py silently ignores incorrect input. + channels: CleanChannels = None # "Optional" with disnake silently ignores incorrect input. ) -> None: """ Commands for cleaning messages in channels. diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py index 178be734d..58e049d4f 100644 --- a/bot/exts/moderation/defcon.py +++ b/bot/exts/moderation/defcon.py @@ -8,9 +8,9 @@ import arrow from aioredis import RedisError from async_rediscache import RedisCache from dateutil.relativedelta import relativedelta -from discord import Colour, Embed, Forbidden, Member, TextChannel, User -from discord.ext import tasks -from discord.ext.commands import Cog, Context, group, has_any_role +from disnake import Colour, Embed, Forbidden, Member, TextChannel, User +from disnake.ext import tasks +from disnake.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Event, Icons, MODERATION_ROLES, Roles diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py index 566422e29..28e131eb4 100644 --- a/bot/exts/moderation/dm_relay.py +++ b/bot/exts/moderation/dm_relay.py @@ -1,5 +1,5 @@ -import discord -from discord.ext.commands import Cog, Context, command, has_any_role +import disnake +from disnake.ext.commands import Cog, Context, command, has_any_role from bot.bot import Bot from bot.constants import Emojis, MODERATION_ROLES @@ -17,7 +17,7 @@ class DMRelay(Cog): self.bot = bot @command(aliases=("relay", "dr")) - async def dmrelay(self, ctx: Context, user: discord.User, limit: int = 100) -> None: + async def dmrelay(self, ctx: Context, user: disnake.User, limit: int = 100) -> None: """Relays the direct message history between the bot and given user.""" log.trace(f"Relaying DMs with {user.name} ({user.id})") diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index b579416a6..c4c03e546 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -4,9 +4,9 @@ from datetime import datetime from enum import Enum from typing import Optional -import discord +import disnake from async_rediscache import RedisCache -from discord.ext.commands import Cog, Context, MessageConverter, MessageNotFound +from disnake.ext.commands import Cog, Context, MessageConverter, MessageNotFound from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Guild, Roles, Webhooks @@ -52,10 +52,10 @@ ALL_SIGNALS: set[str] = {signal.value for signal in Signal} # An embed coupled with an optional file to be dispatched # If the file is not None, the embed attempts to show it in its body -FileEmbed = tuple[discord.Embed, Optional[discord.File]] +FileEmbed = tuple[disnake.Embed, Optional[disnake.File]] -async def download_file(attachment: discord.Attachment) -> Optional[discord.File]: +async def download_file(attachment: disnake.Attachment) -> Optional[disnake.File]: """ Download & return `attachment` file. @@ -65,13 +65,13 @@ async def download_file(attachment: discord.Attachment) -> Optional[discord.File log.debug(f"Attempting to download attachment: {attachment.filename}") try: return await attachment.to_file() - except (discord.NotFound, discord.Forbidden) as exc: + except (disnake.NotFound, disnake.Forbidden) as exc: log.debug(f"Failed to download attachment: {exc}") except Exception: log.exception("Failed to download attachment") -async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> FileEmbed: +async def make_embed(incident: disnake.Message, outcome: Signal, actioned_by: disnake.Member) -> FileEmbed: """ Create an embed representation of `incident` for the #incidents-archive channel. @@ -97,7 +97,7 @@ async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: di colour = Colours.soft_red footer = f"Rejected by {actioned_by}" - embed = discord.Embed( + embed = disnake.Embed( description=incident.content, timestamp=datetime.utcnow(), colour=colour, @@ -113,12 +113,12 @@ async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: di else: embed.set_author(name="[Failed to relay attachment]", url=attachment.proxy_url) # Embed links the file else: - file = discord.utils.MISSING + file = disnake.utils.MISSING return embed, file -def is_incident(message: discord.Message) -> bool: +def is_incident(message: disnake.Message) -> bool: """True if `message` qualifies as an incident, False otherwise.""" conditions = ( message.channel.id == Channels.incidents, # Message sent in #incidents @@ -129,12 +129,12 @@ def is_incident(message: discord.Message) -> bool: return all(conditions) -def own_reactions(message: discord.Message) -> set[str]: +def own_reactions(message: disnake.Message) -> set[str]: """Get the set of reactions placed on `message` by the bot itself.""" return {str(reaction.emoji) for reaction in message.reactions if reaction.me} -def has_signals(message: discord.Message) -> bool: +def has_signals(message: disnake.Message) -> bool: """True if `message` already has all `Signal` reactions, False otherwise.""" return ALL_SIGNALS.issubset(own_reactions(message)) @@ -167,9 +167,9 @@ def shorten_text(text: str) -> str: return text -async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[discord.Embed]: +async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[disnake.Embed]: """ - Create an embedded representation of the discord message link contained in the incident report. + Create an embedded representation of the Discord message link contained in the incident report. The Embed would contain the following information --> Author: @Jason Terror ♦ (736234578745884682) @@ -179,23 +179,23 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d embed = None try: - message: discord.Message = await MessageConverter().convert(ctx, message_link) + message: disnake.Message = await MessageConverter().convert(ctx, message_link) except MessageNotFound: mod_logs_channel = ctx.bot.get_channel(Channels.mod_log) - last_100_logs: list[discord.Message] = await mod_logs_channel.history(limit=100).flatten() + last_100_logs: list[disnake.Message] = await mod_logs_channel.history(limit=100).flatten() for log_entry in last_100_logs: if not log_entry.embeds: continue - log_embed: discord.Embed = log_entry.embeds[0] + log_embed: disnake.Embed = log_entry.embeds[0] if ( log_embed.author.name == "Message deleted" and f"[Jump to message]({message_link})" in log_embed.description ): - embed = discord.Embed( - colour=discord.Colour.dark_gold(), + embed = disnake.Embed( + colour=disnake.Colour.dark_gold(), title="Deleted Message Link", description=( f"Found <#{Channels.mod_log}> entry for deleted message: " @@ -203,12 +203,12 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d ) ) if not embed: - embed = discord.Embed( - colour=discord.Colour.red(), + embed = disnake.Embed( + colour=disnake.Colour.red(), title="Bad Message Link", description=f"Message {message_link} not found." ) - except discord.DiscordException as e: + except disnake.DiscordException as e: log.exception(f"Failed to make message link embed for '{message_link}', raised exception: {e}") else: channel = message.channel @@ -219,12 +219,12 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d ) return - embed = discord.Embed( - colour=discord.Colour.gold(), + embed = disnake.Embed( + colour=disnake.Colour.gold(), description=( f"**Author:** {format_user(message.author)}\n" f"**Channel:** {channel.mention} ({channel.category}" - f"{f'/#{channel.parent.name} - ' if isinstance(channel, discord.Thread) else '/#'}" + f"{f'/#{channel.parent.name} - ' if isinstance(channel, disnake.Thread) else '/#'}" f"{channel.name})\n" ), timestamp=message.created_at @@ -242,7 +242,7 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d return embed -async def add_signals(incident: discord.Message) -> None: +async def add_signals(incident: disnake.Message) -> None: """ Add `Signal` member emoji to `incident` as reactions. @@ -257,7 +257,7 @@ async def add_signals(incident: discord.Message) -> None: log.trace(f"Adding reaction: {signal_emoji}") try: await incident.add_reaction(signal_emoji.value) - except discord.NotFound as e: + except disnake.NotFound as e: if e.code != 10008: raise @@ -300,7 +300,7 @@ class Incidents(Cog): """ # This dictionary maps an incident report message to the message link embed's ID - # RedisCache[discord.Message.id, discord.Message.id] + # RedisCache[disnake.Message.id, disnake.Message.id] message_link_embeds_cache = RedisCache() def __init__(self, bot: Bot) -> None: @@ -319,7 +319,7 @@ class Incidents(Cog): try: self.incidents_webhook = await self.bot.fetch_webhook(Webhooks.incidents) - except discord.HTTPException: + except disnake.HTTPException: log.error(f"Failed to fetch incidents webhook with id `{Webhooks.incidents}`.") async def crawl_incidents(self) -> None: @@ -335,7 +335,7 @@ class Incidents(Cog): Behaviour is configured by: `CRAWL_LIMIT`, `CRAWL_SLEEP`. """ await self.bot.wait_until_guild_available() - incidents: discord.TextChannel = self.bot.get_channel(Channels.incidents) + incidents: disnake.TextChannel = self.bot.get_channel(Channels.incidents) log.debug(f"Crawling messages in #incidents: {CRAWL_LIMIT=}, {CRAWL_SLEEP=}") async for message in incidents.history(limit=CRAWL_LIMIT): @@ -353,7 +353,7 @@ class Incidents(Cog): log.debug("Crawl task finished!") - async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: + async def archive(self, incident: disnake.Message, outcome: Signal, actioned_by: disnake.Member) -> bool: """ Relay an embed representation of `incident` to the #incidents-archive channel. @@ -392,7 +392,7 @@ class Incidents(Cog): log.trace("Message archived successfully!") return True - def make_confirmation_task(self, incident: discord.Message, timeout: int = 5) -> asyncio.Task: + def make_confirmation_task(self, incident: disnake.Message, timeout: int = 5) -> asyncio.Task: """ Create a task to wait `timeout` seconds for `incident` to be deleted. @@ -401,13 +401,13 @@ class Incidents(Cog): """ log.trace(f"Confirmation task will wait {timeout=} seconds for {incident.id=} to be deleted") - def check(payload: discord.RawReactionActionEvent) -> bool: + def check(payload: disnake.RawReactionActionEvent) -> bool: return payload.message_id == incident.id coroutine = self.bot.wait_for(event="raw_message_delete", check=check, timeout=timeout) return scheduling.create_task(coroutine, event_loop=self.bot.loop) - async def process_event(self, reaction: str, incident: discord.Message, member: discord.Member) -> None: + async def process_event(self, reaction: str, incident: disnake.Message, member: disnake.Member) -> None: """ Process a `reaction_add` event in #incidents. @@ -430,7 +430,7 @@ class Incidents(Cog): log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals") try: await incident.remove_reaction(reaction, member) - except discord.NotFound: + except disnake.NotFound: log.trace("Couldn't remove reaction because the reaction or its message was deleted") return @@ -440,7 +440,7 @@ class Incidents(Cog): log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") try: await incident.remove_reaction(reaction, member) - except discord.NotFound: + except disnake.NotFound: log.trace("Couldn't remove reaction because the reaction or its message was deleted") return @@ -461,7 +461,7 @@ class Incidents(Cog): log.trace("Deleting original message") try: await incident.delete() - except discord.NotFound: + except disnake.NotFound: log.trace("Couldn't delete message because it was already deleted") log.trace(f"Awaiting deletion confirmation: {timeout=} seconds") @@ -476,9 +476,9 @@ class Incidents(Cog): # Deletes the message link embeds found in cache from the channel and cache. await self.delete_msg_link_embed(incident.id) - async def resolve_message(self, message_id: int) -> Optional[discord.Message]: + async def resolve_message(self, message_id: int) -> Optional[disnake.Message]: """ - Get `discord.Message` for `message_id` from cache, or API. + Get `disnake.Message` for `message_id` from cache, or API. We first look into the local cache to see if the message is present. @@ -491,7 +491,7 @@ class Incidents(Cog): """ await self.bot.wait_until_guild_available() # First make sure that the cache is ready log.trace(f"Resolving message for: {message_id=}") - message: Optional[discord.Message] = self.bot._connection._get_message(message_id) + message: Optional[disnake.Message] = self.bot._connection._get_message(message_id) if message is not None: log.trace("Message was found in cache") @@ -500,7 +500,7 @@ class Incidents(Cog): log.trace("Message not found, attempting to fetch") try: message = await self.bot.get_channel(Channels.incidents).fetch_message(message_id) - except discord.NotFound: + except disnake.NotFound: log.trace("Message doesn't exist, it was likely already relayed") except Exception: log.exception(f"Failed to fetch message {message_id}!") @@ -509,7 +509,7 @@ class Incidents(Cog): return message @Cog.listener() - async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None: + async def on_raw_reaction_add(self, payload: disnake.RawReactionActionEvent) -> None: """ Pre-process `payload` and pass it to `process_event` if appropriate. @@ -521,11 +521,11 @@ class Incidents(Cog): Next, we acquire `event_lock` - to prevent racing, events are processed one at a time. - Once we have the lock, the `discord.Message` object for this event must be resolved. + Once we have the lock, the `disnake.Message` object for this event must be resolved. If the lock was previously held by an event which successfully relayed the incident, this will fail and we abort the current event. - Finally, with both the lock and the `discord.Message` instance in our hands, we delegate + Finally, with both the lock and the `disnake.Message` instance in our hands, we delegate to `process_event` to handle the event. The justification for using a raw listener is the need to receive events for messages @@ -554,7 +554,7 @@ class Incidents(Cog): log.trace("Releasing event lock") @Cog.listener() - async def on_message(self, message: discord.Message) -> None: + async def on_message(self, message: disnake.Message) -> None: """ Pass `message` to `add_signals` and `extract_message_links` if it satisfies `is_incident`. @@ -575,7 +575,7 @@ class Incidents(Cog): await self.send_message_link_embeds(embed_list, message, self.incidents_webhook) @Cog.listener() - async def on_raw_message_delete(self, payload: discord.RawMessageDeleteEvent) -> None: + async def on_raw_message_delete(self, payload: disnake.RawMessageDeleteEvent) -> None: """ Delete message link embeds for `payload.message_id`. @@ -584,7 +584,7 @@ class Incidents(Cog): if self.incidents_webhook: await self.delete_msg_link_embed(payload.message_id) - async def extract_message_links(self, message: discord.Message) -> Optional[list[discord.Embed]]: + async def extract_message_links(self, message: disnake.Message) -> Optional[list[disnake.Embed]]: """ Check if there's any message links in the text content. @@ -615,8 +615,8 @@ class Incidents(Cog): async def send_message_link_embeds( self, webhook_embed_list: list, - message: discord.Message, - webhook: discord.Webhook, + message: disnake.Message, + webhook: disnake.Webhook, ) -> Optional[int]: """ Send message link embeds to #incidents channel. @@ -634,7 +634,7 @@ class Incidents(Cog): avatar_url=message.author.display_avatar.url, wait=True, ) - except discord.DiscordException: + except disnake.DiscordException: log.exception( f"Failed to send message link embed {message.id} to #incidents." ) @@ -651,7 +651,7 @@ class Incidents(Cog): if webhook_msg_id: try: await self.incidents_webhook.delete_message(webhook_msg_id) - except discord.errors.NotFound: + except disnake.errors.NotFound: log.trace(f"Incidents message link embed (`{webhook_msg_id}`) has already been deleted, skipping.") await self.message_link_embeds_cache.delete(message_id) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 47b639421..d51009358 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -5,8 +5,8 @@ from gettext import ngettext import arrow import dateutil.parser -import discord -from discord.ext.commands import Context +import disnake +from disnake.ext.commands import Context from bot import constants from bot.api import ResponseCodeError @@ -101,7 +101,7 @@ class InfractionScheduler: # Allowing mod log since this is a passive action that should be logged. try: await apply_coro - except discord.HTTPException as e: + except disnake.HTTPException as e: # When user joined and then right after this left again before action completed, this can't apply roles if e.code == 10007 or e.status == 404: log.info( @@ -203,7 +203,7 @@ class InfractionScheduler: if expiry: # Schedule the expiration of the infraction. self.schedule_expiration(infraction) - except discord.HTTPException as e: + except disnake.HTTPException as e: # Accordingly display that applying the infraction failed. # Don't use ctx.message.author; antispam only patches ctx.author. confirm_msg = ":x: failed to apply" @@ -212,7 +212,7 @@ class InfractionScheduler: log_title = "failed to apply" log_msg = f"Failed to apply {' '.join(infr_type.split('_'))} infraction #{id_} to {user}" - if isinstance(e, discord.Forbidden): + if isinstance(e, disnake.Forbidden): log.warning(f"{log_msg}: bot lacks permissions.") elif e.code == 10007 or e.status == 404: log.info( @@ -402,11 +402,11 @@ class InfractionScheduler: raise ValueError( f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" ) - except discord.Forbidden: + except disnake.Forbidden: log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" log_content = mod_role.mention - except discord.HTTPException as e: + except disnake.HTTPException as e: if e.code == 10007 or e.status == 404: log.info( f"Can't pardon {infraction['type']} for user {infraction['user']} because user left the guild." diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 4df833ffb..a464f7c87 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -1,8 +1,8 @@ import typing as t from datetime import datetime -import discord -from discord.ext.commands import Context +import disnake +from disnake.ext.commands import Context from bot.api import ResponseCodeError from bot.bot import Bot @@ -83,7 +83,7 @@ async def post_infraction( dm_sent: bool = False, ) -> t.Optional[dict]: """Posts an infraction to the API.""" - if isinstance(user, (discord.Member, discord.User)) and user.bot: + if isinstance(user, (disnake.Member, disnake.User)) and user.bot: log.trace(f"Posting of {infr_type} infraction for {user} to the API aborted. User is a bot.") raise InvalidInfractedUserError(user) @@ -182,7 +182,7 @@ async def notify_infraction( text += INFRACTION_APPEAL_SERVER_FOOTER if infr_type.lower() == 'ban' else INFRACTION_APPEAL_MODMAIL_FOOTER - embed = discord.Embed( + embed = disnake.Embed( description=text, colour=Colours.soft_red ) @@ -211,7 +211,7 @@ async def notify_pardon( """DM a user about their pardoned infraction and return True if the DM is successful.""" log.trace(f"Sending {user} a DM about their pardoned infraction.") - embed = discord.Embed( + embed = disnake.Embed( description=content, colour=Colours.soft_green ) @@ -221,7 +221,7 @@ async def notify_pardon( return await send_private_embed(user, embed) -async def send_private_embed(user: MemberOrUser, embed: discord.Embed) -> bool: +async def send_private_embed(user: MemberOrUser, embed: disnake.Embed) -> bool: """ A helper method for sending an embed to a user's DMs. @@ -230,7 +230,7 @@ async def send_private_embed(user: MemberOrUser, embed: discord.Embed) -> bool: try: await user.send(embed=embed) return True - except (discord.HTTPException, discord.Forbidden, discord.NotFound): + except (disnake.HTTPException, disnake.Forbidden, disnake.NotFound): log.debug( f"Infraction-related information could not be sent to user {user} ({user.id}). " "The user either could not be retrieved or probably disabled their DMs." diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index af42ab1b8..5ff56abde 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -1,10 +1,10 @@ import textwrap import typing as t -import discord -from discord import Member -from discord.ext import commands -from discord.ext.commands import Context, command +import disnake +from disnake import Member +from disnake.ext import commands +from disnake.ext.commands import Context, command from bot import constants from bot.bot import Bot @@ -35,8 +35,8 @@ class Infractions(InfractionScheduler, commands.Cog): super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning", "voice_mute"}) self.category = "Moderation" - self._muted_role = discord.Object(constants.Roles.muted) - self._voice_verified_role = discord.Object(constants.Roles.voice_verified) + self._muted_role = disnake.Object(constants.Roles.muted) + self._voice_verified_role = disnake.Object(constants.Roles.voice_verified) @commands.Cog.listener() async def on_member_join(self, member: Member) -> None: @@ -123,7 +123,7 @@ class Infractions(InfractionScheduler, commands.Cog): log.error("Failed to apply ban to user %d", user.id) return - # Calling commands directly skips Discord.py's convertors, so we need to convert args manually. + # Calling commands directly skips disnake's convertors, so we need to convert args manually. clean_time = await Age().convert(ctx, "1h") log_url = await clean_cog._clean_messages( @@ -494,7 +494,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def pardon_mute( self, user_id: int, - guild: discord.Guild, + guild: disnake.Guild, reason: t.Optional[str], *, notify: bool = True @@ -525,16 +525,16 @@ class Infractions(InfractionScheduler, commands.Cog): return log_text - async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + async def pardon_ban(self, user_id: int, guild: disnake.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: """Remove a user's ban on the Discord guild and return a log dict.""" - user = discord.Object(user_id) + user = disnake.Object(user_id) log_text = {} self.mod_log.ignore(Event.member_unban, user_id) try: await guild.unban(user, reason=reason) - except discord.NotFound: + except disnake.NotFound: log.info(f"Failed to unban user {user_id}: no active ban found on Discord") log_text["Note"] = "No active ban found on Discord." @@ -543,7 +543,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def pardon_voice_mute( self, user_id: int, - guild: discord.Guild, + guild: disnake.Guild, *, notify: bool = True ) -> t.Dict[str, str]: @@ -597,7 +597,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def cog_command_error(self, ctx: Context, error: Exception) -> None: """Send a notification to the invoking context on a Union failure.""" if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters or Member in error.converters: + if disnake.User in error.converters or Member in error.converters: await ctx.send(str(error.errors[0])) error.handled = True diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index dda3fadae..875e8ef34 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -1,10 +1,10 @@ import textwrap import typing as t -import discord -from discord.ext import commands -from discord.ext.commands import Context -from discord.utils import escape_markdown +import disnake +from disnake.ext import commands +from disnake.ext.commands import Context +from disnake.utils import escape_markdown from bot import constants from bot.bot import Bot @@ -53,9 +53,9 @@ class ModManagement(commands.Cog): await ctx.send_help(ctx.command) return - embed = discord.Embed( + embed = disnake.Embed( title=f"Infraction #{infraction['id']}", - colour=discord.Colour.orange() + colour=disnake.Colour.orange() ) await self.send_infraction_list(ctx, embed, [infraction]) @@ -199,7 +199,7 @@ class ModManagement(commands.Cog): await self.mod_log.send_log_message( icon_url=constants.Icons.pencil, - colour=discord.Colour.og_blurple(), + colour=disnake.Colour.og_blurple(), title="Infraction edited", thumbnail=thumbnail, text=textwrap.dedent(f""" @@ -217,21 +217,21 @@ class ModManagement(commands.Cog): async def infraction_search_group(self, ctx: Context, query: t.Union[UnambiguousUser, Snowflake, str]) -> None: """Searches for infractions in the database.""" if isinstance(query, int): - await self.search_user(ctx, discord.Object(query)) + await self.search_user(ctx, disnake.Object(query)) elif isinstance(query, str): await self.search_reason(ctx, query) else: await self.search_user(ctx, query) @infraction_search_group.command(name="user", aliases=("member", "userid")) - async def search_user(self, ctx: Context, user: t.Union[MemberOrUser, discord.Object]) -> None: + async def search_user(self, ctx: Context, user: t.Union[MemberOrUser, disnake.Object]) -> None: """Search for infractions by member.""" infraction_list = await self.bot.api_client.get( 'bot/infractions/expanded', params={'user__id': str(user.id)} ) - if isinstance(user, (discord.Member, discord.User)): + if isinstance(user, (disnake.Member, disnake.User)): user_str = escape_markdown(str(user)) else: if infraction_list: @@ -241,9 +241,9 @@ class ModManagement(commands.Cog): user_str = str(user.id) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = discord.Embed( + embed = disnake.Embed( title=f"Infractions for {user_str} ({formatted_infraction_count} total)", - colour=discord.Colour.orange() + colour=disnake.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -256,9 +256,9 @@ class ModManagement(commands.Cog): ) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = discord.Embed( + embed = disnake.Embed( title=f"Infractions matching `{reason}` ({formatted_infraction_count} total)", - colour=discord.Colour.orange() + colour=disnake.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -296,9 +296,9 @@ class ModManagement(commands.Cog): ) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = discord.Embed( + embed = disnake.Embed( title=f"Infractions by {actor} ({formatted_infraction_count} total)", - colour=discord.Colour.orange() + colour=disnake.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -321,7 +321,7 @@ class ModManagement(commands.Cog): async def send_infraction_list( self, ctx: Context, - embed: discord.Embed, + embed: disnake.Embed, infractions: t.Iterable[t.Dict[str, t.Any]] ) -> None: """Send a paginated embed of infractions for the specified user.""" @@ -410,7 +410,7 @@ class ModManagement(commands.Cog): async def cog_command_error(self, ctx: Context, error: commands.CommandError) -> None: """Handles errors for commands within this cog.""" if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters: + if disnake.User in error.converters: await ctx.send(str(error.errors[0])) error.handled = True diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 3f1bffd76..1d357d441 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -4,9 +4,9 @@ import textwrap import typing as t from pathlib import Path -from discord import Embed, Member -from discord.ext.commands import Cog, Context, command, has_any_role -from discord.utils import escape_markdown +from disnake import Embed, Member +from disnake.ext.commands import Cog, Context, command, has_any_role +from disnake.utils import escape_markdown from bot import constants from bot.bot import Bot diff --git a/bot/exts/moderation/metabase.py b/bot/exts/moderation/metabase.py index ce9c220b3..482d49b83 100644 --- a/bot/exts/moderation/metabase.py +++ b/bot/exts/moderation/metabase.py @@ -8,7 +8,7 @@ import arrow from aiohttp.client_exceptions import ClientResponseError from arrow import Arrow from async_rediscache import RedisCache -from discord.ext.commands import Cog, Context, group, has_any_role +from disnake.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Metabase as MetabaseConfig, Roles diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index 32ea0dc6a..a96638e53 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -5,13 +5,13 @@ import typing as t from datetime import datetime, timezone from itertools import zip_longest -import discord +import disnake from dateutil.relativedelta import relativedelta from deepdiff import DeepDiff -from discord import Colour, Message, Thread -from discord.abc import GuildChannel -from discord.ext.commands import Cog, Context -from discord.utils import escape_markdown +from disnake import Colour, Message, Thread +from disnake.abc import GuildChannel +from disnake.ext.commands import Cog, Context +from disnake.utils import escape_markdown from bot.bot import Bot from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs @@ -21,7 +21,7 @@ from bot.utils.messages import format_user log = get_logger(__name__) -GUILD_CHANNEL = t.Union[discord.CategoryChannel, discord.TextChannel, discord.VoiceChannel] +GUILD_CHANNEL = t.Union[disnake.CategoryChannel, disnake.TextChannel, disnake.VoiceChannel] CHANNEL_CHANGES_UNSUPPORTED = ("permissions",) CHANNEL_CHANGES_SUPPRESSED = ("_overwrites", "position") @@ -45,7 +45,7 @@ class ModLog(Cog, name="ModLog"): async def upload_log( self, - messages: t.Iterable[discord.Message], + messages: t.Iterable[disnake.Message], actor_id: int, attachments: t.Iterable[t.List[str]] = None ) -> str: @@ -83,22 +83,22 @@ class ModLog(Cog, name="ModLog"): async def send_log_message( self, icon_url: t.Optional[str], - colour: t.Union[discord.Colour, int], + colour: t.Union[disnake.Colour, int], title: t.Optional[str], text: str, - thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, + thumbnail: t.Optional[t.Union[str, disnake.Asset]] = None, channel_id: int = Channels.mod_log, ping_everyone: bool = False, - files: t.Optional[t.List[discord.File]] = None, + files: t.Optional[t.List[disnake.File]] = None, content: t.Optional[str] = None, - additional_embeds: t.Optional[t.List[discord.Embed]] = None, + additional_embeds: t.Optional[t.List[disnake.Embed]] = None, timestamp_override: t.Optional[datetime] = None, footer: t.Optional[str] = None, ) -> Context: """Generate log embed and send to logging channel.""" await self.bot.wait_until_guild_available() # Truncate string directly here to avoid removing newlines - embed = discord.Embed( + embed = disnake.Embed( description=text[:4093] + "..." if len(text) > 4096 else text ) @@ -143,10 +143,10 @@ class ModLog(Cog, name="ModLog"): if channel.guild.id != GuildConstant.id: return - if isinstance(channel, discord.CategoryChannel): + if isinstance(channel, disnake.CategoryChannel): title = "Category created" message = f"{channel.name} (`{channel.id}`)" - elif isinstance(channel, discord.VoiceChannel): + elif isinstance(channel, disnake.VoiceChannel): title = "Voice channel created" if channel.category: @@ -169,14 +169,14 @@ class ModLog(Cog, name="ModLog"): if channel.guild.id != GuildConstant.id: return - if isinstance(channel, discord.CategoryChannel): + if isinstance(channel, disnake.CategoryChannel): title = "Category deleted" - elif isinstance(channel, discord.VoiceChannel): + elif isinstance(channel, disnake.VoiceChannel): title = "Voice channel deleted" else: title = "Text channel deleted" - if channel.category and not isinstance(channel, discord.CategoryChannel): + if channel.category and not isinstance(channel, disnake.CategoryChannel): message = f"{channel.category}/{channel.name} (`{channel.id}`)" else: message = f"{channel.name} (`{channel.id}`)" @@ -256,7 +256,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_create(self, role: discord.Role) -> None: + async def on_guild_role_create(self, role: disnake.Role) -> None: """Log role create event to mod log.""" if role.guild.id != GuildConstant.id: return @@ -267,7 +267,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_delete(self, role: discord.Role) -> None: + async def on_guild_role_delete(self, role: disnake.Role) -> None: """Log role delete event to mod log.""" if role.guild.id != GuildConstant.id: return @@ -278,7 +278,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_update(self, before: discord.Role, after: discord.Role) -> None: + async def on_guild_role_update(self, before: disnake.Role, after: disnake.Role) -> None: """Log role update event to mod log.""" if before.guild.id != GuildConstant.id: return @@ -331,7 +331,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_update(self, before: discord.Guild, after: discord.Guild) -> None: + async def on_guild_update(self, before: disnake.Guild, after: disnake.Guild) -> None: """Log guild update event to mod log.""" if before.id != GuildConstant.id: return @@ -382,7 +382,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_ban(self, guild: discord.Guild, member: discord.Member) -> None: + async def on_member_ban(self, guild: disnake.Guild, member: disnake.Member) -> None: """Log ban event to user log.""" if guild.id != GuildConstant.id: return @@ -399,7 +399,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_join(self, member: discord.Member) -> None: + async def on_member_join(self, member: disnake.Member) -> None: """Log member join event to user log.""" if member.guild.id != GuildConstant.id: return @@ -420,7 +420,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_remove(self, member: discord.Member) -> None: + async def on_member_remove(self, member: disnake.Member) -> None: """Log member leave event to user log.""" if member.guild.id != GuildConstant.id: return @@ -437,7 +437,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_unban(self, guild: discord.Guild, member: discord.User) -> None: + async def on_member_unban(self, guild: disnake.Guild, member: disnake.User) -> None: """Log member unban event to mod log.""" if guild.id != GuildConstant.id: return @@ -454,7 +454,7 @@ class ModLog(Cog, name="ModLog"): ) @staticmethod - def get_role_diff(before: t.List[discord.Role], after: t.List[discord.Role]) -> t.List[str]: + def get_role_diff(before: t.List[disnake.Role], after: t.List[disnake.Role]) -> t.List[str]: """Return a list of strings describing the roles added and removed.""" changes = [] before_roles = set(before) @@ -469,7 +469,7 @@ class ModLog(Cog, name="ModLog"): return changes @Cog.listener() - async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: + async def on_member_update(self, before: disnake.Member, after: disnake.Member) -> None: """Log member update event to user log.""" if before.guild.id != GuildConstant.id: return @@ -552,7 +552,7 @@ class ModLog(Cog, name="ModLog"): return channel.id in GuildConstant.modlog_blacklist - async def log_cached_deleted_message(self, message: discord.Message) -> None: + async def log_cached_deleted_message(self, message: disnake.Message) -> None: """ Log the message's details to message change log. @@ -608,7 +608,7 @@ class ModLog(Cog, name="ModLog"): channel_id=Channels.message_log ) - async def log_uncached_deleted_message(self, event: discord.RawMessageDeleteEvent) -> None: + async def log_uncached_deleted_message(self, event: disnake.RawMessageDeleteEvent) -> None: """ Log the message's details to message change log. @@ -648,7 +648,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: + async def on_raw_message_delete(self, event: disnake.RawMessageDeleteEvent) -> None: """Log message deletions to message change log.""" if event.cached_message is not None: await self.log_cached_deleted_message(event.cached_message) @@ -656,7 +656,7 @@ class ModLog(Cog, name="ModLog"): await self.log_uncached_deleted_message(event) @Cog.listener() - async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: + async def on_message_edit(self, msg_before: disnake.Message, msg_after: disnake.Message) -> None: """Log message edit event to message change log.""" if self.is_message_blacklisted(msg_before): return @@ -727,7 +727,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: + async def on_raw_message_edit(self, event: disnake.RawMessageUpdateEvent) -> None: """Log raw message edit event to message change log.""" if event.guild_id is None: return # ignore DM edits @@ -736,7 +736,7 @@ class ModLog(Cog, name="ModLog"): try: channel = self.bot.get_channel(int(event.data["channel_id"])) message = await channel.fetch_message(event.message_id) - except discord.NotFound: # Was deleted before we got the event + except disnake.NotFound: # Was deleted before we got the event return if self.is_message_blacklisted(message): @@ -860,9 +860,9 @@ class ModLog(Cog, name="ModLog"): @Cog.listener() async def on_voice_state_update( self, - member: discord.Member, - before: discord.VoiceState, - after: discord.VoiceState + member: disnake.Member, + before: disnake.VoiceState, + after: disnake.VoiceState ) -> None: """Log member voice state changes to the voice log channel.""" if ( diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index b5cd29b12..51d161d84 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -4,8 +4,8 @@ import datetime import arrow from async_rediscache import RedisCache from dateutil.parser import isoparse, parse as dateutil_parse -from discord import Embed, Member -from discord.ext.commands import Cog, Context, group, has_any_role +from disnake import Embed, Member +from disnake.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles @@ -22,12 +22,12 @@ MAXIMUM_WORK_LIMIT = 16 class ModPings(Cog): """Commands for a moderator to turn moderator pings on and off.""" - # RedisCache[discord.Member.id, 'Naïve ISO 8601 string'] + # RedisCache[disnake.Member.id, 'Naïve ISO 8601 string'] # The cache's keys are mods who have pings off. # The cache's values are the times when the role should be re-applied to them, stored in ISO format. pings_off_mods = RedisCache() - # RedisCache[discord.Member.id, 'start timestamp|total worktime in seconds'] + # RedisCache[disnake.Member.id, 'start timestamp|total worktime in seconds'] # The cache's keys are mod's ID # The cache's values are their pings on schedule timestamp and the total seconds (work time) until pings off modpings_schedule = RedisCache() diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 511520252..0b677dddb 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -5,10 +5,10 @@ from datetime import datetime, timedelta, timezone from typing import Optional, OrderedDict, Union from async_rediscache import RedisCache -from discord import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel -from discord.ext import commands, tasks -from discord.ext.commands import Context -from discord.utils import MISSING +from disnake import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel +from disnake.ext import commands, tasks +from disnake.ext.commands import Context +from disnake.utils import MISSING from bot import constants from bot.bot import Bot diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py index b6a771441..7fcafc01c 100644 --- a/bot/exts/moderation/slowmode.py +++ b/bot/exts/moderation/slowmode.py @@ -1,8 +1,8 @@ from typing import Optional from dateutil.relativedelta import relativedelta -from discord import TextChannel -from discord.ext.commands import Cog, Context, group, has_any_role +from disnake import TextChannel +from disnake.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, Emojis, MODERATION_ROLES diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 985cc6eb1..7afd9f71d 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -2,10 +2,10 @@ from datetime import timedelta, timezone from operator import itemgetter import arrow -import discord +import disnake from arrow import Arrow from async_rediscache import RedisCache -from discord.ext import commands +from disnake.ext import commands from bot.bot import Bot from bot.constants import ( @@ -24,7 +24,7 @@ class Stream(commands.Cog): """Grant and revoke streaming permissions from members.""" # Stores tasks to remove streaming permission - # RedisCache[discord.Member.id, UtcPosixTimestamp] + # RedisCache[disnake.Member.id, UtcPosixTimestamp] task_cache = RedisCache() def __init__(self, bot: Bot): @@ -37,10 +37,10 @@ class Stream(commands.Cog): self.reload_task.cancel() self.reload_task.add_done_callback(lambda _: self.scheduler.cancel_all()) - async def _revoke_streaming_permission(self, member: discord.Member) -> None: + async def _revoke_streaming_permission(self, member: disnake.Member) -> None: """Remove the streaming permission from the given Member.""" await self.task_cache.delete(member.id) - await member.remove_roles(discord.Object(Roles.video), reason="Streaming access revoked") + await member.remove_roles(disnake.Object(Roles.video), reason="Streaming access revoked") async def _reload_tasks_from_redis(self) -> None: """Reload outstanding tasks from redis on startup, delete the task if the member has since left the server.""" @@ -66,7 +66,7 @@ class Stream(commands.Cog): self._revoke_streaming_permission(member) ) - async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None: + async def _suspend_stream(self, ctx: commands.Context, member: disnake.Member) -> None: """Suspend a member's stream.""" await self.bot.wait_until_guild_available() voice_state = member.voice @@ -90,7 +90,7 @@ class Stream(commands.Cog): @commands.command(aliases=("streaming",)) @commands.has_any_role(*MODERATION_ROLES) - async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None: + async def stream(self, ctx: commands.Context, member: disnake.Member, duration: Expiry = None) -> None: """ Temporarily grant streaming permissions to a member for a given duration. @@ -128,7 +128,7 @@ class Stream(commands.Cog): self.scheduler.schedule_at(duration, member.id, self._revoke_streaming_permission(member)) await self.task_cache.set(member.id, duration.timestamp()) - await member.add_roles(discord.Object(Roles.video), reason="Temporary streaming access granted") + await member.add_roles(disnake.Object(Roles.video), reason="Temporary streaming access granted") await ctx.send(f"{Emojis.check_mark} {member.mention} can now stream until {time.discord_timestamp(duration)}.") @@ -142,7 +142,7 @@ class Stream(commands.Cog): @commands.command(aliases=("pstream",)) @commands.has_any_role(*MODERATION_ROLES) - async def permanentstream(self, ctx: commands.Context, member: discord.Member) -> None: + async def permanentstream(self, ctx: commands.Context, member: disnake.Member) -> None: """Permanently grants the given member the permission to stream.""" log.trace(f"Attempting to give permanent streaming permission to {member} ({member.id}).") @@ -163,13 +163,13 @@ class Stream(commands.Cog): log.debug(f"{member} ({member.id}) already had permanent streaming permission.") return - await member.add_roles(discord.Object(Roles.video), reason="Permanent streaming access granted") + await member.add_roles(disnake.Object(Roles.video), reason="Permanent streaming access granted") await ctx.send(f"{Emojis.check_mark} Permanently granted {member.mention} the permission to stream.") log.debug(f"Successfully gave {member} ({member.id}) permanent streaming permission.") @commands.command(aliases=("unstream", "rstream")) @commands.has_any_role(*MODERATION_ROLES) - async def revokestream(self, ctx: commands.Context, member: discord.Member) -> None: + async def revokestream(self, ctx: commands.Context, member: disnake.Member) -> None: """Revoke the permission to stream from the given member.""" log.trace(f"Attempting to remove streaming permission from {member} ({member.id}).") @@ -222,7 +222,7 @@ class Stream(commands.Cog): # Only output the message in the pagination lines = [line[1] for line in streamer_info] - embed = discord.Embed( + embed = disnake.Embed( title=f"Members with streaming permission (`{len(lines)}` total)", colour=Colours.soft_green ) diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py index 37338d19c..c958aa160 100644 --- a/bot/exts/moderation/verification.py +++ b/bot/exts/moderation/verification.py @@ -1,7 +1,7 @@ import typing as t -import discord -from discord.ext.commands import Cog, Context, command, has_any_role +import disnake +from disnake.ext.commands import Cog, Context, command, has_any_role from bot import constants from bot.bot import Bot @@ -51,7 +51,7 @@ async def safe_dm(coro: t.Coroutine) -> None: """ try: await coro - except discord.HTTPException as discord_exc: + except disnake.HTTPException as discord_exc: log.trace(f"DM dispatch failed on status {discord_exc.status} with code: {discord_exc.code}") if discord_exc.code != 50_007: # If any reason other than disabled DMs raise @@ -72,7 +72,7 @@ class Verification(Cog): # region: listeners @Cog.listener() - async def on_member_join(self, member: discord.Member) -> None: + async def on_member_join(self, member: disnake.Member) -> None: """Attempt to send initial direct message to each new member.""" if member.guild.id != constants.Guild.id: return # Only listen for PyDis events @@ -87,11 +87,11 @@ class Verification(Cog): log.trace(f"Sending on join message to new member: {member.id}") try: await safe_dm(member.send(ON_JOIN_MESSAGE)) - except discord.HTTPException: + except disnake.HTTPException: log.exception("DM dispatch failed on unexpected error code") @Cog.listener() - async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: + async def on_member_update(self, before: disnake.Member, after: disnake.Member) -> None: """Check if we need to send a verification DM to a gated user.""" if before.pending is True and after.pending is False: try: @@ -100,7 +100,7 @@ class Verification(Cog): # our alternate welcome DM which includes info such as our welcome # video. await safe_dm(after.send(VERIFIED_MESSAGE)) - except discord.HTTPException: + except disnake.HTTPException: log.exception("DM dispatch failed on unexpected error code") # endregion @@ -108,7 +108,7 @@ class Verification(Cog): @command(name='verify') @has_any_role(*constants.MODERATION_ROLES) - async def perform_manual_verification(self, ctx: Context, user: discord.Member) -> None: + async def perform_manual_verification(self, ctx: Context, user: disnake.Member) -> None: """Command for moderators to verify any user.""" log.trace(f'verify command called by {ctx.author} for {user.id}.') diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py index fa66b00dd..24ae86bdd 100644 --- a/bot/exts/moderation/voice_gate.py +++ b/bot/exts/moderation/voice_gate.py @@ -3,10 +3,10 @@ from contextlib import suppress from datetime import timedelta import arrow -import discord +import disnake from async_rediscache import RedisCache -from discord import Colour, Member, VoiceState -from discord.ext.commands import Cog, Context, command +from disnake import Colour, Member, VoiceState +from disnake.ext.commands import Cog, Context, command from bot.api import ResponseCodeError from bot.bot import Bot @@ -51,7 +51,7 @@ VOICE_PING_DM = ( class VoiceGate(Cog): """Voice channels verification management.""" - # RedisCache[t.Union[discord.User.id, discord.Member.id], t.Union[discord.Message.id, int]] + # RedisCache[t.Union[disnake.User.id, disnake.Member.id], t.Union[disnake.Message.id, int]] # The cache's keys are the IDs of members who are verified or have joined a voice channel # The cache's values are either the message ID of the ping message or 0 (NO_MSG) if no message is present redis_cache = RedisCache() @@ -75,14 +75,14 @@ class VoiceGate(Cog): """ if message_id := await self.redis_cache.get(member_id): log.trace(f"Removing voice gate reminder message for user: {member_id}") - with suppress(discord.NotFound): + with suppress(disnake.NotFound): await self.bot.http.delete_message(Channels.voice_gate, message_id) await self.redis_cache.set(member_id, NO_MSG) else: log.trace(f"Voice gate reminder message for user {member_id} was already removed") @redis_cache.atomic_transaction - async def _ping_newcomer(self, member: discord.Member) -> tuple: + async def _ping_newcomer(self, member: disnake.Member) -> tuple: """ See if `member` should be sent a voice verification notification, and send it if so. @@ -91,7 +91,7 @@ class VoiceGate(Cog): * The `member` is already voice-verified Otherwise, the notification message ID is stored in `redis_cache` and return (True, channel). - channel is either [discord.TextChannel, discord.DMChannel]. + channel is either [disnake.TextChannel, disnake.DMChannel]. """ if await self.redis_cache.contains(member.id): log.trace("User already in cache. Ignore.") @@ -111,7 +111,7 @@ class VoiceGate(Cog): try: message = await member.send(VOICE_PING_DM.format(channel_mention=voice_verification_channel.mention)) - except discord.Forbidden: + except disnake.Forbidden: log.trace("DM failed for Voice ping message. Sending in channel.") message = await voice_verification_channel.send(f"Hello, {member.mention}! {VOICE_PING}") @@ -137,7 +137,7 @@ class VoiceGate(Cog): data = await self.bot.api_client.get(f"bot/users/{ctx.author.id}/metricity_data") except ResponseCodeError as e: if e.status == 404: - embed = discord.Embed( + embed = disnake.Embed( title="Not found", description=( "We were unable to find user data for you. " @@ -148,7 +148,7 @@ class VoiceGate(Cog): ) log.info(f"Unable to find Metricity data about {ctx.author} ({ctx.author.id})") else: - embed = discord.Embed( + embed = disnake.Embed( title="Unexpected response", description=( "We encountered an error while attempting to find data for your user. " @@ -159,7 +159,7 @@ class VoiceGate(Cog): log.warning(f"Got response code {e.status} while trying to get {ctx.author.id} Metricity data.") try: await ctx.author.send(embed=embed) - except discord.Forbidden: + except disnake.Forbidden: log.info("Could not send user DM. Sending in voice-verify channel and scheduling delete.") await ctx.send(embed=embed) @@ -179,7 +179,7 @@ class VoiceGate(Cog): [self.bot.stats.incr(f"voice_gate.failed.{key}") for key, value in checks.items() if value is True] if failed: - embed = discord.Embed( + embed = disnake.Embed( title="Voice Gate failed", description=FAILED_MESSAGE.format(reasons="\n".join(f'• You {reason}.' for reason in failed_reasons)), color=Colour.red() @@ -187,12 +187,12 @@ class VoiceGate(Cog): try: await ctx.author.send(embed=embed) await ctx.send(f"{ctx.author}, please check your DMs.") - except discord.Forbidden: + except disnake.Forbidden: await ctx.channel.send(ctx.author.mention, embed=embed) return self.mod_log.ignore(Event.member_update, ctx.author.id) - embed = discord.Embed( + embed = disnake.Embed( title="Voice gate passed", description="You have been granted permission to use voice channels in Python Discord.", color=Colour.green() @@ -204,17 +204,17 @@ class VoiceGate(Cog): try: await ctx.author.send(embed=embed) await ctx.send(f"{ctx.author}, please check your DMs.") - except discord.Forbidden: + except disnake.Forbidden: await ctx.channel.send(ctx.author.mention, embed=embed) # wait a little bit so those who don't get DMs see the response in-channel before losing perms to see it. await asyncio.sleep(3) - await ctx.author.add_roles(discord.Object(Roles.voice_verified), reason="Voice Gate passed") + await ctx.author.add_roles(disnake.Object(Roles.voice_verified), reason="Voice Gate passed") self.bot.stats.incr("voice_gate.passed") @Cog.listener() - async def on_message(self, message: discord.Message) -> None: + async def on_message(self, message: disnake.Message) -> None: """Delete all non-staff messages from voice gate channel that don't invoke voice verify command.""" # Check is channel voice gate if message.channel.id != Channels.voice_gate: @@ -229,7 +229,7 @@ class VoiceGate(Cog): if message.content.endswith(VOICE_PING): log.trace("Message is the voice verification ping. Ignore.") return - with suppress(discord.NotFound): + with suppress(disnake.NotFound): await message.delete(delay=GateConf.bot_message_delete_delay) return @@ -242,7 +242,7 @@ class VoiceGate(Cog): if ctx.command is not None and ctx.command.name == "voice_verify": self.mod_log.ignore(Event.message_delete, message.id) - with suppress(discord.NotFound): + with suppress(disnake.NotFound): await message.delete() @Cog.listener() @@ -257,7 +257,7 @@ class VoiceGate(Cog): log.trace("User not in a voice channel. Ignore.") return - if isinstance(after.channel, discord.StageChannel): + if isinstance(after.channel, disnake.StageChannel): log.trace("User joined a stage channel. Ignore.") return @@ -267,7 +267,7 @@ class VoiceGate(Cog): # Schedule the channel ping notification to be deleted after the configured delay, which is # again delegated to an atomic helper - if notification_sent and isinstance(message_channel, discord.TextChannel): + if notification_sent and isinstance(message_channel, disnake.TextChannel): await asyncio.sleep(GateConf.voice_ping_delete_delay) await self._delete_ping(member.id) diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index ee9b6ba45..88669ccaa 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -6,9 +6,9 @@ from collections import defaultdict, deque from dataclasses import dataclass from typing import Any, Dict, Optional -import discord -from discord import Color, DMChannel, Embed, HTTPException, Message, errors -from discord.ext.commands import Cog, Context +import disnake +from disnake import Color, DMChannel, Embed, HTTPException, Message, errors +from disnake.ext.commands import Cog, Context from bot.api import ResponseCodeError from bot.bot import Bot @@ -104,7 +104,7 @@ class WatchChannel(metaclass=CogABCMeta): try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except discord.HTTPException: + except disnake.HTTPException: self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") if self.channel is None or self.webhook is None: @@ -217,7 +217,7 @@ class WatchChannel(metaclass=CogABCMeta): username = messages.sub_clyde(username) try: await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except discord.HTTPException as exc: + except disnake.HTTPException as exc: self.log.exception( "Failed to send a message to the webhook", exc_info=exc @@ -265,7 +265,7 @@ class WatchChannel(metaclass=CogABCMeta): username=msg.author.display_name, avatar_url=msg.author.display_avatar.url ) - except discord.HTTPException as exc: + except disnake.HTTPException as exc: self.log.exception( "Failed to send an attachment to the webhook", exc_info=exc diff --git a/bot/exts/moderation/watchchannels/bigbrother.py b/bot/exts/moderation/watchchannels/bigbrother.py index ab37b1b80..b0a48ceff 100644 --- a/bot/exts/moderation/watchchannels/bigbrother.py +++ b/bot/exts/moderation/watchchannels/bigbrother.py @@ -1,7 +1,7 @@ import textwrap from collections import ChainMap -from discord.ext.commands import Cog, Context, group, has_any_role +from disnake.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Webhooks @@ -94,7 +94,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"): await ctx.send(f":x: {user.mention} is already being watched.") return - # discord.User instances don't have a roles attribute + # disnake.User instances don't have a roles attribute if hasattr(user, "roles") and any(role.id in MODERATION_ROLES for role in user.roles): await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I must be kind to my masters.") return diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 0554bf37a..3d784ef77 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -3,10 +3,10 @@ from collections import ChainMap, defaultdict from io import StringIO from typing import Optional, Union -import discord +import disnake from async_rediscache import RedisCache -from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User -from discord.ext.commands import BadArgument, Cog, Context, group, has_any_role +from disnake import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User +from disnake.ext.commands import BadArgument, Cog, Context, group, has_any_role from bot.api import ResponseCodeError from bot.bot import Bot @@ -483,7 +483,7 @@ class TalentPool(Cog, name="Talentpool"): async def get_review(self, ctx: Context, user_id: int) -> None: """Get the user's review as a markdown file.""" review, _, _ = await self.reviewer.make_review(user_id) - file = discord.File(StringIO(review), f"{user_id}_review.md") + file = disnake.File(StringIO(review), f"{user_id}_review.md") await ctx.send(file=file) @nomination_group.command(aliases=('review',)) diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index b4d177622..d496d0eb2 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -10,8 +10,8 @@ from typing import List, Optional, Union import arrow from dateutil.parser import isoparse -from discord import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel -from discord.ext.commands import Context +from disnake import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel +from disnake.ext.commands import Context from bot.api import ResponseCodeError from bot.bot import Bot diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py index 8f0094bc9..7d18c0ed3 100644 --- a/bot/exts/utils/bot.py +++ b/bot/exts/utils/bot.py @@ -1,7 +1,7 @@ from typing import Optional -from discord import Embed, TextChannel -from discord.ext.commands import Cog, Context, command, group, has_any_role +from disnake import Embed, TextChannel +from disnake.ext.commands import Cog, Context, command, group, has_any_role from bot.bot import Bot from bot.constants import Guild, MODERATION_ROLES, URLs diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index fda1e49e2..3d12ae848 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -2,9 +2,9 @@ import functools import typing as t from enum import Enum -from discord import Colour, Embed -from discord.ext import commands -from discord.ext.commands import Context, group +from disnake import Colour, Embed +from disnake.ext import commands +from disnake.ext.commands import Context, group from bot import exts from bot.bot import Bot diff --git a/bot/exts/utils/internal.py b/bot/exts/utils/internal.py index e7113c09c..28c1867ad 100644 --- a/bot/exts/utils/internal.py +++ b/bot/exts/utils/internal.py @@ -9,8 +9,8 @@ from io import StringIO from typing import Any, Optional, Tuple import arrow -import discord -from discord.ext.commands import Cog, Context, group, has_any_role, is_owner +import disnake +from disnake.ext.commands import Cog, Context, group, has_any_role, is_owner from bot.bot import Bot from bot.constants import DEBUG_MODE, Roles @@ -42,7 +42,7 @@ class Internal(Cog): self.socket_event_total += 1 self.socket_events[event_type] += 1 - def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: + def _format(self, inp: str, out: Any) -> Tuple[str, Optional[disnake.Embed]]: """Format the eval output into a string & attempt to format it into an Embed.""" self._ = out @@ -103,7 +103,7 @@ class Internal(Cog): res += f"Out[{self.ln}]: " - if isinstance(out, discord.Embed): + if isinstance(out, disnake.Embed): # We made an embed? Send that as embed res += "" res = (res, out) @@ -136,7 +136,7 @@ class Internal(Cog): return res # Return (text, embed) - async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: + async def _eval(self, ctx: Context, code: str) -> Optional[disnake.Message]: """Eval the input code string & send an embed to the invoking context.""" self.ln += 1 @@ -154,7 +154,8 @@ class Internal(Cog): "self": self, "bot": self.bot, "inspect": inspect, - "discord": discord, + "discord": disnake, + "disnake": disnake, "contextlib": contextlib } @@ -240,10 +241,10 @@ async def func(): # (None,) -> Any per_s = self.socket_event_total / running_s - stats_embed = discord.Embed( + stats_embed = disnake.Embed( title="WebSocket statistics", description=f"Receiving {per_s:0.2f} events per second.", - color=discord.Color.og_blurple() + color=disnake.Color.og_blurple() ) for event_type, count in self.socket_events.most_common(25): diff --git a/bot/exts/utils/ping.py b/bot/exts/utils/ping.py index 9fb5b7b8f..eeb1d5ff5 100644 --- a/bot/exts/utils/ping.py +++ b/bot/exts/utils/ping.py @@ -1,7 +1,7 @@ import arrow from aiohttp import client_exceptions -from discord import Embed -from discord.ext import commands +from disnake import Embed +from disnake.ext import commands from bot.bot import Bot from bot.constants import Channels, STAFF_PARTNERS_COMMUNITY_ROLES, URLs diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index ad82d49c9..bf0e9d2ac 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -4,9 +4,9 @@ import typing as t from datetime import datetime, timezone from operator import itemgetter -import discord +import disnake from dateutil.parser import isoparse -from discord.ext.commands import Cog, Context, Greedy, group +from disnake.ext.commands import Cog, Context, Greedy, group from bot.bot import Bot from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES @@ -26,8 +26,8 @@ LOCK_NAMESPACE = "reminder" WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 -Mentionable = t.Union[discord.Member, discord.Role] -ReminderMention = t.Union[UnambiguousUser, discord.Role] +Mentionable = t.Union[disnake.Member, disnake.Role] +ReminderMention = t.Union[UnambiguousUser, disnake.Role] class Reminders(Cog): @@ -66,7 +66,7 @@ class Reminders(Cog): else: self.schedule_reminder(reminder) - def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, discord.TextChannel]: + def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, disnake.TextChannel]: """Ensure reminder channel can be fetched otherwise delete the reminder.""" channel = self.bot.get_channel(reminder['channel_id']) is_valid = True @@ -87,9 +87,9 @@ class Reminders(Cog): reminder_id: t.Union[str, int] ) -> None: """Send an embed confirming the reminder change was made successfully.""" - embed = discord.Embed( + embed = disnake.Embed( description=on_success, - colour=discord.Colour.green(), + colour=disnake.Colour.green(), title=random.choice(POSITIVE_REPLIES) ) @@ -113,7 +113,7 @@ class Reminders(Cog): if await has_no_roles_check(ctx, *STAFF_PARTNERS_COMMUNITY_ROLES): return False, "members/roles" elif await has_no_roles_check(ctx, *MODERATION_ROLES): - return all(isinstance(mention, (discord.User, discord.Member)) for mention in mentions), "roles" + return all(isinstance(mention, (disnake.User, disnake.Member)) for mention in mentions), "roles" else: return True, "" @@ -173,15 +173,15 @@ class Reminders(Cog): if not is_valid: # No need to cancel the task too; it'll simply be done once this coroutine returns. return - embed = discord.Embed() + embed = disnake.Embed() if expected_time: - embed.colour = discord.Colour.red() + embed.colour = disnake.Colour.red() embed.set_author( icon_url=Icons.remind_red, name="Sorry, your reminder should have arrived earlier!" ) else: - embed.colour = discord.Colour.og_blurple() + embed.colour = disnake.Colour.og_blurple() embed.set_author( icon_url=Icons.remind_blurple, name="It has arrived!" @@ -200,7 +200,7 @@ class Reminders(Cog): partial_message = channel.get_partial_message(int(jump_url.split("/")[-1])) try: await partial_message.reply(content=f"{additional_mentions}", embed=embed) - except discord.HTTPException as e: + except disnake.HTTPException as e: log.info( f"There was an error when trying to reply to a reminder invocation message, {e}, " "fall back to using jump_url" @@ -284,7 +284,7 @@ class Reminders(Cog): # If `content` isn't provided then we try to get message content of a replied message if not content: if reference := ctx.message.reference: - if isinstance((resolved_message := reference.resolved), discord.Message): + if isinstance((resolved_message := reference.resolved), disnake.Message): content = resolved_message.content # If we weren't able to get the content of a replied message if content is None: @@ -361,8 +361,8 @@ class Reminders(Cog): lines.append(text) - embed = discord.Embed() - embed.colour = discord.Colour.og_blurple() + embed = disnake.Embed() + embed.colour = disnake.Colour.og_blurple() embed.title = f"Reminders for {ctx.author}" # Remind the user that they have no reminders :^) @@ -372,7 +372,7 @@ class Reminders(Cog): return # Construct the embed and paginate it. - embed.colour = discord.Colour.og_blurple() + embed.colour = disnake.Colour.og_blurple() await LinePaginator.paginate( lines, diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index cc3a2e1d7..07d824f87 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -8,8 +8,8 @@ from signal import Signals from typing import Optional, Tuple from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX -from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only +from disnake import AllowedMentions, HTTPException, Message, NotFound, Reaction, User +from disnake.ext.commands import Cog, Context, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs diff --git a/bot/exts/utils/thread_bumper.py b/bot/exts/utils/thread_bumper.py index 35057f1fe..d37b3b51c 100644 --- a/bot/exts/utils/thread_bumper.py +++ b/bot/exts/utils/thread_bumper.py @@ -1,8 +1,8 @@ import typing as t -import discord +import disnake from async_rediscache import RedisCache -from discord.ext import commands +from disnake.ext import commands from bot import constants from bot.bot import Bot @@ -16,14 +16,14 @@ log = get_logger(__name__) class ThreadBumper(commands.Cog): """Cog that allow users to add the current thread to a list that get reopened on archive.""" - # RedisCache[discord.Thread.id, "sentinel"] + # RedisCache[disnake.Thread.id, "sentinel"] threads_to_bump = RedisCache() def __init__(self, bot: Bot): self.bot = bot self.init_task = scheduling.create_task(self.ensure_bumped_threads_are_active(), event_loop=self.bot.loop) - async def unarchive_threads_not_manually_archived(self, threads: list[discord.Thread]) -> None: + async def unarchive_threads_not_manually_archived(self, threads: list[disnake.Thread]) -> None: """ Iterate through and unarchive any threads that weren't manually archived recently. @@ -35,7 +35,7 @@ class ThreadBumper(commands.Cog): guild = self.bot.get_guild(constants.Guild.id) recent_manually_archived_thread_ids = [] - async for thread_update in guild.audit_logs(limit=200, action=discord.AuditLogAction.thread_update): + async for thread_update in guild.audit_logs(limit=200, action=disnake.AuditLogAction.thread_update): if getattr(thread_update.after, "archived", False): recent_manually_archived_thread_ids.append(thread_update.target.id) @@ -58,7 +58,7 @@ class ThreadBumper(commands.Cog): for thread_id, _ in await self.threads_to_bump.items(): try: thread = await channel.get_or_fetch_channel(thread_id) - except discord.NotFound: + except disnake.NotFound: log.info("Thread %d has been deleted, removing from bumped threads.", thread_id) await self.threads_to_bump.delete(thread_id) continue @@ -75,12 +75,12 @@ class ThreadBumper(commands.Cog): await ctx.send_help(ctx.command) @thread_bump_group.command(name="add", aliases=("a",)) - async def add_thread_to_bump_list(self, ctx: commands.Context, thread: t.Optional[discord.Thread]) -> None: + async def add_thread_to_bump_list(self, ctx: commands.Context, thread: t.Optional[disnake.Thread]) -> None: """Add a thread to the bump list.""" await self.init_task if not thread: - if isinstance(ctx.channel, discord.Thread): + if isinstance(ctx.channel, disnake.Thread): thread = ctx.channel else: raise commands.BadArgument("You must provide a thread, or run this command within a thread.") @@ -92,12 +92,12 @@ class ThreadBumper(commands.Cog): await ctx.send(f":ok_hand:{thread.mention} has been added to the bump list.") @thread_bump_group.command(name="remove", aliases=("r", "rem", "d", "del", "delete")) - async def remove_thread_from_bump_list(self, ctx: commands.Context, thread: t.Optional[discord.Thread]) -> None: + async def remove_thread_from_bump_list(self, ctx: commands.Context, thread: t.Optional[disnake.Thread]) -> None: """Remove a thread from the bump list.""" await self.init_task if not thread: - if isinstance(ctx.channel, discord.Thread): + if isinstance(ctx.channel, disnake.Thread): thread = ctx.channel else: raise commands.BadArgument("You must provide a thread, or run this command within a thread.") @@ -114,14 +114,14 @@ class ThreadBumper(commands.Cog): await self.init_task lines = [f"<#{k}>" for k, _ in await self.threads_to_bump.items()] - embed = discord.Embed( + embed = disnake.Embed( title="Threads in the bump list", colour=constants.Colours.blue ) await LinePaginator.paginate(lines, ctx, embed) @commands.Cog.listener() - async def on_thread_update(self, _: discord.Thread, after: discord.Thread) -> None: + async def on_thread_update(self, _: disnake.Thread, after: disnake.Thread) -> None: """ Listen for thread updates and check if the thread has been archived. diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index 2a074788e..77be3315c 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -3,9 +3,9 @@ import re import unicodedata from typing import Tuple, Union -from discord import Colour, Embed, utils -from discord.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role -from discord.utils import snowflake_time +from disnake import Colour, Embed, utils +from disnake.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role +from disnake.utils import snowflake_time from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES diff --git a/bot/log.py b/bot/log.py index 100cd06f6..0b1d1aca6 100644 --- a/bot/log.py +++ b/bot/log.py @@ -74,7 +74,7 @@ def setup() -> None: coloredlogs.install(level=TRACE_LEVEL, logger=root_log, stream=sys.stdout) root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO) - get_logger("discord").setLevel(logging.WARNING) + get_logger("disnake").setLevel(logging.WARNING) get_logger("websockets").setLevel(logging.WARNING) get_logger("chardet").setLevel(logging.WARNING) get_logger("async_rediscache").setLevel(logging.WARNING) diff --git a/bot/monkey_patches.py b/bot/monkey_patches.py index 4840fa454..590be22a2 100644 --- a/bot/monkey_patches.py +++ b/bot/monkey_patches.py @@ -2,8 +2,8 @@ import re from datetime import timedelta import arrow -from discord import Forbidden, http -from discord.ext import commands +from disnake import Forbidden, http +from disnake.ext import commands from bot.log import get_logger @@ -13,7 +13,7 @@ MESSAGE_ID_RE = re.compile(r'(?P[0-9]{15,20})$') class Command(commands.Command): """ - A `discord.ext.commands.Command` subclass which supports root aliases. + A `disnake.ext.commands.Command` subclass which supports root aliases. A `root_aliases` keyword argument is added, which is a sequence of alias names that will act as top-level commands rather than being aliases of the command's group. It's stored as an attribute diff --git a/bot/pagination.py b/bot/pagination.py index 8f4353eb1..1a014daa1 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -3,9 +3,9 @@ import typing as t from contextlib import suppress from functools import partial -import discord -from discord.abc import User -from discord.ext.commands import Context, Paginator +import disnake +from disnake.abc import User +from disnake.ext.commands import Context, Paginator from bot import constants from bot.log import get_logger @@ -55,7 +55,7 @@ class LinePaginator(Paginator): linesep: str = "\n" ) -> None: """ - This function overrides the Paginator.__init__ from inside discord.ext.commands. + This function overrides the Paginator.__init__ from inside disnake.ext.commands. It overrides in order to allow us to configure the maximum number of lines per page. """ @@ -99,7 +99,7 @@ class LinePaginator(Paginator): effort to avoid breaking up single lines across pages, while keeping the total length of the page at a reasonable size. - This function overrides the `Paginator.add_line` from inside `discord.ext.commands`. + This function overrides the `Paginator.add_line` from inside `disnake.ext.commands`. It overrides in order to allow us to configure the maximum number of lines per page. """ @@ -192,7 +192,7 @@ class LinePaginator(Paginator): cls, lines: t.List[str], ctx: Context, - embed: discord.Embed, + embed: disnake.Embed, prefix: str = "", suffix: str = "", max_lines: t.Optional[int] = None, @@ -204,7 +204,7 @@ class LinePaginator(Paginator): footer_text: str = None, url: str = None, exception_on_empty_embed: bool = False, - ) -> t.Optional[discord.Message]: + ) -> t.Optional[disnake.Message]: """ Use a paginator and set of reactions to provide pagination over a set of lines. @@ -219,7 +219,7 @@ class LinePaginator(Paginator): to any user with a moderation role. Example: - >>> embed = discord.Embed() + >>> embed = disnake.Embed() >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) >>> await LinePaginator.paginate([line for line in lines], ctx, embed) """ @@ -367,5 +367,5 @@ class LinePaginator(Paginator): await message.edit(embed=embed) log.debug("Ending pagination and clearing reactions.") - with suppress(discord.NotFound): + with suppress(disnake.NotFound): await message.clear_reactions() diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index 8903c385c..9c890e569 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message async def apply( diff --git a/bot/rules/burst.py b/bot/rules/burst.py index 25c5a2f33..a943cfdeb 100644 --- a/bot/rules/burst.py +++ b/bot/rules/burst.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message async def apply( diff --git a/bot/rules/burst_shared.py b/bot/rules/burst_shared.py index bbe9271b3..dee857e18 100644 --- a/bot/rules/burst_shared.py +++ b/bot/rules/burst_shared.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message async def apply( diff --git a/bot/rules/chars.py b/bot/rules/chars.py index 1f587422c..6d2f6eb83 100644 --- a/bot/rules/chars.py +++ b/bot/rules/chars.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message async def apply( diff --git a/bot/rules/discord_emojis.py b/bot/rules/discord_emojis.py index d979ac5e7..4fe4e88f9 100644 --- a/bot/rules/discord_emojis.py +++ b/bot/rules/discord_emojis.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message from emoji import demojize DISCORD_EMOJI_RE = re.compile(r"<:\w+:\d+>|:\w+:") diff --git a/bot/rules/duplicates.py b/bot/rules/duplicates.py index 8e4fbc12d..77e393db0 100644 --- a/bot/rules/duplicates.py +++ b/bot/rules/duplicates.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message async def apply( diff --git a/bot/rules/links.py b/bot/rules/links.py index c46b783c5..92c13b3f4 100644 --- a/bot/rules/links.py +++ b/bot/rules/links.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message LINK_RE = re.compile(r"(https?://[^\s]+)") diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 6f5addad1..7ee66be31 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message async def apply( diff --git a/bot/rules/newlines.py b/bot/rules/newlines.py index 4e66e1359..45266648e 100644 --- a/bot/rules/newlines.py +++ b/bot/rules/newlines.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message async def apply( diff --git a/bot/rules/role_mentions.py b/bot/rules/role_mentions.py index 0649540b6..1f7a6a74d 100644 --- a/bot/rules/role_mentions.py +++ b/bot/rules/role_mentions.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from disnake import Member, Message async def apply( diff --git a/bot/utils/channel.py b/bot/utils/channel.py index 954a10e56..ee0c87311 100644 --- a/bot/utils/channel.py +++ b/bot/utils/channel.py @@ -1,6 +1,6 @@ from typing import Union -import discord +import disnake import bot from bot import constants @@ -10,7 +10,7 @@ from bot.log import get_logger log = get_logger(__name__) -def is_help_channel(channel: discord.TextChannel) -> bool: +def is_help_channel(channel: disnake.TextChannel) -> bool: """Return True if `channel` is in one of the help categories (excluding dormant).""" log.trace(f"Checking if #{channel} is a help channel.") categories = (Categories.help_available, Categories.help_in_use) @@ -18,9 +18,9 @@ def is_help_channel(channel: discord.TextChannel) -> bool: return any(is_in_category(channel, category) for category in categories) -def is_mod_channel(channel: Union[discord.TextChannel, discord.Thread]) -> bool: +def is_mod_channel(channel: Union[disnake.TextChannel, disnake.Thread]) -> bool: """True if channel, or channel.parent for threads, is considered a mod channel.""" - if isinstance(channel, discord.Thread): + if isinstance(channel, disnake.Thread): channel = channel.parent if channel.id in constants.MODERATION_CHANNELS: @@ -36,11 +36,11 @@ def is_mod_channel(channel: Union[discord.TextChannel, discord.Thread]) -> bool: return False -def is_staff_channel(channel: discord.TextChannel) -> bool: +def is_staff_channel(channel: disnake.TextChannel) -> bool: """True if `channel` is considered a staff channel.""" guild = bot.instance.get_guild(constants.Guild.id) - if channel.type is discord.ChannelType.category: + if channel.type is disnake.ChannelType.category: return False # Channel is staff-only if staff have explicit read allow perms @@ -52,12 +52,12 @@ def is_staff_channel(channel: discord.TextChannel) -> bool: ) -def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: +def is_in_category(channel: disnake.TextChannel, category_id: int) -> bool: """Return True if `channel` is within a category with `category_id`.""" return getattr(channel, "category_id", None) == category_id -async def get_or_fetch_channel(channel_id: int) -> discord.abc.GuildChannel: +async def get_or_fetch_channel(channel_id: int) -> disnake.abc.GuildChannel: """Attempt to get or fetch a channel and return it.""" log.trace(f"Getting the channel {channel_id}.") diff --git a/bot/utils/checks.py b/bot/utils/checks.py index 188285684..9aa9bdc14 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -1,6 +1,6 @@ from typing import Callable, Container, Iterable, Optional, Union -from discord.ext.commands import ( +from disnake.ext.commands import ( BucketType, CheckFailure, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping, NoPrivateMessage, has_any_role ) @@ -135,7 +135,7 @@ def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketTy if any(role.id in bypass for role in ctx.author.roles): return - # cooldown logic, taken from discord.py internals + # cooldown logic, taken from disnake's internals current = ctx.message.created_at.timestamp() bucket = buckets.get_bucket(ctx.message) retry_after = bucket.update_rate_limit(current) diff --git a/bot/utils/function.py b/bot/utils/function.py index 55115d7d3..bb6d8afe3 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -94,7 +94,7 @@ def update_wrapper_globals( """ Update globals of `wrapper` with the globals from `wrapped`. - For forwardrefs in command annotations discordpy uses the __global__ attribute of the function + For forwardrefs in command annotations disnake uses the __global__ attribute of the function to resolve their values, with decorators that replace the function this breaks because they have their own globals. @@ -103,7 +103,7 @@ def update_wrapper_globals( An exception will be raised in case `wrapper` and `wrapped` share a global name that is used by `wrapped`'s typehints and is not in `ignored_conflict_names`, - as this can cause incorrect objects being used by discordpy's converters. + as this can cause incorrect objects being used by disnake's converters. """ annotation_global_names = ( ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) @@ -136,7 +136,7 @@ def command_wraps( *, ignored_conflict_names: t.Set[str] = frozenset(), ) -> t.Callable[[types.FunctionType], types.FunctionType]: - """Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation.""" + """Update the decorated function to look like `wrapped` and update globals for disnake forwardref evaluation.""" def decorator(wrapper: types.FunctionType) -> types.FunctionType: return functools.update_wrapper( update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), diff --git a/bot/utils/helpers.py b/bot/utils/helpers.py index 3501a3933..859f53fdb 100644 --- a/bot/utils/helpers.py +++ b/bot/utils/helpers.py @@ -1,7 +1,7 @@ from abc import ABCMeta from typing import Optional -from discord.ext.commands import CogMeta +from disnake.ext.commands import CogMeta class CogABCMeta(CogMeta, ABCMeta): diff --git a/bot/utils/members.py b/bot/utils/members.py index 693286045..d46baae5b 100644 --- a/bot/utils/members.py +++ b/bot/utils/members.py @@ -1,13 +1,13 @@ import typing as t -import discord +import disnake from bot.log import get_logger log = get_logger(__name__) -async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> t.Optional[discord.Member]: +async def get_or_fetch_member(guild: disnake.Guild, member_id: int) -> t.Optional[disnake.Member]: """ Attempt to get a member from cache; on failure fetch from the API. @@ -18,7 +18,7 @@ async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> t.Optiona else: try: member = await guild.fetch_member(member_id) - except discord.errors.NotFound: + except disnake.errors.NotFound: log.trace("Failed to fetch %d from API.", member_id) return None log.trace("%s fetched from API.", member) @@ -26,23 +26,23 @@ async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> t.Optiona async def handle_role_change( - member: discord.Member, + member: disnake.Member, coro: t.Callable[..., t.Coroutine], - role: discord.Role + role: disnake.Role ) -> None: """ Change `member`'s cooldown role via awaiting `coro` and handle errors. - `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. + `coro` is intended to be `disnake.Member.add_roles` or `disnake.Member.remove_roles`. """ try: await coro(role) - except discord.NotFound: + except disnake.NotFound: log.debug(f"Failed to change role for {member} ({member.id}): member not found") - except discord.Forbidden: + except disnake.Forbidden: log.debug( f"Forbidden to change role for {member} ({member.id}); " f"possibly due to role hierarchy" ) - except discord.HTTPException as e: + except disnake.HTTPException as e: log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py index f68d280c9..edf2111e9 100644 --- a/bot/utils/message_cache.py +++ b/bot/utils/message_cache.py @@ -1,7 +1,7 @@ import typing as t from math import ceil -from discord import Message +from disnake import Message class MessageCache: diff --git a/bot/utils/messages.py b/bot/utils/messages.py index e55c07062..0bdb00a29 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -5,8 +5,8 @@ from functools import partial from io import BytesIO from typing import Callable, List, Optional, Sequence, Union -import discord -from discord.ext.commands import Context +import disnake +from disnake.ext.commands import Context import bot from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES @@ -17,8 +17,8 @@ log = get_logger(__name__) def reaction_check( - reaction: discord.Reaction, - user: discord.abc.User, + reaction: disnake.Reaction, + user: disnake.abc.User, *, message_id: int, allowed_emoji: Sequence[str], @@ -51,14 +51,14 @@ def reaction_check( log.trace(f"Removing reaction {reaction} by {user} on {reaction.message.id}: disallowed user.") scheduling.create_task( reaction.message.remove_reaction(reaction.emoji, user), - suppressed_exceptions=(discord.HTTPException,), + suppressed_exceptions=(disnake.HTTPException,), name=f"remove_reaction-{reaction}-{reaction.message.id}-{user}" ) return False async def wait_for_deletion( - message: discord.Message, + message: disnake.Message, user_ids: Sequence[int], deletion_emojis: Sequence[str] = (Emojis.trashcan,), timeout: float = 60 * 5, @@ -82,7 +82,7 @@ async def wait_for_deletion( for emoji in deletion_emojis: try: await message.add_reaction(emoji) - except discord.NotFound: + except disnake.NotFound: log.trace(f"Aborting wait_for_deletion: message {message.id} deleted prematurely.") return @@ -101,13 +101,13 @@ async def wait_for_deletion( await message.clear_reactions() else: await message.delete() - except discord.NotFound: + except disnake.NotFound: log.trace(f"wait_for_deletion: message {message.id} deleted prematurely.") async def send_attachments( - message: discord.Message, - destination: Union[discord.TextChannel, discord.Webhook], + message: disnake.Message, + destination: Union[disnake.TextChannel, disnake.Webhook], link_large: bool = True, use_cached: bool = False, **kwargs @@ -140,9 +140,9 @@ async def send_attachments( if attachment.size <= destination.guild.filesize_limit - 512: with BytesIO() as file: await attachment.save(file, use_cached=use_cached) - attachment_file = discord.File(file, filename=attachment.filename) + attachment_file = disnake.File(file, filename=attachment.filename) - if isinstance(destination, discord.TextChannel): + if isinstance(destination, disnake.TextChannel): msg = await destination.send(file=attachment_file, **kwargs) urls.append(msg.attachments[0].url) else: @@ -151,7 +151,7 @@ async def send_attachments( large.append(attachment) else: log.info(f"{failure_msg} because it's too large.") - except discord.HTTPException as e: + except disnake.HTTPException as e: if link_large and e.status == 413: large.append(attachment) else: @@ -159,10 +159,10 @@ async def send_attachments( if link_large and large: desc = "\n".join(f"[{attachment.filename}]({attachment.url})" for attachment in large) - embed = discord.Embed(description=desc) + embed = disnake.Embed(description=desc) embed.set_footer(text="Attachments exceed upload size limit.") - if isinstance(destination, discord.TextChannel): + if isinstance(destination, disnake.TextChannel): await destination.send(embed=embed, **kwargs) else: await destination.send(embed=embed, **webhook_send_kwargs) @@ -171,9 +171,9 @@ async def send_attachments( async def count_unique_users_reaction( - message: discord.Message, - reaction_predicate: Callable[[discord.Reaction], bool] = lambda _: True, - user_predicate: Callable[[discord.User], bool] = lambda _: True, + message: disnake.Message, + reaction_predicate: Callable[[disnake.Reaction], bool] = lambda _: True, + user_predicate: Callable[[disnake.User], bool] = lambda _: True, count_bots: bool = True ) -> int: """ @@ -193,7 +193,7 @@ async def count_unique_users_reaction( return len(unique_users) -async def pin_no_system_message(message: discord.Message) -> bool: +async def pin_no_system_message(message: disnake.Message) -> bool: """Pin the given message, wait a couple of seconds and try to delete the system message.""" await message.pin() @@ -201,7 +201,7 @@ async def pin_no_system_message(message: discord.Message) -> bool: await asyncio.sleep(2) # Search for the system message in the last 10 messages async for historical_message in message.channel.history(limit=10): - if historical_message.type == discord.MessageType.pins_add: + if historical_message.type == disnake.MessageType.pins_add: await historical_message.delete() return True @@ -225,16 +225,16 @@ def sub_clyde(username: Optional[str]) -> Optional[str]: return username # Empty string or None -async def send_denial(ctx: Context, reason: str) -> discord.Message: +async def send_denial(ctx: Context, reason: str) -> disnake.Message: """Send an embed denying the user with the given reason.""" - embed = discord.Embed() - embed.colour = discord.Colour.red() + embed = disnake.Embed() + embed.colour = disnake.Colour.red() embed.title = random.choice(NEGATIVE_REPLIES) embed.description = reason return await ctx.send(embed=embed) -def format_user(user: discord.abc.User) -> str: +def format_user(user: disnake.abc.User) -> str: """Return a string for `user` which has their mention and ID.""" return f"{user.mention} (`{user.id}`)" diff --git a/bot/utils/webhooks.py b/bot/utils/webhooks.py index 9c916b63a..8ef929b79 100644 --- a/bot/utils/webhooks.py +++ b/bot/utils/webhooks.py @@ -1,7 +1,7 @@ from typing import Optional -import discord -from discord import Embed +import disnake +from disnake import Embed from bot.log import get_logger from bot.utils.messages import sub_clyde @@ -10,13 +10,13 @@ log = get_logger(__name__) async def send_webhook( - webhook: discord.Webhook, + webhook: disnake.Webhook, content: Optional[str] = None, username: Optional[str] = None, avatar_url: Optional[str] = None, embed: Optional[Embed] = None, wait: Optional[bool] = False -) -> discord.Message: +) -> disnake.Message: """ Send a message using the provided webhook. @@ -30,5 +30,5 @@ async def send_webhook( embed=embed, wait=wait, ) - except discord.HTTPException: + except disnake.HTTPException: log.exception("Failed to send a message to the webhook!") diff --git a/poetry.lock b/poetry.lock index a8ee6ef5c..087fd739c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -143,7 +143,7 @@ lxml = ["lxml"] [[package]] name = "bot-core" -version = "1.2.0" +version = "2.1.0" description = "Bot-Core provides the core functionality and utilities for the bots of the Python Discord community." category = "main" optional = false @@ -154,7 +154,7 @@ python-versions = "3.9.*" [package.source] type = "url" -url = "https://github.com/python-discord/bot-core/archive/511bcba1b0196cd498c707a525ea56921bd971db.zip" +url = "https://github.com/python-discord/bot-core/archive/refs/tags/v2.1.0.zip" [[package]] name = "certifi" version = "2021.10.8" @@ -1074,7 +1074,7 @@ toml = ">=0.10.0,<0.11.0" [[package]] name = "testfixtures" -version = "6.18.3" +version = "6.18.5" description = "A collection of helpers and mock objects for unit tests and doc tests." category = "dev" optional = false @@ -1169,7 +1169,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "3.9.*" -content-hash = "538a4809b9fc6fa93ee1baccf4016515ae311a886f1b7ec9b3d544bb87c830a3" +content-hash = "b8b28311c13f7a66f028041bae889131d3916ca7f667c9a7539871d21bbcd077" [metadata.files] aio-pika = [ @@ -1990,8 +1990,8 @@ taskipy = [ {file = "taskipy-1.7.0.tar.gz", hash = "sha256:960e480b1004971e76454ecd1a0484e640744a30073a1069894a311467f85ed8"}, ] testfixtures = [ - {file = "testfixtures-6.18.3-py2.py3-none-any.whl", hash = "sha256:6ddb7f56a123e1a9339f130a200359092bd0a6455e31838d6c477e8729bb7763"}, - {file = "testfixtures-6.18.3.tar.gz", hash = "sha256:2600100ae96ffd082334b378e355550fef8b4a529a6fa4c34f47130905c7426d"}, + {file = "testfixtures-6.18.5-py2.py3-none-any.whl", hash = "sha256:7de200e24f50a4a5d6da7019fb1197aaf5abd475efb2ec2422fdcf2f2eb98c1d"}, + {file = "testfixtures-6.18.5.tar.gz", hash = "sha256:02dae883f567f5b70fd3ad3c9eefb95912e78ac90be6c7444b5e2f46bf572c84"}, ] tldextract = [ {file = "tldextract-3.1.2-py2.py3-none-any.whl", hash = "sha256:f55e05f6bf4cc952a87d13594386d32ad2dd265630a8bdfc3df03bd60425c6b0"}, diff --git a/pyproject.toml b/pyproject.toml index 90b38ce66..1f02818ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" python = "3.9.*" disnake = "~=2.4" # See https://bot-core.pythondiscord.com/ for docs. -bot-core = {url = "https://github.com/python-discord/bot-core/archive/511bcba1b0196cd498c707a525ea56921bd971db.zip"} +bot-core = {url = "https://github.com/python-discord/bot-core/archive/refs/tags/v2.1.0.zip"} aio-pika = "~=6.1" aiodns = "~=2.0" aiohttp = "~=3.7" diff --git a/tests/README.md b/tests/README.md index b7fddfaa2..fc03b3d43 100644 --- a/tests/README.md +++ b/tests/README.md @@ -121,9 +121,9 @@ As we are trying to test our "units" of code independently, we want to make sure However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks". -To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.). +To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `disnake` types (see the section on the below.). -An example of mocking is when we provide a command with a mocked version of `discord.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): +An example of mocking is when we provide a command with a mocked version of `disnake.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): ```py import asyncio @@ -152,15 +152,15 @@ class BotCogTests(unittest.TestCase): By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. The [`AsyncMock`](https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock) that has been [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest) is an asynchronous version of `MagicMock` that can be used anywhere a coroutine is expected. -### Special mocks for some `discord.py` types +### Special mocks for some `disnake` types To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass. -In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. +In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual disnake types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `disnake` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR. -**Note:** These mock types only "know" the attributes that are set by default when these `discord.py` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: +**Note:** These mock types only "know" the attributes that are set by default when these `disnake` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: ```py import unittest.mock @@ -245,7 +245,7 @@ All in all, it's not only important to consider if all statements or branches we ### Unit Testing vs Integration Testing -Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `discord.py` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. +Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `disnake` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. The answer to this is that we also need to make sure that the individual parts come together into a working application. In addition, we will also need to make sure that the application communicates correctly with external applications. Since we currently have no automated integration tests or functional tests, that means **it's still very important to fire up the bot and test the code you've written manually** in addition to the unit tests you've written. diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..dea7dd678 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,8 +3,8 @@ import unittest from contextlib import contextmanager from typing import Dict -import discord -from discord.ext import commands +import disnake +from disnake.ext import commands from bot.log import get_logger from tests import helpers @@ -80,7 +80,7 @@ class LoggingTestsMixin: class CommandTestCase(unittest.IsolatedAsyncioTestCase): - """TestCase with additional assertions that are useful for testing Discord commands.""" + """TestCase with additional assertions that are useful for testing disnake commands.""" async def assertHasPermissionsCheck( # noqa: N802 self, @@ -98,7 +98,7 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): permissions = {k: not v for k, v in permissions.items()} ctx = helpers.MockContext() - ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) + ctx.channel.permissions_for.return_value = disnake.Permissions(**permissions) with self.assertRaises(commands.MissingPermissions) as cm: await cmd.can_run(ctx) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index fdd0ab74a..4ed7de64d 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -import discord +import disnake from bot import constants from bot.api import ResponseCodeError @@ -257,9 +257,9 @@ class SyncCogListenerTests(SyncCogTestCase): self.assertTrue(self.cog.on_member_update.__cog_listener__) subtests = ( - ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("activities", disnake.Game("Pong"), disnake.Game("Frogger")), ("nick", "old nick", "new nick"), - ("status", discord.Status.online, discord.Status.offline), + ("status", disnake.Status.online, disnake.Status.offline), ) for attribute, old_value, new_value in subtests: diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 541074336..9ecb8fae0 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -import discord +import disnake from bot.exts.backend.sync._syncers import RoleSyncer, _Diff, _Role from tests import helpers @@ -34,8 +34,8 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): for role in roles: mock_role = helpers.MockRole(**role) - mock_role.colour = discord.Colour(role["colour"]) - mock_role.permissions = discord.Permissions(role["permissions"]) + mock_role.colour = disnake.Colour(role["colour"]) + mock_role.permissions = disnake.Permissions(role["permissions"]) guild.roles.append(mock_role) return guild diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index 2fc97af2d..f55f5360f 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from discord.errors import NotFound +from disnake.errors import NotFound from bot.exts.backend.sync._syncers import UserSyncer, _Diff from tests import helpers diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 35fa0ee59..83b5f2749 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import AsyncMock, MagicMock, call, patch -from discord.ext.commands import errors +from disnake.ext.commands import errors from bot.api import ResponseCodeError from bot.errors import InvalidInfractedUserError, LockedResourceError diff --git a/tests/bot/exts/events/test_code_jams.py b/tests/bot/exts/events/test_code_jams.py index 0856546af..fdff36b61 100644 --- a/tests/bot/exts/events/test_code_jams.py +++ b/tests/bot/exts/events/test_code_jams.py @@ -1,8 +1,8 @@ import unittest from unittest.mock import AsyncMock, MagicMock, create_autospec, patch -from discord import CategoryChannel -from discord.ext.commands import BadArgument +from disnake import CategoryChannel +from disnake.ext.commands import BadArgument from bot.constants import Roles from bot.exts.events import code_jams diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index 06d78de9d..0cab405d0 100644 --- a/tests/bot/exts/filters/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import AsyncMock, Mock -from discord import NotFound +from disnake import NotFound from bot.constants import Channels, STAFF_ROLES from bot.exts.filters import antimalware diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py index c0c3baa42..46fa82fd7 100644 --- a/tests/bot/exts/filters/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock -from discord.ext.commands import NoPrivateMessage +from disnake.ext.commands import NoPrivateMessage from bot.exts.filters import security from tests.helpers import MockBot, MockContext diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py index 4db27269a..dd56c10dd 100644 --- a/tests/bot/exts/filters/test_token_remover.py +++ b/tests/bot/exts/filters/test_token_remover.py @@ -3,7 +3,7 @@ from re import Match from unittest import mock from unittest.mock import MagicMock -from discord import Colour, NotFound +from disnake import Colour, NotFound from bot import constants from bot.exts.filters import token_remover diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index d896b7652..9a35de7a9 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -3,7 +3,7 @@ import unittest import unittest.mock from datetime import datetime -import discord +import disnake from bot import constants from bot.exts.info import information @@ -43,7 +43,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): embed = kwargs.pop('embed') self.assertEqual(embed.title, "Role information (Total 1 role)") - self.assertEqual(embed.colour, discord.Colour.og_blurple()) + self.assertEqual(embed.colour, disnake.Colour.og_blurple()) self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") async def test_role_info_command(self): @@ -51,19 +51,19 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): dummy_role = helpers.MockRole( name="Dummy", id=112233445566778899, - colour=discord.Colour.og_blurple(), + colour=disnake.Colour.og_blurple(), position=10, members=[self.ctx.author], - permissions=discord.Permissions(0) + permissions=disnake.Permissions(0) ) admin_role = helpers.MockRole( name="Admins", id=998877665544332211, - colour=discord.Colour.red(), + colour=disnake.Colour.red(), position=3, members=[self.ctx.author], - permissions=discord.Permissions(0), + permissions=disnake.Permissions(0), ) self.ctx.guild.roles.extend([dummy_role, admin_role]) @@ -81,7 +81,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): admin_embed = admin_kwargs["embed"] self.assertEqual(dummy_embed.title, "Dummy info") - self.assertEqual(dummy_embed.colour, discord.Colour.og_blurple()) + self.assertEqual(dummy_embed.colour, disnake.Colour.og_blurple()) self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") @@ -91,7 +91,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(dummy_embed.fields[5].value, "0") self.assertEqual(admin_embed.title, "Admins info") - self.assertEqual(admin_embed.colour, discord.Colour.red()) + self.assertEqual(admin_embed.colour, disnake.Colour.red()) class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase): @@ -449,7 +449,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) - self.assertEqual(embed.colour, discord.Colour(100)) + self.assertEqual(embed.colour, disnake.Colour(100)) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", @@ -463,11 +463,11 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): """The embed should be created with the og blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() - user = helpers.MockMember(id=217, colour=discord.Colour.default()) + user = helpers.MockMember(id=217, colour=disnake.Colour.default()) user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) - self.assertEqual(embed.colour, discord.Colour.og_blurple()) + self.assertEqual(embed.colour, disnake.Colour.og_blurple()) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 052048053..b85d086c9 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -3,7 +3,7 @@ import textwrap import unittest from unittest.mock import ANY, AsyncMock, DEFAULT, MagicMock, Mock, patch -from discord.errors import NotFound +from disnake.errors import NotFound from bot.constants import Event from bot.exts.moderation.clean import Clean diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 350274ecd..6601b9d25 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -3,7 +3,7 @@ from collections import namedtuple from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch -from discord import Embed, Forbidden, HTTPException, NotFound +from disnake import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError from bot.constants import Colours, Icons diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..725455bbe 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -7,7 +7,7 @@ from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp -import discord +import disnake from async_rediscache import RedisSession from bot.constants import Colours @@ -24,7 +24,7 @@ class MockAsyncIterable: Helper for mocking asynchronous for loops. It does not appear that the `unittest` library currently provides anything that would - allow us to simply mock an async iterator, such as `discord.TextChannel.history`. + allow us to simply mock an async iterator, such as `disnake.TextChannel.history`. We therefore write our own helper to wrap a regular synchronous iterable, and feed its values via `__anext__` rather than `__next__`. @@ -60,7 +60,7 @@ class MockSignal(enum.Enum): B = "B" -mock_404 = discord.NotFound( +mock_404 = disnake.NotFound( response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response message="Not found", ) @@ -70,8 +70,8 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): """Collection of tests for the `download_file` helper function.""" async def test_download_file_success(self): - """If `to_file` succeeds, function returns the acquired `discord.File`.""" - file = MagicMock(discord.File, filename="bigbadlemon.jpg") + """If `to_file` succeeds, function returns the acquired `disnake.File`.""" + file = MagicMock(disnake.File, filename="bigbadlemon.jpg") attachment = MockAttachment(to_file=AsyncMock(return_value=file)) acquired_file = await incidents.download_file(attachment) @@ -86,7 +86,7 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): async def test_download_file_fail(self): """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" - arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") + arbitrary_error = disnake.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) with self.assertLogs(logger=incidents.log, level=logging.ERROR): @@ -121,7 +121,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_with_attachment_succeeds(self): """Incident's attachment is downloaded and displayed in the embed's image field.""" - file = MagicMock(discord.File, filename="bigbadjoe.jpg") + file = MagicMock(disnake.File, filename="bigbadjoe.jpg") attachment = MockAttachment(filename="bigbadjoe.jpg") incident = MockMessage(content="this is an incident", attachments=[attachment]) @@ -394,7 +394,7 @@ class TestArchive(TestIncidents): author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), id=123, ) - built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this + built_embed = MagicMock(disnake.Embed, id=123) # We patch `make_embed` to return this with patch("bot.exts.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) @@ -616,7 +616,7 @@ class TestResolveMessage(TestIncidents): """ self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - arbitrary_error = discord.HTTPException( + arbitrary_error = disnake.HTTPException( response=MagicMock(aiohttp.ClientResponse), message="Arbitrary error", ) @@ -649,7 +649,7 @@ class TestOnRawReactionAdd(TestIncidents): super().setUp() # Ensure `cog_instance` is assigned self.payload = MagicMock( - discord.RawReactionActionEvent, + disnake.RawReactionActionEvent, channel_id=123, # Patched at class level message_id=456, member=MockMember(bot=False), diff --git a/tests/bot/exts/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py index 79e04837d..6c9ebed95 100644 --- a/tests/bot/exts/moderation/test_modlog.py +++ b/tests/bot/exts/moderation/test_modlog.py @@ -1,6 +1,6 @@ import unittest -import discord +import disnake from bot.exts.moderation.modlog import ModLog from tests.helpers import MockBot, MockTextChannel @@ -19,7 +19,7 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase): self.bot.get_channel.return_value = self.channel await self.cog.send_log_message( icon_url="foo", - colour=discord.Colour.blue(), + colour=disnake.Colour.blue(), title="bar", text="foo bar" * 3000 ) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 92ce3418a..539651d6c 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -7,7 +7,7 @@ from unittest import mock from unittest.mock import AsyncMock, Mock from async_rediscache import RedisSession -from discord import PermissionOverwrite +from disnake import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence @@ -152,7 +152,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. self.assertTrue(self.cog._init_task.cancelled()) - @autospec("discord.ext.commands", "has_any_role") + @autospec("disnake.ext.commands", "has_any_role") @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): """Role check was called with `MODERATION_ROLES`""" diff --git a/tests/bot/exts/test_cogs.py b/tests/bot/exts/test_cogs.py index f8e120262..5cb071d58 100644 --- a/tests/bot/exts/test_cogs.py +++ b/tests/bot/exts/test_cogs.py @@ -8,7 +8,7 @@ from collections import defaultdict from types import ModuleType from unittest import mock -from discord.ext import commands +from disnake.ext import commands from bot import exts @@ -34,7 +34,7 @@ class CommandNameTests(unittest.TestCase): raise ImportError(name=name) # pragma: no cover # The mock prevents asyncio.get_event_loop() from being called. - with mock.patch("discord.ext.tasks.loop"): + with mock.patch("disnake.ext.tasks.loop"): prefix = f"{exts.__name__}." for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error): if not module.ispkg: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 8bdeedd27..bec7574fb 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -2,8 +2,8 @@ import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch -from discord import AllowedMentions -from discord.ext import commands +from disnake import AllowedMentions +from disnake.ext import commands from bot import constants from bot.exts.utils import snekbox diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 1bb678db2..afb8a973d 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -4,7 +4,7 @@ from datetime import MAXYEAR, datetime, timezone from unittest.mock import MagicMock, patch from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument +from disnake.ext.commands import BadArgument from bot.converters import Duration, HushDurationConverter, ISODateTime, PackageName diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 4ae11d5d3..5675e10ec 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock -from discord import DMChannel +from disnake import DMChannel from bot.utils import checks from bot.utils.checks import InWhitelistCheckFailure diff --git a/tests/helpers.py b/tests/helpers.py index 9d4988d23..bd1418ab9 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,9 +7,9 @@ import unittest.mock from asyncio import AbstractEventLoop from typing import Iterable, Optional -import discord +import disnake from aiohttp import ClientSession -from discord.ext.commands import Context +from disnake.ext.commands import Context from bot.api import APIClient from bot.async_stats import AsyncStatsClient @@ -26,11 +26,11 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -class HashableMixin(discord.mixins.EqualityComparable): +class HashableMixin(disnake.mixins.EqualityComparable): """ - Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. + Mixin that provides similar hashing and equality functionality as disnake's `Hashable` mixin. - Note: discord.py`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions + Note: disnake`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions for the relative small `id` integers we generally use in tests, this bit-shift is omitted. """ @@ -39,22 +39,22 @@ class HashableMixin(discord.mixins.EqualityComparable): class ColourMixin: - """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like discord.py does.""" + """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like disnake does.""" @property - def color(self) -> discord.Colour: + def color(self) -> disnake.Colour: return self.colour @color.setter - def color(self, color: discord.Colour) -> None: + def color(self, color: disnake.Colour) -> None: self.colour = color @property - def accent_color(self) -> discord.Colour: + def accent_color(self) -> disnake.Colour: return self.accent_colour @accent_color.setter - def accent_color(self, color: discord.Colour) -> None: + def accent_color(self, color: disnake.Colour) -> None: self.accent_colour = color @@ -63,7 +63,7 @@ class CustomMockMixin: Provides common functionality for our custom Mock types. The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock - object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the + object. As disnake also uses synchronous methods that nonetheless return coroutine objects, the class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The class method `spec_set` can be overwritten with the object that should be uses as the specification @@ -119,7 +119,7 @@ class CustomMockMixin: return klass(**kw) -# Create a guild instance to get a realistic Mock of `discord.Guild` +# Create a guild instance to get a realistic Mock of `disnake.Guild` guild_data = { 'id': 1, 'name': 'guild', @@ -139,20 +139,20 @@ guild_data = { 'owner_id': 1, 'afk_channel_id': 464033278631084042, } -guild_instance = discord.Guild(data=guild_data, state=unittest.mock.MagicMock()) +guild_instance = disnake.Guild(data=guild_data, state=unittest.mock.MagicMock()) class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ - A `Mock` subclass to mock `discord.Guild` objects. + A `Mock` subclass to mock `disnake.Guild` objects. - A MockGuild instance will follow the specifications of a `discord.Guild` instance. This means + A MockGuild instance will follow the specifications of a `disnake.Guild` instance. This means that if the code you're testing tries to access an attribute or method that normally does not - exist for a `discord.Guild` object this will raise an `AttributeError`. This is to make sure our - tests fail if the code we're testing uses a `discord.Guild` object in the wrong way. + exist for a `disnake.Guild` object this will raise an `AttributeError`. This is to make sure our + tests fail if the code we're testing uses a `disnake.Guild` object in the wrong way. One restriction of that is that if the code tries to access an attribute that normally does not - exist for `discord.Guild` instance but was added dynamically, this will raise an exception with + exist for `disnake.Guild` instance but was added dynamically, this will raise an exception with the mocked object. To get around that, you can set the non-standard attribute explicitly for the instance of `MockGuild`: @@ -160,10 +160,10 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): >>> guild.attribute_that_normally_does_not_exist = unittest.mock.MagicMock() In addition to attribute simulation, mocked guild object will pass an `isinstance` check against - `discord.Guild`: + `disnake.Guild`: >>> guild = MockGuild() - >>> isinstance(guild, discord.Guild) + >>> isinstance(guild, disnake.Guild) True For more info, see the `Mocking` section in `tests/README.md`. @@ -179,16 +179,16 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): self.roles.extend(roles) -# Create a Role instance to get a realistic Mock of `discord.Role` +# Create a Role instance to get a realistic Mock of `disnake.Role` role_data = {'name': 'role', 'id': 1} -role_instance = discord.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) +role_instance = disnake.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ - A Mock subclass to mock `discord.Role` objects. + A Mock subclass to mock `disnake.Role` objects. - Instances of this class will follow the specifications of `discord.Role` instances. For more + Instances of this class will follow the specifications of `disnake.Role` instances. For more information, see the `MockGuild` docstring. """ spec_set = role_instance @@ -198,40 +198,40 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): 'id': next(self.discord_id), 'name': 'role', 'position': 1, - 'colour': discord.Colour(0xdeadbf), - 'permissions': discord.Permissions(), + 'colour': disnake.Colour(0xdeadbf), + 'permissions': disnake.Permissions(), } super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if isinstance(self.colour, int): - self.colour = discord.Colour(self.colour) + self.colour = disnake.Colour(self.colour) if isinstance(self.permissions, int): - self.permissions = discord.Permissions(self.permissions) + self.permissions = disnake.Permissions(self.permissions) if 'mention' not in kwargs: self.mention = f'&{self.name}' def __lt__(self, other): - """Simplified position-based comparisons similar to those of `discord.Role`.""" + """Simplified position-based comparisons similar to those of `disnake.Role`.""" return self.position < other.position def __ge__(self, other): - """Simplified position-based comparisons similar to those of `discord.Role`.""" + """Simplified position-based comparisons similar to those of `disnake.Role`.""" return self.position >= other.position -# Create a Member instance to get a realistic Mock of `discord.Member` +# Create a Member instance to get a realistic Mock of `disnake.Member` member_data = {'user': 'lemon', 'roles': [1]} state_mock = unittest.mock.MagicMock() -member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock) +member_instance = disnake.Member(data=member_data, guild=guild_instance, state=state_mock) class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ A Mock subclass to mock Member objects. - Instances of this class will follow the specifications of `discord.Member` instances. For more + Instances of this class will follow the specifications of `disnake.Member` instances. For more information, see the `MockGuild` docstring. """ spec_set = member_instance @@ -249,11 +249,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin self.mention = f"@{self.name}" -# Create a User instance to get a realistic Mock of `discord.User` +# Create a User instance to get a realistic Mock of `disnake.User` _user_data_mock = collections.defaultdict(unittest.mock.MagicMock, { "accent_color": 0 }) -user_instance = discord.User( +user_instance = disnake.User( data=unittest.mock.MagicMock(get=unittest.mock.Mock(side_effect=_user_data_mock.get)), state=unittest.mock.MagicMock() ) @@ -263,7 +263,7 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ A Mock subclass to mock User objects. - Instances of this class will follow the specifications of `discord.User` instances. For more + Instances of this class will follow the specifications of `disnake.User` instances. For more information, see the `MockGuild` docstring. """ spec_set = user_instance @@ -305,7 +305,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Bot objects. - Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. + Instances of this class will follow the specifications of `disnake.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ spec_set = Bot( @@ -324,7 +324,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) -# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` +# Create a TextChannel instance to get a realistic MagicMock of `disnake.TextChannel` channel_data = { 'id': 1, 'type': 'TextChannel', @@ -337,17 +337,17 @@ channel_data = { } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = disnake.TextChannel(state=state, guild=guild, data=channel_data) channel_data["type"] = "VoiceChannel" -voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data) +voice_channel_instance = disnake.VoiceChannel(state=state, guild=guild, data=channel_data) class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock TextChannel objects. - Instances of this class will follow the specifications of `discord.TextChannel` instances. For + Instances of this class will follow the specifications of `disnake.TextChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = text_channel_instance @@ -364,7 +364,7 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock VoiceChannel objects. - Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For + Instances of this class will follow the specifications of `disnake.VoiceChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = voice_channel_instance @@ -381,14 +381,14 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): state = unittest.mock.MagicMock() me = unittest.mock.MagicMock() dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]} -dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data) +dm_channel_instance = disnake.DMChannel(me=me, state=state, data=dm_channel_data) class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock TextChannel objects. - Instances of this class will follow the specifications of `discord.TextChannel` instances. For + Instances of this class will follow the specifications of `disnake.TextChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = dm_channel_instance @@ -398,17 +398,17 @@ class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): super().__init__(**collections.ChainMap(kwargs, default_kwargs)) -# Create CategoryChannel instance to get a realistic MagicMock of `discord.CategoryChannel` +# Create CategoryChannel instance to get a realistic MagicMock of `disnake.CategoryChannel` category_channel_data = { 'id': 1, - 'type': discord.ChannelType.category, + 'type': disnake.ChannelType.category, 'name': 'category', 'position': 1, } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -category_channel_instance = discord.CategoryChannel( +category_channel_instance = disnake.CategoryChannel( state=state, guild=guild, data=category_channel_data ) @@ -419,7 +419,7 @@ class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): super().__init__(**collections.ChainMap(default_kwargs, kwargs)) -# Create a Message instance to get a realistic MagicMock of `discord.Message` +# Create a Message instance to get a realistic MagicMock of `disnake.Message` message_data = { 'id': 1, 'webhook_id': 431341013479718912, @@ -438,10 +438,10 @@ message_data = { } state = unittest.mock.MagicMock() channel = unittest.mock.MagicMock() -message_instance = discord.Message(state=state, channel=channel, data=message_data) +message_instance = disnake.Message(state=state, channel=channel, data=message_data) -# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context` +# Create a Context instance to get a realistic MagicMock of `disnake.ext.commands.Context` context_instance = Context( message=unittest.mock.MagicMock(), prefix="$", @@ -455,7 +455,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Context objects. - Instances of this class will follow the specifications of `discord.ext.commands.Context` + Instances of this class will follow the specifications of `disnake.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ spec_set = context_instance @@ -471,14 +471,14 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) -attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) +attachment_instance = disnake.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Attachment objects. - Instances of this class will follow the specifications of `discord.Attachment` instances. For + Instances of this class will follow the specifications of `disnake.Attachment` instances. For more information, see the `MockGuild` docstring. """ spec_set = attachment_instance @@ -488,7 +488,7 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. - Instances of this class will follow the specifications of `discord.Message` instances. For more + Instances of this class will follow the specifications of `disnake.Message` instances. For more information, see the `MockGuild` docstring. """ spec_set = message_instance @@ -501,14 +501,14 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'} -emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) +emoji_instance = disnake.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Emoji objects. - Instances of this class will follow the specifications of `discord.Emoji` instances. For more + Instances of this class will follow the specifications of `disnake.Emoji` instances. For more information, see the `MockGuild` docstring. """ spec_set = emoji_instance @@ -518,27 +518,27 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): self.guild = kwargs.get('guild', MockGuild()) -partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido') +partial_emoji_instance = disnake.PartialEmoji(animated=False, name='guido') class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock PartialEmoji objects. - Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For + Instances of this class will follow the specifications of `disnake.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ spec_set = partial_emoji_instance -reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) +reaction_instance = disnake.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) class MockReaction(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Reaction objects. - Instances of this class will follow the specifications of `discord.Reaction` instances. For + Instances of this class will follow the specifications of `disnake.Reaction` instances. For more information, see the `MockGuild` docstring. """ spec_set = reaction_instance @@ -556,14 +556,14 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): self.__str__.return_value = str(self.emoji) -webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock()) +webhook_instance = disnake.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock()) class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter. - Instances of this class will follow the specifications of `discord.Webhook` instances. For + Instances of this class will follow the specifications of `disnake.Webhook` instances. For more information, see the `MockGuild` docstring. """ spec_set = webhook_instance diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 81285e009..c5e799a85 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,20 +2,20 @@ import asyncio import unittest import unittest.mock -import discord +import disnake from tests import helpers class DiscordMocksTests(unittest.TestCase): - """Tests for our specialized discord.py mocks.""" + """Tests for our specialized disnake mocks.""" def test_mock_role_default_initialization(self): """Test if the default initialization of MockRole results in the correct object.""" role = helpers.MockRole() - # The `spec` argument makes sure `isistance` checks with `discord.Role` pass - self.assertIsInstance(role, discord.Role) + # The `spec` argument makes sure `isistance` checks with `disnake.Role` pass + self.assertIsInstance(role, disnake.Role) self.assertEqual(role.name, "role") self.assertEqual(role.position, 1) @@ -61,8 +61,8 @@ class DiscordMocksTests(unittest.TestCase): """Test if the default initialization of Mockmember results in the correct object.""" member = helpers.MockMember() - # The `spec` argument makes sure `isistance` checks with `discord.Member` pass - self.assertIsInstance(member, discord.Member) + # The `spec` argument makes sure `isistance` checks with `disnake.Member` pass + self.assertIsInstance(member, disnake.Member) self.assertEqual(member.name, "member") self.assertListEqual(member.roles, [helpers.MockRole(name="@everyone", position=1, id=0)]) @@ -86,18 +86,18 @@ class DiscordMocksTests(unittest.TestCase): """Test if MockMember accepts and sets abitrary keyword arguments.""" member = helpers.MockMember( nick="Dino Man", - colour=discord.Colour.default(), + colour=disnake.Colour.default(), ) self.assertEqual(member.nick, "Dino Man") - self.assertEqual(member.colour, discord.Colour.default()) + self.assertEqual(member.colour, disnake.Colour.default()) def test_mock_guild_default_initialization(self): """Test if the default initialization of Mockguild results in the correct object.""" guild = helpers.MockGuild() - # The `spec` argument makes sure `isistance` checks with `discord.Guild` pass - self.assertIsInstance(guild, discord.Guild) + # The `spec` argument makes sure `isistance` checks with `disnake.Guild` pass + self.assertIsInstance(guild, disnake.Guild) self.assertListEqual(guild.roles, [helpers.MockRole(name="@everyone", position=1, id=0)]) self.assertListEqual(guild.members, []) @@ -127,15 +127,15 @@ class DiscordMocksTests(unittest.TestCase): """Tests if MockBot initializes with the correct values.""" bot = helpers.MockBot() - # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Bot` pass - self.assertIsInstance(bot, discord.ext.commands.Bot) + # The `spec` argument makes sure `isistance` checks with `disnake.ext.commands.Bot` pass + self.assertIsInstance(bot, disnake.ext.commands.Bot) def test_mock_context_default_initialization(self): """Tests if MockContext initializes with the correct values.""" context = helpers.MockContext() - # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Context` pass - self.assertIsInstance(context, discord.ext.commands.Context) + # The `spec` argument makes sure `isistance` checks with `disnake.ext.commands.Context` pass + self.assertIsInstance(context, disnake.ext.commands.Context) self.assertIsInstance(context.bot, helpers.MockBot) self.assertIsInstance(context.guild, helpers.MockGuild) -- cgit v1.2.3 From 753e101df0ba4f51c85a426a7a9d9678f96e0fd7 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 14 Mar 2022 19:01:56 +0000 Subject: Revert "Update all references of discord.py to disnake" This reverts commit 960619c23300c56c8aaa454edc7241e2badf80ad. --- bot/__init__.py | 2 +- bot/bot.py | 20 ++-- bot/converters.py | 24 ++-- bot/decorators.py | 10 +- bot/errors.py | 2 +- bot/exts/backend/branding/_cog.py | 18 +-- bot/exts/backend/config_verifier.py | 2 +- bot/exts/backend/error_handler.py | 4 +- bot/exts/backend/logging.py | 4 +- bot/exts/backend/sync/_cog.py | 6 +- bot/exts/backend/sync/_syncers.py | 4 +- bot/exts/events/code_jams/_channels.py | 42 +++---- bot/exts/events/code_jams/_cog.py | 24 ++-- bot/exts/filters/antimalware.py | 4 +- bot/exts/filters/antispam.py | 8 +- bot/exts/filters/filter_lists.py | 4 +- bot/exts/filters/filtering.py | 26 ++--- bot/exts/filters/security.py | 2 +- bot/exts/filters/token_remover.py | 6 +- bot/exts/filters/webhook_remover.py | 4 +- bot/exts/fun/duck_pond.py | 22 ++-- bot/exts/fun/off_topic_names.py | 6 +- bot/exts/help_channels/_caches.py | 14 +-- bot/exts/help_channels/_channel.py | 18 +-- bot/exts/help_channels/_cog.py | 64 +++++------ bot/exts/help_channels/_message.py | 36 +++--- bot/exts/help_channels/_name.py | 6 +- bot/exts/info/code_snippets.py | 8 +- bot/exts/info/codeblock/_cog.py | 22 ++-- bot/exts/info/doc/_batch_parser.py | 4 +- bot/exts/info/doc/_cog.py | 18 +-- bot/exts/info/help.py | 4 +- bot/exts/info/information.py | 8 +- bot/exts/info/pep.py | 4 +- bot/exts/info/pypi.py | 6 +- bot/exts/info/python_news.py | 12 +- bot/exts/info/source.py | 4 +- bot/exts/info/stats.py | 6 +- bot/exts/info/subscribe.py | 26 ++--- bot/exts/info/tags.py | 16 +-- bot/exts/moderation/clean.py | 10 +- bot/exts/moderation/defcon.py | 6 +- bot/exts/moderation/dm_relay.py | 6 +- bot/exts/moderation/incidents.py | 100 ++++++++--------- bot/exts/moderation/infraction/_scheduler.py | 14 +-- bot/exts/moderation/infraction/_utils.py | 14 +-- bot/exts/moderation/infraction/infractions.py | 26 ++--- bot/exts/moderation/infraction/management.py | 36 +++--- bot/exts/moderation/infraction/superstarify.py | 6 +- bot/exts/moderation/metabase.py | 2 +- bot/exts/moderation/modlog.py | 72 ++++++------ bot/exts/moderation/modpings.py | 8 +- bot/exts/moderation/silence.py | 8 +- bot/exts/moderation/slowmode.py | 4 +- bot/exts/moderation/stream.py | 24 ++-- bot/exts/moderation/verification.py | 16 +-- bot/exts/moderation/voice_gate.py | 42 +++---- bot/exts/moderation/watchchannels/_watchchannel.py | 12 +- bot/exts/moderation/watchchannels/bigbrother.py | 4 +- bot/exts/recruitment/talentpool/_cog.py | 8 +- bot/exts/recruitment/talentpool/_review.py | 4 +- bot/exts/utils/bot.py | 4 +- bot/exts/utils/extensions.py | 6 +- bot/exts/utils/internal.py | 17 ++- bot/exts/utils/ping.py | 4 +- bot/exts/utils/reminders.py | 32 +++--- bot/exts/utils/snekbox.py | 4 +- bot/exts/utils/thread_bumper.py | 24 ++-- bot/exts/utils/utils.py | 6 +- bot/log.py | 2 +- bot/monkey_patches.py | 6 +- bot/pagination.py | 18 +-- bot/rules/attachments.py | 2 +- bot/rules/burst.py | 2 +- bot/rules/burst_shared.py | 2 +- bot/rules/chars.py | 2 +- bot/rules/discord_emojis.py | 2 +- bot/rules/duplicates.py | 2 +- bot/rules/links.py | 2 +- bot/rules/mentions.py | 2 +- bot/rules/newlines.py | 2 +- bot/rules/role_mentions.py | 2 +- bot/utils/channel.py | 16 +-- bot/utils/checks.py | 4 +- bot/utils/function.py | 6 +- bot/utils/helpers.py | 2 +- bot/utils/members.py | 18 +-- bot/utils/message_cache.py | 2 +- bot/utils/messages.py | 48 ++++---- bot/utils/webhooks.py | 10 +- poetry.lock | 12 +- pyproject.toml | 2 +- tests/README.md | 12 +- tests/base.py | 8 +- tests/bot/exts/backend/sync/test_cog.py | 6 +- tests/bot/exts/backend/sync/test_roles.py | 6 +- tests/bot/exts/backend/sync/test_users.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 2 +- tests/bot/exts/events/test_code_jams.py | 4 +- tests/bot/exts/filters/test_antimalware.py | 2 +- tests/bot/exts/filters/test_security.py | 2 +- tests/bot/exts/filters/test_token_remover.py | 2 +- tests/bot/exts/info/test_information.py | 22 ++-- .../exts/moderation/infraction/test_infractions.py | 2 +- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- tests/bot/exts/moderation/test_incidents.py | 20 ++-- tests/bot/exts/moderation/test_modlog.py | 4 +- tests/bot/exts/moderation/test_silence.py | 4 +- tests/bot/exts/test_cogs.py | 4 +- tests/bot/exts/utils/test_snekbox.py | 4 +- tests/bot/test_converters.py | 2 +- tests/bot/utils/test_checks.py | 2 +- tests/helpers.py | 124 ++++++++++----------- tests/test_helpers.py | 28 ++--- 114 files changed, 734 insertions(+), 735 deletions(-) (limited to 'tests') diff --git a/bot/__init__.py b/bot/__init__.py index b28513bff..17d99105a 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -3,7 +3,7 @@ import os from functools import partial, partialmethod from typing import TYPE_CHECKING -from disnake.ext import commands +from discord.ext import commands from bot import log, monkey_patches diff --git a/bot/bot.py b/bot/bot.py index 2769b7dda..94783a466 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -6,9 +6,9 @@ from contextlib import suppress from typing import Dict, List, Optional import aiohttp -import disnake +import discord from async_rediscache import RedisSession -from disnake.ext import commands +from discord.ext import commands from sentry_sdk import push_scope from bot import api, constants @@ -28,7 +28,7 @@ class StartupError(Exception): class Bot(commands.Bot): - """A subclass of `disnake.ext.commands.Bot` with an aiohttp session and an API client.""" + """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" def __init__(self, *args, redis_session: RedisSession, **kwargs): if "connector" in kwargs: @@ -109,9 +109,9 @@ class Bot(commands.Bot): def create(cls) -> "Bot": """Create and return an instance of a Bot.""" loop = asyncio.get_event_loop() - allowed_roles = list({disnake.Object(id_) for id_ in constants.MODERATION_ROLES}) + allowed_roles = list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}) - intents = disnake.Intents.all() + intents = discord.Intents.all() intents.presences = False intents.dm_typing = False intents.dm_reactions = False @@ -123,10 +123,10 @@ class Bot(commands.Bot): redis_session=_create_redis_session(loop), loop=loop, command_prefix=commands.when_mentioned_or(constants.Bot.prefix), - activity=disnake.Game(name=f"Commands: {constants.Bot.prefix}help"), + activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), case_insensitive=True, max_messages=10_000, - allowed_mentions=disnake.AllowedMentions(everyone=False, roles=allowed_roles), + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), intents=intents, ) @@ -258,7 +258,7 @@ class Bot(commands.Bot): await self.stats.create_socket() await super().login(*args, **kwargs) - async def on_guild_available(self, guild: disnake.Guild) -> None: + async def on_guild_available(self, guild: discord.Guild) -> None: """ Set the internal guild available event when constants.Guild.id becomes available. @@ -274,7 +274,7 @@ class Bot(commands.Bot): try: webhook = await self.fetch_webhook(constants.Webhooks.dev_log) - except disnake.HTTPException as e: + except discord.HTTPException as e: log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}") else: await webhook.send(f"<@&{constants.Roles.admin}> {msg}") @@ -283,7 +283,7 @@ class Bot(commands.Bot): self._guild_available.set() - async def on_guild_unavailable(self, guild: disnake.Guild) -> None: + async def on_guild_unavailable(self, guild: discord.Guild) -> None: """Clear the internal guild available event when constants.Guild.id becomes unavailable.""" if guild.id != constants.Guild.id: return diff --git a/bot/converters.py b/bot/converters.py index 9d93428ca..3522a32aa 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -6,12 +6,12 @@ from datetime import datetime, timezone from ssl import CertificateError import dateutil.parser -import disnake +import discord from aiohttp import ClientConnectorError from botcore.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta -from disnake.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter -from disnake.utils import escape_markdown, snowflake_time +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from discord.utils import escape_markdown, snowflake_time from bot import exts from bot.api import ResponseCodeError @@ -505,14 +505,14 @@ AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Us class UnambiguousUser(UserConverter): """ - Converts to a `disnake.User`, but only if a mention, userID or a username (name#discrim) is provided. + Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. Unlike the default `UserConverter`, it doesn't allow conversion from a name. This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - async def convert(self, ctx: Context, argument: str) -> disnake.User: - """Convert the `argument` to a `disnake.User`.""" + async def convert(self, ctx: Context, argument: str) -> discord.User: + """Convert the `argument` to a `discord.User`.""" if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: @@ -521,14 +521,14 @@ class UnambiguousUser(UserConverter): class UnambiguousMember(MemberConverter): """ - Converts to a `disnake.Member`, but only if a mention, userID or a username (name#discrim) is provided. + Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - async def convert(self, ctx: Context, argument: str) -> disnake.Member: - """Convert the `argument` to a `disnake.Member`.""" + async def convert(self, ctx: Context, argument: str) -> discord.Member: + """Convert the `argument` to a `discord.Member`.""" if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: @@ -588,10 +588,10 @@ if t.TYPE_CHECKING: OffTopicName = str # noqa: F811 ISODateTime = datetime # noqa: F811 HushDurationConverter = int # noqa: F811 - UnambiguousUser = disnake.User # noqa: F811 - UnambiguousMember = disnake.Member # noqa: F811 + UnambiguousUser = discord.User # noqa: F811 + UnambiguousMember = discord.Member # noqa: F811 Infraction = t.Optional[dict] # noqa: F811 Expiry = t.Union[Duration, ISODateTime] -MemberOrUser = t.Union[disnake.Member, disnake.User] +MemberOrUser = t.Union[discord.Member, discord.User] UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] diff --git a/bot/decorators.py b/bot/decorators.py index 9ae98442c..f4331264f 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -4,9 +4,9 @@ import types import typing as t from contextlib import suppress -from disnake import Member, NotFound -from disnake.ext import commands -from disnake.ext.commands import Cog, Context +from discord import Member, NotFound +from discord.ext import commands +from discord.ext.commands import Cog, Context from bot.constants import Channels, DEBUG_MODE, RedirectOutput from bot.log import get_logger @@ -179,7 +179,7 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: Ensure the highest role of the invoking member is greater than that of the target member. If the condition fails, a warning is sent to the invoking context. A target which is not an - instance of disnake.Member will always pass. + instance of discord.Member will always pass. `member_arg` is the keyword name or position index of the parameter of the decorated command whose value is the target member. @@ -195,7 +195,7 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: target = function.get_arg_value(member_arg, bound_args) if not isinstance(target, Member): - log.trace("The target is not a disnake.Member; skipping role hierarchy check.") + log.trace("The target is not a discord.Member; skipping role hierarchy check.") return await func(*args, **kwargs) ctx = function.get_arg_value(1, bound_args) diff --git a/bot/errors.py b/bot/errors.py index 298e7ac2d..078b645f1 100644 --- a/bot/errors.py +++ b/bot/errors.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Hashable, TYPE_CHECKING, Union -from disnake.ext.commands import ConversionError, Converter +from discord.ext.commands import ConversionError, Converter if TYPE_CHECKING: from bot.converters import MemberOrUser diff --git a/bot/exts/backend/branding/_cog.py b/bot/exts/backend/branding/_cog.py index a07e70d58..0c5839a7a 100644 --- a/bot/exts/backend/branding/_cog.py +++ b/bot/exts/backend/branding/_cog.py @@ -7,10 +7,10 @@ from enum import Enum from operator import attrgetter import async_timeout -import disnake +import discord from arrow import Arrow from async_rediscache import RedisCache -from disnake.ext import commands, tasks +from discord.ext import commands, tasks from bot.bot import Bot from bot.constants import Branding as BrandingConfig, Channels, Colours, Guild, MODERATION_ROLES @@ -42,7 +42,7 @@ def compound_hash(objects: t.Iterable[RemoteObject]) -> str: return "-".join(item.sha for item in objects) -def make_embed(title: str, description: str, *, success: bool) -> disnake.Embed: +def make_embed(title: str, description: str, *, success: bool) -> discord.Embed: """ Construct simple response embed. @@ -51,7 +51,7 @@ def make_embed(title: str, description: str, *, success: bool) -> disnake.Embed: For both `title` and `description`, empty string are valid values ~ fields will be empty. """ colour = Colours.soft_green if success else Colours.soft_red - return disnake.Embed(title=title[:256], description=description[:4096], colour=colour) + return discord.Embed(title=title[:256], description=description[:4096], colour=colour) def extract_event_duration(event: Event) -> str: @@ -147,13 +147,13 @@ class Branding(commands.Cog): return False await self.bot.wait_until_guild_available() - pydis: disnake.Guild = self.bot.get_guild(Guild.id) + pydis: discord.Guild = self.bot.get_guild(Guild.id) timeout = 10 # Seconds. try: with async_timeout.timeout(timeout): # Raise after `timeout` seconds. await pydis.edit(**{asset_type.value: file}) - except disnake.HTTPException: + except discord.HTTPException: log.exception("Asset upload to Discord failed.") return False except asyncio.TimeoutError: @@ -277,7 +277,7 @@ class Branding(commands.Cog): log.debug(f"Sending event information event to channel: {channel_id} ({is_notification=}).") await self.bot.wait_until_guild_available() - channel: t.Optional[disnake.TextChannel] = self.bot.get_channel(channel_id) + channel: t.Optional[discord.TextChannel] = self.bot.get_channel(channel_id) if channel is None: log.warning(f"Cannot send event information: channel {channel_id} not found!") @@ -294,7 +294,7 @@ class Branding(commands.Cog): else: content = "Python Discord is entering a new event!" if is_notification else None - embed = disnake.Embed(description=description[:4096], colour=disnake.Colour.og_blurple()) + embed = discord.Embed(description=description[:4096], colour=discord.Colour.og_blurple()) embed.set_footer(text=duration[:4096]) await channel.send(content=content, embed=embed) @@ -573,7 +573,7 @@ class Branding(commands.Cog): await ctx.send(embed=resp) return - embed = disnake.Embed(title="Current event calendar", colour=disnake.Colour.og_blurple()) + embed = discord.Embed(title="Current event calendar", colour=discord.Colour.og_blurple()) # Because Discord embeds can only contain up to 25 fields, we only show the first 25. first_25 = list(available_events.items())[:25] diff --git a/bot/exts/backend/config_verifier.py b/bot/exts/backend/config_verifier.py index 1ade2bce7..dc85a65a2 100644 --- a/bot/exts/backend/config_verifier.py +++ b/bot/exts/backend/config_verifier.py @@ -1,4 +1,4 @@ -from disnake.ext.commands import Cog +from discord.ext.commands import Cog from bot import constants from bot.bot import Bot diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 953843a77..c79c7b2a7 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,7 +1,7 @@ import difflib -from disnake import Embed -from disnake.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors +from discord import Embed +from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from sentry_sdk import push_scope from bot.api import ResponseCodeError diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py index 040fb5d37..2d03cd580 100644 --- a/bot/exts/backend/logging.py +++ b/bot/exts/backend/logging.py @@ -1,5 +1,5 @@ -from disnake import Embed -from disnake.ext.commands import Cog +from discord import Embed +from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index d08e56077..80f5750bc 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -1,8 +1,8 @@ from typing import Any, Dict -from disnake import Member, Role, User -from disnake.ext import commands -from disnake.ext.commands import Cog, Context +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context from bot import constants from bot.api import ResponseCodeError diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 48ee3c842..45301b098 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,8 +2,8 @@ import abc import typing as t from collections import namedtuple -from disnake import Guild -from disnake.ext.commands import Context +from discord import Guild +from discord.ext.commands import Context from more_itertools import chunked import bot diff --git a/bot/exts/events/code_jams/_channels.py b/bot/exts/events/code_jams/_channels.py index fc4693bd4..e8cf5f7bf 100644 --- a/bot/exts/events/code_jams/_channels.py +++ b/bot/exts/events/code_jams/_channels.py @@ -1,6 +1,6 @@ import typing as t -import disnake +import discord from bot.constants import Categories, Channels, Roles from bot.log import get_logger @@ -11,7 +11,7 @@ MAX_CHANNELS = 50 CATEGORY_NAME = "Code Jam" -async def _get_category(guild: disnake.Guild) -> disnake.CategoryChannel: +async def _get_category(guild: discord.Guild) -> discord.CategoryChannel: """ Return a code jam category. @@ -24,13 +24,13 @@ async def _get_category(guild: disnake.Guild) -> disnake.CategoryChannel: return await _create_category(guild) -async def _create_category(guild: disnake.Guild) -> disnake.CategoryChannel: +async def _create_category(guild: discord.Guild) -> discord.CategoryChannel: """Create a new code jam category and return it.""" log.info("Creating a new code jam category.") category_overwrites = { - guild.default_role: disnake.PermissionOverwrite(read_messages=False), - guild.me: disnake.PermissionOverwrite(read_messages=True) + guild.default_role: discord.PermissionOverwrite(read_messages=False), + guild.me: discord.PermissionOverwrite(read_messages=True) } category = await guild.create_category_channel( @@ -47,17 +47,17 @@ async def _create_category(guild: disnake.Guild) -> disnake.CategoryChannel: def _get_overwrites( - members: list[tuple[disnake.Member, bool]], - guild: disnake.Guild, -) -> dict[t.Union[disnake.Member, disnake.Role], disnake.PermissionOverwrite]: + members: list[tuple[discord.Member, bool]], + guild: discord.Guild, +) -> dict[t.Union[discord.Member, discord.Role], discord.PermissionOverwrite]: """Get code jam team channels permission overwrites.""" team_channel_overwrites = { - guild.default_role: disnake.PermissionOverwrite(read_messages=False), - guild.get_role(Roles.code_jam_event_team): disnake.PermissionOverwrite(read_messages=True) + guild.default_role: discord.PermissionOverwrite(read_messages=False), + guild.get_role(Roles.code_jam_event_team): discord.PermissionOverwrite(read_messages=True) } for member, _ in members: - team_channel_overwrites[member] = disnake.PermissionOverwrite( + team_channel_overwrites[member] = discord.PermissionOverwrite( read_messages=True ) @@ -65,10 +65,10 @@ def _get_overwrites( async def create_team_channel( - guild: disnake.Guild, + guild: discord.Guild, team_name: str, - members: list[tuple[disnake.Member, bool]], - team_leaders: disnake.Role + members: list[tuple[discord.Member, bool]], + team_leaders: discord.Role ) -> None: """Create the team's text channel.""" await _add_team_leader_roles(members, team_leaders) @@ -84,29 +84,29 @@ async def create_team_channel( ) -async def create_team_leader_channel(guild: disnake.Guild, team_leaders: disnake.Role) -> None: +async def create_team_leader_channel(guild: discord.Guild, team_leaders: discord.Role) -> None: """Create the Team Leader Chat channel for the Code Jam team leaders.""" - category: disnake.CategoryChannel = guild.get_channel(Categories.summer_code_jam) + category: discord.CategoryChannel = guild.get_channel(Categories.summer_code_jam) team_leaders_chat = await category.create_text_channel( name="team-leaders-chat", overwrites={ - guild.default_role: disnake.PermissionOverwrite(read_messages=False), - team_leaders: disnake.PermissionOverwrite(read_messages=True) + guild.default_role: discord.PermissionOverwrite(read_messages=False), + team_leaders: discord.PermissionOverwrite(read_messages=True) } ) await _send_status_update(guild, f"Created {team_leaders_chat.mention} in the {category} category.") -async def _send_status_update(guild: disnake.Guild, message: str) -> None: +async def _send_status_update(guild: discord.Guild, message: str) -> None: """Inform the events lead with a status update when the command is ran.""" - channel: disnake.TextChannel = guild.get_channel(Channels.code_jam_planning) + channel: discord.TextChannel = guild.get_channel(Channels.code_jam_planning) await channel.send(f"<@&{Roles.events_lead}>\n\n{message}") -async def _add_team_leader_roles(members: list[tuple[disnake.Member, bool]], team_leaders: disnake.Role) -> None: +async def _add_team_leader_roles(members: list[tuple[discord.Member, bool]], team_leaders: discord.Role) -> None: """Assign the team leader role to the team leaders.""" for member, is_leader in members: if is_leader: diff --git a/bot/exts/events/code_jams/_cog.py b/bot/exts/events/code_jams/_cog.py index 5cb11826d..452199f5f 100644 --- a/bot/exts/events/code_jams/_cog.py +++ b/bot/exts/events/code_jams/_cog.py @@ -3,9 +3,9 @@ import csv import typing as t from collections import defaultdict -import disnake -from disnake import Colour, Embed, Guild, Member -from disnake.ext import commands +import discord +from discord import Colour, Embed, Guild, Member +from discord.ext import commands from bot.bot import Bot from bot.constants import Emojis, Roles @@ -85,7 +85,7 @@ class CodeJams(commands.Cog): A confirmation message is displayed with the categories and channels to be deleted.. Pressing the added reaction deletes those channels. """ - def predicate_deletion_emoji_reaction(reaction: disnake.Reaction, user: disnake.User) -> bool: + def predicate_deletion_emoji_reaction(reaction: discord.Reaction, user: discord.User) -> bool: """Return True if the reaction :boom: was added by the context message author on this message.""" return ( reaction.message.id == message.id @@ -124,14 +124,14 @@ class CodeJams(commands.Cog): @staticmethod async def _build_confirmation_message( - categories: dict[disnake.CategoryChannel, list[disnake.abc.GuildChannel]] + categories: dict[discord.CategoryChannel, list[discord.abc.GuildChannel]] ) -> str: """Sends details of the channels to be deleted to the pasting service, and formats the confirmation message.""" - def channel_repr(channel: disnake.abc.GuildChannel) -> str: + def channel_repr(channel: discord.abc.GuildChannel) -> str: """Formats the channel name and ID and a readable format.""" return f"{channel.name} ({channel.id})" - def format_category_info(category: disnake.CategoryChannel, channels: list[disnake.abc.GuildChannel]) -> str: + def format_category_info(category: discord.CategoryChannel, channels: list[discord.abc.GuildChannel]) -> str: """Displays the category and the channels within it in a readable format.""" return f"{channel_repr(category)}:\n" + "\n".join(" - " + channel_repr(channel) for channel in channels) @@ -187,7 +187,7 @@ class CodeJams(commands.Cog): await old_team_channel.set_permissions(member, overwrite=None, reason=f"Participant moved to {new_team_name}") await new_team_channel.set_permissions( member, - overwrite=disnake.PermissionOverwrite(read_messages=True), + overwrite=discord.PermissionOverwrite(read_messages=True), reason=f"Participant moved from {old_team_channel.name}" ) @@ -212,16 +212,16 @@ class CodeJams(commands.Cog): await ctx.send(f"Removed the participant from `{self.team_name(channel)}`.") @staticmethod - def jam_categories(guild: Guild) -> list[disnake.CategoryChannel]: + def jam_categories(guild: Guild) -> list[discord.CategoryChannel]: """Get all the code jam team categories.""" return [category for category in guild.categories if category.name == _channels.CATEGORY_NAME] @staticmethod - def team_channel(guild: Guild, criterion: t.Union[str, Member]) -> t.Optional[disnake.TextChannel]: + def team_channel(guild: Guild, criterion: t.Union[str, Member]) -> t.Optional[discord.TextChannel]: """Get a team channel through either a participant or the team name.""" for category in CodeJams.jam_categories(guild): for channel in category.channels: - if isinstance(channel, disnake.TextChannel): + if isinstance(channel, discord.TextChannel): if ( # If it's a string. criterion == channel.name or criterion == CodeJams.team_name(channel) @@ -231,6 +231,6 @@ class CodeJams(commands.Cog): return channel @staticmethod - def team_name(channel: disnake.TextChannel) -> str: + def team_name(channel: discord.TextChannel) -> str: """Retrieves the team name from the given channel.""" return channel.name.replace("-", " ").title() diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py index e55ece910..6cccf3680 100644 --- a/bot/exts/filters/antimalware.py +++ b/bot/exts/filters/antimalware.py @@ -1,8 +1,8 @@ import typing as t from os.path import splitext -from disnake import Embed, Message, NotFound -from disnake.ext.commands import Cog +from discord import Embed, Message, NotFound +from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, Filter, URLs diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index c887cf5fc..bcd845a43 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -8,8 +8,8 @@ from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set import arrow -from disnake import Colour, Member, Message, NotFound, Object, TextChannel -from disnake.ext.commands import Cog +from discord import Colour, Member, Message, NotFound, Object, TextChannel +from discord.ext.commands import Cog from bot import rules from bot.bot import Bot @@ -195,7 +195,7 @@ class AntiSpam(Cog): result = await rule_function(message, messages_for_rule, rule_config) # If the rule returns `None`, that means the message didn't violate it. - # If it doesn't, it returns a tuple in the form `(str, Iterable[disnake.Member])` + # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` # which contains the reason for why the message violated the rule and # an iterable of all members that violated the rule. if result is not None: @@ -265,7 +265,7 @@ class AntiSpam(Cog): # In the rare case where we found messages matching the # spam filter across multiple channels, it is possible # that a single channel will only contain a single message - # to delete. If that should be the case, disnake will + # to delete. If that should be the case, discord.py will # use the "delete single message" endpoint instead of the # bulk delete endpoint, and the single message deletion # endpoint will complain if you give it that does not exist. diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py index 05910973a..a883ddf54 100644 --- a/bot/exts/filters/filter_lists.py +++ b/bot/exts/filters/filter_lists.py @@ -1,8 +1,8 @@ import re from typing import Optional -from disnake import Colour, Embed -from disnake.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role +from discord import Colour, Embed +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role from bot import constants from bot.api import ResponseCodeError diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index e8c9bab62..f44b28125 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -6,15 +6,15 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union import arrow import dateutil.parser -import disnake.errors +import discord.errors import regex import tldextract from async_rediscache import RedisCache from botcore.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta -from disnake import Colour, HTTPException, Member, Message, NotFound, TextChannel -from disnake.ext.commands import Cog -from disnake.utils import escape_markdown +from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel +from discord.ext.commands import Cog +from discord.utils import escape_markdown from bot.api import ResponseCodeError from bot.bot import Bot @@ -63,14 +63,14 @@ AUTO_BAN_REASON = ( ) AUTO_BAN_DURATION = timedelta(days=4) -FilterMatch = Union[re.Match, dict, bool, List[disnake.Embed]] +FilterMatch = Union[re.Match, dict, bool, List[discord.Embed]] class Stats(NamedTuple): """Additional stats on a triggered filter to append to a mod log.""" message_content: str - additional_embeds: Optional[List[disnake.Embed]] + additional_embeds: Optional[List[discord.Embed]] class Filtering(Cog): @@ -339,7 +339,7 @@ class Filtering(Cog): match = result if match: - is_private = msg.channel.type is disnake.ChannelType.private + is_private = msg.channel.type is discord.ChannelType.private # If this is a filter (not a watchlist) and not in a DM, delete the message. if _filter["type"] == "filter" and not is_private: @@ -354,7 +354,7 @@ class Filtering(Cog): # In addition, to avoid sending two notifications to the user, the # logs, and mod_alert, we return if the message no longer exists. await msg.delete() - except disnake.errors.NotFound: + except discord.errors.NotFound: return # Notify the user if the filter specifies @@ -409,14 +409,14 @@ class Filtering(Cog): self, filter_name: str, _filter: Dict[str, Any], - msg: disnake.Message, + msg: discord.Message, stats: Stats, reason: Optional[str] = None, *, is_eval: bool = False, ) -> None: """Send a mod log for a triggered filter.""" - if msg.channel.type is disnake.ChannelType.private: + if msg.channel.type is discord.ChannelType.private: channel_str = "via DM" ping_everyone = False else: @@ -478,7 +478,7 @@ class Filtering(Cog): additional_embeds = [] for _, data in match.items(): reason = f"Reason: {data['reason']} | " if data.get('reason') else "" - embed = disnake.Embed(description=( + embed = discord.Embed(description=( f"**Members:**\n{data['members']}\n" f"**Active:**\n{data['active']}" )) @@ -626,7 +626,7 @@ class Filtering(Cog): return invite_data if invite_data else False @staticmethod - async def _has_rich_embed(msg: Message) -> Union[bool, List[disnake.Embed]]: + async def _has_rich_embed(msg: Message) -> Union[bool, List[discord.Embed]]: """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" if msg.embeds: for embed in msg.embeds: @@ -662,7 +662,7 @@ class Filtering(Cog): """ try: await filtered_member.send(reason) - except disnake.errors.Forbidden: + except discord.errors.Forbidden: await channel.send(f"{filtered_member.mention} {reason}") def schedule_msg_delete(self, msg: dict) -> None: diff --git a/bot/exts/filters/security.py b/bot/exts/filters/security.py index bbb15542f..fe3918423 100644 --- a/bot/exts/filters/security.py +++ b/bot/exts/filters/security.py @@ -1,4 +1,4 @@ -from disnake.ext.commands import Cog, Context, NoPrivateMessage +from discord.ext.commands import Cog, Context, NoPrivateMessage from bot.bot import Bot from bot.log import get_logger diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filters/token_remover.py index da42bb0aa..520283ba3 100644 --- a/bot/exts/filters/token_remover.py +++ b/bot/exts/filters/token_remover.py @@ -3,8 +3,8 @@ import binascii import re import typing as t -from disnake import Colour, Message, NotFound -from disnake.ext.commands import Cog +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog from bot import utils from bot.bot import Bot @@ -53,7 +53,7 @@ class Token(t.NamedTuple): class TokenRemover(Cog): - """Scans messages for potential Discord bot tokens and removes them.""" + """Scans messages for potential discord.py bot tokens and removes them.""" def __init__(self, bot: Bot): self.bot = bot diff --git a/bot/exts/filters/webhook_remover.py b/bot/exts/filters/webhook_remover.py index a5d51700c..96334317c 100644 --- a/bot/exts/filters/webhook_remover.py +++ b/bot/exts/filters/webhook_remover.py @@ -1,7 +1,7 @@ import re -from disnake import Colour, Message, NotFound -from disnake.ext.commands import Cog +from discord import Colour, Message, NotFound +from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, Colours, Event, Icons diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py index 55196cd65..c51656343 100644 --- a/bot/exts/fun/duck_pond.py +++ b/bot/exts/fun/duck_pond.py @@ -1,9 +1,9 @@ import asyncio from typing import Union -import disnake -from disnake import Color, Embed, Message, RawReactionActionEvent, TextChannel, errors -from disnake.ext.commands import Cog, Context, command +import discord +from discord import Color, Embed, Message, RawReactionActionEvent, TextChannel, errors +from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot @@ -34,7 +34,7 @@ class DuckPond(Cog): try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except disnake.HTTPException: + except discord.HTTPException: log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") @staticmethod @@ -67,7 +67,7 @@ class DuckPond(Cog): return False @staticmethod - def _is_duck_emoji(emoji: Union[str, disnake.PartialEmoji, disnake.Emoji]) -> bool: + def _is_duck_emoji(emoji: Union[str, discord.PartialEmoji, discord.Emoji]) -> bool: """Check if the emoji is a valid duck emoji.""" if isinstance(emoji, str): return emoji == "🦆" @@ -111,7 +111,7 @@ class DuckPond(Cog): username=message.author.display_name, avatar_url=message.author.display_avatar.url ) - except disnake.HTTPException: + except discord.HTTPException: log.exception("Failed to send an attachment to the webhook") async def locked_relay(self, message: Message) -> bool: @@ -133,7 +133,7 @@ class DuckPond(Cog): await message.add_reaction("✅") return True - def _payload_has_duckpond_emoji(self, emoji: disnake.PartialEmoji) -> bool: + def _payload_has_duckpond_emoji(self, emoji: discord.PartialEmoji) -> bool: """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" if emoji.is_unicode_emoji(): # For unicode PartialEmojis, the `name` attribute is just the string @@ -165,7 +165,7 @@ class DuckPond(Cog): if not self._payload_has_duckpond_emoji(payload.emoji): return - channel = disnake.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) if channel is None: return @@ -175,10 +175,10 @@ class DuckPond(Cog): try: message = await channel.fetch_message(payload.message_id) - except disnake.NotFound: + except discord.NotFound: return # Message was deleted. - member = disnake.utils.get(message.guild.members, id=payload.user_id) + member = discord.utils.get(message.guild.members, id=payload.user_id) if not member: return # Member left or wasn't in the cache. @@ -205,7 +205,7 @@ class DuckPond(Cog): if payload.guild_id != constants.Guild.id: return - channel = disnake.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) if channel is None: return diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index d49f71320..7df1d172d 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -2,9 +2,9 @@ import difflib from datetime import timedelta import arrow -from disnake import Colour, Embed -from disnake.ext.commands import Cog, Context, group, has_any_role -from disnake.utils import sleep_until +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, group, has_any_role +from discord.utils import sleep_until from bot.api import ResponseCodeError from bot.bot import Bot diff --git a/bot/exts/help_channels/_caches.py b/bot/exts/help_channels/_caches.py index f4eaf3291..8d45c2466 100644 --- a/bot/exts/help_channels/_caches.py +++ b/bot/exts/help_channels/_caches.py @@ -1,24 +1,24 @@ from async_rediscache import RedisCache # This dictionary maps a help channel to the time it was claimed -# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] +# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] claim_times = RedisCache(namespace="HelpChannels.claim_times") # This cache tracks which channels are claimed by which members. -# RedisCache[disnake.TextChannel.id, t.Union[disnake.User.id, disnake.Member.id]] +# RedisCache[discord.TextChannel.id, t.Union[discord.User.id, discord.Member.id]] claimants = RedisCache(namespace="HelpChannels.help_channel_claimants") # Stores the timestamp of the last message from the claimant of a help channel -# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] +# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] claimant_last_message_times = RedisCache(namespace="HelpChannels.claimant_last_message_times") # This cache maps a help channel to the timestamp of the last non-claimant message. # This cache being empty for a given help channel indicates the question is unanswered. -# RedisCache[disnake.TextChannel.id, UtcPosixTimestamp] +# RedisCache[discord.TextChannel.id, UtcPosixTimestamp] non_claimant_last_message_times = RedisCache(namespace="HelpChannels.non_claimant_last_message_times") # This cache maps a help channel to original question message in same channel. -# RedisCache[disnake.TextChannel.id, disnake.Message.id] +# RedisCache[discord.TextChannel.id, discord.Message.id] question_messages = RedisCache(namespace="HelpChannels.question_messages") # This cache keeps track of the dynamic message ID for @@ -26,10 +26,10 @@ question_messages = RedisCache(namespace="HelpChannels.question_messages") dynamic_message = RedisCache(namespace="HelpChannels.dynamic_message") # This cache keeps track of who has help-dms on. -# RedisCache[disnake.User.id, bool] +# RedisCache[discord.User.id, bool] help_dm = RedisCache(namespace="HelpChannels.help_dm") # This cache tracks member who are participating and opted in to help channel dms. # serialise the set as a comma separated string to allow usage with redis -# RedisCache[disnake.TextChannel.id, str[set[disnake.User.id]]] +# RedisCache[discord.TextChannel.id, str[set[discord.User.id]]] session_participants = RedisCache(namespace="HelpChannels.session_participants") diff --git a/bot/exts/help_channels/_channel.py b/bot/exts/help_channels/_channel.py index 3c4eaa2b2..d9cebf215 100644 --- a/bot/exts/help_channels/_channel.py +++ b/bot/exts/help_channels/_channel.py @@ -4,7 +4,7 @@ from datetime import timedelta from enum import Enum import arrow -import disnake +import discord from arrow import Arrow import bot @@ -31,7 +31,7 @@ class ClosingReason(Enum): CLEANUP = "auto.cleanup" -def get_category_channels(category: disnake.CategoryChannel) -> t.Iterable[disnake.TextChannel]: +def get_category_channels(category: discord.CategoryChannel) -> t.Iterable[discord.TextChannel]: """Yield the text channels of the `category` in an unsorted manner.""" log.trace(f"Getting text channels in the category '{category}' ({category.id}).") @@ -41,7 +41,7 @@ def get_category_channels(category: disnake.CategoryChannel) -> t.Iterable[disna yield channel -async def get_closing_time(channel: disnake.TextChannel, init_done: bool) -> t.Tuple[Arrow, ClosingReason]: +async def get_closing_time(channel: discord.TextChannel, init_done: bool) -> t.Tuple[Arrow, ClosingReason]: """ Return the time at which the given help `channel` should be closed along with the reason. @@ -116,12 +116,12 @@ async def get_in_use_time(channel_id: int) -> t.Optional[timedelta]: return arrow.utcnow() - claimed -def is_excluded_channel(channel: disnake.abc.GuildChannel) -> bool: +def is_excluded_channel(channel: discord.abc.GuildChannel) -> bool: """Check if a channel should be excluded from the help channel system.""" - return not isinstance(channel, disnake.TextChannel) or channel.id in EXCLUDED_CHANNELS + return not isinstance(channel, discord.TextChannel) or channel.id in EXCLUDED_CHANNELS -async def move_to_bottom(channel: disnake.TextChannel, category_id: int, **options) -> None: +async def move_to_bottom(channel: discord.TextChannel, category_id: int, **options) -> None: """ Move the `channel` to the bottom position of `category` and edit channel attributes. @@ -130,8 +130,8 @@ async def move_to_bottom(channel: disnake.TextChannel, category_id: int, **optio really ends up at the bottom of the category. If `options` are provided, the channel will be edited after the move is completed. This is the - same order of operations that `disnake.TextChannel.edit` uses. For information on available - options, see the documentation on `disnake.TextChannel.edit`. While possible, position-related + same order of operations that `discord.TextChannel.edit` uses. For information on available + options, see the documentation on `discord.TextChannel.edit`. While possible, position-related options should be avoided, as it may interfere with the category move we perform. """ # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. @@ -161,7 +161,7 @@ async def move_to_bottom(channel: disnake.TextChannel, category_id: int, **optio await channel.edit(**options) -async def ensure_cached_claimant(channel: disnake.TextChannel) -> None: +async def ensure_cached_claimant(channel: discord.TextChannel) -> None: """ Ensure there is a claimant cached for each help channel. diff --git a/bot/exts/help_channels/_cog.py b/bot/exts/help_channels/_cog.py index fc55fa1df..d3d70e252 100644 --- a/bot/exts/help_channels/_cog.py +++ b/bot/exts/help_channels/_cog.py @@ -5,9 +5,9 @@ from datetime import timedelta from operator import attrgetter import arrow -import disnake -import disnake.abc -from disnake.ext import commands +import discord +import discord.abc +from discord.ext import commands from bot import constants from bot.bot import Bot @@ -66,16 +66,16 @@ class HelpChannels(commands.Cog): self.bot = bot self.scheduler = scheduling.Scheduler(self.__class__.__name__) - self.guild: disnake.Guild = None - self.cooldown_role: disnake.Role = None + self.guild: discord.Guild = None + self.cooldown_role: discord.Role = None # Categories - self.available_category: disnake.CategoryChannel = None - self.in_use_category: disnake.CategoryChannel = None - self.dormant_category: disnake.CategoryChannel = None + self.available_category: discord.CategoryChannel = None + self.in_use_category: discord.CategoryChannel = None + self.dormant_category: discord.CategoryChannel = None # Queues - self.channel_queue: asyncio.Queue[disnake.TextChannel] = None + self.channel_queue: asyncio.Queue[discord.TextChannel] = None self.name_queue: t.Deque[str] = None # Notifications @@ -84,7 +84,7 @@ class HelpChannels(commands.Cog): self.last_running_low_notification = arrow.get('1815-12-10T18:00:00.00000+00:00') self.dynamic_message: t.Optional[int] = None - self.available_help_channels: t.Set[disnake.TextChannel] = set() + self.available_help_channels: t.Set[discord.TextChannel] = set() # Asyncio stuff self.queue_tasks: t.List[asyncio.Task] = [] @@ -104,7 +104,7 @@ class HelpChannels(commands.Cog): @lock.lock_arg(NAMESPACE, "message", attrgetter("channel.id")) @lock.lock_arg(NAMESPACE, "message", attrgetter("author.id")) @lock.lock_arg(f"{NAMESPACE}.unclaim", "message", attrgetter("author.id"), wait=True) - async def claim_channel(self, message: disnake.Message) -> None: + async def claim_channel(self, message: discord.Message) -> None: """ Claim the channel in which the question `message` was sent. @@ -116,7 +116,7 @@ class HelpChannels(commands.Cog): try: await self.move_to_in_use(message.channel) - except disnake.DiscordServerError: + except discord.DiscordServerError: try: await message.channel.send( "The bot encountered a Discord API error while trying to move this channel, please try again later." @@ -133,14 +133,14 @@ class HelpChannels(commands.Cog): self.bot.stats.incr("help.failed_claims.500_on_move") return - embed = disnake.Embed( + embed = discord.Embed( description=f"Channel claimed by {message.author.mention}.", color=constants.Colours.bright_green, ) await message.channel.send(embed=embed) - # Handle odd edge case of `message.author` not being a `disnake.Member` (see bot#1839) - if not isinstance(message.author, disnake.Member): + # Handle odd edge case of `message.author` not being a `discord.Member` (see bot#1839) + if not isinstance(message.author, discord.Member): log.debug(f"{message.author} ({message.author.id}) isn't a member. Not giving cooldown role or sending DM.") else: await members.handle_role_change(message.author, message.author.add_roles, self.cooldown_role) @@ -189,7 +189,7 @@ class HelpChannels(commands.Cog): return queue - async def create_dormant(self) -> t.Optional[disnake.TextChannel]: + async def create_dormant(self) -> t.Optional[discord.TextChannel]: """ Create and return a new channel in the Dormant category. @@ -234,12 +234,12 @@ class HelpChannels(commands.Cog): May only be invoked by the channel's claimant or by staff. """ - # Don't use a disnake check because the check needs to fail silently. + # Don't use a discord.py check because the check needs to fail silently. if await self.close_check(ctx): log.info(f"Close command invoked by {ctx.author} in #{ctx.channel}.") await self.unclaim_channel(ctx.channel, closed_on=_channel.ClosingReason.COMMAND) - async def get_available_candidate(self) -> disnake.TextChannel: + async def get_available_candidate(self) -> discord.TextChannel: """ Return a dormant channel to turn into an available channel. @@ -313,7 +313,7 @@ class HelpChannels(commands.Cog): self.dormant_category = await channel_utils.get_or_fetch_channel( constants.Categories.help_dormant ) - except disnake.HTTPException: + except discord.HTTPException: log.exception("Failed to get a category; cog will be removed") self.bot.remove_cog(self.qualified_name) @@ -355,7 +355,7 @@ class HelpChannels(commands.Cog): log.info("Cog is ready!") - async def move_idle_channel(self, channel: disnake.TextChannel, has_task: bool = True) -> None: + async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None: """ Make the `channel` dormant if idle or schedule the move if still active. @@ -416,7 +416,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() - async def move_to_dormant(self, channel: disnake.TextChannel) -> None: + async def move_to_dormant(self, channel: discord.TextChannel) -> None: """Make the `channel` dormant.""" log.info(f"Moving #{channel} ({channel.id}) to the Dormant category.") await _channel.move_to_bottom( @@ -425,7 +425,7 @@ class HelpChannels(commands.Cog): ) log.trace(f"Sending dormant message for #{channel} ({channel.id}).") - embed = disnake.Embed( + embed = discord.Embed( description=_message.DORMANT_MSG.format( dormant=self.dormant_category.name, available=self.available_category.name, @@ -439,7 +439,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() @lock.lock_arg(f"{NAMESPACE}.unclaim", "channel") - async def unclaim_channel(self, channel: disnake.TextChannel, *, closed_on: _channel.ClosingReason) -> None: + async def unclaim_channel(self, channel: discord.TextChannel, *, closed_on: _channel.ClosingReason) -> None: """ Unclaim an in-use help `channel` to make it dormant. @@ -462,7 +462,7 @@ class HelpChannels(commands.Cog): async def _unclaim_channel( self, - channel: disnake.TextChannel, + channel: discord.TextChannel, claimant_id: t.Optional[int], closed_on: _channel.ClosingReason ) -> None: @@ -488,7 +488,7 @@ class HelpChannels(commands.Cog): if closed_on == _channel.ClosingReason.COMMAND: self.scheduler.cancel(channel.id) - async def move_to_in_use(self, channel: disnake.TextChannel) -> None: + async def move_to_in_use(self, channel: discord.TextChannel) -> None: """Make a channel in-use and schedule it to be made dormant.""" log.info(f"Moving #{channel} ({channel.id}) to the In Use category.") @@ -504,7 +504,7 @@ class HelpChannels(commands.Cog): _stats.report_counts() @commands.Cog.listener() - async def on_message(self, message: disnake.Message) -> None: + async def on_message(self, message: discord.Message) -> None: """Move an available channel to the In Use category and replace it with a dormant one.""" if message.author.bot: return # Ignore messages sent by bots. @@ -520,7 +520,7 @@ class HelpChannels(commands.Cog): await _message.update_message_caches(message) @commands.Cog.listener() - async def on_message_delete(self, msg: disnake.Message) -> None: + async def on_message_delete(self, msg: discord.Message) -> None: """ Reschedule an in-use channel to become dormant sooner if the channel is empty. @@ -542,7 +542,7 @@ class HelpChannels(commands.Cog): delay = constants.HelpChannels.deleted_idle_minutes * 60 self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel)) - async def wait_for_dormant_channel(self) -> disnake.TextChannel: + async def wait_for_dormant_channel(self) -> discord.TextChannel: """Wait for a dormant channel to become available in the queue and return it.""" log.trace("Waiting for a dormant channel.") @@ -569,7 +569,7 @@ class HelpChannels(commands.Cog): await self.bot.http.edit_message( constants.Channels.how_to_get_help, self.dynamic_message, content=available_channels, files=None ) - except disnake.NotFound: + except discord.NotFound: pass else: return @@ -593,7 +593,7 @@ class HelpChannels(commands.Cog): @lock.lock_arg(NAMESPACE, "message", attrgetter("channel.id")) @lock.lock_arg(NAMESPACE, "message", attrgetter("author.id")) - async def notify_session_participants(self, message: disnake.Message) -> None: + async def notify_session_participants(self, message: discord.Message) -> None: """ Check if the message author meets the requirements to be notified. @@ -615,7 +615,7 @@ class HelpChannels(commands.Cog): if message.author.id not in session_participants: session_participants.add(message.author.id) - embed = disnake.Embed( + embed = discord.Embed( title="Currently Helping", description=f"You're currently helping in {message.channel.mention}", color=constants.Colours.bright_green, @@ -625,7 +625,7 @@ class HelpChannels(commands.Cog): try: await message.author.send(embed=embed) - except disnake.Forbidden: + except discord.Forbidden: log.trace( f"Failed to send helpdm message to {message.author.id}. DMs Closed/Blocked. " "Removing user from helpdm." diff --git a/bot/exts/help_channels/_message.py b/bot/exts/help_channels/_message.py index e08043694..7ceed9b4d 100644 --- a/bot/exts/help_channels/_message.py +++ b/bot/exts/help_channels/_message.py @@ -2,7 +2,7 @@ import textwrap import typing as t import arrow -import disnake +import discord from arrow import Arrow import bot @@ -41,7 +41,7 @@ through our guide for **[asking a good question]({ASKING_GUIDE_URL})**. """ -async def update_message_caches(message: disnake.Message) -> None: +async def update_message_caches(message: discord.Message) -> None: """Checks the source of new content in a help channel and updates the appropriate cache.""" channel = message.channel @@ -62,18 +62,18 @@ async def update_message_caches(message: disnake.Message) -> None: await _caches.non_claimant_last_message_times.set(channel.id, timestamp) -async def get_last_message(channel: disnake.TextChannel) -> t.Optional[disnake.Message]: +async def get_last_message(channel: discord.TextChannel) -> t.Optional[discord.Message]: """Return the last message sent in the channel or None if no messages exist.""" log.trace(f"Getting the last message in #{channel} ({channel.id}).") try: return await channel.history(limit=1).next() # noqa: B305 - except disnake.NoMoreItems: + except discord.NoMoreItems: log.debug(f"No last message available; #{channel} ({channel.id}) has no messages.") return None -async def is_empty(channel: disnake.TextChannel) -> bool: +async def is_empty(channel: discord.TextChannel) -> bool: """Return True if there's an AVAILABLE_MSG and the messages leading up are bot messages.""" log.trace(f"Checking if #{channel} ({channel.id}) is empty.") @@ -92,13 +92,13 @@ async def is_empty(channel: disnake.TextChannel) -> bool: return False -async def dm_on_open(message: disnake.Message) -> None: +async def dm_on_open(message: discord.Message) -> None: """ DM claimant with a link to the claimed channel's first message, with a 100 letter preview of the message. Does nothing if the user has DMs disabled. """ - embed = disnake.Embed( + embed = discord.Embed( title="Help channel opened", description=f"You claimed {message.channel.mention}.", colour=bot.constants.Colours.bright_green, @@ -118,7 +118,7 @@ async def dm_on_open(message: disnake.Message) -> None: try: await message.author.send(embed=embed) log.trace(f"Sent DM to {message.author.id} after claiming help channel.") - except disnake.errors.Forbidden: + except discord.errors.Forbidden: log.trace( f"Ignoring to send DM to {message.author.id} after claiming help channel: DMs disabled." ) @@ -146,7 +146,7 @@ async def notify_none_remaining(last_notification: Arrow) -> t.Optional[Arrow]: log.trace("Notifying about lack of channels.") mentions = " ".join(f"<@&{role}>" for role in constants.HelpChannels.notify_none_remaining_roles) - allowed_roles = [disnake.Object(id_) for id_ in constants.HelpChannels.notify_none_remaining_roles] + allowed_roles = [discord.Object(id_) for id_ in constants.HelpChannels.notify_none_remaining_roles] channel = bot.instance.get_channel(constants.HelpChannels.notify_channel) if channel is None: @@ -157,7 +157,7 @@ async def notify_none_remaining(last_notification: Arrow) -> t.Optional[Arrow]: f"{mentions} A new available help channel is needed but there " "are no more dormant ones. Consider freeing up some in-use channels manually by " f"using the `{constants.Bot.prefix}dormant` command within the channels.", - allowed_mentions=disnake.AllowedMentions(everyone=False, roles=allowed_roles) + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) ) except Exception: # Handle it here cause this feature isn't critical for the functionality of the system. @@ -213,18 +213,18 @@ async def notify_running_low(number_of_channels_left: int, last_notification: Ar return arrow.utcnow() -async def pin(message: disnake.Message) -> None: +async def pin(message: discord.Message) -> None: """Pin an initial question `message` and store it in a cache.""" if await pin_wrapper(message.id, message.channel, pin=True): await _caches.question_messages.set(message.channel.id, message.id) -async def send_available_message(channel: disnake.TextChannel) -> None: +async def send_available_message(channel: discord.TextChannel) -> None: """Send the available message by editing a dormant message or sending a new message.""" channel_info = f"#{channel} ({channel.id})" log.trace(f"Sending available message in {channel_info}.") - embed = disnake.Embed( + embed = discord.Embed( color=constants.Colours.bright_green, description=AVAILABLE_MSG, ) @@ -240,7 +240,7 @@ async def send_available_message(channel: disnake.TextChannel) -> None: await channel.send(embed=embed) -async def unpin(channel: disnake.TextChannel) -> None: +async def unpin(channel: discord.TextChannel) -> None: """Unpin the initial question message sent in `channel`.""" msg_id = await _caches.question_messages.pop(channel.id) if msg_id is None: @@ -249,19 +249,19 @@ async def unpin(channel: disnake.TextChannel) -> None: await pin_wrapper(msg_id, channel, pin=False) -def _match_bot_embed(message: t.Optional[disnake.Message], description: str) -> bool: +def _match_bot_embed(message: t.Optional[discord.Message], description: str) -> bool: """Return `True` if the bot's `message`'s embed description matches `description`.""" if not message or not message.embeds: return False bot_msg_desc = message.embeds[0].description - if bot_msg_desc is disnake.Embed.Empty: + if bot_msg_desc is discord.Embed.Empty: log.trace("Last message was a bot embed but it was empty.") return False return message.author == bot.instance.user and bot_msg_desc.strip() == description.strip() -async def pin_wrapper(msg_id: int, channel: disnake.TextChannel, *, pin: bool) -> bool: +async def pin_wrapper(msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool: """ Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. @@ -277,7 +277,7 @@ async def pin_wrapper(msg_id: int, channel: disnake.TextChannel, *, pin: bool) - try: await func(channel.id, msg_id) - except disnake.HTTPException as e: + except discord.HTTPException as e: if e.code == 10008: log.debug(f"Message {msg_id} in {channel_str} doesn't exist; can't {verb}.") else: diff --git a/bot/exts/help_channels/_name.py b/bot/exts/help_channels/_name.py index 50b250cb5..a9d9b2df1 100644 --- a/bot/exts/help_channels/_name.py +++ b/bot/exts/help_channels/_name.py @@ -3,7 +3,7 @@ import typing as t from collections import deque from pathlib import Path -import disnake +import discord from bot import constants from bot.exts.help_channels._channel import MAX_CHANNELS_PER_CATEGORY, get_category_channels @@ -12,7 +12,7 @@ from bot.log import get_logger log = get_logger(__name__) -def create_name_queue(*categories: disnake.CategoryChannel) -> deque: +def create_name_queue(*categories: discord.CategoryChannel) -> deque: """ Return a queue of food names to use for creating new channels. @@ -50,7 +50,7 @@ def _get_names() -> t.List[str]: return all_names[:count] -def _get_used_names(*categories: disnake.CategoryChannel) -> t.Set[str]: +def _get_used_names(*categories: discord.CategoryChannel) -> t.Set[str]: """Return names which are already being used by channels in `categories`.""" log.trace("Getting channel names which are already being used.") diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py index 68eb52a59..f2f29020f 100644 --- a/bot/exts/info/code_snippets.py +++ b/bot/exts/info/code_snippets.py @@ -4,9 +4,9 @@ import textwrap from typing import Any from urllib.parse import quote_plus -import disnake +import discord from aiohttp import ClientResponseError -from disnake.ext.commands import Cog +from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels @@ -241,7 +241,7 @@ class CodeSnippets(Cog): return '\n'.join(map(lambda x: x[1], sorted(all_snippets))) @Cog.listener() - async def on_message(self, message: disnake.Message) -> None: + async def on_message(self, message: discord.Message) -> None: """Checks if the message has a snippet link, removes the embed, then sends the snippet contents.""" if message.author.bot: return @@ -255,7 +255,7 @@ class CodeSnippets(Cog): if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: try: await message.edit(suppress=True) - except disnake.NotFound: + except discord.NotFound: # Don't send snippets if the original message was deleted. return diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index cf8c7d0be..a859d8cef 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -1,9 +1,9 @@ import time from typing import Optional -import disnake -from disnake import Message, RawMessageUpdateEvent -from disnake.ext.commands import Cog +import discord +from discord import Message, RawMessageUpdateEvent +from discord.ext.commands import Cog from bot import constants from bot.bot import Bot @@ -62,9 +62,9 @@ class CodeBlockCog(Cog, name="Code Block"): self.codeblock_message_ids = {} @staticmethod - def create_embed(instructions: str) -> disnake.Embed: + def create_embed(instructions: str) -> discord.Embed: """Return an embed which displays code block formatting `instructions`.""" - return disnake.Embed(description=instructions) + return discord.Embed(description=instructions) async def get_sent_instructions(self, payload: RawMessageUpdateEvent) -> Optional[Message]: """ @@ -78,11 +78,11 @@ class CodeBlockCog(Cog, name="Code Block"): try: return await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) - except disnake.NotFound: + except discord.NotFound: log.debug("Could not find instructions message; it was probably deleted.") return None - def is_on_cooldown(self, channel: disnake.TextChannel) -> bool: + def is_on_cooldown(self, channel: discord.TextChannel) -> bool: """ Return True if an embed was sent too recently for `channel`. @@ -93,7 +93,7 @@ class CodeBlockCog(Cog, name="Code Block"): cooldown = constants.CodeBlock.cooldown_seconds return (time.time() - self.channel_cooldowns.get(channel.id, 0)) < cooldown - def is_valid_channel(self, channel: disnake.TextChannel) -> bool: + def is_valid_channel(self, channel: discord.TextChannel) -> bool: """Return True if `channel` is a help channel, may be on a cooldown, or is whitelisted.""" log.trace(f"Checking if #{channel} qualifies for code block detection.") return ( @@ -102,7 +102,7 @@ class CodeBlockCog(Cog, name="Code Block"): or channel.id in constants.CodeBlock.channel_whitelist ) - async def send_instructions(self, message: disnake.Message, instructions: str) -> None: + async def send_instructions(self, message: discord.Message, instructions: str) -> None: """ Send an embed with `instructions` on fixing an incorrect code block in a `message`. @@ -119,7 +119,7 @@ class CodeBlockCog(Cog, name="Code Block"): # Increase amount of codeblock correction in stats self.bot.stats.incr("codeblock_corrections") - def should_parse(self, message: disnake.Message) -> bool: + def should_parse(self, message: discord.Message) -> bool: """ Return True if `message` should be parsed. @@ -185,5 +185,5 @@ class CodeBlockCog(Cog, name="Code Block"): else: log.info("Message edited but still has invalid code blocks; editing instructions.") await bot_message.edit(embed=self.create_embed(instructions)) - except disnake.NotFound: + except discord.NotFound: log.debug("Could not find instructions message; it was probably deleted.") diff --git a/bot/exts/info/doc/_batch_parser.py b/bot/exts/info/doc/_batch_parser.py index 487a0fd21..c27f28eac 100644 --- a/bot/exts/info/doc/_batch_parser.py +++ b/bot/exts/info/doc/_batch_parser.py @@ -7,7 +7,7 @@ from contextlib import suppress from operator import attrgetter from typing import Deque, Dict, List, NamedTuple, Optional, Union -import disnake +import discord from bs4 import BeautifulSoup import bot @@ -48,7 +48,7 @@ class StaleInventoryNotifier: if await self.symbol_counter.increment_for(doc_item) < 3: self._warned_urls.add(doc_item.url) await self._init_task - embed = disnake.Embed( + embed = discord.Embed( description=f"Doc item `{doc_item.symbol_id=}` present in loaded documentation inventories " f"not found on [site]({doc_item.url}), inventories may need to be refreshed." ) diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 77fc61389..4dc5276d9 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -9,8 +9,8 @@ from types import SimpleNamespace from typing import Dict, NamedTuple, Optional, Tuple, Union import aiohttp -import disnake -from disnake.ext import commands +import discord +from discord.ext import commands from bot.api import ResponseCodeError from bot.bot import Bot @@ -275,7 +275,7 @@ class DocCog(commands.Cog): return "Unable to parse the requested symbol." return markdown - async def create_symbol_embed(self, symbol_name: str) -> Optional[disnake.Embed]: + async def create_symbol_embed(self, symbol_name: str) -> Optional[discord.Embed]: """ Attempt to scrape and fetch the data for the given `symbol_name`, and build an embed from its contents. @@ -304,8 +304,8 @@ class DocCog(commands.Cog): else: footer_text = "" - embed = disnake.Embed( - title=disnake.utils.escape_markdown(symbol_name), + embed = discord.Embed( + title=discord.utils.escape_markdown(symbol_name), url=f"{doc_item.url}#{doc_item.symbol_id}", description=await self.get_symbol_markdown(doc_item) ) @@ -331,9 +331,9 @@ class DocCog(commands.Cog): !docs getdoc aiohttp.ClientSession """ if not symbol_name: - inventory_embed = disnake.Embed( + inventory_embed = discord.Embed( title=f"All inventories (`{len(self.base_urls)}` total)", - colour=disnake.Colour.blue() + colour=discord.Colour.blue() ) lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) @@ -355,7 +355,7 @@ class DocCog(commands.Cog): # Make sure that we won't cause a ghost-ping by deleting the message if not (ctx.message.mentions or ctx.message.role_mentions): - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await ctx.message.delete() await error_message.delete() @@ -449,7 +449,7 @@ class DocCog(commands.Cog): if removed := ", ".join(old_inventories - new_inventories): removed = "- " + removed - embed = disnake.Embed( + embed = discord.Embed( title="Inventories refreshed", description=f"```diff\n{added}\n{removed}```" if added or removed else "" ) diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py index 29d73c564..864e7edd2 100644 --- a/bot/exts/info/help.py +++ b/bot/exts/info/help.py @@ -6,8 +6,8 @@ from collections import namedtuple from contextlib import suppress from typing import List, Optional, Union -from disnake import ButtonStyle, Colour, Embed, Emoji, Interaction, PartialEmoji, ui -from disnake.ext.commands import Bot, Cog, Command, CommandError, Context, DisabledCommand, Group, HelpCommand +from discord import ButtonStyle, Colour, Embed, Emoji, Interaction, PartialEmoji, ui +from discord.ext.commands import Bot, Cog, Command, CommandError, Context, DisabledCommand, Group, HelpCommand from rapidfuzz import fuzz, process from rapidfuzz.utils import default_process diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 44a9b8f1a..e616b9208 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -6,9 +6,9 @@ from textwrap import shorten from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union import rapidfuzz -from disnake import AllowedMentions, Colour, Embed, Guild, Message, Role -from disnake.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role -from disnake.utils import escape_markdown +from discord import AllowedMentions, Colour, Embed, Guild, Message, Role +from discord.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role +from discord.utils import escape_markdown from bot import constants from bot.api import ResponseCodeError @@ -466,7 +466,7 @@ class Information(Cog): async def send_raw_content(self, ctx: Context, message: Message, json: bool = False) -> None: """ - Send information about the raw API response for a `disnake.Message`. + Send information about the raw API response for a `discord.Message`. If `json` is True, send the information in a copy-pasteable Python format. """ diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 08c693581..67866620b 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -3,8 +3,8 @@ from email.parser import HeaderParser from io import StringIO from typing import Dict, Optional, Tuple -from disnake import Colour, Embed -from disnake.ext.commands import Cog, Context, command +from discord import Colour, Embed +from discord.ext.commands import Cog, Context, command from bot.bot import Bot from bot.constants import Keys diff --git a/bot/exts/info/pypi.py b/bot/exts/info/pypi.py index 0a7705eb0..dacf7bc12 100644 --- a/bot/exts/info/pypi.py +++ b/bot/exts/info/pypi.py @@ -3,9 +3,9 @@ import random import re from contextlib import suppress -from disnake import Embed, NotFound -from disnake.ext.commands import Cog, Context, command -from disnake.utils import escape_markdown +from discord import Embed, NotFound +from discord.ext.commands import Cog, Context, command +from discord.utils import escape_markdown from bot.bot import Bot from bot.constants import Colours, NEGATIVE_REPLIES, RedirectOutput diff --git a/bot/exts/info/python_news.py b/bot/exts/info/python_news.py index 7603b402b..2fad9d2ab 100644 --- a/bot/exts/info/python_news.py +++ b/bot/exts/info/python_news.py @@ -2,11 +2,11 @@ import re import typing as t from datetime import date, datetime -import disnake +import discord import feedparser from bs4 import BeautifulSoup -from disnake.ext.commands import Cog -from disnake.ext.tasks import loop +from discord.ext.commands import Cog +from discord.ext.tasks import loop from bot import constants from bot.bot import Bot @@ -40,7 +40,7 @@ class PythonNews(Cog): def __init__(self, bot: Bot): self.bot = bot self.webhook_names = {} - self.webhook: t.Optional[disnake.Webhook] = None + self.webhook: t.Optional[discord.Webhook] = None scheduling.create_task(self.get_webhook_names(), event_loop=self.bot.loop) scheduling.create_task(self.get_webhook_and_channel(), event_loop=self.bot.loop) @@ -119,7 +119,7 @@ class PythonNews(Cog): continue # Build an embed and send a webhook - embed = disnake.Embed( + embed = discord.Embed( title=self.escape_markdown(new["title"]), description=self.escape_markdown(new["summary"]), timestamp=new_datetime, @@ -189,7 +189,7 @@ class PythonNews(Cog): link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist) # Build an embed and send a message to the webhook - embed = disnake.Embed( + embed = discord.Embed( title=self.escape_markdown(thread_information["subject"]), description=content[:1000] + f"... [continue reading]({link})" if len(content) > 1000 else content, timestamp=new_date, diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index 6305a9842..e3e7029ca 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -2,8 +2,8 @@ import inspect from pathlib import Path from typing import Optional, Tuple, Union -from disnake import Embed -from disnake.ext import commands +from discord import Embed +from discord.ext import commands from bot.bot import Bot from bot.constants import URLs diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py index 08422b38e..4d8bb645e 100644 --- a/bot/exts/info/stats.py +++ b/bot/exts/info/stats.py @@ -1,8 +1,8 @@ import string -from disnake import Member, Message -from disnake.ext.commands import Cog, Context -from disnake.ext.tasks import loop +from discord import Member, Message +from discord.ext.commands import Cog, Context +from discord.ext.tasks import loop from bot.bot import Bot from bot.constants import Categories, Channels, Guild diff --git a/bot/exts/info/subscribe.py b/bot/exts/info/subscribe.py index 0f285e0cb..eff0c13b8 100644 --- a/bot/exts/info/subscribe.py +++ b/bot/exts/info/subscribe.py @@ -4,9 +4,9 @@ import typing as t from dataclasses import dataclass import arrow -import disnake -from disnake.ext import commands -from disnake.interactions import Interaction +import discord +from discord.ext import commands +from discord.interactions import Interaction from bot import constants from bot.bot import Bot @@ -58,10 +58,10 @@ DELETE_MESSAGE_AFTER = 300 # Seconds log = get_logger(__name__) -class RoleButtonView(disnake.ui.View): +class RoleButtonView(discord.ui.View): """A list of SingleRoleButtons to show to the member.""" - def __init__(self, member: disnake.Member): + def __init__(self, member: discord.Member): super().__init__() self.interaction_owner = member @@ -76,12 +76,12 @@ class RoleButtonView(disnake.ui.View): return True -class SingleRoleButton(disnake.ui.Button): +class SingleRoleButton(discord.ui.Button): """A button that adds or removes a role from the member depending on it's current state.""" - ADD_STYLE = disnake.ButtonStyle.success - REMOVE_STYLE = disnake.ButtonStyle.red - UNAVAILABLE_STYLE = disnake.ButtonStyle.secondary + ADD_STYLE = discord.ButtonStyle.success + REMOVE_STYLE = discord.ButtonStyle.red + UNAVAILABLE_STYLE = discord.ButtonStyle.secondary LABEL_FORMAT = "{action} role {role_name}." CUSTOM_ID_FORMAT = "subscribe-{role_id}" @@ -104,7 +104,7 @@ class SingleRoleButton(disnake.ui.Button): async def callback(self, interaction: Interaction) -> None: """Update the member's role and change button text to reflect current text.""" - if isinstance(interaction.user, disnake.User): + if isinstance(interaction.user, discord.User): log.trace("User %s is not a member", interaction.user) await interaction.message.delete() self.view.stop() @@ -117,7 +117,7 @@ class SingleRoleButton(disnake.ui.Button): await members.handle_role_change( interaction.user, interaction.user.remove_roles if self.assigned else interaction.user.add_roles, - disnake.Object(self.role.role_id), + discord.Object(self.role.role_id), ) self.assigned = not self.assigned @@ -133,7 +133,7 @@ class SingleRoleButton(disnake.ui.Button): self.label = self.LABEL_FORMAT.format(action="Remove" if self.assigned else "Add", role_name=self.role.name) try: await interaction.message.edit(view=self.view) - except disnake.NotFound: + except discord.NotFound: log.debug("Subscribe message for %s removed before buttons could be updated", interaction.user) self.view.stop() @@ -145,7 +145,7 @@ class Subscribe(commands.Cog): self.bot = bot self.init_task = scheduling.create_task(self.init_cog(), event_loop=self.bot.loop) self.assignable_roles: list[AssignableRole] = [] - self.guild: disnake.Guild = None + self.guild: discord.Guild = None async def init_cog(self) -> None: """Initialise the cog by resolving the role IDs in ASSIGNABLE_ROLES to role names.""" diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index baeb21adb..f66237c8e 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -6,10 +6,10 @@ import time from pathlib import Path from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union -import disnake +import discord import frontmatter -from disnake import Embed, Member -from disnake.ext.commands import Cog, Context, group +from discord import Embed, Member +from discord.ext.commands import Cog, Context, group from bot import constants from bot.bot import Bot @@ -81,7 +81,7 @@ class Tag: self.content = post.content self.metadata = post.metadata self._restricted_to: set[int] = set(self.metadata.get("restricted_to", ())) - self._cooldowns: dict[disnake.TextChannel, float] = {} + self._cooldowns: dict[discord.TextChannel, float] = {} @property def embed(self) -> Embed: @@ -90,18 +90,18 @@ class Tag: embed.description = self.content return embed - def accessible_by(self, member: disnake.Member) -> bool: + def accessible_by(self, member: discord.Member) -> bool: """Check whether `member` can access the tag.""" return bool( not self._restricted_to or self._restricted_to & {role.id for role in member.roles} ) - def on_cooldown_in(self, channel: disnake.TextChannel) -> bool: + def on_cooldown_in(self, channel: discord.TextChannel) -> bool: """Check whether the tag is on cooldown in `channel`.""" return self._cooldowns.get(channel, float("-inf")) > time.time() - def set_cooldown_for(self, channel: disnake.TextChannel) -> None: + def set_cooldown_for(self, channel: discord.TextChannel) -> None: """Set the tag to be on cooldown in `channel` for `constants.Cooldowns.tags` seconds.""" self._cooldowns[channel] = time.time() + constants.Cooldowns.tags @@ -344,7 +344,7 @@ class Tags(Cog): return result_lines - def accessible_tags_in_group(self, group: str, user: disnake.Member) -> list[str]: + def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]: """Return a formatted list of tags in `group`, that are accessible by `user`.""" return sorted( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py index 2e274b23b..cb6836258 100644 --- a/bot/exts/moderation/clean.py +++ b/bot/exts/moderation/clean.py @@ -7,10 +7,10 @@ from datetime import datetime from itertools import takewhile from typing import Callable, Iterable, Literal, Optional, TYPE_CHECKING, Union -from disnake import Colour, Message, NotFound, TextChannel, User, errors -from disnake.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role -from disnake.ext.commands.converter import TextChannelConverter -from disnake.ext.commands.errors import BadArgument +from discord import Colour, Message, NotFound, TextChannel, User, errors +from discord.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role +from discord.ext.commands.converter import TextChannelConverter +from discord.ext.commands.errors import BadArgument from bot.bot import Bot from bot.constants import Channels, CleanMessages, Colours, Emojis, Event, Icons, MODERATION_ROLES @@ -459,7 +459,7 @@ class Clean(Cog): regex: Optional[Regex] = None, bots_only: Optional[bool] = False, *, - channels: CleanChannels = None # "Optional" with disnake silently ignores incorrect input. + channels: CleanChannels = None # "Optional" with discord.py silently ignores incorrect input. ) -> None: """ Commands for cleaning messages in channels. diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py index 58e049d4f..178be734d 100644 --- a/bot/exts/moderation/defcon.py +++ b/bot/exts/moderation/defcon.py @@ -8,9 +8,9 @@ import arrow from aioredis import RedisError from async_rediscache import RedisCache from dateutil.relativedelta import relativedelta -from disnake import Colour, Embed, Forbidden, Member, TextChannel, User -from disnake.ext import tasks -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord import Colour, Embed, Forbidden, Member, TextChannel, User +from discord.ext import tasks +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Event, Icons, MODERATION_ROLES, Roles diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py index 28e131eb4..566422e29 100644 --- a/bot/exts/moderation/dm_relay.py +++ b/bot/exts/moderation/dm_relay.py @@ -1,5 +1,5 @@ -import disnake -from disnake.ext.commands import Cog, Context, command, has_any_role +import discord +from discord.ext.commands import Cog, Context, command, has_any_role from bot.bot import Bot from bot.constants import Emojis, MODERATION_ROLES @@ -17,7 +17,7 @@ class DMRelay(Cog): self.bot = bot @command(aliases=("relay", "dr")) - async def dmrelay(self, ctx: Context, user: disnake.User, limit: int = 100) -> None: + async def dmrelay(self, ctx: Context, user: discord.User, limit: int = 100) -> None: """Relays the direct message history between the bot and given user.""" log.trace(f"Relaying DMs with {user.name} ({user.id})") diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index c4c03e546..b579416a6 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -4,9 +4,9 @@ from datetime import datetime from enum import Enum from typing import Optional -import disnake +import discord from async_rediscache import RedisCache -from disnake.ext.commands import Cog, Context, MessageConverter, MessageNotFound +from discord.ext.commands import Cog, Context, MessageConverter, MessageNotFound from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Guild, Roles, Webhooks @@ -52,10 +52,10 @@ ALL_SIGNALS: set[str] = {signal.value for signal in Signal} # An embed coupled with an optional file to be dispatched # If the file is not None, the embed attempts to show it in its body -FileEmbed = tuple[disnake.Embed, Optional[disnake.File]] +FileEmbed = tuple[discord.Embed, Optional[discord.File]] -async def download_file(attachment: disnake.Attachment) -> Optional[disnake.File]: +async def download_file(attachment: discord.Attachment) -> Optional[discord.File]: """ Download & return `attachment` file. @@ -65,13 +65,13 @@ async def download_file(attachment: disnake.Attachment) -> Optional[disnake.File log.debug(f"Attempting to download attachment: {attachment.filename}") try: return await attachment.to_file() - except (disnake.NotFound, disnake.Forbidden) as exc: + except (discord.NotFound, discord.Forbidden) as exc: log.debug(f"Failed to download attachment: {exc}") except Exception: log.exception("Failed to download attachment") -async def make_embed(incident: disnake.Message, outcome: Signal, actioned_by: disnake.Member) -> FileEmbed: +async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> FileEmbed: """ Create an embed representation of `incident` for the #incidents-archive channel. @@ -97,7 +97,7 @@ async def make_embed(incident: disnake.Message, outcome: Signal, actioned_by: di colour = Colours.soft_red footer = f"Rejected by {actioned_by}" - embed = disnake.Embed( + embed = discord.Embed( description=incident.content, timestamp=datetime.utcnow(), colour=colour, @@ -113,12 +113,12 @@ async def make_embed(incident: disnake.Message, outcome: Signal, actioned_by: di else: embed.set_author(name="[Failed to relay attachment]", url=attachment.proxy_url) # Embed links the file else: - file = disnake.utils.MISSING + file = discord.utils.MISSING return embed, file -def is_incident(message: disnake.Message) -> bool: +def is_incident(message: discord.Message) -> bool: """True if `message` qualifies as an incident, False otherwise.""" conditions = ( message.channel.id == Channels.incidents, # Message sent in #incidents @@ -129,12 +129,12 @@ def is_incident(message: disnake.Message) -> bool: return all(conditions) -def own_reactions(message: disnake.Message) -> set[str]: +def own_reactions(message: discord.Message) -> set[str]: """Get the set of reactions placed on `message` by the bot itself.""" return {str(reaction.emoji) for reaction in message.reactions if reaction.me} -def has_signals(message: disnake.Message) -> bool: +def has_signals(message: discord.Message) -> bool: """True if `message` already has all `Signal` reactions, False otherwise.""" return ALL_SIGNALS.issubset(own_reactions(message)) @@ -167,9 +167,9 @@ def shorten_text(text: str) -> str: return text -async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[disnake.Embed]: +async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[discord.Embed]: """ - Create an embedded representation of the Discord message link contained in the incident report. + Create an embedded representation of the discord message link contained in the incident report. The Embed would contain the following information --> Author: @Jason Terror ♦ (736234578745884682) @@ -179,23 +179,23 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d embed = None try: - message: disnake.Message = await MessageConverter().convert(ctx, message_link) + message: discord.Message = await MessageConverter().convert(ctx, message_link) except MessageNotFound: mod_logs_channel = ctx.bot.get_channel(Channels.mod_log) - last_100_logs: list[disnake.Message] = await mod_logs_channel.history(limit=100).flatten() + last_100_logs: list[discord.Message] = await mod_logs_channel.history(limit=100).flatten() for log_entry in last_100_logs: if not log_entry.embeds: continue - log_embed: disnake.Embed = log_entry.embeds[0] + log_embed: discord.Embed = log_entry.embeds[0] if ( log_embed.author.name == "Message deleted" and f"[Jump to message]({message_link})" in log_embed.description ): - embed = disnake.Embed( - colour=disnake.Colour.dark_gold(), + embed = discord.Embed( + colour=discord.Colour.dark_gold(), title="Deleted Message Link", description=( f"Found <#{Channels.mod_log}> entry for deleted message: " @@ -203,12 +203,12 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d ) ) if not embed: - embed = disnake.Embed( - colour=disnake.Colour.red(), + embed = discord.Embed( + colour=discord.Colour.red(), title="Bad Message Link", description=f"Message {message_link} not found." ) - except disnake.DiscordException as e: + except discord.DiscordException as e: log.exception(f"Failed to make message link embed for '{message_link}', raised exception: {e}") else: channel = message.channel @@ -219,12 +219,12 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d ) return - embed = disnake.Embed( - colour=disnake.Colour.gold(), + embed = discord.Embed( + colour=discord.Colour.gold(), description=( f"**Author:** {format_user(message.author)}\n" f"**Channel:** {channel.mention} ({channel.category}" - f"{f'/#{channel.parent.name} - ' if isinstance(channel, disnake.Thread) else '/#'}" + f"{f'/#{channel.parent.name} - ' if isinstance(channel, discord.Thread) else '/#'}" f"{channel.name})\n" ), timestamp=message.created_at @@ -242,7 +242,7 @@ async def make_message_link_embed(ctx: Context, message_link: str) -> Optional[d return embed -async def add_signals(incident: disnake.Message) -> None: +async def add_signals(incident: discord.Message) -> None: """ Add `Signal` member emoji to `incident` as reactions. @@ -257,7 +257,7 @@ async def add_signals(incident: disnake.Message) -> None: log.trace(f"Adding reaction: {signal_emoji}") try: await incident.add_reaction(signal_emoji.value) - except disnake.NotFound as e: + except discord.NotFound as e: if e.code != 10008: raise @@ -300,7 +300,7 @@ class Incidents(Cog): """ # This dictionary maps an incident report message to the message link embed's ID - # RedisCache[disnake.Message.id, disnake.Message.id] + # RedisCache[discord.Message.id, discord.Message.id] message_link_embeds_cache = RedisCache() def __init__(self, bot: Bot) -> None: @@ -319,7 +319,7 @@ class Incidents(Cog): try: self.incidents_webhook = await self.bot.fetch_webhook(Webhooks.incidents) - except disnake.HTTPException: + except discord.HTTPException: log.error(f"Failed to fetch incidents webhook with id `{Webhooks.incidents}`.") async def crawl_incidents(self) -> None: @@ -335,7 +335,7 @@ class Incidents(Cog): Behaviour is configured by: `CRAWL_LIMIT`, `CRAWL_SLEEP`. """ await self.bot.wait_until_guild_available() - incidents: disnake.TextChannel = self.bot.get_channel(Channels.incidents) + incidents: discord.TextChannel = self.bot.get_channel(Channels.incidents) log.debug(f"Crawling messages in #incidents: {CRAWL_LIMIT=}, {CRAWL_SLEEP=}") async for message in incidents.history(limit=CRAWL_LIMIT): @@ -353,7 +353,7 @@ class Incidents(Cog): log.debug("Crawl task finished!") - async def archive(self, incident: disnake.Message, outcome: Signal, actioned_by: disnake.Member) -> bool: + async def archive(self, incident: discord.Message, outcome: Signal, actioned_by: discord.Member) -> bool: """ Relay an embed representation of `incident` to the #incidents-archive channel. @@ -392,7 +392,7 @@ class Incidents(Cog): log.trace("Message archived successfully!") return True - def make_confirmation_task(self, incident: disnake.Message, timeout: int = 5) -> asyncio.Task: + def make_confirmation_task(self, incident: discord.Message, timeout: int = 5) -> asyncio.Task: """ Create a task to wait `timeout` seconds for `incident` to be deleted. @@ -401,13 +401,13 @@ class Incidents(Cog): """ log.trace(f"Confirmation task will wait {timeout=} seconds for {incident.id=} to be deleted") - def check(payload: disnake.RawReactionActionEvent) -> bool: + def check(payload: discord.RawReactionActionEvent) -> bool: return payload.message_id == incident.id coroutine = self.bot.wait_for(event="raw_message_delete", check=check, timeout=timeout) return scheduling.create_task(coroutine, event_loop=self.bot.loop) - async def process_event(self, reaction: str, incident: disnake.Message, member: disnake.Member) -> None: + async def process_event(self, reaction: str, incident: discord.Message, member: discord.Member) -> None: """ Process a `reaction_add` event in #incidents. @@ -430,7 +430,7 @@ class Incidents(Cog): log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals") try: await incident.remove_reaction(reaction, member) - except disnake.NotFound: + except discord.NotFound: log.trace("Couldn't remove reaction because the reaction or its message was deleted") return @@ -440,7 +440,7 @@ class Incidents(Cog): log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") try: await incident.remove_reaction(reaction, member) - except disnake.NotFound: + except discord.NotFound: log.trace("Couldn't remove reaction because the reaction or its message was deleted") return @@ -461,7 +461,7 @@ class Incidents(Cog): log.trace("Deleting original message") try: await incident.delete() - except disnake.NotFound: + except discord.NotFound: log.trace("Couldn't delete message because it was already deleted") log.trace(f"Awaiting deletion confirmation: {timeout=} seconds") @@ -476,9 +476,9 @@ class Incidents(Cog): # Deletes the message link embeds found in cache from the channel and cache. await self.delete_msg_link_embed(incident.id) - async def resolve_message(self, message_id: int) -> Optional[disnake.Message]: + async def resolve_message(self, message_id: int) -> Optional[discord.Message]: """ - Get `disnake.Message` for `message_id` from cache, or API. + Get `discord.Message` for `message_id` from cache, or API. We first look into the local cache to see if the message is present. @@ -491,7 +491,7 @@ class Incidents(Cog): """ await self.bot.wait_until_guild_available() # First make sure that the cache is ready log.trace(f"Resolving message for: {message_id=}") - message: Optional[disnake.Message] = self.bot._connection._get_message(message_id) + message: Optional[discord.Message] = self.bot._connection._get_message(message_id) if message is not None: log.trace("Message was found in cache") @@ -500,7 +500,7 @@ class Incidents(Cog): log.trace("Message not found, attempting to fetch") try: message = await self.bot.get_channel(Channels.incidents).fetch_message(message_id) - except disnake.NotFound: + except discord.NotFound: log.trace("Message doesn't exist, it was likely already relayed") except Exception: log.exception(f"Failed to fetch message {message_id}!") @@ -509,7 +509,7 @@ class Incidents(Cog): return message @Cog.listener() - async def on_raw_reaction_add(self, payload: disnake.RawReactionActionEvent) -> None: + async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None: """ Pre-process `payload` and pass it to `process_event` if appropriate. @@ -521,11 +521,11 @@ class Incidents(Cog): Next, we acquire `event_lock` - to prevent racing, events are processed one at a time. - Once we have the lock, the `disnake.Message` object for this event must be resolved. + Once we have the lock, the `discord.Message` object for this event must be resolved. If the lock was previously held by an event which successfully relayed the incident, this will fail and we abort the current event. - Finally, with both the lock and the `disnake.Message` instance in our hands, we delegate + Finally, with both the lock and the `discord.Message` instance in our hands, we delegate to `process_event` to handle the event. The justification for using a raw listener is the need to receive events for messages @@ -554,7 +554,7 @@ class Incidents(Cog): log.trace("Releasing event lock") @Cog.listener() - async def on_message(self, message: disnake.Message) -> None: + async def on_message(self, message: discord.Message) -> None: """ Pass `message` to `add_signals` and `extract_message_links` if it satisfies `is_incident`. @@ -575,7 +575,7 @@ class Incidents(Cog): await self.send_message_link_embeds(embed_list, message, self.incidents_webhook) @Cog.listener() - async def on_raw_message_delete(self, payload: disnake.RawMessageDeleteEvent) -> None: + async def on_raw_message_delete(self, payload: discord.RawMessageDeleteEvent) -> None: """ Delete message link embeds for `payload.message_id`. @@ -584,7 +584,7 @@ class Incidents(Cog): if self.incidents_webhook: await self.delete_msg_link_embed(payload.message_id) - async def extract_message_links(self, message: disnake.Message) -> Optional[list[disnake.Embed]]: + async def extract_message_links(self, message: discord.Message) -> Optional[list[discord.Embed]]: """ Check if there's any message links in the text content. @@ -615,8 +615,8 @@ class Incidents(Cog): async def send_message_link_embeds( self, webhook_embed_list: list, - message: disnake.Message, - webhook: disnake.Webhook, + message: discord.Message, + webhook: discord.Webhook, ) -> Optional[int]: """ Send message link embeds to #incidents channel. @@ -634,7 +634,7 @@ class Incidents(Cog): avatar_url=message.author.display_avatar.url, wait=True, ) - except disnake.DiscordException: + except discord.DiscordException: log.exception( f"Failed to send message link embed {message.id} to #incidents." ) @@ -651,7 +651,7 @@ class Incidents(Cog): if webhook_msg_id: try: await self.incidents_webhook.delete_message(webhook_msg_id) - except disnake.errors.NotFound: + except discord.errors.NotFound: log.trace(f"Incidents message link embed (`{webhook_msg_id}`) has already been deleted, skipping.") await self.message_link_embeds_cache.delete(message_id) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 8107b502a..2fc54856f 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -5,8 +5,8 @@ from gettext import ngettext import arrow import dateutil.parser -import disnake -from disnake.ext.commands import Context +import discord +from discord.ext.commands import Context from bot import constants from bot.api import ResponseCodeError @@ -101,7 +101,7 @@ class InfractionScheduler: # Allowing mod log since this is a passive action that should be logged. try: await apply_coro - except disnake.HTTPException as e: + except discord.HTTPException as e: # When user joined and then right after this left again before action completed, this can't apply roles if e.code == 10007 or e.status == 404: log.info( @@ -200,7 +200,7 @@ class InfractionScheduler: if expiry: # Schedule the expiration of the infraction. self.schedule_expiration(infraction) - except disnake.HTTPException as e: + except discord.HTTPException as e: # Accordingly display that applying the infraction failed. # Don't use ctx.message.author; antispam only patches ctx.author. confirm_msg = ":x: failed to apply" @@ -209,7 +209,7 @@ class InfractionScheduler: log_title = "failed to apply" log_msg = f"Failed to apply {' '.join(infr_type.split('_'))} infraction #{id_} to {user}" - if isinstance(e, disnake.Forbidden): + if isinstance(e, discord.Forbidden): log.warning(f"{log_msg}: bot lacks permissions.") elif e.code == 10007 or e.status == 404: log.info( @@ -396,11 +396,11 @@ class InfractionScheduler: raise ValueError( f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" ) - except disnake.Forbidden: + except discord.Forbidden: log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") log_text["Failure"] = "The bot lacks permissions to do this (role hierarchy?)" log_content = mod_role.mention - except disnake.HTTPException as e: + except discord.HTTPException as e: if e.code == 10007 or e.status == 404: log.info( f"Can't pardon {infraction['type']} for user {infraction['user']} because user left the guild." diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 36e818ec6..c1be18362 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -2,8 +2,8 @@ import typing as t from datetime import datetime import arrow -import disnake -from disnake.ext.commands import Context +import discord +from discord.ext.commands import Context import bot from bot.api import ResponseCodeError @@ -86,7 +86,7 @@ async def post_infraction( dm_sent: bool = False, ) -> t.Optional[dict]: """Posts an infraction to the API.""" - if isinstance(user, (disnake.Member, disnake.User)) and user.bot: + if isinstance(user, (discord.Member, discord.User)) and user.bot: log.trace(f"Posting of {infr_type} infraction for {user} to the API aborted. User is a bot.") raise InvalidInfractedUserError(user) @@ -209,7 +209,7 @@ async def notify_infraction( text += INFRACTION_APPEAL_SERVER_FOOTER if infraction["type"] == 'ban' else INFRACTION_APPEAL_MODMAIL_FOOTER - embed = disnake.Embed( + embed = discord.Embed( description=text, colour=Colours.soft_red ) @@ -238,7 +238,7 @@ async def notify_pardon( """DM a user about their pardoned infraction and return True if the DM is successful.""" log.trace(f"Sending {user} a DM about their pardoned infraction.") - embed = disnake.Embed( + embed = discord.Embed( description=content, colour=Colours.soft_green ) @@ -248,7 +248,7 @@ async def notify_pardon( return await send_private_embed(user, embed) -async def send_private_embed(user: MemberOrUser, embed: disnake.Embed) -> bool: +async def send_private_embed(user: MemberOrUser, embed: discord.Embed) -> bool: """ A helper method for sending an embed to a user's DMs. @@ -257,7 +257,7 @@ async def send_private_embed(user: MemberOrUser, embed: disnake.Embed) -> bool: try: await user.send(embed=embed) return True - except (disnake.HTTPException, disnake.Forbidden, disnake.NotFound): + except (discord.HTTPException, discord.Forbidden, discord.NotFound): log.debug( f"Infraction-related information could not be sent to user {user} ({user.id}). " "The user either could not be retrieved or probably disabled their DMs." diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 5ff56abde..af42ab1b8 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -1,10 +1,10 @@ import textwrap import typing as t -import disnake -from disnake import Member -from disnake.ext import commands -from disnake.ext.commands import Context, command +import discord +from discord import Member +from discord.ext import commands +from discord.ext.commands import Context, command from bot import constants from bot.bot import Bot @@ -35,8 +35,8 @@ class Infractions(InfractionScheduler, commands.Cog): super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning", "voice_mute"}) self.category = "Moderation" - self._muted_role = disnake.Object(constants.Roles.muted) - self._voice_verified_role = disnake.Object(constants.Roles.voice_verified) + self._muted_role = discord.Object(constants.Roles.muted) + self._voice_verified_role = discord.Object(constants.Roles.voice_verified) @commands.Cog.listener() async def on_member_join(self, member: Member) -> None: @@ -123,7 +123,7 @@ class Infractions(InfractionScheduler, commands.Cog): log.error("Failed to apply ban to user %d", user.id) return - # Calling commands directly skips disnake's convertors, so we need to convert args manually. + # Calling commands directly skips Discord.py's convertors, so we need to convert args manually. clean_time = await Age().convert(ctx, "1h") log_url = await clean_cog._clean_messages( @@ -494,7 +494,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def pardon_mute( self, user_id: int, - guild: disnake.Guild, + guild: discord.Guild, reason: t.Optional[str], *, notify: bool = True @@ -525,16 +525,16 @@ class Infractions(InfractionScheduler, commands.Cog): return log_text - async def pardon_ban(self, user_id: int, guild: disnake.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + async def pardon_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: """Remove a user's ban on the Discord guild and return a log dict.""" - user = disnake.Object(user_id) + user = discord.Object(user_id) log_text = {} self.mod_log.ignore(Event.member_unban, user_id) try: await guild.unban(user, reason=reason) - except disnake.NotFound: + except discord.NotFound: log.info(f"Failed to unban user {user_id}: no active ban found on Discord") log_text["Note"] = "No active ban found on Discord." @@ -543,7 +543,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def pardon_voice_mute( self, user_id: int, - guild: disnake.Guild, + guild: discord.Guild, *, notify: bool = True ) -> t.Dict[str, str]: @@ -597,7 +597,7 @@ class Infractions(InfractionScheduler, commands.Cog): async def cog_command_error(self, ctx: Context, error: Exception) -> None: """Send a notification to the invoking context on a Union failure.""" if isinstance(error, commands.BadUnionArgument): - if disnake.User in error.converters or Member in error.converters: + if discord.User in error.converters or Member in error.converters: await ctx.send(str(error.errors[0])) error.handled = True diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index 25420cd7a..c12dff928 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -1,10 +1,10 @@ import textwrap import typing as t -import disnake -from disnake.ext import commands -from disnake.ext.commands import Context -from disnake.utils import escape_markdown +import discord +from discord.ext import commands +from discord.ext.commands import Context +from discord.utils import escape_markdown from bot import constants from bot.bot import Bot @@ -52,9 +52,9 @@ class ModManagement(commands.Cog): await ctx.send_help(ctx.command) return - embed = disnake.Embed( + embed = discord.Embed( title=f"Infraction #{infraction['id']}", - colour=disnake.Colour.orange() + colour=discord.Colour.orange() ) await self.send_infraction_list(ctx, embed, [infraction]) @@ -222,7 +222,7 @@ class ModManagement(commands.Cog): await self.mod_log.send_log_message( icon_url=constants.Icons.pencil, - colour=disnake.Colour.og_blurple(), + colour=discord.Colour.og_blurple(), title="Infraction edited", thumbnail=thumbnail, text=textwrap.dedent(f""" @@ -240,21 +240,21 @@ class ModManagement(commands.Cog): async def infraction_search_group(self, ctx: Context, query: t.Union[UnambiguousUser, Snowflake, str]) -> None: """Searches for infractions in the database.""" if isinstance(query, int): - await self.search_user(ctx, disnake.Object(query)) + await self.search_user(ctx, discord.Object(query)) elif isinstance(query, str): await self.search_reason(ctx, query) else: await self.search_user(ctx, query) @infraction_search_group.command(name="user", aliases=("member", "userid")) - async def search_user(self, ctx: Context, user: t.Union[MemberOrUser, disnake.Object]) -> None: + async def search_user(self, ctx: Context, user: t.Union[MemberOrUser, discord.Object]) -> None: """Search for infractions by member.""" infraction_list = await self.bot.api_client.get( 'bot/infractions/expanded', params={'user__id': str(user.id)} ) - if isinstance(user, (disnake.Member, disnake.User)): + if isinstance(user, (discord.Member, discord.User)): user_str = escape_markdown(str(user)) else: if infraction_list: @@ -264,9 +264,9 @@ class ModManagement(commands.Cog): user_str = str(user.id) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = disnake.Embed( + embed = discord.Embed( title=f"Infractions for {user_str} ({formatted_infraction_count} total)", - colour=disnake.Colour.orange() + colour=discord.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -279,9 +279,9 @@ class ModManagement(commands.Cog): ) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = disnake.Embed( + embed = discord.Embed( title=f"Infractions matching `{reason}` ({formatted_infraction_count} total)", - colour=disnake.Colour.orange() + colour=discord.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -319,9 +319,9 @@ class ModManagement(commands.Cog): ) formatted_infraction_count = self.format_infraction_count(len(infraction_list)) - embed = disnake.Embed( + embed = discord.Embed( title=f"Infractions by {actor} ({formatted_infraction_count} total)", - colour=disnake.Colour.orange() + colour=discord.Colour.orange() ) await self.send_infraction_list(ctx, embed, infraction_list) @@ -344,7 +344,7 @@ class ModManagement(commands.Cog): async def send_infraction_list( self, ctx: Context, - embed: disnake.Embed, + embed: discord.Embed, infractions: t.Iterable[t.Dict[str, t.Any]] ) -> None: """Send a paginated embed of infractions for the specified user.""" @@ -433,7 +433,7 @@ class ModManagement(commands.Cog): async def cog_command_error(self, ctx: Context, error: commands.CommandError) -> None: """Handles errors for commands within this cog.""" if isinstance(error, commands.BadUnionArgument): - if disnake.User in error.converters: + if discord.User in error.converters: await ctx.send(str(error.errors[0])) error.handled = True diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 41ba52580..b91a5edba 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -4,9 +4,9 @@ import textwrap import typing as t from pathlib import Path -from disnake import Embed, Member -from disnake.ext.commands import Cog, Context, command, has_any_role -from disnake.utils import escape_markdown +from discord import Embed, Member +from discord.ext.commands import Cog, Context, command, has_any_role +from discord.utils import escape_markdown from bot import constants from bot.bot import Bot diff --git a/bot/exts/moderation/metabase.py b/bot/exts/moderation/metabase.py index 482d49b83..ce9c220b3 100644 --- a/bot/exts/moderation/metabase.py +++ b/bot/exts/moderation/metabase.py @@ -8,7 +8,7 @@ import arrow from aiohttp.client_exceptions import ClientResponseError from arrow import Arrow from async_rediscache import RedisCache -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Metabase as MetabaseConfig, Roles diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index a96638e53..32ea0dc6a 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -5,13 +5,13 @@ import typing as t from datetime import datetime, timezone from itertools import zip_longest -import disnake +import discord from dateutil.relativedelta import relativedelta from deepdiff import DeepDiff -from disnake import Colour, Message, Thread -from disnake.abc import GuildChannel -from disnake.ext.commands import Cog, Context -from disnake.utils import escape_markdown +from discord import Colour, Message, Thread +from discord.abc import GuildChannel +from discord.ext.commands import Cog, Context +from discord.utils import escape_markdown from bot.bot import Bot from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs @@ -21,7 +21,7 @@ from bot.utils.messages import format_user log = get_logger(__name__) -GUILD_CHANNEL = t.Union[disnake.CategoryChannel, disnake.TextChannel, disnake.VoiceChannel] +GUILD_CHANNEL = t.Union[discord.CategoryChannel, discord.TextChannel, discord.VoiceChannel] CHANNEL_CHANGES_UNSUPPORTED = ("permissions",) CHANNEL_CHANGES_SUPPRESSED = ("_overwrites", "position") @@ -45,7 +45,7 @@ class ModLog(Cog, name="ModLog"): async def upload_log( self, - messages: t.Iterable[disnake.Message], + messages: t.Iterable[discord.Message], actor_id: int, attachments: t.Iterable[t.List[str]] = None ) -> str: @@ -83,22 +83,22 @@ class ModLog(Cog, name="ModLog"): async def send_log_message( self, icon_url: t.Optional[str], - colour: t.Union[disnake.Colour, int], + colour: t.Union[discord.Colour, int], title: t.Optional[str], text: str, - thumbnail: t.Optional[t.Union[str, disnake.Asset]] = None, + thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, channel_id: int = Channels.mod_log, ping_everyone: bool = False, - files: t.Optional[t.List[disnake.File]] = None, + files: t.Optional[t.List[discord.File]] = None, content: t.Optional[str] = None, - additional_embeds: t.Optional[t.List[disnake.Embed]] = None, + additional_embeds: t.Optional[t.List[discord.Embed]] = None, timestamp_override: t.Optional[datetime] = None, footer: t.Optional[str] = None, ) -> Context: """Generate log embed and send to logging channel.""" await self.bot.wait_until_guild_available() # Truncate string directly here to avoid removing newlines - embed = disnake.Embed( + embed = discord.Embed( description=text[:4093] + "..." if len(text) > 4096 else text ) @@ -143,10 +143,10 @@ class ModLog(Cog, name="ModLog"): if channel.guild.id != GuildConstant.id: return - if isinstance(channel, disnake.CategoryChannel): + if isinstance(channel, discord.CategoryChannel): title = "Category created" message = f"{channel.name} (`{channel.id}`)" - elif isinstance(channel, disnake.VoiceChannel): + elif isinstance(channel, discord.VoiceChannel): title = "Voice channel created" if channel.category: @@ -169,14 +169,14 @@ class ModLog(Cog, name="ModLog"): if channel.guild.id != GuildConstant.id: return - if isinstance(channel, disnake.CategoryChannel): + if isinstance(channel, discord.CategoryChannel): title = "Category deleted" - elif isinstance(channel, disnake.VoiceChannel): + elif isinstance(channel, discord.VoiceChannel): title = "Voice channel deleted" else: title = "Text channel deleted" - if channel.category and not isinstance(channel, disnake.CategoryChannel): + if channel.category and not isinstance(channel, discord.CategoryChannel): message = f"{channel.category}/{channel.name} (`{channel.id}`)" else: message = f"{channel.name} (`{channel.id}`)" @@ -256,7 +256,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_create(self, role: disnake.Role) -> None: + async def on_guild_role_create(self, role: discord.Role) -> None: """Log role create event to mod log.""" if role.guild.id != GuildConstant.id: return @@ -267,7 +267,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_delete(self, role: disnake.Role) -> None: + async def on_guild_role_delete(self, role: discord.Role) -> None: """Log role delete event to mod log.""" if role.guild.id != GuildConstant.id: return @@ -278,7 +278,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_role_update(self, before: disnake.Role, after: disnake.Role) -> None: + async def on_guild_role_update(self, before: discord.Role, after: discord.Role) -> None: """Log role update event to mod log.""" if before.guild.id != GuildConstant.id: return @@ -331,7 +331,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_guild_update(self, before: disnake.Guild, after: disnake.Guild) -> None: + async def on_guild_update(self, before: discord.Guild, after: discord.Guild) -> None: """Log guild update event to mod log.""" if before.id != GuildConstant.id: return @@ -382,7 +382,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_ban(self, guild: disnake.Guild, member: disnake.Member) -> None: + async def on_member_ban(self, guild: discord.Guild, member: discord.Member) -> None: """Log ban event to user log.""" if guild.id != GuildConstant.id: return @@ -399,7 +399,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_join(self, member: disnake.Member) -> None: + async def on_member_join(self, member: discord.Member) -> None: """Log member join event to user log.""" if member.guild.id != GuildConstant.id: return @@ -420,7 +420,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_remove(self, member: disnake.Member) -> None: + async def on_member_remove(self, member: discord.Member) -> None: """Log member leave event to user log.""" if member.guild.id != GuildConstant.id: return @@ -437,7 +437,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_member_unban(self, guild: disnake.Guild, member: disnake.User) -> None: + async def on_member_unban(self, guild: discord.Guild, member: discord.User) -> None: """Log member unban event to mod log.""" if guild.id != GuildConstant.id: return @@ -454,7 +454,7 @@ class ModLog(Cog, name="ModLog"): ) @staticmethod - def get_role_diff(before: t.List[disnake.Role], after: t.List[disnake.Role]) -> t.List[str]: + def get_role_diff(before: t.List[discord.Role], after: t.List[discord.Role]) -> t.List[str]: """Return a list of strings describing the roles added and removed.""" changes = [] before_roles = set(before) @@ -469,7 +469,7 @@ class ModLog(Cog, name="ModLog"): return changes @Cog.listener() - async def on_member_update(self, before: disnake.Member, after: disnake.Member) -> None: + async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: """Log member update event to user log.""" if before.guild.id != GuildConstant.id: return @@ -552,7 +552,7 @@ class ModLog(Cog, name="ModLog"): return channel.id in GuildConstant.modlog_blacklist - async def log_cached_deleted_message(self, message: disnake.Message) -> None: + async def log_cached_deleted_message(self, message: discord.Message) -> None: """ Log the message's details to message change log. @@ -608,7 +608,7 @@ class ModLog(Cog, name="ModLog"): channel_id=Channels.message_log ) - async def log_uncached_deleted_message(self, event: disnake.RawMessageDeleteEvent) -> None: + async def log_uncached_deleted_message(self, event: discord.RawMessageDeleteEvent) -> None: """ Log the message's details to message change log. @@ -648,7 +648,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_raw_message_delete(self, event: disnake.RawMessageDeleteEvent) -> None: + async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: """Log message deletions to message change log.""" if event.cached_message is not None: await self.log_cached_deleted_message(event.cached_message) @@ -656,7 +656,7 @@ class ModLog(Cog, name="ModLog"): await self.log_uncached_deleted_message(event) @Cog.listener() - async def on_message_edit(self, msg_before: disnake.Message, msg_after: disnake.Message) -> None: + async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: """Log message edit event to message change log.""" if self.is_message_blacklisted(msg_before): return @@ -727,7 +727,7 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_raw_message_edit(self, event: disnake.RawMessageUpdateEvent) -> None: + async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: """Log raw message edit event to message change log.""" if event.guild_id is None: return # ignore DM edits @@ -736,7 +736,7 @@ class ModLog(Cog, name="ModLog"): try: channel = self.bot.get_channel(int(event.data["channel_id"])) message = await channel.fetch_message(event.message_id) - except disnake.NotFound: # Was deleted before we got the event + except discord.NotFound: # Was deleted before we got the event return if self.is_message_blacklisted(message): @@ -860,9 +860,9 @@ class ModLog(Cog, name="ModLog"): @Cog.listener() async def on_voice_state_update( self, - member: disnake.Member, - before: disnake.VoiceState, - after: disnake.VoiceState + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState ) -> None: """Log member voice state changes to the voice log channel.""" if ( diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index 51d161d84..b5cd29b12 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -4,8 +4,8 @@ import datetime import arrow from async_rediscache import RedisCache from dateutil.parser import isoparse, parse as dateutil_parse -from disnake import Embed, Member -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord import Embed, Member +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles @@ -22,12 +22,12 @@ MAXIMUM_WORK_LIMIT = 16 class ModPings(Cog): """Commands for a moderator to turn moderator pings on and off.""" - # RedisCache[disnake.Member.id, 'Naïve ISO 8601 string'] + # RedisCache[discord.Member.id, 'Naïve ISO 8601 string'] # The cache's keys are mods who have pings off. # The cache's values are the times when the role should be re-applied to them, stored in ISO format. pings_off_mods = RedisCache() - # RedisCache[disnake.Member.id, 'start timestamp|total worktime in seconds'] + # RedisCache[discord.Member.id, 'start timestamp|total worktime in seconds'] # The cache's keys are mod's ID # The cache's values are their pings on schedule timestamp and the total seconds (work time) until pings off modpings_schedule = RedisCache() diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 0b677dddb..511520252 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -5,10 +5,10 @@ from datetime import datetime, timedelta, timezone from typing import Optional, OrderedDict, Union from async_rediscache import RedisCache -from disnake import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel -from disnake.ext import commands, tasks -from disnake.ext.commands import Context -from disnake.utils import MISSING +from discord import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel +from discord.ext import commands, tasks +from discord.ext.commands import Context +from discord.utils import MISSING from bot import constants from bot.bot import Bot diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py index 7fcafc01c..b6a771441 100644 --- a/bot/exts/moderation/slowmode.py +++ b/bot/exts/moderation/slowmode.py @@ -1,8 +1,8 @@ from typing import Optional from dateutil.relativedelta import relativedelta -from disnake import TextChannel -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord import TextChannel +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, Emojis, MODERATION_ROLES diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 7afd9f71d..985cc6eb1 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -2,10 +2,10 @@ from datetime import timedelta, timezone from operator import itemgetter import arrow -import disnake +import discord from arrow import Arrow from async_rediscache import RedisCache -from disnake.ext import commands +from discord.ext import commands from bot.bot import Bot from bot.constants import ( @@ -24,7 +24,7 @@ class Stream(commands.Cog): """Grant and revoke streaming permissions from members.""" # Stores tasks to remove streaming permission - # RedisCache[disnake.Member.id, UtcPosixTimestamp] + # RedisCache[discord.Member.id, UtcPosixTimestamp] task_cache = RedisCache() def __init__(self, bot: Bot): @@ -37,10 +37,10 @@ class Stream(commands.Cog): self.reload_task.cancel() self.reload_task.add_done_callback(lambda _: self.scheduler.cancel_all()) - async def _revoke_streaming_permission(self, member: disnake.Member) -> None: + async def _revoke_streaming_permission(self, member: discord.Member) -> None: """Remove the streaming permission from the given Member.""" await self.task_cache.delete(member.id) - await member.remove_roles(disnake.Object(Roles.video), reason="Streaming access revoked") + await member.remove_roles(discord.Object(Roles.video), reason="Streaming access revoked") async def _reload_tasks_from_redis(self) -> None: """Reload outstanding tasks from redis on startup, delete the task if the member has since left the server.""" @@ -66,7 +66,7 @@ class Stream(commands.Cog): self._revoke_streaming_permission(member) ) - async def _suspend_stream(self, ctx: commands.Context, member: disnake.Member) -> None: + async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None: """Suspend a member's stream.""" await self.bot.wait_until_guild_available() voice_state = member.voice @@ -90,7 +90,7 @@ class Stream(commands.Cog): @commands.command(aliases=("streaming",)) @commands.has_any_role(*MODERATION_ROLES) - async def stream(self, ctx: commands.Context, member: disnake.Member, duration: Expiry = None) -> None: + async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None: """ Temporarily grant streaming permissions to a member for a given duration. @@ -128,7 +128,7 @@ class Stream(commands.Cog): self.scheduler.schedule_at(duration, member.id, self._revoke_streaming_permission(member)) await self.task_cache.set(member.id, duration.timestamp()) - await member.add_roles(disnake.Object(Roles.video), reason="Temporary streaming access granted") + await member.add_roles(discord.Object(Roles.video), reason="Temporary streaming access granted") await ctx.send(f"{Emojis.check_mark} {member.mention} can now stream until {time.discord_timestamp(duration)}.") @@ -142,7 +142,7 @@ class Stream(commands.Cog): @commands.command(aliases=("pstream",)) @commands.has_any_role(*MODERATION_ROLES) - async def permanentstream(self, ctx: commands.Context, member: disnake.Member) -> None: + async def permanentstream(self, ctx: commands.Context, member: discord.Member) -> None: """Permanently grants the given member the permission to stream.""" log.trace(f"Attempting to give permanent streaming permission to {member} ({member.id}).") @@ -163,13 +163,13 @@ class Stream(commands.Cog): log.debug(f"{member} ({member.id}) already had permanent streaming permission.") return - await member.add_roles(disnake.Object(Roles.video), reason="Permanent streaming access granted") + await member.add_roles(discord.Object(Roles.video), reason="Permanent streaming access granted") await ctx.send(f"{Emojis.check_mark} Permanently granted {member.mention} the permission to stream.") log.debug(f"Successfully gave {member} ({member.id}) permanent streaming permission.") @commands.command(aliases=("unstream", "rstream")) @commands.has_any_role(*MODERATION_ROLES) - async def revokestream(self, ctx: commands.Context, member: disnake.Member) -> None: + async def revokestream(self, ctx: commands.Context, member: discord.Member) -> None: """Revoke the permission to stream from the given member.""" log.trace(f"Attempting to remove streaming permission from {member} ({member.id}).") @@ -222,7 +222,7 @@ class Stream(commands.Cog): # Only output the message in the pagination lines = [line[1] for line in streamer_info] - embed = disnake.Embed( + embed = discord.Embed( title=f"Members with streaming permission (`{len(lines)}` total)", colour=Colours.soft_green ) diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py index c958aa160..37338d19c 100644 --- a/bot/exts/moderation/verification.py +++ b/bot/exts/moderation/verification.py @@ -1,7 +1,7 @@ import typing as t -import disnake -from disnake.ext.commands import Cog, Context, command, has_any_role +import discord +from discord.ext.commands import Cog, Context, command, has_any_role from bot import constants from bot.bot import Bot @@ -51,7 +51,7 @@ async def safe_dm(coro: t.Coroutine) -> None: """ try: await coro - except disnake.HTTPException as discord_exc: + except discord.HTTPException as discord_exc: log.trace(f"DM dispatch failed on status {discord_exc.status} with code: {discord_exc.code}") if discord_exc.code != 50_007: # If any reason other than disabled DMs raise @@ -72,7 +72,7 @@ class Verification(Cog): # region: listeners @Cog.listener() - async def on_member_join(self, member: disnake.Member) -> None: + async def on_member_join(self, member: discord.Member) -> None: """Attempt to send initial direct message to each new member.""" if member.guild.id != constants.Guild.id: return # Only listen for PyDis events @@ -87,11 +87,11 @@ class Verification(Cog): log.trace(f"Sending on join message to new member: {member.id}") try: await safe_dm(member.send(ON_JOIN_MESSAGE)) - except disnake.HTTPException: + except discord.HTTPException: log.exception("DM dispatch failed on unexpected error code") @Cog.listener() - async def on_member_update(self, before: disnake.Member, after: disnake.Member) -> None: + async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: """Check if we need to send a verification DM to a gated user.""" if before.pending is True and after.pending is False: try: @@ -100,7 +100,7 @@ class Verification(Cog): # our alternate welcome DM which includes info such as our welcome # video. await safe_dm(after.send(VERIFIED_MESSAGE)) - except disnake.HTTPException: + except discord.HTTPException: log.exception("DM dispatch failed on unexpected error code") # endregion @@ -108,7 +108,7 @@ class Verification(Cog): @command(name='verify') @has_any_role(*constants.MODERATION_ROLES) - async def perform_manual_verification(self, ctx: Context, user: disnake.Member) -> None: + async def perform_manual_verification(self, ctx: Context, user: discord.Member) -> None: """Command for moderators to verify any user.""" log.trace(f'verify command called by {ctx.author} for {user.id}.') diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py index 24ae86bdd..fa66b00dd 100644 --- a/bot/exts/moderation/voice_gate.py +++ b/bot/exts/moderation/voice_gate.py @@ -3,10 +3,10 @@ from contextlib import suppress from datetime import timedelta import arrow -import disnake +import discord from async_rediscache import RedisCache -from disnake import Colour, Member, VoiceState -from disnake.ext.commands import Cog, Context, command +from discord import Colour, Member, VoiceState +from discord.ext.commands import Cog, Context, command from bot.api import ResponseCodeError from bot.bot import Bot @@ -51,7 +51,7 @@ VOICE_PING_DM = ( class VoiceGate(Cog): """Voice channels verification management.""" - # RedisCache[t.Union[disnake.User.id, disnake.Member.id], t.Union[disnake.Message.id, int]] + # RedisCache[t.Union[discord.User.id, discord.Member.id], t.Union[discord.Message.id, int]] # The cache's keys are the IDs of members who are verified or have joined a voice channel # The cache's values are either the message ID of the ping message or 0 (NO_MSG) if no message is present redis_cache = RedisCache() @@ -75,14 +75,14 @@ class VoiceGate(Cog): """ if message_id := await self.redis_cache.get(member_id): log.trace(f"Removing voice gate reminder message for user: {member_id}") - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await self.bot.http.delete_message(Channels.voice_gate, message_id) await self.redis_cache.set(member_id, NO_MSG) else: log.trace(f"Voice gate reminder message for user {member_id} was already removed") @redis_cache.atomic_transaction - async def _ping_newcomer(self, member: disnake.Member) -> tuple: + async def _ping_newcomer(self, member: discord.Member) -> tuple: """ See if `member` should be sent a voice verification notification, and send it if so. @@ -91,7 +91,7 @@ class VoiceGate(Cog): * The `member` is already voice-verified Otherwise, the notification message ID is stored in `redis_cache` and return (True, channel). - channel is either [disnake.TextChannel, disnake.DMChannel]. + channel is either [discord.TextChannel, discord.DMChannel]. """ if await self.redis_cache.contains(member.id): log.trace("User already in cache. Ignore.") @@ -111,7 +111,7 @@ class VoiceGate(Cog): try: message = await member.send(VOICE_PING_DM.format(channel_mention=voice_verification_channel.mention)) - except disnake.Forbidden: + except discord.Forbidden: log.trace("DM failed for Voice ping message. Sending in channel.") message = await voice_verification_channel.send(f"Hello, {member.mention}! {VOICE_PING}") @@ -137,7 +137,7 @@ class VoiceGate(Cog): data = await self.bot.api_client.get(f"bot/users/{ctx.author.id}/metricity_data") except ResponseCodeError as e: if e.status == 404: - embed = disnake.Embed( + embed = discord.Embed( title="Not found", description=( "We were unable to find user data for you. " @@ -148,7 +148,7 @@ class VoiceGate(Cog): ) log.info(f"Unable to find Metricity data about {ctx.author} ({ctx.author.id})") else: - embed = disnake.Embed( + embed = discord.Embed( title="Unexpected response", description=( "We encountered an error while attempting to find data for your user. " @@ -159,7 +159,7 @@ class VoiceGate(Cog): log.warning(f"Got response code {e.status} while trying to get {ctx.author.id} Metricity data.") try: await ctx.author.send(embed=embed) - except disnake.Forbidden: + except discord.Forbidden: log.info("Could not send user DM. Sending in voice-verify channel and scheduling delete.") await ctx.send(embed=embed) @@ -179,7 +179,7 @@ class VoiceGate(Cog): [self.bot.stats.incr(f"voice_gate.failed.{key}") for key, value in checks.items() if value is True] if failed: - embed = disnake.Embed( + embed = discord.Embed( title="Voice Gate failed", description=FAILED_MESSAGE.format(reasons="\n".join(f'• You {reason}.' for reason in failed_reasons)), color=Colour.red() @@ -187,12 +187,12 @@ class VoiceGate(Cog): try: await ctx.author.send(embed=embed) await ctx.send(f"{ctx.author}, please check your DMs.") - except disnake.Forbidden: + except discord.Forbidden: await ctx.channel.send(ctx.author.mention, embed=embed) return self.mod_log.ignore(Event.member_update, ctx.author.id) - embed = disnake.Embed( + embed = discord.Embed( title="Voice gate passed", description="You have been granted permission to use voice channels in Python Discord.", color=Colour.green() @@ -204,17 +204,17 @@ class VoiceGate(Cog): try: await ctx.author.send(embed=embed) await ctx.send(f"{ctx.author}, please check your DMs.") - except disnake.Forbidden: + except discord.Forbidden: await ctx.channel.send(ctx.author.mention, embed=embed) # wait a little bit so those who don't get DMs see the response in-channel before losing perms to see it. await asyncio.sleep(3) - await ctx.author.add_roles(disnake.Object(Roles.voice_verified), reason="Voice Gate passed") + await ctx.author.add_roles(discord.Object(Roles.voice_verified), reason="Voice Gate passed") self.bot.stats.incr("voice_gate.passed") @Cog.listener() - async def on_message(self, message: disnake.Message) -> None: + async def on_message(self, message: discord.Message) -> None: """Delete all non-staff messages from voice gate channel that don't invoke voice verify command.""" # Check is channel voice gate if message.channel.id != Channels.voice_gate: @@ -229,7 +229,7 @@ class VoiceGate(Cog): if message.content.endswith(VOICE_PING): log.trace("Message is the voice verification ping. Ignore.") return - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await message.delete(delay=GateConf.bot_message_delete_delay) return @@ -242,7 +242,7 @@ class VoiceGate(Cog): if ctx.command is not None and ctx.command.name == "voice_verify": self.mod_log.ignore(Event.message_delete, message.id) - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await message.delete() @Cog.listener() @@ -257,7 +257,7 @@ class VoiceGate(Cog): log.trace("User not in a voice channel. Ignore.") return - if isinstance(after.channel, disnake.StageChannel): + if isinstance(after.channel, discord.StageChannel): log.trace("User joined a stage channel. Ignore.") return @@ -267,7 +267,7 @@ class VoiceGate(Cog): # Schedule the channel ping notification to be deleted after the configured delay, which is # again delegated to an atomic helper - if notification_sent and isinstance(message_channel, disnake.TextChannel): + if notification_sent and isinstance(message_channel, discord.TextChannel): await asyncio.sleep(GateConf.voice_ping_delete_delay) await self._delete_ping(member.id) diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index 88669ccaa..ee9b6ba45 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -6,9 +6,9 @@ from collections import defaultdict, deque from dataclasses import dataclass from typing import Any, Dict, Optional -import disnake -from disnake import Color, DMChannel, Embed, HTTPException, Message, errors -from disnake.ext.commands import Cog, Context +import discord +from discord import Color, DMChannel, Embed, HTTPException, Message, errors +from discord.ext.commands import Cog, Context from bot.api import ResponseCodeError from bot.bot import Bot @@ -104,7 +104,7 @@ class WatchChannel(metaclass=CogABCMeta): try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) - except disnake.HTTPException: + except discord.HTTPException: self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") if self.channel is None or self.webhook is None: @@ -217,7 +217,7 @@ class WatchChannel(metaclass=CogABCMeta): username = messages.sub_clyde(username) try: await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed) - except disnake.HTTPException as exc: + except discord.HTTPException as exc: self.log.exception( "Failed to send a message to the webhook", exc_info=exc @@ -265,7 +265,7 @@ class WatchChannel(metaclass=CogABCMeta): username=msg.author.display_name, avatar_url=msg.author.display_avatar.url ) - except disnake.HTTPException as exc: + except discord.HTTPException as exc: self.log.exception( "Failed to send an attachment to the webhook", exc_info=exc diff --git a/bot/exts/moderation/watchchannels/bigbrother.py b/bot/exts/moderation/watchchannels/bigbrother.py index f19e3d103..31b106a20 100644 --- a/bot/exts/moderation/watchchannels/bigbrother.py +++ b/bot/exts/moderation/watchchannels/bigbrother.py @@ -1,7 +1,7 @@ import textwrap from collections import ChainMap -from disnake.ext.commands import Cog, Context, group, has_any_role +from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Webhooks @@ -94,7 +94,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"): await ctx.send(f":x: {user.mention} is already being watched.") return - # disnake.User instances don't have a roles attribute + # discord.User instances don't have a roles attribute if hasattr(user, "roles") and any(role.id in MODERATION_ROLES for role in user.roles): await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I must be kind to my masters.") return diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 3d784ef77..0554bf37a 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -3,10 +3,10 @@ from collections import ChainMap, defaultdict from io import StringIO from typing import Optional, Union -import disnake +import discord from async_rediscache import RedisCache -from disnake import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User -from disnake.ext.commands import BadArgument, Cog, Context, group, has_any_role +from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User +from discord.ext.commands import BadArgument, Cog, Context, group, has_any_role from bot.api import ResponseCodeError from bot.bot import Bot @@ -483,7 +483,7 @@ class TalentPool(Cog, name="Talentpool"): async def get_review(self, ctx: Context, user_id: int) -> None: """Get the user's review as a markdown file.""" review, _, _ = await self.reviewer.make_review(user_id) - file = disnake.File(StringIO(review), f"{user_id}_review.md") + file = discord.File(StringIO(review), f"{user_id}_review.md") await ctx.send(file=file) @nomination_group.command(aliases=('review',)) diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index d496d0eb2..b4d177622 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -10,8 +10,8 @@ from typing import List, Optional, Union import arrow from dateutil.parser import isoparse -from disnake import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel -from disnake.ext.commands import Context +from discord import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel +from discord.ext.commands import Context from bot.api import ResponseCodeError from bot.bot import Bot diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py index 7d18c0ed3..8f0094bc9 100644 --- a/bot/exts/utils/bot.py +++ b/bot/exts/utils/bot.py @@ -1,7 +1,7 @@ from typing import Optional -from disnake import Embed, TextChannel -from disnake.ext.commands import Cog, Context, command, group, has_any_role +from discord import Embed, TextChannel +from discord.ext.commands import Cog, Context, command, group, has_any_role from bot.bot import Bot from bot.constants import Guild, MODERATION_ROLES, URLs diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index 3d12ae848..fda1e49e2 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -2,9 +2,9 @@ import functools import typing as t from enum import Enum -from disnake import Colour, Embed -from disnake.ext import commands -from disnake.ext.commands import Context, group +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Context, group from bot import exts from bot.bot import Bot diff --git a/bot/exts/utils/internal.py b/bot/exts/utils/internal.py index 28c1867ad..e7113c09c 100644 --- a/bot/exts/utils/internal.py +++ b/bot/exts/utils/internal.py @@ -9,8 +9,8 @@ from io import StringIO from typing import Any, Optional, Tuple import arrow -import disnake -from disnake.ext.commands import Cog, Context, group, has_any_role, is_owner +import discord +from discord.ext.commands import Cog, Context, group, has_any_role, is_owner from bot.bot import Bot from bot.constants import DEBUG_MODE, Roles @@ -42,7 +42,7 @@ class Internal(Cog): self.socket_event_total += 1 self.socket_events[event_type] += 1 - def _format(self, inp: str, out: Any) -> Tuple[str, Optional[disnake.Embed]]: + def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: """Format the eval output into a string & attempt to format it into an Embed.""" self._ = out @@ -103,7 +103,7 @@ class Internal(Cog): res += f"Out[{self.ln}]: " - if isinstance(out, disnake.Embed): + if isinstance(out, discord.Embed): # We made an embed? Send that as embed res += "" res = (res, out) @@ -136,7 +136,7 @@ class Internal(Cog): return res # Return (text, embed) - async def _eval(self, ctx: Context, code: str) -> Optional[disnake.Message]: + async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: """Eval the input code string & send an embed to the invoking context.""" self.ln += 1 @@ -154,8 +154,7 @@ class Internal(Cog): "self": self, "bot": self.bot, "inspect": inspect, - "discord": disnake, - "disnake": disnake, + "discord": discord, "contextlib": contextlib } @@ -241,10 +240,10 @@ async def func(): # (None,) -> Any per_s = self.socket_event_total / running_s - stats_embed = disnake.Embed( + stats_embed = discord.Embed( title="WebSocket statistics", description=f"Receiving {per_s:0.2f} events per second.", - color=disnake.Color.og_blurple() + color=discord.Color.og_blurple() ) for event_type, count in self.socket_events.most_common(25): diff --git a/bot/exts/utils/ping.py b/bot/exts/utils/ping.py index eeb1d5ff5..9fb5b7b8f 100644 --- a/bot/exts/utils/ping.py +++ b/bot/exts/utils/ping.py @@ -1,7 +1,7 @@ import arrow from aiohttp import client_exceptions -from disnake import Embed -from disnake.ext import commands +from discord import Embed +from discord.ext import commands from bot.bot import Bot from bot.constants import Channels, STAFF_PARTNERS_COMMUNITY_ROLES, URLs diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index bf0e9d2ac..ad82d49c9 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -4,9 +4,9 @@ import typing as t from datetime import datetime, timezone from operator import itemgetter -import disnake +import discord from dateutil.parser import isoparse -from disnake.ext.commands import Cog, Context, Greedy, group +from discord.ext.commands import Cog, Context, Greedy, group from bot.bot import Bot from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES @@ -26,8 +26,8 @@ LOCK_NAMESPACE = "reminder" WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 -Mentionable = t.Union[disnake.Member, disnake.Role] -ReminderMention = t.Union[UnambiguousUser, disnake.Role] +Mentionable = t.Union[discord.Member, discord.Role] +ReminderMention = t.Union[UnambiguousUser, discord.Role] class Reminders(Cog): @@ -66,7 +66,7 @@ class Reminders(Cog): else: self.schedule_reminder(reminder) - def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, disnake.TextChannel]: + def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, discord.TextChannel]: """Ensure reminder channel can be fetched otherwise delete the reminder.""" channel = self.bot.get_channel(reminder['channel_id']) is_valid = True @@ -87,9 +87,9 @@ class Reminders(Cog): reminder_id: t.Union[str, int] ) -> None: """Send an embed confirming the reminder change was made successfully.""" - embed = disnake.Embed( + embed = discord.Embed( description=on_success, - colour=disnake.Colour.green(), + colour=discord.Colour.green(), title=random.choice(POSITIVE_REPLIES) ) @@ -113,7 +113,7 @@ class Reminders(Cog): if await has_no_roles_check(ctx, *STAFF_PARTNERS_COMMUNITY_ROLES): return False, "members/roles" elif await has_no_roles_check(ctx, *MODERATION_ROLES): - return all(isinstance(mention, (disnake.User, disnake.Member)) for mention in mentions), "roles" + return all(isinstance(mention, (discord.User, discord.Member)) for mention in mentions), "roles" else: return True, "" @@ -173,15 +173,15 @@ class Reminders(Cog): if not is_valid: # No need to cancel the task too; it'll simply be done once this coroutine returns. return - embed = disnake.Embed() + embed = discord.Embed() if expected_time: - embed.colour = disnake.Colour.red() + embed.colour = discord.Colour.red() embed.set_author( icon_url=Icons.remind_red, name="Sorry, your reminder should have arrived earlier!" ) else: - embed.colour = disnake.Colour.og_blurple() + embed.colour = discord.Colour.og_blurple() embed.set_author( icon_url=Icons.remind_blurple, name="It has arrived!" @@ -200,7 +200,7 @@ class Reminders(Cog): partial_message = channel.get_partial_message(int(jump_url.split("/")[-1])) try: await partial_message.reply(content=f"{additional_mentions}", embed=embed) - except disnake.HTTPException as e: + except discord.HTTPException as e: log.info( f"There was an error when trying to reply to a reminder invocation message, {e}, " "fall back to using jump_url" @@ -284,7 +284,7 @@ class Reminders(Cog): # If `content` isn't provided then we try to get message content of a replied message if not content: if reference := ctx.message.reference: - if isinstance((resolved_message := reference.resolved), disnake.Message): + if isinstance((resolved_message := reference.resolved), discord.Message): content = resolved_message.content # If we weren't able to get the content of a replied message if content is None: @@ -361,8 +361,8 @@ class Reminders(Cog): lines.append(text) - embed = disnake.Embed() - embed.colour = disnake.Colour.og_blurple() + embed = discord.Embed() + embed.colour = discord.Colour.og_blurple() embed.title = f"Reminders for {ctx.author}" # Remind the user that they have no reminders :^) @@ -372,7 +372,7 @@ class Reminders(Cog): return # Construct the embed and paginate it. - embed.colour = disnake.Colour.og_blurple() + embed.colour = discord.Colour.og_blurple() await LinePaginator.paginate( lines, diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 07d824f87..cc3a2e1d7 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -8,8 +8,8 @@ from signal import Signals from typing import Optional, Tuple from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX -from disnake import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from disnake.ext.commands import Cog, Context, command, guild_only +from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User +from discord.ext.commands import Cog, Context, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs diff --git a/bot/exts/utils/thread_bumper.py b/bot/exts/utils/thread_bumper.py index d37b3b51c..35057f1fe 100644 --- a/bot/exts/utils/thread_bumper.py +++ b/bot/exts/utils/thread_bumper.py @@ -1,8 +1,8 @@ import typing as t -import disnake +import discord from async_rediscache import RedisCache -from disnake.ext import commands +from discord.ext import commands from bot import constants from bot.bot import Bot @@ -16,14 +16,14 @@ log = get_logger(__name__) class ThreadBumper(commands.Cog): """Cog that allow users to add the current thread to a list that get reopened on archive.""" - # RedisCache[disnake.Thread.id, "sentinel"] + # RedisCache[discord.Thread.id, "sentinel"] threads_to_bump = RedisCache() def __init__(self, bot: Bot): self.bot = bot self.init_task = scheduling.create_task(self.ensure_bumped_threads_are_active(), event_loop=self.bot.loop) - async def unarchive_threads_not_manually_archived(self, threads: list[disnake.Thread]) -> None: + async def unarchive_threads_not_manually_archived(self, threads: list[discord.Thread]) -> None: """ Iterate through and unarchive any threads that weren't manually archived recently. @@ -35,7 +35,7 @@ class ThreadBumper(commands.Cog): guild = self.bot.get_guild(constants.Guild.id) recent_manually_archived_thread_ids = [] - async for thread_update in guild.audit_logs(limit=200, action=disnake.AuditLogAction.thread_update): + async for thread_update in guild.audit_logs(limit=200, action=discord.AuditLogAction.thread_update): if getattr(thread_update.after, "archived", False): recent_manually_archived_thread_ids.append(thread_update.target.id) @@ -58,7 +58,7 @@ class ThreadBumper(commands.Cog): for thread_id, _ in await self.threads_to_bump.items(): try: thread = await channel.get_or_fetch_channel(thread_id) - except disnake.NotFound: + except discord.NotFound: log.info("Thread %d has been deleted, removing from bumped threads.", thread_id) await self.threads_to_bump.delete(thread_id) continue @@ -75,12 +75,12 @@ class ThreadBumper(commands.Cog): await ctx.send_help(ctx.command) @thread_bump_group.command(name="add", aliases=("a",)) - async def add_thread_to_bump_list(self, ctx: commands.Context, thread: t.Optional[disnake.Thread]) -> None: + async def add_thread_to_bump_list(self, ctx: commands.Context, thread: t.Optional[discord.Thread]) -> None: """Add a thread to the bump list.""" await self.init_task if not thread: - if isinstance(ctx.channel, disnake.Thread): + if isinstance(ctx.channel, discord.Thread): thread = ctx.channel else: raise commands.BadArgument("You must provide a thread, or run this command within a thread.") @@ -92,12 +92,12 @@ class ThreadBumper(commands.Cog): await ctx.send(f":ok_hand:{thread.mention} has been added to the bump list.") @thread_bump_group.command(name="remove", aliases=("r", "rem", "d", "del", "delete")) - async def remove_thread_from_bump_list(self, ctx: commands.Context, thread: t.Optional[disnake.Thread]) -> None: + async def remove_thread_from_bump_list(self, ctx: commands.Context, thread: t.Optional[discord.Thread]) -> None: """Remove a thread from the bump list.""" await self.init_task if not thread: - if isinstance(ctx.channel, disnake.Thread): + if isinstance(ctx.channel, discord.Thread): thread = ctx.channel else: raise commands.BadArgument("You must provide a thread, or run this command within a thread.") @@ -114,14 +114,14 @@ class ThreadBumper(commands.Cog): await self.init_task lines = [f"<#{k}>" for k, _ in await self.threads_to_bump.items()] - embed = disnake.Embed( + embed = discord.Embed( title="Threads in the bump list", colour=constants.Colours.blue ) await LinePaginator.paginate(lines, ctx, embed) @commands.Cog.listener() - async def on_thread_update(self, _: disnake.Thread, after: disnake.Thread) -> None: + async def on_thread_update(self, _: discord.Thread, after: discord.Thread) -> None: """ Listen for thread updates and check if the thread has been archived. diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index 77be3315c..2a074788e 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -3,9 +3,9 @@ import re import unicodedata from typing import Tuple, Union -from disnake import Colour, Embed, utils -from disnake.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role -from disnake.utils import snowflake_time +from discord import Colour, Embed, utils +from discord.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role +from discord.utils import snowflake_time from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES diff --git a/bot/log.py b/bot/log.py index 0b1d1aca6..100cd06f6 100644 --- a/bot/log.py +++ b/bot/log.py @@ -74,7 +74,7 @@ def setup() -> None: coloredlogs.install(level=TRACE_LEVEL, logger=root_log, stream=sys.stdout) root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO) - get_logger("disnake").setLevel(logging.WARNING) + get_logger("discord").setLevel(logging.WARNING) get_logger("websockets").setLevel(logging.WARNING) get_logger("chardet").setLevel(logging.WARNING) get_logger("async_rediscache").setLevel(logging.WARNING) diff --git a/bot/monkey_patches.py b/bot/monkey_patches.py index 590be22a2..4840fa454 100644 --- a/bot/monkey_patches.py +++ b/bot/monkey_patches.py @@ -2,8 +2,8 @@ import re from datetime import timedelta import arrow -from disnake import Forbidden, http -from disnake.ext import commands +from discord import Forbidden, http +from discord.ext import commands from bot.log import get_logger @@ -13,7 +13,7 @@ MESSAGE_ID_RE = re.compile(r'(?P[0-9]{15,20})$') class Command(commands.Command): """ - A `disnake.ext.commands.Command` subclass which supports root aliases. + A `discord.ext.commands.Command` subclass which supports root aliases. A `root_aliases` keyword argument is added, which is a sequence of alias names that will act as top-level commands rather than being aliases of the command's group. It's stored as an attribute diff --git a/bot/pagination.py b/bot/pagination.py index 1a014daa1..8f4353eb1 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -3,9 +3,9 @@ import typing as t from contextlib import suppress from functools import partial -import disnake -from disnake.abc import User -from disnake.ext.commands import Context, Paginator +import discord +from discord.abc import User +from discord.ext.commands import Context, Paginator from bot import constants from bot.log import get_logger @@ -55,7 +55,7 @@ class LinePaginator(Paginator): linesep: str = "\n" ) -> None: """ - This function overrides the Paginator.__init__ from inside disnake.ext.commands. + This function overrides the Paginator.__init__ from inside discord.ext.commands. It overrides in order to allow us to configure the maximum number of lines per page. """ @@ -99,7 +99,7 @@ class LinePaginator(Paginator): effort to avoid breaking up single lines across pages, while keeping the total length of the page at a reasonable size. - This function overrides the `Paginator.add_line` from inside `disnake.ext.commands`. + This function overrides the `Paginator.add_line` from inside `discord.ext.commands`. It overrides in order to allow us to configure the maximum number of lines per page. """ @@ -192,7 +192,7 @@ class LinePaginator(Paginator): cls, lines: t.List[str], ctx: Context, - embed: disnake.Embed, + embed: discord.Embed, prefix: str = "", suffix: str = "", max_lines: t.Optional[int] = None, @@ -204,7 +204,7 @@ class LinePaginator(Paginator): footer_text: str = None, url: str = None, exception_on_empty_embed: bool = False, - ) -> t.Optional[disnake.Message]: + ) -> t.Optional[discord.Message]: """ Use a paginator and set of reactions to provide pagination over a set of lines. @@ -219,7 +219,7 @@ class LinePaginator(Paginator): to any user with a moderation role. Example: - >>> embed = disnake.Embed() + >>> embed = discord.Embed() >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) >>> await LinePaginator.paginate([line for line in lines], ctx, embed) """ @@ -367,5 +367,5 @@ class LinePaginator(Paginator): await message.edit(embed=embed) log.debug("Ending pagination and clearing reactions.") - with suppress(disnake.NotFound): + with suppress(discord.NotFound): await message.clear_reactions() diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index 9c890e569..8903c385c 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/burst.py b/bot/rules/burst.py index a943cfdeb..25c5a2f33 100644 --- a/bot/rules/burst.py +++ b/bot/rules/burst.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/burst_shared.py b/bot/rules/burst_shared.py index dee857e18..bbe9271b3 100644 --- a/bot/rules/burst_shared.py +++ b/bot/rules/burst_shared.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/chars.py b/bot/rules/chars.py index 6d2f6eb83..1f587422c 100644 --- a/bot/rules/chars.py +++ b/bot/rules/chars.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/discord_emojis.py b/bot/rules/discord_emojis.py index 4fe4e88f9..d979ac5e7 100644 --- a/bot/rules/discord_emojis.py +++ b/bot/rules/discord_emojis.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message from emoji import demojize DISCORD_EMOJI_RE = re.compile(r"<:\w+:\d+>|:\w+:") diff --git a/bot/rules/duplicates.py b/bot/rules/duplicates.py index 77e393db0..8e4fbc12d 100644 --- a/bot/rules/duplicates.py +++ b/bot/rules/duplicates.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/links.py b/bot/rules/links.py index 92c13b3f4..c46b783c5 100644 --- a/bot/rules/links.py +++ b/bot/rules/links.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message LINK_RE = re.compile(r"(https?://[^\s]+)") diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 7ee66be31..6f5addad1 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/newlines.py b/bot/rules/newlines.py index 45266648e..4e66e1359 100644 --- a/bot/rules/newlines.py +++ b/bot/rules/newlines.py @@ -1,7 +1,7 @@ import re from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/rules/role_mentions.py b/bot/rules/role_mentions.py index 1f7a6a74d..0649540b6 100644 --- a/bot/rules/role_mentions.py +++ b/bot/rules/role_mentions.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple -from disnake import Member, Message +from discord import Member, Message async def apply( diff --git a/bot/utils/channel.py b/bot/utils/channel.py index ee0c87311..954a10e56 100644 --- a/bot/utils/channel.py +++ b/bot/utils/channel.py @@ -1,6 +1,6 @@ from typing import Union -import disnake +import discord import bot from bot import constants @@ -10,7 +10,7 @@ from bot.log import get_logger log = get_logger(__name__) -def is_help_channel(channel: disnake.TextChannel) -> bool: +def is_help_channel(channel: discord.TextChannel) -> bool: """Return True if `channel` is in one of the help categories (excluding dormant).""" log.trace(f"Checking if #{channel} is a help channel.") categories = (Categories.help_available, Categories.help_in_use) @@ -18,9 +18,9 @@ def is_help_channel(channel: disnake.TextChannel) -> bool: return any(is_in_category(channel, category) for category in categories) -def is_mod_channel(channel: Union[disnake.TextChannel, disnake.Thread]) -> bool: +def is_mod_channel(channel: Union[discord.TextChannel, discord.Thread]) -> bool: """True if channel, or channel.parent for threads, is considered a mod channel.""" - if isinstance(channel, disnake.Thread): + if isinstance(channel, discord.Thread): channel = channel.parent if channel.id in constants.MODERATION_CHANNELS: @@ -36,11 +36,11 @@ def is_mod_channel(channel: Union[disnake.TextChannel, disnake.Thread]) -> bool: return False -def is_staff_channel(channel: disnake.TextChannel) -> bool: +def is_staff_channel(channel: discord.TextChannel) -> bool: """True if `channel` is considered a staff channel.""" guild = bot.instance.get_guild(constants.Guild.id) - if channel.type is disnake.ChannelType.category: + if channel.type is discord.ChannelType.category: return False # Channel is staff-only if staff have explicit read allow perms @@ -52,12 +52,12 @@ def is_staff_channel(channel: disnake.TextChannel) -> bool: ) -def is_in_category(channel: disnake.TextChannel, category_id: int) -> bool: +def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: """Return True if `channel` is within a category with `category_id`.""" return getattr(channel, "category_id", None) == category_id -async def get_or_fetch_channel(channel_id: int) -> disnake.abc.GuildChannel: +async def get_or_fetch_channel(channel_id: int) -> discord.abc.GuildChannel: """Attempt to get or fetch a channel and return it.""" log.trace(f"Getting the channel {channel_id}.") diff --git a/bot/utils/checks.py b/bot/utils/checks.py index 9aa9bdc14..188285684 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -1,6 +1,6 @@ from typing import Callable, Container, Iterable, Optional, Union -from disnake.ext.commands import ( +from discord.ext.commands import ( BucketType, CheckFailure, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping, NoPrivateMessage, has_any_role ) @@ -135,7 +135,7 @@ def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketTy if any(role.id in bypass for role in ctx.author.roles): return - # cooldown logic, taken from disnake's internals + # cooldown logic, taken from discord.py internals current = ctx.message.created_at.timestamp() bucket = buckets.get_bucket(ctx.message) retry_after = bucket.update_rate_limit(current) diff --git a/bot/utils/function.py b/bot/utils/function.py index bb6d8afe3..55115d7d3 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -94,7 +94,7 @@ def update_wrapper_globals( """ Update globals of `wrapper` with the globals from `wrapped`. - For forwardrefs in command annotations disnake uses the __global__ attribute of the function + For forwardrefs in command annotations discordpy uses the __global__ attribute of the function to resolve their values, with decorators that replace the function this breaks because they have their own globals. @@ -103,7 +103,7 @@ def update_wrapper_globals( An exception will be raised in case `wrapper` and `wrapped` share a global name that is used by `wrapped`'s typehints and is not in `ignored_conflict_names`, - as this can cause incorrect objects being used by disnake's converters. + as this can cause incorrect objects being used by discordpy's converters. """ annotation_global_names = ( ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) @@ -136,7 +136,7 @@ def command_wraps( *, ignored_conflict_names: t.Set[str] = frozenset(), ) -> t.Callable[[types.FunctionType], types.FunctionType]: - """Update the decorated function to look like `wrapped` and update globals for disnake forwardref evaluation.""" + """Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation.""" def decorator(wrapper: types.FunctionType) -> types.FunctionType: return functools.update_wrapper( update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), diff --git a/bot/utils/helpers.py b/bot/utils/helpers.py index 859f53fdb..3501a3933 100644 --- a/bot/utils/helpers.py +++ b/bot/utils/helpers.py @@ -1,7 +1,7 @@ from abc import ABCMeta from typing import Optional -from disnake.ext.commands import CogMeta +from discord.ext.commands import CogMeta class CogABCMeta(CogMeta, ABCMeta): diff --git a/bot/utils/members.py b/bot/utils/members.py index d46baae5b..693286045 100644 --- a/bot/utils/members.py +++ b/bot/utils/members.py @@ -1,13 +1,13 @@ import typing as t -import disnake +import discord from bot.log import get_logger log = get_logger(__name__) -async def get_or_fetch_member(guild: disnake.Guild, member_id: int) -> t.Optional[disnake.Member]: +async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> t.Optional[discord.Member]: """ Attempt to get a member from cache; on failure fetch from the API. @@ -18,7 +18,7 @@ async def get_or_fetch_member(guild: disnake.Guild, member_id: int) -> t.Optiona else: try: member = await guild.fetch_member(member_id) - except disnake.errors.NotFound: + except discord.errors.NotFound: log.trace("Failed to fetch %d from API.", member_id) return None log.trace("%s fetched from API.", member) @@ -26,23 +26,23 @@ async def get_or_fetch_member(guild: disnake.Guild, member_id: int) -> t.Optiona async def handle_role_change( - member: disnake.Member, + member: discord.Member, coro: t.Callable[..., t.Coroutine], - role: disnake.Role + role: discord.Role ) -> None: """ Change `member`'s cooldown role via awaiting `coro` and handle errors. - `coro` is intended to be `disnake.Member.add_roles` or `disnake.Member.remove_roles`. + `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. """ try: await coro(role) - except disnake.NotFound: + except discord.NotFound: log.debug(f"Failed to change role for {member} ({member.id}): member not found") - except disnake.Forbidden: + except discord.Forbidden: log.debug( f"Forbidden to change role for {member} ({member.id}); " f"possibly due to role hierarchy" ) - except disnake.HTTPException as e: + except discord.HTTPException as e: log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py index edf2111e9..f68d280c9 100644 --- a/bot/utils/message_cache.py +++ b/bot/utils/message_cache.py @@ -1,7 +1,7 @@ import typing as t from math import ceil -from disnake import Message +from discord import Message class MessageCache: diff --git a/bot/utils/messages.py b/bot/utils/messages.py index 0bdb00a29..e55c07062 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -5,8 +5,8 @@ from functools import partial from io import BytesIO from typing import Callable, List, Optional, Sequence, Union -import disnake -from disnake.ext.commands import Context +import discord +from discord.ext.commands import Context import bot from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES @@ -17,8 +17,8 @@ log = get_logger(__name__) def reaction_check( - reaction: disnake.Reaction, - user: disnake.abc.User, + reaction: discord.Reaction, + user: discord.abc.User, *, message_id: int, allowed_emoji: Sequence[str], @@ -51,14 +51,14 @@ def reaction_check( log.trace(f"Removing reaction {reaction} by {user} on {reaction.message.id}: disallowed user.") scheduling.create_task( reaction.message.remove_reaction(reaction.emoji, user), - suppressed_exceptions=(disnake.HTTPException,), + suppressed_exceptions=(discord.HTTPException,), name=f"remove_reaction-{reaction}-{reaction.message.id}-{user}" ) return False async def wait_for_deletion( - message: disnake.Message, + message: discord.Message, user_ids: Sequence[int], deletion_emojis: Sequence[str] = (Emojis.trashcan,), timeout: float = 60 * 5, @@ -82,7 +82,7 @@ async def wait_for_deletion( for emoji in deletion_emojis: try: await message.add_reaction(emoji) - except disnake.NotFound: + except discord.NotFound: log.trace(f"Aborting wait_for_deletion: message {message.id} deleted prematurely.") return @@ -101,13 +101,13 @@ async def wait_for_deletion( await message.clear_reactions() else: await message.delete() - except disnake.NotFound: + except discord.NotFound: log.trace(f"wait_for_deletion: message {message.id} deleted prematurely.") async def send_attachments( - message: disnake.Message, - destination: Union[disnake.TextChannel, disnake.Webhook], + message: discord.Message, + destination: Union[discord.TextChannel, discord.Webhook], link_large: bool = True, use_cached: bool = False, **kwargs @@ -140,9 +140,9 @@ async def send_attachments( if attachment.size <= destination.guild.filesize_limit - 512: with BytesIO() as file: await attachment.save(file, use_cached=use_cached) - attachment_file = disnake.File(file, filename=attachment.filename) + attachment_file = discord.File(file, filename=attachment.filename) - if isinstance(destination, disnake.TextChannel): + if isinstance(destination, discord.TextChannel): msg = await destination.send(file=attachment_file, **kwargs) urls.append(msg.attachments[0].url) else: @@ -151,7 +151,7 @@ async def send_attachments( large.append(attachment) else: log.info(f"{failure_msg} because it's too large.") - except disnake.HTTPException as e: + except discord.HTTPException as e: if link_large and e.status == 413: large.append(attachment) else: @@ -159,10 +159,10 @@ async def send_attachments( if link_large and large: desc = "\n".join(f"[{attachment.filename}]({attachment.url})" for attachment in large) - embed = disnake.Embed(description=desc) + embed = discord.Embed(description=desc) embed.set_footer(text="Attachments exceed upload size limit.") - if isinstance(destination, disnake.TextChannel): + if isinstance(destination, discord.TextChannel): await destination.send(embed=embed, **kwargs) else: await destination.send(embed=embed, **webhook_send_kwargs) @@ -171,9 +171,9 @@ async def send_attachments( async def count_unique_users_reaction( - message: disnake.Message, - reaction_predicate: Callable[[disnake.Reaction], bool] = lambda _: True, - user_predicate: Callable[[disnake.User], bool] = lambda _: True, + message: discord.Message, + reaction_predicate: Callable[[discord.Reaction], bool] = lambda _: True, + user_predicate: Callable[[discord.User], bool] = lambda _: True, count_bots: bool = True ) -> int: """ @@ -193,7 +193,7 @@ async def count_unique_users_reaction( return len(unique_users) -async def pin_no_system_message(message: disnake.Message) -> bool: +async def pin_no_system_message(message: discord.Message) -> bool: """Pin the given message, wait a couple of seconds and try to delete the system message.""" await message.pin() @@ -201,7 +201,7 @@ async def pin_no_system_message(message: disnake.Message) -> bool: await asyncio.sleep(2) # Search for the system message in the last 10 messages async for historical_message in message.channel.history(limit=10): - if historical_message.type == disnake.MessageType.pins_add: + if historical_message.type == discord.MessageType.pins_add: await historical_message.delete() return True @@ -225,16 +225,16 @@ def sub_clyde(username: Optional[str]) -> Optional[str]: return username # Empty string or None -async def send_denial(ctx: Context, reason: str) -> disnake.Message: +async def send_denial(ctx: Context, reason: str) -> discord.Message: """Send an embed denying the user with the given reason.""" - embed = disnake.Embed() - embed.colour = disnake.Colour.red() + embed = discord.Embed() + embed.colour = discord.Colour.red() embed.title = random.choice(NEGATIVE_REPLIES) embed.description = reason return await ctx.send(embed=embed) -def format_user(user: disnake.abc.User) -> str: +def format_user(user: discord.abc.User) -> str: """Return a string for `user` which has their mention and ID.""" return f"{user.mention} (`{user.id}`)" diff --git a/bot/utils/webhooks.py b/bot/utils/webhooks.py index 8ef929b79..9c916b63a 100644 --- a/bot/utils/webhooks.py +++ b/bot/utils/webhooks.py @@ -1,7 +1,7 @@ from typing import Optional -import disnake -from disnake import Embed +import discord +from discord import Embed from bot.log import get_logger from bot.utils.messages import sub_clyde @@ -10,13 +10,13 @@ log = get_logger(__name__) async def send_webhook( - webhook: disnake.Webhook, + webhook: discord.Webhook, content: Optional[str] = None, username: Optional[str] = None, avatar_url: Optional[str] = None, embed: Optional[Embed] = None, wait: Optional[bool] = False -) -> disnake.Message: +) -> discord.Message: """ Send a message using the provided webhook. @@ -30,5 +30,5 @@ async def send_webhook( embed=embed, wait=wait, ) - except disnake.HTTPException: + except discord.HTTPException: log.exception("Failed to send a message to the webhook!") diff --git a/poetry.lock b/poetry.lock index 087fd739c..a8ee6ef5c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -143,7 +143,7 @@ lxml = ["lxml"] [[package]] name = "bot-core" -version = "2.1.0" +version = "1.2.0" description = "Bot-Core provides the core functionality and utilities for the bots of the Python Discord community." category = "main" optional = false @@ -154,7 +154,7 @@ python-versions = "3.9.*" [package.source] type = "url" -url = "https://github.com/python-discord/bot-core/archive/refs/tags/v2.1.0.zip" +url = "https://github.com/python-discord/bot-core/archive/511bcba1b0196cd498c707a525ea56921bd971db.zip" [[package]] name = "certifi" version = "2021.10.8" @@ -1074,7 +1074,7 @@ toml = ">=0.10.0,<0.11.0" [[package]] name = "testfixtures" -version = "6.18.5" +version = "6.18.3" description = "A collection of helpers and mock objects for unit tests and doc tests." category = "dev" optional = false @@ -1169,7 +1169,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "3.9.*" -content-hash = "b8b28311c13f7a66f028041bae889131d3916ca7f667c9a7539871d21bbcd077" +content-hash = "538a4809b9fc6fa93ee1baccf4016515ae311a886f1b7ec9b3d544bb87c830a3" [metadata.files] aio-pika = [ @@ -1990,8 +1990,8 @@ taskipy = [ {file = "taskipy-1.7.0.tar.gz", hash = "sha256:960e480b1004971e76454ecd1a0484e640744a30073a1069894a311467f85ed8"}, ] testfixtures = [ - {file = "testfixtures-6.18.5-py2.py3-none-any.whl", hash = "sha256:7de200e24f50a4a5d6da7019fb1197aaf5abd475efb2ec2422fdcf2f2eb98c1d"}, - {file = "testfixtures-6.18.5.tar.gz", hash = "sha256:02dae883f567f5b70fd3ad3c9eefb95912e78ac90be6c7444b5e2f46bf572c84"}, + {file = "testfixtures-6.18.3-py2.py3-none-any.whl", hash = "sha256:6ddb7f56a123e1a9339f130a200359092bd0a6455e31838d6c477e8729bb7763"}, + {file = "testfixtures-6.18.3.tar.gz", hash = "sha256:2600100ae96ffd082334b378e355550fef8b4a529a6fa4c34f47130905c7426d"}, ] tldextract = [ {file = "tldextract-3.1.2-py2.py3-none-any.whl", hash = "sha256:f55e05f6bf4cc952a87d13594386d32ad2dd265630a8bdfc3df03bd60425c6b0"}, diff --git a/pyproject.toml b/pyproject.toml index 1f02818ee..90b38ce66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" python = "3.9.*" disnake = "~=2.4" # See https://bot-core.pythondiscord.com/ for docs. -bot-core = {url = "https://github.com/python-discord/bot-core/archive/refs/tags/v2.1.0.zip"} +bot-core = {url = "https://github.com/python-discord/bot-core/archive/511bcba1b0196cd498c707a525ea56921bd971db.zip"} aio-pika = "~=6.1" aiodns = "~=2.0" aiohttp = "~=3.7" diff --git a/tests/README.md b/tests/README.md index fc03b3d43..b7fddfaa2 100644 --- a/tests/README.md +++ b/tests/README.md @@ -121,9 +121,9 @@ As we are trying to test our "units" of code independently, we want to make sure However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks". -To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `disnake` types (see the section on the below.). +To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.). -An example of mocking is when we provide a command with a mocked version of `disnake.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): +An example of mocking is when we provide a command with a mocked version of `discord.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): ```py import asyncio @@ -152,15 +152,15 @@ class BotCogTests(unittest.TestCase): By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. The [`AsyncMock`](https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock) that has been [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest) is an asynchronous version of `MagicMock` that can be used anywhere a coroutine is expected. -### Special mocks for some `disnake` types +### Special mocks for some `discord.py` types To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass. -In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual disnake types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `disnake` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. +In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR. -**Note:** These mock types only "know" the attributes that are set by default when these `disnake` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: +**Note:** These mock types only "know" the attributes that are set by default when these `discord.py` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: ```py import unittest.mock @@ -245,7 +245,7 @@ All in all, it's not only important to consider if all statements or branches we ### Unit Testing vs Integration Testing -Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `disnake` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. +Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `discord.py` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. The answer to this is that we also need to make sure that the individual parts come together into a working application. In addition, we will also need to make sure that the application communicates correctly with external applications. Since we currently have no automated integration tests or functional tests, that means **it's still very important to fire up the bot and test the code you've written manually** in addition to the unit tests you've written. diff --git a/tests/base.py b/tests/base.py index dea7dd678..5e304ea9d 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,8 +3,8 @@ import unittest from contextlib import contextmanager from typing import Dict -import disnake -from disnake.ext import commands +import discord +from discord.ext import commands from bot.log import get_logger from tests import helpers @@ -80,7 +80,7 @@ class LoggingTestsMixin: class CommandTestCase(unittest.IsolatedAsyncioTestCase): - """TestCase with additional assertions that are useful for testing disnake commands.""" + """TestCase with additional assertions that are useful for testing Discord commands.""" async def assertHasPermissionsCheck( # noqa: N802 self, @@ -98,7 +98,7 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): permissions = {k: not v for k, v in permissions.items()} ctx = helpers.MockContext() - ctx.channel.permissions_for.return_value = disnake.Permissions(**permissions) + ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) with self.assertRaises(commands.MissingPermissions) as cm: await cmd.can_run(ctx) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 4ed7de64d..fdd0ab74a 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -import disnake +import discord from bot import constants from bot.api import ResponseCodeError @@ -257,9 +257,9 @@ class SyncCogListenerTests(SyncCogTestCase): self.assertTrue(self.cog.on_member_update.__cog_listener__) subtests = ( - ("activities", disnake.Game("Pong"), disnake.Game("Frogger")), + ("activities", discord.Game("Pong"), discord.Game("Frogger")), ("nick", "old nick", "new nick"), - ("status", disnake.Status.online, disnake.Status.offline), + ("status", discord.Status.online, discord.Status.offline), ) for attribute, old_value, new_value in subtests: diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 9ecb8fae0..541074336 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -import disnake +import discord from bot.exts.backend.sync._syncers import RoleSyncer, _Diff, _Role from tests import helpers @@ -34,8 +34,8 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): for role in roles: mock_role = helpers.MockRole(**role) - mock_role.colour = disnake.Colour(role["colour"]) - mock_role.permissions = disnake.Permissions(role["permissions"]) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) guild.roles.append(mock_role) return guild diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index f55f5360f..2fc97af2d 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from disnake.errors import NotFound +from discord.errors import NotFound from bot.exts.backend.sync._syncers import UserSyncer, _Diff from tests import helpers diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 83b5f2749..35fa0ee59 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import AsyncMock, MagicMock, call, patch -from disnake.ext.commands import errors +from discord.ext.commands import errors from bot.api import ResponseCodeError from bot.errors import InvalidInfractedUserError, LockedResourceError diff --git a/tests/bot/exts/events/test_code_jams.py b/tests/bot/exts/events/test_code_jams.py index fdff36b61..0856546af 100644 --- a/tests/bot/exts/events/test_code_jams.py +++ b/tests/bot/exts/events/test_code_jams.py @@ -1,8 +1,8 @@ import unittest from unittest.mock import AsyncMock, MagicMock, create_autospec, patch -from disnake import CategoryChannel -from disnake.ext.commands import BadArgument +from discord import CategoryChannel +from discord.ext.commands import BadArgument from bot.constants import Roles from bot.exts.events import code_jams diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index 0cab405d0..06d78de9d 100644 --- a/tests/bot/exts/filters/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import AsyncMock, Mock -from disnake import NotFound +from discord import NotFound from bot.constants import Channels, STAFF_ROLES from bot.exts.filters import antimalware diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py index 46fa82fd7..c0c3baa42 100644 --- a/tests/bot/exts/filters/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock -from disnake.ext.commands import NoPrivateMessage +from discord.ext.commands import NoPrivateMessage from bot.exts.filters import security from tests.helpers import MockBot, MockContext diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py index dd56c10dd..4db27269a 100644 --- a/tests/bot/exts/filters/test_token_remover.py +++ b/tests/bot/exts/filters/test_token_remover.py @@ -3,7 +3,7 @@ from re import Match from unittest import mock from unittest.mock import MagicMock -from disnake import Colour, NotFound +from discord import Colour, NotFound from bot import constants from bot.exts.filters import token_remover diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 9a35de7a9..d896b7652 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -3,7 +3,7 @@ import unittest import unittest.mock from datetime import datetime -import disnake +import discord from bot import constants from bot.exts.info import information @@ -43,7 +43,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): embed = kwargs.pop('embed') self.assertEqual(embed.title, "Role information (Total 1 role)") - self.assertEqual(embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(embed.colour, discord.Colour.og_blurple()) self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") async def test_role_info_command(self): @@ -51,19 +51,19 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): dummy_role = helpers.MockRole( name="Dummy", id=112233445566778899, - colour=disnake.Colour.og_blurple(), + colour=discord.Colour.og_blurple(), position=10, members=[self.ctx.author], - permissions=disnake.Permissions(0) + permissions=discord.Permissions(0) ) admin_role = helpers.MockRole( name="Admins", id=998877665544332211, - colour=disnake.Colour.red(), + colour=discord.Colour.red(), position=3, members=[self.ctx.author], - permissions=disnake.Permissions(0), + permissions=discord.Permissions(0), ) self.ctx.guild.roles.extend([dummy_role, admin_role]) @@ -81,7 +81,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): admin_embed = admin_kwargs["embed"] self.assertEqual(dummy_embed.title, "Dummy info") - self.assertEqual(dummy_embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(dummy_embed.colour, discord.Colour.og_blurple()) self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") @@ -91,7 +91,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(dummy_embed.fields[5].value, "0") self.assertEqual(admin_embed.title, "Admins info") - self.assertEqual(admin_embed.colour, disnake.Colour.red()) + self.assertEqual(admin_embed.colour, discord.Colour.red()) class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase): @@ -449,7 +449,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) - self.assertEqual(embed.colour, disnake.Colour(100)) + self.assertEqual(embed.colour, discord.Colour(100)) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", @@ -463,11 +463,11 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): """The embed should be created with the og blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() - user = helpers.MockMember(id=217, colour=disnake.Colour.default()) + user = helpers.MockMember(id=217, colour=discord.Colour.default()) user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) - self.assertEqual(embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(embed.colour, discord.Colour.og_blurple()) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index b85d086c9..052048053 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -3,7 +3,7 @@ import textwrap import unittest from unittest.mock import ANY, AsyncMock, DEFAULT, MagicMock, Mock, patch -from disnake.errors import NotFound +from discord.errors import NotFound from bot.constants import Event from bot.exts.moderation.clean import Clean diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index eaa0e701e..ff81ddd65 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -3,7 +3,7 @@ from collections import namedtuple from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch -from disnake import Embed, Forbidden, HTTPException, NotFound +from discord import Embed, Forbidden, HTTPException, NotFound from bot.api import ResponseCodeError from bot.constants import Colours, Icons diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 725455bbe..cfe0c4b03 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -7,7 +7,7 @@ from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp -import disnake +import discord from async_rediscache import RedisSession from bot.constants import Colours @@ -24,7 +24,7 @@ class MockAsyncIterable: Helper for mocking asynchronous for loops. It does not appear that the `unittest` library currently provides anything that would - allow us to simply mock an async iterator, such as `disnake.TextChannel.history`. + allow us to simply mock an async iterator, such as `discord.TextChannel.history`. We therefore write our own helper to wrap a regular synchronous iterable, and feed its values via `__anext__` rather than `__next__`. @@ -60,7 +60,7 @@ class MockSignal(enum.Enum): B = "B" -mock_404 = disnake.NotFound( +mock_404 = discord.NotFound( response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response message="Not found", ) @@ -70,8 +70,8 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): """Collection of tests for the `download_file` helper function.""" async def test_download_file_success(self): - """If `to_file` succeeds, function returns the acquired `disnake.File`.""" - file = MagicMock(disnake.File, filename="bigbadlemon.jpg") + """If `to_file` succeeds, function returns the acquired `discord.File`.""" + file = MagicMock(discord.File, filename="bigbadlemon.jpg") attachment = MockAttachment(to_file=AsyncMock(return_value=file)) acquired_file = await incidents.download_file(attachment) @@ -86,7 +86,7 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): async def test_download_file_fail(self): """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" - arbitrary_error = disnake.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") + arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) with self.assertLogs(logger=incidents.log, level=logging.ERROR): @@ -121,7 +121,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_with_attachment_succeeds(self): """Incident's attachment is downloaded and displayed in the embed's image field.""" - file = MagicMock(disnake.File, filename="bigbadjoe.jpg") + file = MagicMock(discord.File, filename="bigbadjoe.jpg") attachment = MockAttachment(filename="bigbadjoe.jpg") incident = MockMessage(content="this is an incident", attachments=[attachment]) @@ -394,7 +394,7 @@ class TestArchive(TestIncidents): author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), id=123, ) - built_embed = MagicMock(disnake.Embed, id=123) # We patch `make_embed` to return this + built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this with patch("bot.exts.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) @@ -616,7 +616,7 @@ class TestResolveMessage(TestIncidents): """ self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - arbitrary_error = disnake.HTTPException( + arbitrary_error = discord.HTTPException( response=MagicMock(aiohttp.ClientResponse), message="Arbitrary error", ) @@ -649,7 +649,7 @@ class TestOnRawReactionAdd(TestIncidents): super().setUp() # Ensure `cog_instance` is assigned self.payload = MagicMock( - disnake.RawReactionActionEvent, + discord.RawReactionActionEvent, channel_id=123, # Patched at class level message_id=456, member=MockMember(bot=False), diff --git a/tests/bot/exts/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py index 6c9ebed95..79e04837d 100644 --- a/tests/bot/exts/moderation/test_modlog.py +++ b/tests/bot/exts/moderation/test_modlog.py @@ -1,6 +1,6 @@ import unittest -import disnake +import discord from bot.exts.moderation.modlog import ModLog from tests.helpers import MockBot, MockTextChannel @@ -19,7 +19,7 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase): self.bot.get_channel.return_value = self.channel await self.cog.send_log_message( icon_url="foo", - colour=disnake.Colour.blue(), + colour=discord.Colour.blue(), title="bar", text="foo bar" * 3000 ) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 539651d6c..92ce3418a 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -7,7 +7,7 @@ from unittest import mock from unittest.mock import AsyncMock, Mock from async_rediscache import RedisSession -from disnake import PermissionOverwrite +from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence @@ -152,7 +152,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. self.assertTrue(self.cog._init_task.cancelled()) - @autospec("disnake.ext.commands", "has_any_role") + @autospec("discord.ext.commands", "has_any_role") @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): """Role check was called with `MODERATION_ROLES`""" diff --git a/tests/bot/exts/test_cogs.py b/tests/bot/exts/test_cogs.py index 5cb071d58..f8e120262 100644 --- a/tests/bot/exts/test_cogs.py +++ b/tests/bot/exts/test_cogs.py @@ -8,7 +8,7 @@ from collections import defaultdict from types import ModuleType from unittest import mock -from disnake.ext import commands +from discord.ext import commands from bot import exts @@ -34,7 +34,7 @@ class CommandNameTests(unittest.TestCase): raise ImportError(name=name) # pragma: no cover # The mock prevents asyncio.get_event_loop() from being called. - with mock.patch("disnake.ext.tasks.loop"): + with mock.patch("discord.ext.tasks.loop"): prefix = f"{exts.__name__}." for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error): if not module.ispkg: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index bec7574fb..8bdeedd27 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -2,8 +2,8 @@ import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch -from disnake import AllowedMentions -from disnake.ext import commands +from discord import AllowedMentions +from discord.ext import commands from bot import constants from bot.exts.utils import snekbox diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index afb8a973d..1bb678db2 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -4,7 +4,7 @@ from datetime import MAXYEAR, datetime, timezone from unittest.mock import MagicMock, patch from dateutil.relativedelta import relativedelta -from disnake.ext.commands import BadArgument +from discord.ext.commands import BadArgument from bot.converters import Duration, HushDurationConverter, ISODateTime, PackageName diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 5675e10ec..4ae11d5d3 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock -from disnake import DMChannel +from discord import DMChannel from bot.utils import checks from bot.utils.checks import InWhitelistCheckFailure diff --git a/tests/helpers.py b/tests/helpers.py index bd1418ab9..9d4988d23 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,9 +7,9 @@ import unittest.mock from asyncio import AbstractEventLoop from typing import Iterable, Optional -import disnake +import discord from aiohttp import ClientSession -from disnake.ext.commands import Context +from discord.ext.commands import Context from bot.api import APIClient from bot.async_stats import AsyncStatsClient @@ -26,11 +26,11 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -class HashableMixin(disnake.mixins.EqualityComparable): +class HashableMixin(discord.mixins.EqualityComparable): """ - Mixin that provides similar hashing and equality functionality as disnake's `Hashable` mixin. + Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. - Note: disnake`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions + Note: discord.py`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions for the relative small `id` integers we generally use in tests, this bit-shift is omitted. """ @@ -39,22 +39,22 @@ class HashableMixin(disnake.mixins.EqualityComparable): class ColourMixin: - """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like disnake does.""" + """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like discord.py does.""" @property - def color(self) -> disnake.Colour: + def color(self) -> discord.Colour: return self.colour @color.setter - def color(self, color: disnake.Colour) -> None: + def color(self, color: discord.Colour) -> None: self.colour = color @property - def accent_color(self) -> disnake.Colour: + def accent_color(self) -> discord.Colour: return self.accent_colour @accent_color.setter - def accent_color(self, color: disnake.Colour) -> None: + def accent_color(self, color: discord.Colour) -> None: self.accent_colour = color @@ -63,7 +63,7 @@ class CustomMockMixin: Provides common functionality for our custom Mock types. The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock - object. As disnake also uses synchronous methods that nonetheless return coroutine objects, the + object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The class method `spec_set` can be overwritten with the object that should be uses as the specification @@ -119,7 +119,7 @@ class CustomMockMixin: return klass(**kw) -# Create a guild instance to get a realistic Mock of `disnake.Guild` +# Create a guild instance to get a realistic Mock of `discord.Guild` guild_data = { 'id': 1, 'name': 'guild', @@ -139,20 +139,20 @@ guild_data = { 'owner_id': 1, 'afk_channel_id': 464033278631084042, } -guild_instance = disnake.Guild(data=guild_data, state=unittest.mock.MagicMock()) +guild_instance = discord.Guild(data=guild_data, state=unittest.mock.MagicMock()) class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ - A `Mock` subclass to mock `disnake.Guild` objects. + A `Mock` subclass to mock `discord.Guild` objects. - A MockGuild instance will follow the specifications of a `disnake.Guild` instance. This means + A MockGuild instance will follow the specifications of a `discord.Guild` instance. This means that if the code you're testing tries to access an attribute or method that normally does not - exist for a `disnake.Guild` object this will raise an `AttributeError`. This is to make sure our - tests fail if the code we're testing uses a `disnake.Guild` object in the wrong way. + exist for a `discord.Guild` object this will raise an `AttributeError`. This is to make sure our + tests fail if the code we're testing uses a `discord.Guild` object in the wrong way. One restriction of that is that if the code tries to access an attribute that normally does not - exist for `disnake.Guild` instance but was added dynamically, this will raise an exception with + exist for `discord.Guild` instance but was added dynamically, this will raise an exception with the mocked object. To get around that, you can set the non-standard attribute explicitly for the instance of `MockGuild`: @@ -160,10 +160,10 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): >>> guild.attribute_that_normally_does_not_exist = unittest.mock.MagicMock() In addition to attribute simulation, mocked guild object will pass an `isinstance` check against - `disnake.Guild`: + `discord.Guild`: >>> guild = MockGuild() - >>> isinstance(guild, disnake.Guild) + >>> isinstance(guild, discord.Guild) True For more info, see the `Mocking` section in `tests/README.md`. @@ -179,16 +179,16 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): self.roles.extend(roles) -# Create a Role instance to get a realistic Mock of `disnake.Role` +# Create a Role instance to get a realistic Mock of `discord.Role` role_data = {'name': 'role', 'id': 1} -role_instance = disnake.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) +role_instance = discord.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ - A Mock subclass to mock `disnake.Role` objects. + A Mock subclass to mock `discord.Role` objects. - Instances of this class will follow the specifications of `disnake.Role` instances. For more + Instances of this class will follow the specifications of `discord.Role` instances. For more information, see the `MockGuild` docstring. """ spec_set = role_instance @@ -198,40 +198,40 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): 'id': next(self.discord_id), 'name': 'role', 'position': 1, - 'colour': disnake.Colour(0xdeadbf), - 'permissions': disnake.Permissions(), + 'colour': discord.Colour(0xdeadbf), + 'permissions': discord.Permissions(), } super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if isinstance(self.colour, int): - self.colour = disnake.Colour(self.colour) + self.colour = discord.Colour(self.colour) if isinstance(self.permissions, int): - self.permissions = disnake.Permissions(self.permissions) + self.permissions = discord.Permissions(self.permissions) if 'mention' not in kwargs: self.mention = f'&{self.name}' def __lt__(self, other): - """Simplified position-based comparisons similar to those of `disnake.Role`.""" + """Simplified position-based comparisons similar to those of `discord.Role`.""" return self.position < other.position def __ge__(self, other): - """Simplified position-based comparisons similar to those of `disnake.Role`.""" + """Simplified position-based comparisons similar to those of `discord.Role`.""" return self.position >= other.position -# Create a Member instance to get a realistic Mock of `disnake.Member` +# Create a Member instance to get a realistic Mock of `discord.Member` member_data = {'user': 'lemon', 'roles': [1]} state_mock = unittest.mock.MagicMock() -member_instance = disnake.Member(data=member_data, guild=guild_instance, state=state_mock) +member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock) class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ A Mock subclass to mock Member objects. - Instances of this class will follow the specifications of `disnake.Member` instances. For more + Instances of this class will follow the specifications of `discord.Member` instances. For more information, see the `MockGuild` docstring. """ spec_set = member_instance @@ -249,11 +249,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin self.mention = f"@{self.name}" -# Create a User instance to get a realistic Mock of `disnake.User` +# Create a User instance to get a realistic Mock of `discord.User` _user_data_mock = collections.defaultdict(unittest.mock.MagicMock, { "accent_color": 0 }) -user_instance = disnake.User( +user_instance = discord.User( data=unittest.mock.MagicMock(get=unittest.mock.Mock(side_effect=_user_data_mock.get)), state=unittest.mock.MagicMock() ) @@ -263,7 +263,7 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ A Mock subclass to mock User objects. - Instances of this class will follow the specifications of `disnake.User` instances. For more + Instances of this class will follow the specifications of `discord.User` instances. For more information, see the `MockGuild` docstring. """ spec_set = user_instance @@ -305,7 +305,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Bot objects. - Instances of this class will follow the specifications of `disnake.ext.commands.Bot` instances. + Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ spec_set = Bot( @@ -324,7 +324,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) -# Create a TextChannel instance to get a realistic MagicMock of `disnake.TextChannel` +# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` channel_data = { 'id': 1, 'type': 'TextChannel', @@ -337,17 +337,17 @@ channel_data = { } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -text_channel_instance = disnake.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) channel_data["type"] = "VoiceChannel" -voice_channel_instance = disnake.VoiceChannel(state=state, guild=guild, data=channel_data) +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data) class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock TextChannel objects. - Instances of this class will follow the specifications of `disnake.TextChannel` instances. For + Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = text_channel_instance @@ -364,7 +364,7 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock VoiceChannel objects. - Instances of this class will follow the specifications of `disnake.VoiceChannel` instances. For + Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = voice_channel_instance @@ -381,14 +381,14 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): state = unittest.mock.MagicMock() me = unittest.mock.MagicMock() dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]} -dm_channel_instance = disnake.DMChannel(me=me, state=state, data=dm_channel_data) +dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data) class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock TextChannel objects. - Instances of this class will follow the specifications of `disnake.TextChannel` instances. For + Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = dm_channel_instance @@ -398,17 +398,17 @@ class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): super().__init__(**collections.ChainMap(kwargs, default_kwargs)) -# Create CategoryChannel instance to get a realistic MagicMock of `disnake.CategoryChannel` +# Create CategoryChannel instance to get a realistic MagicMock of `discord.CategoryChannel` category_channel_data = { 'id': 1, - 'type': disnake.ChannelType.category, + 'type': discord.ChannelType.category, 'name': 'category', 'position': 1, } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -category_channel_instance = disnake.CategoryChannel( +category_channel_instance = discord.CategoryChannel( state=state, guild=guild, data=category_channel_data ) @@ -419,7 +419,7 @@ class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): super().__init__(**collections.ChainMap(default_kwargs, kwargs)) -# Create a Message instance to get a realistic MagicMock of `disnake.Message` +# Create a Message instance to get a realistic MagicMock of `discord.Message` message_data = { 'id': 1, 'webhook_id': 431341013479718912, @@ -438,10 +438,10 @@ message_data = { } state = unittest.mock.MagicMock() channel = unittest.mock.MagicMock() -message_instance = disnake.Message(state=state, channel=channel, data=message_data) +message_instance = discord.Message(state=state, channel=channel, data=message_data) -# Create a Context instance to get a realistic MagicMock of `disnake.ext.commands.Context` +# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context` context_instance = Context( message=unittest.mock.MagicMock(), prefix="$", @@ -455,7 +455,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Context objects. - Instances of this class will follow the specifications of `disnake.ext.commands.Context` + Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ spec_set = context_instance @@ -471,14 +471,14 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) -attachment_instance = disnake.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) +attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Attachment objects. - Instances of this class will follow the specifications of `disnake.Attachment` instances. For + Instances of this class will follow the specifications of `discord.Attachment` instances. For more information, see the `MockGuild` docstring. """ spec_set = attachment_instance @@ -488,7 +488,7 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. - Instances of this class will follow the specifications of `disnake.Message` instances. For more + Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ spec_set = message_instance @@ -501,14 +501,14 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'} -emoji_instance = disnake.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) +emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Emoji objects. - Instances of this class will follow the specifications of `disnake.Emoji` instances. For more + Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ spec_set = emoji_instance @@ -518,27 +518,27 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): self.guild = kwargs.get('guild', MockGuild()) -partial_emoji_instance = disnake.PartialEmoji(animated=False, name='guido') +partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido') class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock PartialEmoji objects. - Instances of this class will follow the specifications of `disnake.PartialEmoji` instances. For + Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ spec_set = partial_emoji_instance -reaction_instance = disnake.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) +reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) class MockReaction(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Reaction objects. - Instances of this class will follow the specifications of `disnake.Reaction` instances. For + Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ spec_set = reaction_instance @@ -556,14 +556,14 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): self.__str__.return_value = str(self.emoji) -webhook_instance = disnake.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock()) +webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock()) class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter. - Instances of this class will follow the specifications of `disnake.Webhook` instances. For + Instances of this class will follow the specifications of `discord.Webhook` instances. For more information, see the `MockGuild` docstring. """ spec_set = webhook_instance diff --git a/tests/test_helpers.py b/tests/test_helpers.py index c5e799a85..81285e009 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,20 +2,20 @@ import asyncio import unittest import unittest.mock -import disnake +import discord from tests import helpers class DiscordMocksTests(unittest.TestCase): - """Tests for our specialized disnake mocks.""" + """Tests for our specialized discord.py mocks.""" def test_mock_role_default_initialization(self): """Test if the default initialization of MockRole results in the correct object.""" role = helpers.MockRole() - # The `spec` argument makes sure `isistance` checks with `disnake.Role` pass - self.assertIsInstance(role, disnake.Role) + # The `spec` argument makes sure `isistance` checks with `discord.Role` pass + self.assertIsInstance(role, discord.Role) self.assertEqual(role.name, "role") self.assertEqual(role.position, 1) @@ -61,8 +61,8 @@ class DiscordMocksTests(unittest.TestCase): """Test if the default initialization of Mockmember results in the correct object.""" member = helpers.MockMember() - # The `spec` argument makes sure `isistance` checks with `disnake.Member` pass - self.assertIsInstance(member, disnake.Member) + # The `spec` argument makes sure `isistance` checks with `discord.Member` pass + self.assertIsInstance(member, discord.Member) self.assertEqual(member.name, "member") self.assertListEqual(member.roles, [helpers.MockRole(name="@everyone", position=1, id=0)]) @@ -86,18 +86,18 @@ class DiscordMocksTests(unittest.TestCase): """Test if MockMember accepts and sets abitrary keyword arguments.""" member = helpers.MockMember( nick="Dino Man", - colour=disnake.Colour.default(), + colour=discord.Colour.default(), ) self.assertEqual(member.nick, "Dino Man") - self.assertEqual(member.colour, disnake.Colour.default()) + self.assertEqual(member.colour, discord.Colour.default()) def test_mock_guild_default_initialization(self): """Test if the default initialization of Mockguild results in the correct object.""" guild = helpers.MockGuild() - # The `spec` argument makes sure `isistance` checks with `disnake.Guild` pass - self.assertIsInstance(guild, disnake.Guild) + # The `spec` argument makes sure `isistance` checks with `discord.Guild` pass + self.assertIsInstance(guild, discord.Guild) self.assertListEqual(guild.roles, [helpers.MockRole(name="@everyone", position=1, id=0)]) self.assertListEqual(guild.members, []) @@ -127,15 +127,15 @@ class DiscordMocksTests(unittest.TestCase): """Tests if MockBot initializes with the correct values.""" bot = helpers.MockBot() - # The `spec` argument makes sure `isistance` checks with `disnake.ext.commands.Bot` pass - self.assertIsInstance(bot, disnake.ext.commands.Bot) + # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Bot` pass + self.assertIsInstance(bot, discord.ext.commands.Bot) def test_mock_context_default_initialization(self): """Tests if MockContext initializes with the correct values.""" context = helpers.MockContext() - # The `spec` argument makes sure `isistance` checks with `disnake.ext.commands.Context` pass - self.assertIsInstance(context, disnake.ext.commands.Context) + # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Context` pass + self.assertIsInstance(context, discord.ext.commands.Context) self.assertIsInstance(context.bot, helpers.MockBot) self.assertIsInstance(context.guild, helpers.MockGuild) -- cgit v1.2.3 From 43b6fee9eba12a6836530029a642cba6e7e505f0 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 19 Mar 2022 16:34:01 +0000 Subject: Use bot-core scheduling and member util functions --- bot/__init__.py | 17 +- bot/async_stats.py | 3 +- bot/converters.py | 2 +- bot/decorators.py | 3 +- bot/exts/backend/logging.py | 2 +- bot/exts/filters/antispam.py | 3 +- bot/exts/filters/filtering.py | 4 +- bot/exts/filters/token_remover.py | 5 +- bot/exts/fun/off_topic_names.py | 2 +- bot/exts/help_channels/_cog.py | 3 +- bot/exts/info/codeblock/_cog.py | 3 +- bot/exts/info/doc/_batch_parser.py | 2 +- bot/exts/info/doc/_cog.py | 4 +- bot/exts/info/subscribe.py | 2 +- bot/exts/moderation/defcon.py | 5 +- bot/exts/moderation/incidents.py | 2 +- bot/exts/moderation/infraction/_scheduler.py | 3 +- bot/exts/moderation/metabase.py | 5 +- bot/exts/moderation/modpings.py | 5 +- bot/exts/moderation/silence.py | 4 +- bot/exts/moderation/stream.py | 3 +- bot/exts/moderation/watchchannels/_watchchannel.py | 3 +- bot/exts/recruitment/talentpool/_cog.py | 3 +- bot/exts/recruitment/talentpool/_review.py | 2 +- bot/exts/utils/reminders.py | 5 +- bot/exts/utils/snekbox.py | 5 +- bot/monkey_patches.py | 76 -------- bot/utils/messages.py | 2 +- bot/utils/scheduling.py | 194 --------------------- tests/bot/exts/backend/sync/test_cog.py | 2 +- tests/bot/exts/filters/test_filtering.py | 2 +- 31 files changed, 53 insertions(+), 323 deletions(-) delete mode 100644 bot/monkey_patches.py delete mode 100644 bot/utils/scheduling.py (limited to 'tests') diff --git a/bot/__init__.py b/bot/__init__.py index 17d99105a..c652897be 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -1,11 +1,10 @@ import asyncio import os -from functools import partial, partialmethod from typing import TYPE_CHECKING -from discord.ext import commands +from botcore.utils import apply_monkey_patches -from bot import log, monkey_patches +from bot import log if TYPE_CHECKING: from bot.bot import Bot @@ -16,16 +15,6 @@ log.setup() if os.name == "nt": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -monkey_patches.patch_typing() - -# This patches any convertors that use PartialMessage, but not the PartialMessageConverter itself -# as library objects are made by this mapping. -# https://github.com/Rapptz/discord.py/blob/1a4e73d59932cdbe7bf2c281f25e32529fc7ae1f/discord/ext/commands/converter.py#L984-L1004 -commands.converter.PartialMessageConverter = monkey_patches.FixedPartialMessageConverter - -# Monkey-patch discord.py decorators to use the Command subclass which supports root aliases. -# Must be patched before any cogs are added. -commands.command = partial(commands.command, cls=monkey_patches.Command) -commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=monkey_patches.Command) +apply_monkey_patches() instance: "Bot" = None # Global Bot instance. diff --git a/bot/async_stats.py b/bot/async_stats.py index 2af832e5b..0303de7a1 100644 --- a/bot/async_stats.py +++ b/bot/async_stats.py @@ -1,10 +1,9 @@ import asyncio import socket +from botcore.utils import scheduling from statsd.client.base import StatsClientBase -from bot.utils import scheduling - class AsyncStatsClient(StatsClientBase): """An async transport method for statsd communication.""" diff --git a/bot/converters.py b/bot/converters.py index 3522a32aa..e819e4713 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -8,7 +8,7 @@ from ssl import CertificateError import dateutil.parser import discord from aiohttp import ClientConnectorError -from botcore.regex import DISCORD_INVITE +from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter from discord.utils import escape_markdown, snowflake_time diff --git a/bot/decorators.py b/bot/decorators.py index 8971898b3..466770c3a 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -5,13 +5,14 @@ import typing as t from contextlib import suppress import arrow +from botcore.utils import scheduling from discord import Member, NotFound from discord.ext import commands from discord.ext.commands import Cog, Context from bot.constants import Channels, DEBUG_MODE, RedirectOutput from bot.log import get_logger -from bot.utils import function, scheduling +from bot.utils import function from bot.utils.checks import ContextCheckFailure, in_whitelist_check from bot.utils.function import command_wraps diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py index 2d03cd580..469331ae5 100644 --- a/bot/exts/backend/logging.py +++ b/bot/exts/backend/logging.py @@ -1,10 +1,10 @@ +from botcore.utils import scheduling from discord import Embed from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE from bot.log import get_logger -from bot.utils import scheduling log = get_logger(__name__) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index bcd845a43..d9e23b25e 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -8,6 +8,7 @@ from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set import arrow +from botcore.utils import scheduling from discord import Colour, Member, Message, NotFound, Object, TextChannel from discord.ext.commands import Cog @@ -20,7 +21,7 @@ from bot.converters import Duration from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import lock, scheduling +from bot.utils import lock from bot.utils.message_cache import MessageCache from bot.utils.messages import format_user, send_attachments diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index b9f2a0e51..32efcc307 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -9,7 +9,8 @@ import dateutil.parser import regex import tldextract from async_rediscache import RedisCache -from botcore.regex import DISCORD_INVITE +from botcore.utils import scheduling +from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord import ChannelType, Colour, Embed, Forbidden, HTTPException, Member, Message, NotFound, TextChannel from discord.ext.commands import Cog @@ -21,7 +22,6 @@ from bot.constants import Channels, Colours, Filter, Guild, Icons, URLs from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import scheduling from bot.utils.messages import format_user log = get_logger(__name__) diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filters/token_remover.py index 520283ba3..436e6dc19 100644 --- a/bot/exts/filters/token_remover.py +++ b/bot/exts/filters/token_remover.py @@ -1,5 +1,4 @@ import base64 -import binascii import re import typing as t @@ -182,7 +181,7 @@ class TokenRemover(Cog): # that means it's not a valid user id. return None return int(string) - except (binascii.Error, ValueError): + except ValueError: return None @staticmethod @@ -198,7 +197,7 @@ class TokenRemover(Cog): try: decoded_bytes = base64.urlsafe_b64decode(b64_content) timestamp = int.from_bytes(decoded_bytes, byteorder="big") - except (binascii.Error, ValueError) as e: + except ValueError as e: log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") return False diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index 7df1d172d..33f43f2a8 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -2,6 +2,7 @@ import difflib from datetime import timedelta import arrow +from botcore.utils import scheduling from discord import Colour, Embed from discord.ext.commands import Cog, Context, group, has_any_role from discord.utils import sleep_until @@ -12,7 +13,6 @@ from bot.constants import Channels, MODERATION_ROLES from bot.converters import OffTopicName from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) log = get_logger(__name__) diff --git a/bot/exts/help_channels/_cog.py b/bot/exts/help_channels/_cog.py index a93acffb6..fc80c968c 100644 --- a/bot/exts/help_channels/_cog.py +++ b/bot/exts/help_channels/_cog.py @@ -7,6 +7,7 @@ from operator import attrgetter import arrow import discord import discord.abc +from botcore.utils import members, scheduling from discord.ext import commands from bot import constants @@ -14,7 +15,7 @@ from bot.bot import Bot from bot.constants import Channels, RedirectOutput from bot.exts.help_channels import _caches, _channel, _message, _name, _stats from bot.log import get_logger -from bot.utils import channel as channel_utils, lock, members, scheduling +from bot.utils import channel as channel_utils, lock log = get_logger(__name__) diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index a859d8cef..9027105d9 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -2,6 +2,7 @@ import time from typing import Optional import discord +from botcore.utils import scheduling from discord import Message, RawMessageUpdateEvent from discord.ext.commands import Cog @@ -11,7 +12,7 @@ from bot.exts.filters.token_remover import TokenRemover from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE from bot.exts.info.codeblock._instructions import get_instructions from bot.log import get_logger -from bot.utils import has_lines, scheduling +from bot.utils import has_lines from bot.utils.channel import is_help_channel from bot.utils.messages import wait_for_deletion diff --git a/bot/exts/info/doc/_batch_parser.py b/bot/exts/info/doc/_batch_parser.py index c27f28eac..41a15fb6e 100644 --- a/bot/exts/info/doc/_batch_parser.py +++ b/bot/exts/info/doc/_batch_parser.py @@ -8,12 +8,12 @@ from operator import attrgetter from typing import Deque, Dict, List, NamedTuple, Optional, Union import discord +from botcore.utils import scheduling from bs4 import BeautifulSoup import bot from bot.constants import Channels from bot.log import get_logger -from bot.utils import scheduling from . import _cog, doc_cache from ._parsing import get_symbol_markdown diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 4dc5276d9..3789fdbe3 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -10,6 +10,8 @@ from typing import Dict, NamedTuple, Optional, Tuple, Union import aiohttp import discord +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord.ext import commands from bot.api import ResponseCodeError @@ -18,10 +20,8 @@ from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import Inventory, PackageName, ValidURL, allowed_strings from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling from bot.utils.lock import SharedEvent, lock from bot.utils.messages import send_denial, wait_for_deletion -from bot.utils.scheduling import Scheduler from . import NAMESPACE, PRIORITY_PACKAGES, _batch_parser, doc_cache from ._inventory_parser import InvalidHeaderError, InventoryDict, fetch_inventory diff --git a/bot/exts/info/subscribe.py b/bot/exts/info/subscribe.py index eff0c13b8..ed134ff78 100644 --- a/bot/exts/info/subscribe.py +++ b/bot/exts/info/subscribe.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import arrow import discord +from botcore.utils import members, scheduling from discord.ext import commands from discord.interactions import Interaction @@ -12,7 +13,6 @@ from bot import constants from bot.bot import Bot from bot.decorators import redirect_output from bot.log import get_logger -from bot.utils import members, scheduling @dataclass(frozen=True) diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py index 178be734d..a8640cb1b 100644 --- a/bot/exts/moderation/defcon.py +++ b/bot/exts/moderation/defcon.py @@ -7,6 +7,8 @@ from typing import Optional, Union import arrow from aioredis import RedisError from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.relativedelta import relativedelta from discord import Colour, Embed, Forbidden, Member, TextChannel, User from discord.ext import tasks @@ -17,9 +19,8 @@ from bot.constants import Channels, Colours, Emojis, Event, Icons, MODERATION_RO from bot.converters import DurationDelta, Expiry from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.messages import format_user -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index b579416a6..d34c1c7fa 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -6,12 +6,12 @@ from typing import Optional import discord from async_rediscache import RedisCache +from botcore.utils import scheduling from discord.ext.commands import Cog, Context, MessageConverter, MessageNotFound from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Guild, Roles, Webhooks from bot.log import get_logger -from bot.utils import scheduling from bot.utils.messages import format_user, sub_clyde log = get_logger(__name__) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 2fc54856f..9f5800e2a 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -6,6 +6,7 @@ from gettext import ngettext import arrow import dateutil.parser import discord +from botcore.utils import scheduling from discord.ext.commands import Context from bot import constants @@ -16,7 +17,7 @@ from bot.converters import MemberOrUser from bot.exts.moderation.infraction import _utils from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import messages, scheduling, time +from bot.utils import messages, time from bot.utils.channel import is_mod_channel log = get_logger(__name__) diff --git a/bot/exts/moderation/metabase.py b/bot/exts/moderation/metabase.py index ce9c220b3..d68726faf 100644 --- a/bot/exts/moderation/metabase.py +++ b/bot/exts/moderation/metabase.py @@ -8,15 +8,16 @@ import arrow from aiohttp.client_exceptions import ClientResponseError from arrow import Arrow from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Metabase as MetabaseConfig, Roles from bot.converters import allowed_strings from bot.log import get_logger -from bot.utils import scheduling, send_to_paste_service +from bot.utils import send_to_paste_service from bot.utils.channel import is_mod_channel -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index b5cd29b12..cb1e4fd05 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -3,6 +3,8 @@ import datetime import arrow from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse, parse as dateutil_parse from discord import Embed, Member from discord.ext.commands import Cog, Context, group, has_any_role @@ -11,8 +13,7 @@ from bot.bot import Bot from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles from bot.converters import Expiry from bot.log import get_logger -from bot.utils import scheduling, time -from bot.utils.scheduling import Scheduler +from bot.utils import time log = get_logger(__name__) diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 511520252..307729181 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -5,6 +5,8 @@ from datetime import datetime, timedelta, timezone from typing import Optional, OrderedDict, Union from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel from discord.ext import commands, tasks from discord.ext.commands import Context @@ -14,9 +16,7 @@ from bot import constants from bot.bot import Bot from bot.converters import HushDurationConverter from bot.log import get_logger -from bot.utils import scheduling from bot.utils.lock import LockedResourceError, lock, lock_arg -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 985cc6eb1..17d24eb89 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -5,6 +5,7 @@ import arrow import discord from arrow import Arrow from async_rediscache import RedisCache +from botcore.utils import scheduling from discord.ext import commands from bot.bot import Bot @@ -14,7 +15,7 @@ from bot.constants import ( from bot.converters import Expiry from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.members import get_or_fetch_member log = get_logger(__name__) diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index ee9b6ba45..bae7ecd02 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional import discord +from botcore.utils import scheduling from discord import Color, DMChannel, Embed, HTTPException, Message, errors from discord.ext.commands import Cog, Context @@ -18,7 +19,7 @@ from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE from bot.exts.moderation.modlog import ModLog from bot.log import CustomLogger, get_logger from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages, scheduling, time +from bot.utils import CogABCMeta, messages, time from bot.utils.members import get_or_fetch_member log = get_logger(__name__) diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 0554bf37a..0d51af2ca 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -5,6 +5,7 @@ from typing import Optional, Union import discord from async_rediscache import RedisCache +from botcore.utils import scheduling from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User from discord.ext.commands import BadArgument, Cog, Context, group, has_any_role @@ -15,7 +16,7 @@ from bot.converters import MemberOrUser, UnambiguousMemberOrUser from bot.exts.recruitment.talentpool._review import Reviewer from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.members import get_or_fetch_member AUTOREVIEW_ENABLED_KEY = "autoreview_enabled" diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index b4d177622..214d85851 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta from typing import List, Optional, Union import arrow +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse from discord import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel from discord.ext.commands import Context @@ -20,7 +21,6 @@ from bot.log import get_logger from bot.utils import time from bot.utils.members import get_or_fetch_member from bot.utils.messages import count_unique_users_reaction, pin_no_system_message -from bot.utils.scheduling import Scheduler if typing.TYPE_CHECKING: from bot.exts.recruitment.talentpool._cog import TalentPool diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index ad82d49c9..62603697c 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -5,6 +5,8 @@ from datetime import datetime, timezone from operator import itemgetter import discord +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse from discord.ext.commands import Cog, Context, Greedy, group @@ -13,12 +15,11 @@ from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Role from bot.converters import Duration, UnambiguousUser from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.checks import has_any_role_check, has_no_roles_check from bot.utils.lock import lock_arg from bot.utils.members import get_or_fetch_member from bot.utils.messages import send_denial -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 3c1009d2a..2b073ed72 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -7,7 +7,8 @@ from signal import Signals from textwrap import dedent from typing import Optional, Tuple -from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX +from botcore.utils import scheduling +from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only @@ -15,7 +16,7 @@ from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs from bot.decorators import redirect_output from bot.log import get_logger -from bot.utils import scheduling, send_to_paste_service +from bot.utils import send_to_paste_service from bot.utils.messages import wait_for_deletion log = get_logger(__name__) diff --git a/bot/monkey_patches.py b/bot/monkey_patches.py deleted file mode 100644 index 4840fa454..000000000 --- a/bot/monkey_patches.py +++ /dev/null @@ -1,76 +0,0 @@ -import re -from datetime import timedelta - -import arrow -from discord import Forbidden, http -from discord.ext import commands - -from bot.log import get_logger - -log = get_logger(__name__) -MESSAGE_ID_RE = re.compile(r'(?P[0-9]{15,20})$') - - -class Command(commands.Command): - """ - A `discord.ext.commands.Command` subclass which supports root aliases. - - A `root_aliases` keyword argument is added, which is a sequence of alias names that will act as - top-level commands rather than being aliases of the command's group. It's stored as an attribute - also named `root_aliases`. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.root_aliases = kwargs.get("root_aliases", []) - - if not isinstance(self.root_aliases, (list, tuple)): - raise TypeError("Root aliases of a command must be a list or a tuple of strings.") - - -def patch_typing() -> None: - """ - Sometimes discord turns off typing events by throwing 403's. - - Handle those issues by patching the trigger_typing method so it ignores 403's in general. - """ - log.debug("Patching send_typing, which should fix things breaking when discord disables typing events. Stay safe!") - - original = http.HTTPClient.send_typing - last_403 = None - - async def honeybadger_type(self, channel_id: int) -> None: # noqa: ANN001 - nonlocal last_403 - if last_403 and (arrow.utcnow() - last_403) < timedelta(minutes=5): - log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.") - return - try: - await original(self, channel_id) - except Forbidden: - last_403 = arrow.utcnow() - log.warning("Got a 403 from typing event!") - pass - - http.HTTPClient.send_typing = honeybadger_type - - -class FixedPartialMessageConverter(commands.PartialMessageConverter): - """ - Make the Message converter infer channelID from the given context if only a messageID is given. - - Discord.py's Message converter is supposed to infer channelID based - on ctx.channel if only a messageID is given. A refactor commit, linked below, - a few weeks before d.py's archival broke this defined behaviour of the converter. - Currently, if only a messageID is given to the converter, it will only find that message - if it's in the bot's cache. - - https://github.com/Rapptz/discord.py/commit/1a4e73d59932cdbe7bf2c281f25e32529fc7ae1f - """ - - @staticmethod - def _get_id_matches(ctx: commands.Context, argument: str) -> tuple[int, int, int]: - """Inserts ctx.channel.id before calling super method if argument is just a messageID.""" - match = MESSAGE_ID_RE.match(argument) - if match: - argument = f"{ctx.channel.id}-{match.group('message_id')}" - return commands.PartialMessageConverter._get_id_matches(ctx, argument) diff --git a/bot/utils/messages.py b/bot/utils/messages.py index e55c07062..a5ed84351 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -6,12 +6,12 @@ from io import BytesIO from typing import Callable, List, Optional, Sequence, Union import discord +from botcore.utils import scheduling from discord.ext.commands import Context import bot from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES from bot.log import get_logger -from bot.utils import scheduling log = get_logger(__name__) diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py deleted file mode 100644 index 23acacf74..000000000 --- a/bot/utils/scheduling.py +++ /dev/null @@ -1,194 +0,0 @@ -import asyncio -import contextlib -import inspect -import typing as t -from datetime import datetime -from functools import partial - -from arrow import Arrow - -from bot.log import get_logger - - -class Scheduler: - """ - Schedule the execution of coroutines and keep track of them. - - When instantiating a Scheduler, a name must be provided. This name is used to distinguish the - instance's log messages from other instances. Using the name of the class or module containing - the instance is suggested. - - Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at` - or `schedule_later`. A unique ID is required to be given in order to keep track of the - resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing - the same ID used to schedule it. The `in` operator is supported for checking if a task with a - given ID is currently scheduled. - - Any exception raised in a scheduled task is logged when the task is done. - """ - - def __init__(self, name: str): - self.name = name - - self._log = get_logger(f"{__name__}.{name}") - self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} - - def __contains__(self, task_id: t.Hashable) -> bool: - """Return True if a task with the given `task_id` is currently scheduled.""" - return task_id in self._scheduled_tasks - - def schedule(self, task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule the execution of a `coroutine`. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - self._log.trace(f"Scheduling task #{task_id}...") - - msg = f"Cannot schedule an already started coroutine for #{task_id}" - assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg - - if task_id in self._scheduled_tasks: - self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") - coroutine.close() - return - - task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") - task.add_done_callback(partial(self._task_done_callback, task_id)) - - self._scheduled_tasks[task_id] = task - self._log.debug(f"Scheduled task #{task_id} {id(task)}.") - - def schedule_at(self, time: t.Union[datetime, Arrow], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule `coroutine` to be executed at the given `time`. - - If `time` is timezone aware, then use that timezone to calculate now() when subtracting. - If `time` is naïve, then use UTC. - - If `time` is in the past, schedule `coroutine` immediately. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() - delay = (time - now_datetime).total_seconds() - if delay > 0: - coroutine = self._await_later(delay, task_id, coroutine) - - self.schedule(task_id, coroutine) - - def schedule_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule `coroutine` to be executed after the given `delay` number of seconds. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - self.schedule(task_id, self._await_later(delay, task_id, coroutine)) - - def cancel(self, task_id: t.Hashable) -> None: - """Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist.""" - self._log.trace(f"Cancelling task #{task_id}...") - - try: - task = self._scheduled_tasks.pop(task_id) - except KeyError: - self._log.warning(f"Failed to unschedule {task_id} (no task found).") - else: - task.cancel() - - self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") - - def cancel_all(self) -> None: - """Unschedule all known tasks.""" - self._log.debug("Unscheduling all tasks") - - for task_id in self._scheduled_tasks.copy(): - self.cancel(task_id) - - async def _await_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """Await `coroutine` after the given `delay` number of seconds.""" - try: - self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") - await asyncio.sleep(delay) - - # Use asyncio.shield to prevent the coroutine from cancelling itself. - self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") - await asyncio.shield(coroutine) - finally: - # Close it to prevent unawaited coroutine warnings, - # which would happen if the task was cancelled during the sleep. - # Only close it if it's not been awaited yet. This check is important because the - # coroutine may cancel this task, which would also trigger the finally block. - state = inspect.getcoroutinestate(coroutine) - if state == "CORO_CREATED": - self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") - coroutine.close() - else: - self._log.debug(f"Finally block reached for #{task_id}; {state=}") - - def _task_done_callback(self, task_id: t.Hashable, done_task: asyncio.Task) -> None: - """ - Delete the task and raise its exception if one exists. - - If `done_task` and the task associated with `task_id` are different, then the latter - will not be deleted. In this case, a new task was likely rescheduled with the same ID. - """ - self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") - - scheduled_task = self._scheduled_tasks.get(task_id) - - if scheduled_task and done_task is scheduled_task: - # A task for the ID exists and is the same as the done task. - # Since this is the done callback, the task is already done so no need to cancel it. - self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") - del self._scheduled_tasks[task_id] - elif scheduled_task: - # A new task was likely rescheduled with the same ID. - self._log.debug( - f"The scheduled task #{task_id} {id(scheduled_task)} " - f"and the done task {id(done_task)} differ." - ) - elif not done_task.cancelled(): - self._log.warning( - f"Task #{task_id} not found while handling task {id(done_task)}! " - f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." - ) - - with contextlib.suppress(asyncio.CancelledError): - exception = done_task.exception() - # Log the exception if one exists. - if exception: - self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) - - -def create_task( - coro: t.Awaitable, - *, - suppressed_exceptions: tuple[t.Type[Exception]] = (), - event_loop: t.Optional[asyncio.AbstractEventLoop] = None, - **kwargs, -) -> asyncio.Task: - """ - Wrapper for creating asyncio `Task`s which logs exceptions raised in the task. - - If the loop kwarg is provided, the task is created from that event loop, otherwise the running loop is used. - """ - if event_loop is not None: - task = event_loop.create_task(coro, **kwargs) - else: - task = asyncio.create_task(coro, **kwargs) - task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) - return task - - -def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: t.Tuple[t.Type[Exception]]) -> None: - """Retrieve and log the exception raised in `task` if one exists.""" - with contextlib.suppress(asyncio.CancelledError): - exception = task.exception() - # Log the exception if one exists. - if exception and not isinstance(exception, suppressed_exceptions): - log = get_logger(__name__) - log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index fdd0ab74a..7dff38f96 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -60,7 +60,7 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch("bot.utils.scheduling.create_task") + @mock.patch("botcore.utils.scheduling.create_task") @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) def test_sync_cog_init(self, sync_guild, create_task): """Should instantiate syncers and run a sync for the guild.""" diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py index 8ae59c1f1..bd26532f1 100644 --- a/tests/bot/exts/filters/test_filtering.py +++ b/tests/bot/exts/filters/test_filtering.py @@ -11,7 +11,7 @@ class FilteringCogTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Instantiate the bot and cog.""" self.bot = MockBot() - with patch("bot.utils.scheduling.create_task", new=lambda task, **_: task.close()): + with patch("botcore.utils.scheduling.create_task", new=lambda task, **_: task.close()): self.cog = filtering.Filtering(self.bot) @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) -- cgit v1.2.3 From 047705ac91c2997ccb509ea4e1fb3fad38840412 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Thu, 31 Mar 2022 20:46:45 +0100 Subject: Remove async stats and site api wrapper We now source them from bot-core, so no need to have them here too. --- bot/api.py | 102 --------------------- bot/async_stats.py | 40 -------- bot/converters.py | 2 +- bot/exts/backend/error_handler.py | 2 +- bot/exts/backend/sync/_cog.py | 2 +- bot/exts/backend/sync/_syncers.py | 2 +- bot/exts/filters/filter_lists.py | 2 +- bot/exts/filters/filtering.py | 2 +- bot/exts/fun/off_topic_names.py | 2 +- bot/exts/info/doc/_cog.py | 2 +- bot/exts/info/information.py | 2 +- bot/exts/moderation/infraction/_scheduler.py | 2 +- bot/exts/moderation/infraction/_utils.py | 2 +- bot/exts/moderation/voice_gate.py | 2 +- bot/exts/moderation/watchchannels/_watchchannel.py | 2 +- bot/exts/recruitment/talentpool/_cog.py | 2 +- bot/exts/recruitment/talentpool/_review.py | 2 +- tests/bot/exts/backend/sync/test_base.py | 3 +- tests/bot/exts/backend/sync/test_cog.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 2 +- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- tests/bot/test_api.py | 66 ------------- tests/helpers.py | 4 +- 23 files changed, 22 insertions(+), 229 deletions(-) delete mode 100644 bot/api.py delete mode 100644 bot/async_stats.py delete mode 100644 tests/bot/test_api.py (limited to 'tests') diff --git a/bot/api.py b/bot/api.py deleted file mode 100644 index 856f7c865..000000000 --- a/bot/api.py +++ /dev/null @@ -1,102 +0,0 @@ -import asyncio -from typing import Optional -from urllib.parse import quote as quote_url - -import aiohttp - -from bot.log import get_logger - -from .constants import Keys, URLs - -log = get_logger(__name__) - - -class ResponseCodeError(ValueError): - """Raised when a non-OK HTTP response is received.""" - - def __init__( - self, - response: aiohttp.ClientResponse, - response_json: Optional[dict] = None, - response_text: str = "" - ): - self.status = response.status - self.response_json = response_json or {} - self.response_text = response_text - self.response = response - - def __str__(self): - response = self.response_json if self.response_json else self.response_text - return f"Status: {self.status} Response: {response}" - - -class APIClient: - """Django Site API wrapper.""" - - # These are class attributes so they can be seen when being mocked for tests. - # See commit 22a55534ef13990815a6f69d361e2a12693075d5 for details. - session: Optional[aiohttp.ClientSession] = None - loop: asyncio.AbstractEventLoop = None - - def __init__(self, **session_kwargs): - auth_headers = { - 'Authorization': f"Token {Keys.site_api}" - } - - if 'headers' in session_kwargs: - session_kwargs['headers'].update(auth_headers) - else: - session_kwargs['headers'] = auth_headers - - # aiohttp will complain if APIClient gets instantiated outside a coroutine. Thankfully, we - # don't and shouldn't need to do that, so we can avoid scheduling a task to create it. - self.session = aiohttp.ClientSession(**session_kwargs) - - @staticmethod - def _url_for(endpoint: str) -> str: - return f"{URLs.site_api_schema}{URLs.site_api}/{quote_url(endpoint)}" - - async def close(self) -> None: - """Close the aiohttp session.""" - await self.session.close() - - async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: - """Raise ResponseCodeError for non-OK response if an exception should be raised.""" - if should_raise and response.status >= 400: - try: - response_json = await response.json() - raise ResponseCodeError(response=response, response_json=response_json) - except aiohttp.ContentTypeError: - response_text = await response.text() - raise ResponseCodeError(response=response, response_text=response_text) - - async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Send an HTTP request to the site API and return the JSON response.""" - async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() - - async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Site API GET.""" - return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Site API PATCH.""" - return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Site API POST.""" - return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Site API PUT.""" - return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]: - """Site API DELETE.""" - async with self.session.delete(self._url_for(endpoint), **kwargs) as resp: - if resp.status == 204: - return None - - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() diff --git a/bot/async_stats.py b/bot/async_stats.py deleted file mode 100644 index 0303de7a1..000000000 --- a/bot/async_stats.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -import socket - -from botcore.utils import scheduling -from statsd.client.base import StatsClientBase - - -class AsyncStatsClient(StatsClientBase): - """An async transport method for statsd communication.""" - - def __init__( - self, - loop: asyncio.AbstractEventLoop, - host: str = 'localhost', - port: int = 8125, - prefix: str = None - ): - """Create a new client.""" - family, _, _, _, addr = socket.getaddrinfo( - host, port, socket.AF_INET, socket.SOCK_DGRAM)[0] - self._addr = addr - self._prefix = prefix - self._loop = loop - self._transport = None - - async def create_socket(self) -> None: - """Use the loop.create_datagram_endpoint method to create a socket.""" - self._transport, _ = await self._loop.create_datagram_endpoint( - asyncio.DatagramProtocol, - family=socket.AF_INET, - remote_addr=self._addr - ) - - def _send(self, data: str) -> None: - """Start an async task to send data to statsd.""" - scheduling.create_task(self._async_send(data), event_loop=self._loop) - - async def _async_send(self, data: str) -> None: - """Send data to the statsd server using the async transport.""" - self._transport.sendto(data.encode('ascii'), self._addr) diff --git a/bot/converters.py b/bot/converters.py index e819e4713..a3f4630a0 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -8,13 +8,13 @@ from ssl import CertificateError import dateutil.parser import discord from aiohttp import ClientConnectorError +from botcore.site_api import ResponseCodeError from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter from discord.utils import escape_markdown, snowflake_time from bot import exts -from bot.api import ResponseCodeError from bot.constants import URLs from bot.errors import InvalidInfraction from bot.exts.info.doc import _inventory_parser diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index fabb2dbb5..5391a7f15 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,10 +1,10 @@ import difflib +from botcore.site_api import ResponseCodeError from discord import Embed from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from sentry_sdk import push_scope -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Colours, Icons, MODERATION_ROLES from bot.errors import InvalidInfractedUserError, LockedResourceError diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index 58aabc141..a5bf82397 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -1,11 +1,11 @@ from typing import Any, Dict +from botcore.site_api import ResponseCodeError from discord import Member, Role, User from discord.ext import commands from discord.ext.commands import Cog, Context from bot import constants -from bot.api import ResponseCodeError from bot.bot import Bot from bot.exts.backend.sync import _syncers from bot.log import get_logger diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 45301b098..e1c4541ef 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,12 +2,12 @@ import abc import typing as t from collections import namedtuple +from botcore.site_api import ResponseCodeError from discord import Guild from discord.ext.commands import Context from more_itertools import chunked import bot -from bot.api import ResponseCodeError from bot.log import get_logger from bot.utils.members import get_or_fetch_member diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py index 3e3f5c562..fc9cfbeca 100644 --- a/bot/exts/filters/filter_lists.py +++ b/bot/exts/filters/filter_lists.py @@ -1,11 +1,11 @@ import re from typing import Optional +from botcore.site_api import ResponseCodeError from discord import Colour, Embed from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role from bot import constants -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels from bot.converters import ValidDiscordServerInvite, ValidFilterListType diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index cabb7f0b6..6982f5948 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -9,6 +9,7 @@ import dateutil.parser import regex import tldextract from async_rediscache import RedisCache +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta @@ -16,7 +17,6 @@ from discord import ChannelType, Colour, Embed, Forbidden, HTTPException, Member from discord.ext.commands import Cog from discord.utils import escape_markdown -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, Colours, Filter, Guild, Icons, URLs from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index ac172f2a8..d8111bdf5 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -2,12 +2,12 @@ import difflib from datetime import timedelta import arrow +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from discord import Colour, Embed from discord.ext.commands import Cog, Context, group, has_any_role from discord.utils import sleep_until -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES from bot.converters import OffTopicName diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 8c3038c5b..bbdc4e82a 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -10,11 +10,11 @@ from typing import Dict, NamedTuple, Optional, Tuple, Union import aiohttp import discord +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from botcore.utils.scheduling import Scheduler from discord.ext import commands -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import Inventory, PackageName, ValidURL, allowed_strings diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index b56fd171a..e7d17c971 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -6,12 +6,12 @@ from textwrap import shorten from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union import rapidfuzz +from botcore.site_api import ResponseCodeError from discord import AllowedMentions, Colour, Embed, Guild, Message, Role from discord.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role from discord.utils import escape_markdown from bot import constants -from bot.api import ResponseCodeError from bot.bot import Bot from bot.converters import MemberOrUser from bot.decorators import in_whitelist diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 137358ec3..9c73bde5f 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -6,11 +6,11 @@ from gettext import ngettext import arrow import dateutil.parser import discord +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from discord.ext.commands import Context from bot import constants -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Colours from bot.converters import MemberOrUser diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index c1be18362..3a2485ec2 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -3,10 +3,10 @@ from datetime import datetime import arrow import discord +from botcore.site_api import ResponseCodeError from discord.ext.commands import Context import bot -from bot.api import ResponseCodeError from bot.constants import Colours, Icons from bot.converters import MemberOrUser from bot.errors import InvalidInfractedUserError diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py index 33096e7e0..9b1621c01 100644 --- a/bot/exts/moderation/voice_gate.py +++ b/bot/exts/moderation/voice_gate.py @@ -5,10 +5,10 @@ from datetime import timedelta import arrow import discord from async_rediscache import RedisCache +from botcore.site_api import ResponseCodeError from discord import Colour, Member, VoiceState from discord.ext.commands import Cog, Context, command -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Roles, VoiceGate as GateConf from bot.decorators import has_no_roles, in_whitelist diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index ab5ce62f9..bc78b3934 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -7,11 +7,11 @@ from dataclasses import dataclass from typing import Any, Dict, Optional import discord +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from discord import Color, DMChannel, Embed, HTTPException, Message, errors from discord.ext.commands import Cog, Context -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons from bot.exts.filters.token_remover import TokenRemover diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 8aa124536..24496af54 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -5,11 +5,11 @@ from typing import Optional, Union import discord from async_rediscache import RedisCache +from botcore.site_api import ResponseCodeError from botcore.utils import scheduling from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User from discord.ext.commands import BadArgument, Cog, Context, group, has_any_role -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles, STAFF_ROLES from bot.converters import MemberOrUser, UnambiguousMemberOrUser diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index d0edf5388..be181d005 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -9,12 +9,12 @@ from datetime import datetime, timedelta from typing import List, Optional, Union import arrow +from botcore.site_api import ResponseCodeError from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse from discord import Embed, Emoji, Member, Message, NotFound, PartialMessage, TextChannel from discord.ext.commands import Context -from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Guild, Roles from bot.log import get_logger diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 9dc46005b..a17c1fa10 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -1,7 +1,8 @@ import unittest from unittest import mock -from bot.api import ResponseCodeError +from botcore.site_api import ResponseCodeError + from bot.exts.backend.sync._syncers import Syncer from tests import helpers diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 7dff38f96..4ec36e39f 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -2,9 +2,9 @@ import unittest from unittest import mock import discord +from botcore.site_api import ResponseCodeError from bot import constants -from bot.api import ResponseCodeError from bot.exts.backend import sync from bot.exts.backend.sync._cog import Sync from bot.exts.backend.sync._syncers import Syncer diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 35fa0ee59..04a018289 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -1,9 +1,9 @@ import unittest from unittest.mock import AsyncMock, MagicMock, call, patch +from botcore.site_api import ResponseCodeError from discord.ext.commands import errors -from bot.api import ResponseCodeError from bot.errors import InvalidInfractedUserError, LockedResourceError from bot.exts.backend.error_handler import ErrorHandler, setup from bot.exts.info.tags import Tags diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index ff81ddd65..5cf02033d 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -3,9 +3,9 @@ from collections import namedtuple from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch +from botcore.site_api import ResponseCodeError from discord import Embed, Forbidden, HTTPException, NotFound -from bot.api import ResponseCodeError from bot.constants import Colours, Icons from bot.exts.moderation.infraction import _utils as utils from tests.helpers import MockBot, MockContext, MockMember, MockUser diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py deleted file mode 100644 index 76bcb481d..000000000 --- a/tests/bot/test_api.py +++ /dev/null @@ -1,66 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from bot import api - - -class APIClientTests(unittest.IsolatedAsyncioTestCase): - """Tests for the bot's API client.""" - - @classmethod - def setUpClass(cls): - """Sets up the shared fixtures for the tests.""" - cls.error_api_response = MagicMock() - cls.error_api_response.status = 999 - - def test_response_code_error_default_initialization(self): - """Test the default initialization of `ResponseCodeError` without `text` or `json`""" - error = api.ResponseCodeError(response=self.error_api_response) - - self.assertIs(error.status, self.error_api_response.status) - self.assertEqual(error.response_json, {}) - self.assertEqual(error.response_text, "") - self.assertIs(error.response, self.error_api_response) - - def test_response_code_error_string_representation_default_initialization(self): - """Test the string representation of `ResponseCodeError` initialized without text or json.""" - error = api.ResponseCodeError(response=self.error_api_response) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") - - def test_response_code_error_initialization_with_json(self): - """Test the initialization of `ResponseCodeError` with json.""" - json_data = {'hello': 'world'} - error = api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data, - ) - self.assertEqual(error.response_json, json_data) - self.assertEqual(error.response_text, "") - - def test_response_code_error_string_representation_with_nonempty_response_json(self): - """Test the string representation of `ResponseCodeError` initialized with json.""" - json_data = {'hello': 'world'} - error = api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}") - - def test_response_code_error_initialization_with_text(self): - """Test the initialization of `ResponseCodeError` with text.""" - text_data = 'Lemon will eat your soul' - error = api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data, - ) - self.assertEqual(error.response_text, text_data) - self.assertEqual(error.response_json, {}) - - def test_response_code_error_string_representation_with_nonempty_response_text(self): - """Test the string representation of `ResponseCodeError` initialized with text.""" - text_data = 'Lemon will eat your soul' - error = api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") diff --git a/tests/helpers.py b/tests/helpers.py index 9d4988d23..3e6290e58 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -9,10 +9,10 @@ from typing import Iterable, Optional import discord from aiohttp import ClientSession +from botcore.async_stats import AsyncStatsClient +from botcore.site_api import APIClient from discord.ext.commands import Context -from bot.api import APIClient -from bot.async_stats import AsyncStatsClient from bot.bot import Bot from tests._autospec import autospec # noqa: F401 other modules import it via this module -- cgit v1.2.3 From 277bb011d8ae6b1c58c9de80d10b61791ee1fc49 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:42:58 +0100 Subject: Adding missing kwargs required by BotBase in test helper --- tests/helpers.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 3e6290e58..e6e95c20c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -312,6 +312,9 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop(), redis_session=unittest.mock.MagicMock(), + http_session=unittest.mock.MagicMock(), + allowed_roles=[1], + guild_id=1, ) additional_spec_asyncs = ("wait_for", "redis_ready") -- cgit v1.2.3 From 56e38eeb38de10611611f0f81f5cbc429d7f2bc8 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:44:53 +0100 Subject: Update test helpers with breaking d.py changes region was removed from the guild object, so this has been replaced with features add_cog is now async, so it is now an async_mock during tests Two new required voice_channel attrs were added channel.type is required to be set to ChannelType due to a new isinstance check in d.py --- tests/helpers.py | 4 ++++ tests/test_helpers.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index e6e95c20c..2f0c9b4ad 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -325,6 +325,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.api_client = MockAPIClient(loop=self.loop) self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) + self.add_cog = unittest.mock.AsyncMock() # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` @@ -337,6 +338,8 @@ channel_data = { 'position': 1, 'nsfw': False, 'last_message_id': 1, + 'bitrate': 1337, + 'user_limit': 25, } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() @@ -441,6 +444,7 @@ message_data = { } state = unittest.mock.MagicMock() channel = unittest.mock.MagicMock() +channel.type = discord.ChannelType.text message_instance = discord.Message(state=state, channel=channel, data=message_data) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 81285e009..f3040b305 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -327,7 +327,7 @@ class MockObjectTests(unittest.TestCase): def test_spec_propagation_of_mock_subclasses(self): """Test if the `spec` does not propagate to attributes of the mock object.""" test_values = ( - (helpers.MockGuild, "region"), + (helpers.MockGuild, "features"), (helpers.MockRole, "mentionable"), (helpers.MockMember, "display_name"), (helpers.MockBot, "owner_id"), -- cgit v1.2.3 From 450c9ce5b9bb711681ce87508d5b33a0ad6aed52 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:47:13 +0100 Subject: Update tests to use new async cog setup function --- tests/bot/exts/backend/sync/test_cog.py | 6 +++--- tests/bot/exts/backend/test_error_handler.py | 8 ++++---- tests/bot/exts/events/test_code_jams.py | 8 ++++---- tests/bot/exts/filters/test_antimalware.py | 8 ++++---- tests/bot/exts/filters/test_security.py | 11 +++++------ tests/bot/exts/filters/test_token_remover.py | 8 ++++---- tests/bot/exts/utils/test_snekbox.py | 8 ++++---- 7 files changed, 28 insertions(+), 29 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 4ec36e39f..ce620aa8d 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -16,11 +16,11 @@ class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the sync extension.""" @staticmethod - def test_extension_setup(): + async def test_extension_setup(): """The Sync cog should be added.""" bot = helpers.MockBot() - sync.setup(bot) - bot.add_cog.assert_called_once() + await sync.setup(bot) + bot.add_cog.assert_awaited_once() class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 04a018289..193f1d822 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -544,11 +544,11 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): push_scope_mock.set_extra.has_calls(set_extra_calls) -class ErrorHandlerSetupTests(unittest.TestCase): +class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): """Tests for `ErrorHandler` `setup` function.""" - def test_setup(self): + async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - setup(bot) - bot.add_cog.assert_called_once() + await setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/events/test_code_jams.py b/tests/bot/exts/events/test_code_jams.py index 0856546af..684f7abcd 100644 --- a/tests/bot/exts/events/test_code_jams.py +++ b/tests/bot/exts/events/test_code_jams.py @@ -160,11 +160,11 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase): member.add_roles.assert_not_awaited() -class CodeJamSetup(unittest.TestCase): +class CodeJamSetup(unittest.IsolatedAsyncioTestCase): """Test for `setup` function of `CodeJam` cog.""" - def test_setup(self): + async def test_setup(self): """Should call `bot.add_cog`.""" bot = MockBot() - code_jams.setup(bot) - bot.add_cog.assert_called_once() + await code_jams.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index 06d78de9d..7282334e2 100644 --- a/tests/bot/exts/filters/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -192,11 +192,11 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) -class AntiMalwareSetupTests(unittest.TestCase): +class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase): """Tests setup of the `AntiMalware` cog.""" - def test_setup(self): + async def test_setup(self): """Setup of the extension should call add_cog.""" bot = MockBot() - antimalware.setup(bot) - bot.add_cog.assert_called_once() + await antimalware.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py index c0c3baa42..007b7b1eb 100644 --- a/tests/bot/exts/filters/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -1,5 +1,4 @@ import unittest -from unittest.mock import MagicMock from discord.ext.commands import NoPrivateMessage @@ -44,11 +43,11 @@ class SecurityCogTests(unittest.TestCase): self.assertTrue(self.cog.check_on_guild(self.ctx)) -class SecurityCogLoadTests(unittest.TestCase): +class SecurityCogLoadTests(unittest.IsolatedAsyncioTestCase): """Tests loading the `Security` cog.""" - def test_security_cog_load(self): + async def test_security_cog_load(self): """Setup of the extension should call add_cog.""" - bot = MagicMock() - security.setup(bot) - bot.add_cog.assert_called_once() + bot = MockBot() + await security.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py index 4db27269a..c1f3762ac 100644 --- a/tests/bot/exts/filters/test_token_remover.py +++ b/tests/bot/exts/filters/test_token_remover.py @@ -395,15 +395,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.channel.send.assert_not_awaited() -class TokenRemoverExtensionTests(unittest.TestCase): +class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the token_remover extension.""" @autospec("bot.exts.filters.token_remover", "TokenRemover") - def test_extension_setup(self, cog): + async def test_extension_setup(self, cog): """The TokenRemover cog should be added.""" bot = MockBot() - token_remover.setup(bot) + await token_remover.setup(bot) cog.assert_called_once_with(bot) - bot.add_cog.assert_called_once() + bot.add_cog.assert_awaited_once() self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index f68a20089..3c555c051 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -403,11 +403,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(actual, expected) -class SnekboxSetupTests(unittest.TestCase): +class SnekboxSetupTests(unittest.IsolatedAsyncioTestCase): """Tests setup of the `Snekbox` cog.""" - def test_setup(self): + async def test_setup(self): """Setup of the extension should call add_cog.""" bot = MockBot() - snekbox.setup(bot) - bot.add_cog.assert_called_once() + await snekbox.setup(bot) + bot.add_cog.assert_awaited_once() -- cgit v1.2.3 From 4db508e7abdc9d5835940f15d5faecdd2f045d0a Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:49:36 +0100 Subject: Update tests to use new async cog_load function --- tests/bot/exts/backend/sync/test_cog.py | 2 +- tests/bot/exts/moderation/test_silence.py | 40 +++++++++++++------------------ 2 files changed, 17 insertions(+), 25 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index ce620aa8d..4e9941d88 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -87,7 +87,7 @@ class SyncCogTests(SyncCogTestCase): self.bot.get_guild = mock.MagicMock(return_value=guild) - await self.cog.sync_guild() + await self.cog.cog_load() self.bot.wait_until_guild_available.assert_called_once() self.bot.get_guild.assert_called_once_with(constants.Guild.id) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 92ce3418a..2ebb16978 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -114,44 +114,36 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.cog = silence.Silence(self.bot) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_got_guild(self): + async def testcog_load_got_guild(self): """Bot got guild after it became available.""" - await self.cog._async_init() + await self.cog.cog_load() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_got_channels(self): + async def testcog_load_got_channels(self): """Got channels from bot.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog._async_init() + await self.cog.cog_load() self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") - async def test_async_init_got_notifier(self, notifier): + async def testcog_load_got_notifier(self, notifier): """Notifier was started with channel.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog._async_init() + await self.cog.cog_load() notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) self.assertEqual(self.cog.notifier, notifier.return_value) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_rescheduled(self): + async def testcog_load_rescheduled(self): """`_reschedule_` coroutine was awaited.""" self.cog._reschedule = mock.create_autospec(self.cog._reschedule) - await self.cog._async_init() + await self.cog.cog_load() self.cog._reschedule.assert_awaited_once_with() - def test_cog_unload_cancelled_tasks(self): - """The init task was cancelled.""" - self.cog._init_task = asyncio.Future() - self.cog.cog_unload() - - # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. - self.assertTrue(self.cog._init_task.cancelled()) - @autospec("discord.ext.commands", "has_any_role") @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): @@ -165,7 +157,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_force_voice_sync(self): """Tests the _force_voice_sync helper function.""" - await self.cog._async_init() + await self.cog.cog_load() # Create a regular member, and one member for each of the moderation roles moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] @@ -187,7 +179,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_force_voice_sync_no_channel(self): """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" - await self.cog._async_init() + await self.cog.cog_load() channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) new_channel = MockVoiceChannel(delete=AsyncMock()) @@ -206,7 +198,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_voice_kick(self): """Test to ensure kick function can remove all members from a voice channel.""" - await self.cog._async_init() + await self.cog.cog_load() # Create a regular member, and one member for each of the moderation roles moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] @@ -236,7 +228,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_kick_move_to_error(self): """Test to ensure move_to gets called on all members during kick, even if some fail.""" - await self.cog._async_init() + await self.cog.cog_load() _, members = self.create_erroneous_members() await self.cog._kick_voice_members(MockVoiceChannel(members=members)) @@ -245,7 +237,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_sync_move_to_error(self): """Test to ensure move_to gets called on all members during sync, even if some fail.""" - await self.cog._async_init() + await self.cog.cog_load() failing_member, members = self.create_erroneous_members() await self.cog._force_voice_sync(MockVoiceChannel(members=members)) @@ -339,7 +331,7 @@ class RescheduleTests(unittest.IsolatedAsyncioTestCase): self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) with mock.patch.object(self.cog, "_reschedule", autospec=True): - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. async def test_skipped_missing_channel(self): """Did nothing because the channel couldn't be retrieved.""" @@ -428,7 +420,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite( @@ -701,7 +693,7 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) self.cog.previous_overwrites = overwrites_cache - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. self.cog.scheduler.__contains__.return_value = True overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' -- cgit v1.2.3 From 0afe07d0734e50854d7abd8685086e79858fcadf Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:50:33 +0100 Subject: Remove sync cog init test Discord.py now implicitly calls the new async cog_load function from within it's internals on load. There is no longer a need to test that this happens. --- tests/bot/exts/backend/sync/test_cog.py | 17 ----------------- 1 file changed, 17 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 4e9941d88..28afeebeb 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -60,23 +60,6 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch("botcore.utils.scheduling.create_task") - @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) - def test_sync_cog_init(self, sync_guild, create_task): - """Should instantiate syncers and run a sync for the guild.""" - # Reset because a Sync cog was already instantiated in setUp. - self.RoleSyncer.reset_mock() - self.UserSyncer.reset_mock() - - mock_sync_guild_coro = mock.MagicMock() - sync_guild.return_value = mock_sync_guild_coro - - Sync(self.bot) - - sync_guild.assert_called_once_with() - create_task.assert_called_once() - self.assertEqual(create_task.call_args.args[0], mock_sync_guild_coro) - async def test_sync_cog_sync_guild(self): """Roles and users should be synced only if a guild is successfully retrieved.""" for guild in (helpers.MockGuild(), None): -- cgit v1.2.3 From 4a9e2819929908182f8c6a148502671f281357ca Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 2 Apr 2022 22:51:01 +0100 Subject: Don't try to overwrite a read-only attr in help command test --- tests/bot/exts/info/test_help.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/info/test_help.py b/tests/bot/exts/info/test_help.py index 604c69671..21d124f3a 100644 --- a/tests/bot/exts/info/test_help.py +++ b/tests/bot/exts/info/test_help.py @@ -1,4 +1,5 @@ import unittest +import unittest.mock import rapidfuzz @@ -12,7 +13,6 @@ class HelpCogTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = help.Help(self.bot) self.ctx = MockContext(bot=self.bot) - self.bot.help_command.context = self.ctx @autospec(help.CustomHelpCommand, "get_all_help_choices", return_value={"help"}, pass_mocks=False) async def test_help_fuzzy_matching(self): -- cgit v1.2.3 From 87edeff7347f011c2317cd4b3681bea5bbe07185 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 18 Apr 2022 17:39:20 +0100 Subject: Test that sync cog syncers run when sync cog is loaded --- tests/bot/exts/backend/sync/test_cog.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'tests') diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 28afeebeb..87b76c6b4 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -60,6 +60,19 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" + async def test_sync_cog_sync_on_load(self): + """Roles and users should be synced on cog load.""" + guild = helpers.MockGuild() + self.bot.get_guild = mock.MagicMock(return_value=guild) + + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + + await self.cog.cog_load() + + self.RoleSyncer.sync.assert_called_once_with(guild) + self.UserSyncer.sync.assert_called_once_with(guild) + async def test_sync_cog_sync_guild(self): """Roles and users should be synced only if a guild is successfully retrieved.""" for guild in (helpers.MockGuild(), None): -- cgit v1.2.3 From 267f6d94cb10a45f873ea1d0e22f812267be5e69 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 18 Apr 2022 17:40:41 +0100 Subject: Add missing underscores to test function names --- tests/bot/exts/moderation/test_silence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 2ebb16978..65aecad28 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -114,14 +114,14 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.cog = silence.Silence(self.bot) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def testcog_load_got_guild(self): + async def test_cog_load_got_guild(self): """Bot got guild after it became available.""" await self.cog.cog_load() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def testcog_load_got_channels(self): + async def test_cog_load_got_channels(self): """Got channels from bot.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) @@ -129,7 +129,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") - async def testcog_load_got_notifier(self, notifier): + async def test_cog_load_got_notifier(self, notifier): """Notifier was started with channel.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) -- cgit v1.2.3 From 4c1a076dd0b7c021ac1b352589bb353139f86a6f Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Tue, 19 Apr 2022 17:23:28 +0100 Subject: Pass the now required intents kwarg when creating MockBot --- tests/helpers.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 2f0c9b4ad..a6e4bdd66 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -315,6 +315,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): http_session=unittest.mock.MagicMock(), allowed_roles=[1], guild_id=1, + intents=discord.Intents.all(), ) additional_spec_asyncs = ("wait_for", "redis_ready") -- cgit v1.2.3 From 03a8c8138c53b6582a0921b3dcf7a1b7d55877de Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Wed, 20 Apr 2022 21:06:13 +0100 Subject: remove unneeded import in tests --- tests/bot/exts/info/test_help.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/info/test_help.py b/tests/bot/exts/info/test_help.py index 21d124f3a..2644ae40d 100644 --- a/tests/bot/exts/info/test_help.py +++ b/tests/bot/exts/info/test_help.py @@ -1,5 +1,4 @@ import unittest -import unittest.mock import rapidfuzz -- cgit v1.2.3 From 8f9f25a796f6cc07f01f8f7f56e825cb5ebf56c8 Mon Sep 17 00:00:00 2001 From: Hassan Abouelela Date: Sat, 23 Apr 2022 14:24:20 +0400 Subject: Speed Up Sync Cog Loading The user syncer was blocking the startup of the sync cog due to having to perform thousands of pointless member fetch requests. This speeds up that process by increasing the probability that the cache is up-to-date using `Guild.chunked`, and limiting the fetches to members who were in the guild during the previous sync only. Co-authored-by: ChrisJL Co-authored-by: wookie184 Signed-off-by: Hassan Abouelela --- bot/exts/backend/sync/_cog.py | 6 ++++++ bot/exts/backend/sync/_syncers.py | 13 +++++++++++-- tests/helpers.py | 2 +- 3 files changed, 18 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index a5bf82397..4ec822d3f 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict from botcore.site_api import ResponseCodeError @@ -27,6 +28,11 @@ class Sync(Cog): if guild is None: return + log.info("Waiting for guild to be chunked to start syncers.") + while not guild.chunked: + await asyncio.sleep(10) + log.info("Starting syncers.") + for syncer in (_syncers.RoleSyncer, _syncers.UserSyncer): await syncer.sync(guild) diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index e1c4541ef..799137cb9 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -2,6 +2,7 @@ import abc import typing as t from collections import namedtuple +import discord.errors from botcore.site_api import ResponseCodeError from discord import Guild from discord.ext.commands import Context @@ -9,7 +10,6 @@ from more_itertools import chunked import bot from bot.log import get_logger -from bot.utils.members import get_or_fetch_member log = get_logger(__name__) @@ -157,7 +157,16 @@ class UserSyncer(Syncer): if db_user[db_field] != guild_value: updated_fields[db_field] = guild_value - if guild_user := await get_or_fetch_member(guild, db_user["id"]): + guild_user = guild.get_member(db_user["id"]) + if not guild_user and db_user["in_guild"]: + # The member was in the guild during the last sync. + # We try to fetch them to verify cache integrity. + try: + guild_user = await guild.fetch_member(db_user["id"]) + except discord.errors.NotFound: + guild_user = None + + if guild_user: seen_guild_users.add(guild_user.id) maybe_update("name", guild_user.name) diff --git a/tests/helpers.py b/tests/helpers.py index a6e4bdd66..5f3111616 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -171,7 +171,7 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): spec_set = guild_instance def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'members': []} + default_kwargs = {'id': next(self.discord_id), 'members': [], "chunked": True} super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] -- cgit v1.2.3 From e4ece6e598803b26770eb46e39024ccb34367705 Mon Sep 17 00:00:00 2001 From: wookie184 Date: Mon, 2 May 2022 16:16:05 +0100 Subject: Fix tests --- tests/bot/exts/utils/test_snekbox.py | 6 +++--- tests/bot/utils/test_services.py | 21 ++++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 3c555c051..b870a9945 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -35,15 +35,15 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): resp.json.assert_awaited_once() async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LEN.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + """Reject output longer than MAX_PASTE_LENGTH.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LENGTH + 1)) self.assertEqual(result, "too long to upload") @patch("bot.exts.utils.snekbox.send_to_paste_service") async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with("Test output.", extension="txt") + mock_paste_util.assert_called_once_with("Test output.", extension="txt", max_length=snekbox.MAX_PASTE_LENGTH) async def test_codeblock_converter(self): ctx = MockContext() diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index 3b71022db..de166c813 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -4,7 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch from aiohttp import ClientConnectorError -from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service +from bot.utils.services import ( + FAILED_REQUEST_ATTEMPTS, MAX_PASTE_LENGTH, PasteTooLongError, PasteUploadError, send_to_paste_service +) from tests.helpers import MockBot @@ -55,23 +57,28 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): for error_json in test_cases: with self.subTest(error_json=error_json): response.json = AsyncMock(return_value=error_json) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) self.bot.http_session.post.reset_mock() async def test_request_repeated_on_connection_errors(self): """Requests are repeated in the case of connection errors.""" self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) async def test_general_error_handled_and_request_repeated(self): """All `Exception`s are handled, logged and request repeated.""" self.bot.http_session.post = MagicMock(side_effect=Exception) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertLogs("bot.utils", logging.ERROR) - self.assertIsNone(result) + + async def test_raises_error_on_too_long_input(self): + contents = "a" * (MAX_PASTE_LENGTH+1) + with self.assertRaises(PasteTooLongError): + await send_to_paste_service(contents) -- cgit v1.2.3 From 890d2578d3d3d65814edfb350e981ce2fbe2957e Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sat, 21 May 2022 17:40:01 +0100 Subject: Bump malformed API response from debug to error log (#2175) --- bot/exts/backend/error_handler.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 5391a7f15..35dddd8dc 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -285,7 +285,7 @@ class ErrorHandler(Cog): ctx.bot.stats.incr("errors.api_error_404") elif e.status == 400: content = await e.response.json() - log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + log.error(f"API responded with 400 for command {ctx.command}: %r.", content) await ctx.send("According to the API, your request is malformed.") ctx.bot.stats.incr("errors.api_error_400") elif 500 <= e.status < 600: diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 193f1d822..d02bd7c34 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -477,11 +477,11 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.backend.error_handler.log") async def test_handle_api_error(self, log_mock): - """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code.""" + """Should `ctx.send` on HTTP error codes, and log at correct level.""" test_cases = ( { "error": ResponseCodeError(AsyncMock(status=400)), - "log_level": "debug" + "log_level": "error" }, { "error": ResponseCodeError(AsyncMock(status=404)), @@ -505,6 +505,8 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): self.ctx.send.assert_awaited_once() if case["log_level"] == "warning": log_mock.warning.assert_called_once() + elif case["log_level"] == "error": + log_mock.error.assert_called_once() else: log_mock.debug.assert_called_once() -- cgit v1.2.3 From 96c7deab22f5018a17b97cd68e3914c37f926be5 Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sat, 28 May 2022 16:46:42 +0100 Subject: Fix tests --- tests/bot/exts/backend/test_error_handler.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'tests') diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index d02bd7c34..0a58126e7 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -48,6 +48,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): cog = ErrorHandler(self.bot) cog.try_silence = AsyncMock() cog.try_get_tag = AsyncMock() + cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): @@ -76,6 +77,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): cog = ErrorHandler(self.bot) cog.try_silence = AsyncMock() cog.try_get_tag = AsyncMock() + cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() @@ -83,6 +85,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): cog.try_silence.assert_not_awaited() cog.try_get_tag.assert_not_awaited() + cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): -- cgit v1.2.3 From 381dfc1e7fd1849c3381971e9332f2fec20c7a7e Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sun, 29 May 2022 15:56:46 +0100 Subject: Raise ValueError if max_length greater than allowed by paste service --- bot/utils/services.py | 9 ++++++--- tests/bot/utils/test_services.py | 5 +++++ 2 files changed, 11 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/utils/services.py b/bot/utils/services.py index 3a6833e72..82c7d284c 100644 --- a/bot/utils/services.py +++ b/bot/utils/services.py @@ -25,17 +25,20 @@ async def send_to_paste_service(contents: str, *, extension: str = "", max_lengt `extension` is added to the output URL. `max_length` can be used to limit the allowed contents length to lower than the maximum allowed by the paste service. + Raises `ValueError` if `max_length` is greater than the maximum allowed by the paste service. Raises `PasteTooLongError` if contents is too long to upload, and `PasteUploadError` if uploading fails. Returns the generated URL with the extension. """ + if max_length > MAX_PASTE_LENGTH: + raise ValueError(f"`max_length` must not be greater than {MAX_PASTE_LENGTH}") + extension = extension and f".{extension}" - max_size = min(max_length, MAX_PASTE_LENGTH) contents_size = len(contents.encode()) - if contents_size > max_size: + if contents_size > max_length: log.info("Contents too large to send to paste service.") - raise PasteTooLongError(f"Contents of size {contents_size} greater than maximum size {max_size}") + raise PasteTooLongError(f"Contents of size {contents_size} greater than maximum size {max_length}") log.debug(f"Sending contents of size {contents_size} bytes to paste service.") paste_url = URLs.paste_service.format(key="documents") diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index de166c813..e6b95be8e 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -82,3 +82,8 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): contents = "a" * (MAX_PASTE_LENGTH+1) with self.assertRaises(PasteTooLongError): await send_to_paste_service(contents) + + async def test_raises_on_too_large_max_length(self): + """Ensure ValueError is raised if `max_length` passed is greater than `MAX_PASTE_LENGTH`.""" + with self.assertRaises(ValueError): + await send_to_paste_service("Hello World!", max_length=MAX_PASTE_LENGTH+1) -- cgit v1.2.3 From 3fe10aa1e07fcdd8a2efb4a00118b7ec0fcdedce Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sun, 29 May 2022 16:08:45 +0100 Subject: Make small wording and style changes --- bot/utils/services.py | 10 +++++----- tests/bot/utils/test_services.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/bot/utils/services.py b/bot/utils/services.py index 82c7d284c..a752ac0ec 100644 --- a/bot/utils/services.py +++ b/bot/utils/services.py @@ -22,13 +22,13 @@ async def send_to_paste_service(contents: str, *, extension: str = "", max_lengt """ Upload `contents` to the paste service. - `extension` is added to the output URL. `max_length` can be used to limit the allowed contents length + Add `extension` to the output URL. Use `max_length` to limit the allowed contents length to lower than the maximum allowed by the paste service. - Raises `ValueError` if `max_length` is greater than the maximum allowed by the paste service. - Raises `PasteTooLongError` if contents is too long to upload, and `PasteUploadError` if uploading fails. + Raise `ValueError` if `max_length` is greater than the maximum allowed by the paste service. + Raise `PasteTooLongError` if `contents` is too long to upload, and `PasteUploadError` if uploading fails. - Returns the generated URL with the extension. + Return the generated URL with the extension. """ if max_length > MAX_PASTE_LENGTH: raise ValueError(f"`max_length` must not be greater than {MAX_PASTE_LENGTH}") @@ -80,4 +80,4 @@ async def send_to_paste_service(contents: str, *, extension: str = "", max_lengt f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS})." ) - raise PasteUploadError("Failed to upload contents to pastebin") + raise PasteUploadError("Failed to upload contents to paste service") diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index e6b95be8e..d0e801299 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -79,11 +79,12 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): self.assertLogs("bot.utils", logging.ERROR) async def test_raises_error_on_too_long_input(self): - contents = "a" * (MAX_PASTE_LENGTH+1) + """Ensure PasteTooLongError is raised if `contents` is longer than `MAX_PASTE_LENGTH`.""" + contents = "a" * (MAX_PASTE_LENGTH + 1) with self.assertRaises(PasteTooLongError): await send_to_paste_service(contents) async def test_raises_on_too_large_max_length(self): """Ensure ValueError is raised if `max_length` passed is greater than `MAX_PASTE_LENGTH`.""" with self.assertRaises(ValueError): - await send_to_paste_service("Hello World!", max_length=MAX_PASTE_LENGTH+1) + await send_to_paste_service("Hello World!", max_length=MAX_PASTE_LENGTH + 1) -- cgit v1.2.3