diff options
author | 2022-03-19 16:34:01 +0000 | |
---|---|---|
committer | 2022-04-18 17:44:24 +0100 | |
commit | 43b6fee9eba12a6836530029a642cba6e7e505f0 (patch) | |
tree | 680e5d6c42d4b81504359722f7e9b3f8b6b82d73 | |
parent | Bump d.py and bot-core (diff) |
Use bot-core scheduling and member util functions
31 files changed, 53 insertions, 323 deletions
diff --git a/bot/__init__.py b/bot/__init__.py index 17d99105a..c652897be 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -1,11 +1,10 @@ import asyncio import os -from functools import partial, partialmethod from typing import TYPE_CHECKING -from discord.ext import commands +from botcore.utils import apply_monkey_patches -from bot import log, monkey_patches +from bot import log if TYPE_CHECKING: from bot.bot import Bot @@ -16,16 +15,6 @@ log.setup() if os.name == "nt": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -monkey_patches.patch_typing() - -# This patches any convertors that use PartialMessage, but not the PartialMessageConverter itself -# as library objects are made by this mapping. -# https://github.com/Rapptz/discord.py/blob/1a4e73d59932cdbe7bf2c281f25e32529fc7ae1f/discord/ext/commands/converter.py#L984-L1004 -commands.converter.PartialMessageConverter = monkey_patches.FixedPartialMessageConverter - -# Monkey-patch discord.py decorators to use the Command subclass which supports root aliases. -# Must be patched before any cogs are added. -commands.command = partial(commands.command, cls=monkey_patches.Command) -commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=monkey_patches.Command) +apply_monkey_patches() instance: "Bot" = None # Global Bot instance. diff --git a/bot/async_stats.py b/bot/async_stats.py index 2af832e5b..0303de7a1 100644 --- a/bot/async_stats.py +++ b/bot/async_stats.py @@ -1,10 +1,9 @@ import asyncio import socket +from botcore.utils import scheduling from statsd.client.base import StatsClientBase -from bot.utils import scheduling - class AsyncStatsClient(StatsClientBase): """An async transport method for statsd communication.""" diff --git a/bot/converters.py b/bot/converters.py index 3522a32aa..e819e4713 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -8,7 +8,7 @@ from ssl import CertificateError import dateutil.parser import discord from aiohttp import ClientConnectorError -from botcore.regex import DISCORD_INVITE +from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter from discord.utils import escape_markdown, snowflake_time diff --git a/bot/decorators.py b/bot/decorators.py index 8971898b3..466770c3a 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -5,13 +5,14 @@ import typing as t from contextlib import suppress import arrow +from botcore.utils import scheduling from discord import Member, NotFound from discord.ext import commands from discord.ext.commands import Cog, Context from bot.constants import Channels, DEBUG_MODE, RedirectOutput from bot.log import get_logger -from bot.utils import function, scheduling +from bot.utils import function from bot.utils.checks import ContextCheckFailure, in_whitelist_check from bot.utils.function import command_wraps diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py index 2d03cd580..469331ae5 100644 --- a/bot/exts/backend/logging.py +++ b/bot/exts/backend/logging.py @@ -1,10 +1,10 @@ +from botcore.utils import scheduling from discord import Embed from discord.ext.commands import Cog from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE from bot.log import get_logger -from bot.utils import scheduling log = get_logger(__name__) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index bcd845a43..d9e23b25e 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -8,6 +8,7 @@ from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set import arrow +from botcore.utils import scheduling from discord import Colour, Member, Message, NotFound, Object, TextChannel from discord.ext.commands import Cog @@ -20,7 +21,7 @@ from bot.converters import Duration from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import lock, scheduling +from bot.utils import lock from bot.utils.message_cache import MessageCache from bot.utils.messages import format_user, send_attachments diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index b9f2a0e51..32efcc307 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -9,7 +9,8 @@ import dateutil.parser import regex import tldextract from async_rediscache import RedisCache -from botcore.regex import DISCORD_INVITE +from botcore.utils import scheduling +from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord import ChannelType, Colour, Embed, Forbidden, HTTPException, Member, Message, NotFound, TextChannel from discord.ext.commands import Cog @@ -21,7 +22,6 @@ from bot.constants import Channels, Colours, Filter, Guild, Icons, URLs from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import scheduling from bot.utils.messages import format_user log = get_logger(__name__) diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filters/token_remover.py index 520283ba3..436e6dc19 100644 --- a/bot/exts/filters/token_remover.py +++ b/bot/exts/filters/token_remover.py @@ -1,5 +1,4 @@ import base64 -import binascii import re import typing as t @@ -182,7 +181,7 @@ class TokenRemover(Cog): # that means it's not a valid user id. return None return int(string) - except (binascii.Error, ValueError): + except ValueError: return None @staticmethod @@ -198,7 +197,7 @@ class TokenRemover(Cog): try: decoded_bytes = base64.urlsafe_b64decode(b64_content) timestamp = int.from_bytes(decoded_bytes, byteorder="big") - except (binascii.Error, ValueError) as e: + except ValueError as e: log.debug(f"Failed to decode token timestamp '{b64_content}': {e}") return False diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index 7df1d172d..33f43f2a8 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -2,6 +2,7 @@ import difflib from datetime import timedelta import arrow +from botcore.utils import scheduling from discord import Colour, Embed from discord.ext.commands import Cog, Context, group, has_any_role from discord.utils import sleep_until @@ -12,7 +13,6 @@ from bot.constants import Channels, MODERATION_ROLES from bot.converters import OffTopicName from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) log = get_logger(__name__) diff --git a/bot/exts/help_channels/_cog.py b/bot/exts/help_channels/_cog.py index a93acffb6..fc80c968c 100644 --- a/bot/exts/help_channels/_cog.py +++ b/bot/exts/help_channels/_cog.py @@ -7,6 +7,7 @@ from operator import attrgetter import arrow import discord import discord.abc +from botcore.utils import members, scheduling from discord.ext import commands from bot import constants @@ -14,7 +15,7 @@ from bot.bot import Bot from bot.constants import Channels, RedirectOutput from bot.exts.help_channels import _caches, _channel, _message, _name, _stats from bot.log import get_logger -from bot.utils import channel as channel_utils, lock, members, scheduling +from bot.utils import channel as channel_utils, lock log = get_logger(__name__) diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index a859d8cef..9027105d9 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -2,6 +2,7 @@ import time from typing import Optional import discord +from botcore.utils import scheduling from discord import Message, RawMessageUpdateEvent from discord.ext.commands import Cog @@ -11,7 +12,7 @@ from bot.exts.filters.token_remover import TokenRemover from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE from bot.exts.info.codeblock._instructions import get_instructions from bot.log import get_logger -from bot.utils import has_lines, scheduling +from bot.utils import has_lines from bot.utils.channel import is_help_channel from bot.utils.messages import wait_for_deletion diff --git a/bot/exts/info/doc/_batch_parser.py b/bot/exts/info/doc/_batch_parser.py index c27f28eac..41a15fb6e 100644 --- a/bot/exts/info/doc/_batch_parser.py +++ b/bot/exts/info/doc/_batch_parser.py @@ -8,12 +8,12 @@ from operator import attrgetter from typing import Deque, Dict, List, NamedTuple, Optional, Union import discord +from botcore.utils import scheduling from bs4 import BeautifulSoup import bot from bot.constants import Channels from bot.log import get_logger -from bot.utils import scheduling from . import _cog, doc_cache from ._parsing import get_symbol_markdown diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 4dc5276d9..3789fdbe3 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -10,6 +10,8 @@ from typing import Dict, NamedTuple, Optional, Tuple, Union import aiohttp import discord +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord.ext import commands from bot.api import ResponseCodeError @@ -18,10 +20,8 @@ from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import Inventory, PackageName, ValidURL, allowed_strings from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling from bot.utils.lock import SharedEvent, lock from bot.utils.messages import send_denial, wait_for_deletion -from bot.utils.scheduling import Scheduler from . import NAMESPACE, PRIORITY_PACKAGES, _batch_parser, doc_cache from ._inventory_parser import InvalidHeaderError, InventoryDict, fetch_inventory diff --git a/bot/exts/info/subscribe.py b/bot/exts/info/subscribe.py index eff0c13b8..ed134ff78 100644 --- a/bot/exts/info/subscribe.py +++ b/bot/exts/info/subscribe.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import arrow import discord +from botcore.utils import members, scheduling from discord.ext import commands from discord.interactions import Interaction @@ -12,7 +13,6 @@ from bot import constants from bot.bot import Bot from bot.decorators import redirect_output from bot.log import get_logger -from bot.utils import members, scheduling @dataclass(frozen=True) diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py index 178be734d..a8640cb1b 100644 --- a/bot/exts/moderation/defcon.py +++ b/bot/exts/moderation/defcon.py @@ -7,6 +7,8 @@ from typing import Optional, Union import arrow from aioredis import RedisError from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.relativedelta import relativedelta from discord import Colour, Embed, Forbidden, Member, TextChannel, User from discord.ext import tasks @@ -17,9 +19,8 @@ from bot.constants import Channels, Colours, Emojis, Event, Icons, MODERATION_RO from bot.converters import DurationDelta, Expiry from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.messages import format_user -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index b579416a6..d34c1c7fa 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -6,12 +6,12 @@ from typing import Optional import discord from async_rediscache import RedisCache +from botcore.utils import scheduling from discord.ext.commands import Cog, Context, MessageConverter, MessageNotFound from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Guild, Roles, Webhooks from bot.log import get_logger -from bot.utils import scheduling from bot.utils.messages import format_user, sub_clyde log = get_logger(__name__) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 2fc54856f..9f5800e2a 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -6,6 +6,7 @@ from gettext import ngettext import arrow import dateutil.parser import discord +from botcore.utils import scheduling from discord.ext.commands import Context from bot import constants @@ -16,7 +17,7 @@ from bot.converters import MemberOrUser from bot.exts.moderation.infraction import _utils from bot.exts.moderation.modlog import ModLog from bot.log import get_logger -from bot.utils import messages, scheduling, time +from bot.utils import messages, time from bot.utils.channel import is_mod_channel log = get_logger(__name__) diff --git a/bot/exts/moderation/metabase.py b/bot/exts/moderation/metabase.py index ce9c220b3..d68726faf 100644 --- a/bot/exts/moderation/metabase.py +++ b/bot/exts/moderation/metabase.py @@ -8,15 +8,16 @@ import arrow from aiohttp.client_exceptions import ClientResponseError from arrow import Arrow from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot from bot.constants import Metabase as MetabaseConfig, Roles from bot.converters import allowed_strings from bot.log import get_logger -from bot.utils import scheduling, send_to_paste_service +from bot.utils import send_to_paste_service from bot.utils.channel import is_mod_channel -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index b5cd29b12..cb1e4fd05 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -3,6 +3,8 @@ import datetime import arrow from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse, parse as dateutil_parse from discord import Embed, Member from discord.ext.commands import Cog, Context, group, has_any_role @@ -11,8 +13,7 @@ from bot.bot import Bot from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles from bot.converters import Expiry from bot.log import get_logger -from bot.utils import scheduling, time -from bot.utils.scheduling import Scheduler +from bot.utils import time log = get_logger(__name__) diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 511520252..307729181 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -5,6 +5,8 @@ from datetime import datetime, timedelta, timezone from typing import Optional, OrderedDict, Union from async_rediscache import RedisCache +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from discord import Guild, PermissionOverwrite, TextChannel, Thread, VoiceChannel from discord.ext import commands, tasks from discord.ext.commands import Context @@ -14,9 +16,7 @@ from bot import constants from bot.bot import Bot from bot.converters import HushDurationConverter from bot.log import get_logger -from bot.utils import scheduling from bot.utils.lock import LockedResourceError, lock, lock_arg -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 985cc6eb1..17d24eb89 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -5,6 +5,7 @@ import arrow import discord from arrow import Arrow from async_rediscache import RedisCache +from botcore.utils import scheduling from discord.ext import commands from bot.bot import Bot @@ -14,7 +15,7 @@ from bot.constants import ( from bot.converters import Expiry from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.members import get_or_fetch_member log = get_logger(__name__) diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index ee9b6ba45..bae7ecd02 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional import discord +from botcore.utils import scheduling from discord import Color, DMChannel, Embed, HTTPException, Message, errors from discord.ext.commands import Cog, Context @@ -18,7 +19,7 @@ from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE from bot.exts.moderation.modlog import ModLog from bot.log import CustomLogger, get_logger from bot.pagination import LinePaginator -from bot.utils import CogABCMeta, messages, scheduling, time +from bot.utils import CogABCMeta, messages, time from bot.utils.members import get_or_fetch_member log = get_logger(__name__) diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 0554bf37a..0d51af2ca 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -5,6 +5,7 @@ from typing import Optional, Union import discord from async_rediscache import RedisCache +from botcore.utils import scheduling from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User from discord.ext.commands import BadArgument, Cog, Context, group, has_any_role @@ -15,7 +16,7 @@ from bot.converters import MemberOrUser, UnambiguousMemberOrUser from bot.exts.recruitment.talentpool._review import Reviewer from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.members import get_or_fetch_member AUTOREVIEW_ENABLED_KEY = "autoreview_enabled" diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index b4d177622..214d85851 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta from typing import List, Optional, Union import arrow +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse from discord import Embed, Emoji, Member, Message, NoMoreItems, NotFound, PartialMessage, TextChannel from discord.ext.commands import Context @@ -20,7 +21,6 @@ from bot.log import get_logger from bot.utils import time from bot.utils.members import get_or_fetch_member from bot.utils.messages import count_unique_users_reaction, pin_no_system_message -from bot.utils.scheduling import Scheduler if typing.TYPE_CHECKING: from bot.exts.recruitment.talentpool._cog import TalentPool diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index ad82d49c9..62603697c 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -5,6 +5,8 @@ from datetime import datetime, timezone from operator import itemgetter import discord +from botcore.utils import scheduling +from botcore.utils.scheduling import Scheduler from dateutil.parser import isoparse from discord.ext.commands import Cog, Context, Greedy, group @@ -13,12 +15,11 @@ from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Role from bot.converters import Duration, UnambiguousUser from bot.log import get_logger from bot.pagination import LinePaginator -from bot.utils import scheduling, time +from bot.utils import time from bot.utils.checks import has_any_role_check, has_no_roles_check from bot.utils.lock import lock_arg from bot.utils.members import get_or_fetch_member from bot.utils.messages import send_denial -from bot.utils.scheduling import Scheduler log = get_logger(__name__) diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 3c1009d2a..2b073ed72 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -7,7 +7,8 @@ from signal import Signals from textwrap import dedent from typing import Optional, Tuple -from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX +from botcore.utils import scheduling +from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only @@ -15,7 +16,7 @@ from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs from bot.decorators import redirect_output from bot.log import get_logger -from bot.utils import scheduling, send_to_paste_service +from bot.utils import send_to_paste_service from bot.utils.messages import wait_for_deletion log = get_logger(__name__) diff --git a/bot/monkey_patches.py b/bot/monkey_patches.py deleted file mode 100644 index 4840fa454..000000000 --- a/bot/monkey_patches.py +++ /dev/null @@ -1,76 +0,0 @@ -import re -from datetime import timedelta - -import arrow -from discord import Forbidden, http -from discord.ext import commands - -from bot.log import get_logger - -log = get_logger(__name__) -MESSAGE_ID_RE = re.compile(r'(?P<message_id>[0-9]{15,20})$') - - -class Command(commands.Command): - """ - A `discord.ext.commands.Command` subclass which supports root aliases. - - A `root_aliases` keyword argument is added, which is a sequence of alias names that will act as - top-level commands rather than being aliases of the command's group. It's stored as an attribute - also named `root_aliases`. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.root_aliases = kwargs.get("root_aliases", []) - - if not isinstance(self.root_aliases, (list, tuple)): - raise TypeError("Root aliases of a command must be a list or a tuple of strings.") - - -def patch_typing() -> None: - """ - Sometimes discord turns off typing events by throwing 403's. - - Handle those issues by patching the trigger_typing method so it ignores 403's in general. - """ - log.debug("Patching send_typing, which should fix things breaking when discord disables typing events. Stay safe!") - - original = http.HTTPClient.send_typing - last_403 = None - - async def honeybadger_type(self, channel_id: int) -> None: # noqa: ANN001 - nonlocal last_403 - if last_403 and (arrow.utcnow() - last_403) < timedelta(minutes=5): - log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.") - return - try: - await original(self, channel_id) - except Forbidden: - last_403 = arrow.utcnow() - log.warning("Got a 403 from typing event!") - pass - - http.HTTPClient.send_typing = honeybadger_type - - -class FixedPartialMessageConverter(commands.PartialMessageConverter): - """ - Make the Message converter infer channelID from the given context if only a messageID is given. - - Discord.py's Message converter is supposed to infer channelID based - on ctx.channel if only a messageID is given. A refactor commit, linked below, - a few weeks before d.py's archival broke this defined behaviour of the converter. - Currently, if only a messageID is given to the converter, it will only find that message - if it's in the bot's cache. - - https://github.com/Rapptz/discord.py/commit/1a4e73d59932cdbe7bf2c281f25e32529fc7ae1f - """ - - @staticmethod - def _get_id_matches(ctx: commands.Context, argument: str) -> tuple[int, int, int]: - """Inserts ctx.channel.id before calling super method if argument is just a messageID.""" - match = MESSAGE_ID_RE.match(argument) - if match: - argument = f"{ctx.channel.id}-{match.group('message_id')}" - return commands.PartialMessageConverter._get_id_matches(ctx, argument) diff --git a/bot/utils/messages.py b/bot/utils/messages.py index e55c07062..a5ed84351 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -6,12 +6,12 @@ from io import BytesIO from typing import Callable, List, Optional, Sequence, Union import discord +from botcore.utils import scheduling from discord.ext.commands import Context import bot from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES from bot.log import get_logger -from bot.utils import scheduling log = get_logger(__name__) diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py deleted file mode 100644 index 23acacf74..000000000 --- a/bot/utils/scheduling.py +++ /dev/null @@ -1,194 +0,0 @@ -import asyncio -import contextlib -import inspect -import typing as t -from datetime import datetime -from functools import partial - -from arrow import Arrow - -from bot.log import get_logger - - -class Scheduler: - """ - Schedule the execution of coroutines and keep track of them. - - When instantiating a Scheduler, a name must be provided. This name is used to distinguish the - instance's log messages from other instances. Using the name of the class or module containing - the instance is suggested. - - Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at` - or `schedule_later`. A unique ID is required to be given in order to keep track of the - resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing - the same ID used to schedule it. The `in` operator is supported for checking if a task with a - given ID is currently scheduled. - - Any exception raised in a scheduled task is logged when the task is done. - """ - - def __init__(self, name: str): - self.name = name - - self._log = get_logger(f"{__name__}.{name}") - self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} - - def __contains__(self, task_id: t.Hashable) -> bool: - """Return True if a task with the given `task_id` is currently scheduled.""" - return task_id in self._scheduled_tasks - - def schedule(self, task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule the execution of a `coroutine`. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - self._log.trace(f"Scheduling task #{task_id}...") - - msg = f"Cannot schedule an already started coroutine for #{task_id}" - assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg - - if task_id in self._scheduled_tasks: - self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") - coroutine.close() - return - - task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") - task.add_done_callback(partial(self._task_done_callback, task_id)) - - self._scheduled_tasks[task_id] = task - self._log.debug(f"Scheduled task #{task_id} {id(task)}.") - - def schedule_at(self, time: t.Union[datetime, Arrow], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule `coroutine` to be executed at the given `time`. - - If `time` is timezone aware, then use that timezone to calculate now() when subtracting. - If `time` is naïve, then use UTC. - - If `time` is in the past, schedule `coroutine` immediately. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() - delay = (time - now_datetime).total_seconds() - if delay > 0: - coroutine = self._await_later(delay, task_id, coroutine) - - self.schedule(task_id, coroutine) - - def schedule_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """ - Schedule `coroutine` to be executed after the given `delay` number of seconds. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - """ - self.schedule(task_id, self._await_later(delay, task_id, coroutine)) - - def cancel(self, task_id: t.Hashable) -> None: - """Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist.""" - self._log.trace(f"Cancelling task #{task_id}...") - - try: - task = self._scheduled_tasks.pop(task_id) - except KeyError: - self._log.warning(f"Failed to unschedule {task_id} (no task found).") - else: - task.cancel() - - self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") - - def cancel_all(self) -> None: - """Unschedule all known tasks.""" - self._log.debug("Unscheduling all tasks") - - for task_id in self._scheduled_tasks.copy(): - self.cancel(task_id) - - async def _await_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: - """Await `coroutine` after the given `delay` number of seconds.""" - try: - self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") - await asyncio.sleep(delay) - - # Use asyncio.shield to prevent the coroutine from cancelling itself. - self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") - await asyncio.shield(coroutine) - finally: - # Close it to prevent unawaited coroutine warnings, - # which would happen if the task was cancelled during the sleep. - # Only close it if it's not been awaited yet. This check is important because the - # coroutine may cancel this task, which would also trigger the finally block. - state = inspect.getcoroutinestate(coroutine) - if state == "CORO_CREATED": - self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") - coroutine.close() - else: - self._log.debug(f"Finally block reached for #{task_id}; {state=}") - - def _task_done_callback(self, task_id: t.Hashable, done_task: asyncio.Task) -> None: - """ - Delete the task and raise its exception if one exists. - - If `done_task` and the task associated with `task_id` are different, then the latter - will not be deleted. In this case, a new task was likely rescheduled with the same ID. - """ - self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") - - scheduled_task = self._scheduled_tasks.get(task_id) - - if scheduled_task and done_task is scheduled_task: - # A task for the ID exists and is the same as the done task. - # Since this is the done callback, the task is already done so no need to cancel it. - self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") - del self._scheduled_tasks[task_id] - elif scheduled_task: - # A new task was likely rescheduled with the same ID. - self._log.debug( - f"The scheduled task #{task_id} {id(scheduled_task)} " - f"and the done task {id(done_task)} differ." - ) - elif not done_task.cancelled(): - self._log.warning( - f"Task #{task_id} not found while handling task {id(done_task)}! " - f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." - ) - - with contextlib.suppress(asyncio.CancelledError): - exception = done_task.exception() - # Log the exception if one exists. - if exception: - self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) - - -def create_task( - coro: t.Awaitable, - *, - suppressed_exceptions: tuple[t.Type[Exception]] = (), - event_loop: t.Optional[asyncio.AbstractEventLoop] = None, - **kwargs, -) -> asyncio.Task: - """ - Wrapper for creating asyncio `Task`s which logs exceptions raised in the task. - - If the loop kwarg is provided, the task is created from that event loop, otherwise the running loop is used. - """ - if event_loop is not None: - task = event_loop.create_task(coro, **kwargs) - else: - task = asyncio.create_task(coro, **kwargs) - task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) - return task - - -def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: t.Tuple[t.Type[Exception]]) -> None: - """Retrieve and log the exception raised in `task` if one exists.""" - with contextlib.suppress(asyncio.CancelledError): - exception = task.exception() - # Log the exception if one exists. - if exception and not isinstance(exception, suppressed_exceptions): - log = get_logger(__name__) - log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index fdd0ab74a..7dff38f96 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -60,7 +60,7 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch("bot.utils.scheduling.create_task") + @mock.patch("botcore.utils.scheduling.create_task") @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) def test_sync_cog_init(self, sync_guild, create_task): """Should instantiate syncers and run a sync for the guild.""" diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py index 8ae59c1f1..bd26532f1 100644 --- a/tests/bot/exts/filters/test_filtering.py +++ b/tests/bot/exts/filters/test_filtering.py @@ -11,7 +11,7 @@ class FilteringCogTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Instantiate the bot and cog.""" self.bot = MockBot() - with patch("bot.utils.scheduling.create_task", new=lambda task, **_: task.close()): + with patch("botcore.utils.scheduling.create_task", new=lambda task, **_: task.close()): self.cog = filtering.Filtering(self.bot) @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) |