diff options
| author | 2020-07-13 11:50:19 -0400 | |
|---|---|---|
| committer | 2020-07-13 11:50:19 -0400 | |
| commit | c8a4e912f3a2863fb1c255642deeb8a07a9a2474 (patch) | |
| tree | 1734cb255a1441ca9b7547b142311fe195c1ae4a | |
| parent | Revert "Ping @Moderators in ModLog" (diff) | |
| parent | Fix rescheduling of edited infractions (diff) | |
Merge branch 'master' into 1038_allow_role_mentions_in_specific_areas
| -rw-r--r-- | bot/cogs/filtering.py | 16 | ||||
| -rw-r--r-- | bot/cogs/help_channels.py | 70 | ||||
| -rw-r--r-- | bot/cogs/moderation/management.py | 4 | ||||
| -rw-r--r-- | bot/cogs/moderation/scheduler.py | 23 | ||||
| -rw-r--r-- | bot/cogs/moderation/silence.py | 32 | ||||
| -rw-r--r-- | bot/cogs/moderation/superstarify.py | 2 | ||||
| -rw-r--r-- | bot/cogs/reminders.py | 29 | ||||
| -rw-r--r-- | bot/utils/scheduling.py | 146 | 
8 files changed, 153 insertions, 169 deletions
diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 76ea68660..099606b82 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -19,7 +19,6 @@ from bot.constants import (  )  from bot.utils.redis_cache import RedisCache  from bot.utils.scheduling import Scheduler -from bot.utils.time import wait_until  log = logging.getLogger(__name__) @@ -60,7 +59,7 @@ def expand_spoilers(text: str) -> str:  OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) -class Filtering(Cog, Scheduler): +class Filtering(Cog):      """Filtering out invites, blacklisting domains, and warning us of certain regular expressions."""      # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent @@ -68,8 +67,7 @@ class Filtering(Cog, Scheduler):      def __init__(self, bot: Bot):          self.bot = bot -        super().__init__() - +        self.scheduler = Scheduler(self.__class__.__name__)          self.name_lock = asyncio.Lock()          staff_mistake_str = "If you believe this was a mistake, please let staff know!" @@ -268,7 +266,7 @@ class Filtering(Cog, Scheduler):                              }                              await self.bot.api_client.post('bot/offensive-messages', json=data) -                            self.schedule_task(msg.id, data) +                            self.schedule_msg_delete(data)                              log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}")                          if is_private: @@ -457,12 +455,10 @@ class Filtering(Cog, Scheduler):          except discord.errors.Forbidden:              await channel.send(f"{filtered_member.mention} {reason}") -    async def _scheduled_task(self, msg: dict) -> None: +    def schedule_msg_delete(self, msg: dict) -> None:          """Delete an offensive message once its deletion date is reached."""          delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - -        await wait_until(delete_at) -        await self.delete_offensive_msg(msg) +        self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg))      async def reschedule_offensive_msg_deletion(self) -> None:          """Get all the pending message deletion from the API and reschedule them.""" @@ -477,7 +473,7 @@ class Filtering(Cog, Scheduler):              if delete_at < now:                  await self.delete_offensive_msg(msg)              else: -                self.schedule_task(msg['id'], msg) +                self.schedule_msg_delete(msg)      async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None:          """Delete an offensive message, and then delete it from the db.""" diff --git a/bot/cogs/help_channels.py b/bot/cogs/help_channels.py index fd1a449c1..4d0c534b0 100644 --- a/bot/cogs/help_channels.py +++ b/bot/cogs/help_channels.py @@ -1,5 +1,4 @@  import asyncio -import inspect  import json  import logging  import random @@ -57,14 +56,7 @@ through our guide for [asking a good question]({ASKING_GUIDE_URL}).  CoroutineFunc = t.Callable[..., t.Coroutine] -class TaskData(t.NamedTuple): -    """Data for a scheduled task.""" - -    wait_time: int -    callback: t.Awaitable - - -class HelpChannels(Scheduler, commands.Cog): +class HelpChannels(commands.Cog):      """      Manage the help channel system of the guild. @@ -114,9 +106,8 @@ class HelpChannels(Scheduler, commands.Cog):      claim_times = RedisCache()      def __init__(self, bot: Bot): -        super().__init__() -          self.bot = bot +        self.scheduler = Scheduler(self.__class__.__name__)          # Categories          self.available_category: discord.CategoryChannel = None @@ -145,7 +136,7 @@ class HelpChannels(Scheduler, commands.Cog):          for task in self.queue_tasks:              task.cancel() -        self.cancel_all() +        self.scheduler.cancel_all()      def create_channel_queue(self) -> asyncio.Queue:          """ @@ -229,10 +220,11 @@ class HelpChannels(Scheduler, commands.Cog):                  await self.remove_cooldown_role(ctx.author)                  # Ignore missing task when cooldown has passed but the channel still isn't dormant. -                self.cancel_task(ctx.author.id, ignore_missing=True) +                if ctx.author.id in self.scheduler: +                    self.scheduler.cancel(ctx.author.id)                  await self.move_to_dormant(ctx.channel, "command") -                self.cancel_task(ctx.channel.id) +                self.scheduler.cancel(ctx.channel.id)          else:              log.debug(f"{ctx.author} invoked command 'dormant' outside an in-use help channel") @@ -474,16 +466,15 @@ class HelpChannels(Scheduler, commands.Cog):          else:              # Cancel the existing task, if any.              if has_task: -                self.cancel_task(channel.id) - -            data = TaskData(idle_seconds - time_elapsed, self.move_idle_channel(channel)) +                self.scheduler.cancel(channel.id) +            delay = idle_seconds - time_elapsed              log.info(                  f"#{channel} ({channel.id}) is still active; " -                f"scheduling it to be moved after {data.wait_time} seconds." +                f"scheduling it to be moved after {delay} seconds."              ) -            self.schedule_task(channel.id, data) +            self.scheduler.schedule_later(delay, channel.id, self.move_idle_channel(channel))      async def move_to_bottom_position(self, channel: discord.TextChannel, category_id: int, **options) -> None:          """ @@ -588,8 +579,7 @@ class HelpChannels(Scheduler, commands.Cog):          timeout = constants.HelpChannels.idle_minutes * 60          log.trace(f"Scheduling #{channel} ({channel.id}) to become dormant in {timeout} sec.") -        data = TaskData(timeout, self.move_idle_channel(channel)) -        self.schedule_task(channel.id, data) +        self.scheduler.schedule_later(timeout, channel.id, self.move_idle_channel(channel))          self.report_stats()      async def notify(self) -> None: @@ -724,10 +714,10 @@ class HelpChannels(Scheduler, commands.Cog):          log.info(f"Claimant of #{msg.channel} ({msg.author}) deleted message, channel is empty now. Rescheduling task.")          # Cancel existing dormant task before scheduling new. -        self.cancel_task(msg.channel.id) +        self.scheduler.cancel(msg.channel.id) -        task = TaskData(constants.HelpChannels.deleted_idle_minutes * 60, self.move_idle_channel(msg.channel)) -        self.schedule_task(msg.channel.id, task) +        delay = constants.HelpChannels.deleted_idle_minutes * 60 +        self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel))      async def is_empty(self, channel: discord.TextChannel) -> bool:          """Return True if the most recent message in `channel` is the bot's `AVAILABLE_MSG`.""" @@ -754,8 +744,8 @@ class HelpChannels(Scheduler, commands.Cog):                  await self.remove_cooldown_role(member)              else:                  # The member is still on a cooldown; re-schedule it for the remaining time. -                remaining = cooldown - in_use_time.seconds -                await self.schedule_cooldown_expiration(member, remaining) +                delay = cooldown - in_use_time.seconds +                self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member))      async def add_cooldown_role(self, member: discord.Member) -> None:          """Add the help cooldown role to `member`.""" @@ -806,16 +796,11 @@ class HelpChannels(Scheduler, commands.Cog):          # Cancel the existing task, if any.          # Would mean the user somehow bypassed the lack of permissions (e.g. user is guild owner). -        self.cancel_task(member.id, ignore_missing=True) +        if member.id in self.scheduler: +            self.scheduler.cancel(member.id) -        await self.schedule_cooldown_expiration(member, constants.HelpChannels.claim_minutes * 60) - -    async def schedule_cooldown_expiration(self, member: discord.Member, seconds: int) -> None: -        """Schedule the cooldown role for `member` to be removed after a duration of `seconds`.""" -        log.trace(f"Scheduling removal of {member}'s ({member.id}) cooldown.") - -        callback = self.remove_cooldown_role(member) -        self.schedule_task(member.id, TaskData(seconds, callback)) +        delay = constants.HelpChannels.claim_minutes * 60 +        self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member))      async def send_available_message(self, channel: discord.TextChannel) -> None:          """Send the available message by editing a dormant message or sending a new message.""" @@ -857,21 +842,6 @@ class HelpChannels(Scheduler, commands.Cog):          return channel -    async def _scheduled_task(self, data: TaskData) -> None: -        """Await the `data.callback` coroutine after waiting for `data.wait_time` seconds.""" -        try: -            log.trace(f"Waiting {data.wait_time} seconds before awaiting callback.") -            await asyncio.sleep(data.wait_time) - -            # Use asyncio.shield to prevent callback from cancelling itself. -            # The parent task (_scheduled_task) will still get cancelled. -            log.trace("Done waiting; now awaiting the callback.") -            await asyncio.shield(data.callback) -        finally: -            if inspect.iscoroutine(data.callback): -                log.trace("Explicitly closing coroutine.") -                data.callback.close() -  def validate_config() -> None:      """Raise a ValueError if the cog's config is invalid.""" diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 617d957ed..672bb0e9c 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -135,11 +135,11 @@ class ModManagement(commands.Cog):          if 'expires_at' in request_data:              # A scheduled task should only exist if the old infraction wasn't permanent              if old_infraction['expires_at']: -                self.infractions_cog.cancel_task(new_infraction['id']) +                self.infractions_cog.scheduler.cancel(new_infraction['id'])              # If the infraction was not marked as permanent, schedule a new expiration task              if request_data['expires_at']: -                self.infractions_cog.schedule_task(new_infraction['id'], new_infraction) +                self.infractions_cog.schedule_expiration(new_infraction)              log_text += f"""                  Previous expiry: {old_infraction['expires_at'] or "Permanent"} diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index d75a72ddb..601e238c9 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -1,4 +1,3 @@ -import asyncio  import logging  import textwrap  import typing as t @@ -23,13 +22,13 @@ from .utils import UserSnowflake  log = logging.getLogger(__name__) -class InfractionScheduler(Scheduler): +class InfractionScheduler:      """Handles the application, pardoning, and expiration of infractions."""      def __init__(self, bot: Bot, supported_infractions: t.Container[str]): -        super().__init__() -          self.bot = bot +        self.scheduler = Scheduler(self.__class__.__name__) +          self.bot.loop.create_task(self.reschedule_infractions(supported_infractions))      @property @@ -49,7 +48,7 @@ class InfractionScheduler(Scheduler):          )          for infraction in infractions:              if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: -                self.schedule_task(infraction["id"], infraction) +                self.schedule_expiration(infraction)      async def reapply_infraction(          self, @@ -155,7 +154,7 @@ class InfractionScheduler(Scheduler):                  await action_coro                  if expiry:                      # Schedule the expiration of the infraction. -                    self.schedule_task(infraction["id"], infraction) +                    self.schedule_expiration(infraction)              except discord.HTTPException as e:                  # Accordingly display that applying the infraction failed.                  confirm_msg = ":x: failed to apply" @@ -278,7 +277,7 @@ class InfractionScheduler(Scheduler):                  # Cancel pending expiration task.                  if infraction["expires_at"] is not None: -                    self.cancel_task(infraction["id"]) +                    self.scheduler.cancel(infraction["id"])          # Accordingly display whether the user was successfully notified via DM.          dm_emoji = "" @@ -415,7 +414,7 @@ class InfractionScheduler(Scheduler):          # Cancel the expiration task.          if infraction["expires_at"] is not None: -            self.cancel_task(infraction["id"]) +            self.scheduler.cancel(infraction["id"])          # Send a log message to the mod log.          if send_log: @@ -449,7 +448,7 @@ class InfractionScheduler(Scheduler):          """          raise NotImplementedError -    async def _scheduled_task(self, infraction: utils.Infraction) -> None: +    def schedule_expiration(self, infraction: utils.Infraction) -> None:          """          Marks an infraction expired after the delay from time of scheduling to time of expiration. @@ -457,8 +456,4 @@ class InfractionScheduler(Scheduler):          expiration task is cancelled.          """          expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) -        await time.wait_until(expiry) - -        # Because deactivate_infraction() explicitly cancels this scheduled task, it is shielded -        # to avoid prematurely cancelling itself. -        await asyncio.shield(self.deactivate_infraction(infraction)) +        self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index c8ab6443b..ae4fb7b64 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -1,7 +1,7 @@  import asyncio  import logging  from contextlib import suppress -from typing import NamedTuple, Optional +from typing import Optional  from discord import TextChannel  from discord.ext import commands, tasks @@ -16,13 +16,6 @@ from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__) -class TaskData(NamedTuple): -    """Data for a scheduled task.""" - -    delay: int -    ctx: Context - -  class SilenceNotifier(tasks.Loop):      """Loop notifier for posting notices to `alert_channel` containing added channels.""" @@ -61,25 +54,17 @@ class SilenceNotifier(tasks.Loop):              await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") -class Silence(Scheduler, commands.Cog): +class Silence(commands.Cog):      """Commands for stopping channel messages for `verified` role in a channel."""      def __init__(self, bot: Bot): -        super().__init__()          self.bot = bot +        self.scheduler = Scheduler(self.__class__.__name__)          self.muted_channels = set() +          self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars())          self._get_instance_vars_event = asyncio.Event() -    async def _scheduled_task(self, task: TaskData) -> None: -        """Calls `self.unsilence` on expired silenced channel to unsilence it.""" -        await asyncio.sleep(task.delay) -        log.info("Unsilencing channel after set delay.") - -        # Because `self.unsilence` explicitly cancels this scheduled task, it is shielded -        # to avoid prematurely cancelling itself -        await asyncio.shield(task.ctx.invoke(self.unsilence)) -      async def _get_instance_vars(self) -> None:          """Get instance variables after they're available to get from the guild."""          await self.bot.wait_until_guild_available() @@ -109,12 +94,7 @@ class Silence(Scheduler, commands.Cog):          await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") -        task_data = TaskData( -            delay=duration*60, -            ctx=ctx -        ) - -        self.schedule_task(ctx.channel.id, task_data) +        self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence))      @commands.command(aliases=("unhush",))      async def unsilence(self, ctx: Context) -> None: @@ -164,7 +144,7 @@ class Silence(Scheduler, commands.Cog):          if current_overwrite.send_messages is False:              await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None))              log.info(f"Unsilenced channel #{channel} ({channel.id}).") -            self.cancel_task(channel.id) +            self.scheduler.cancel(channel.id)              self.notifier.remove_channel(channel)              self.muted_channels.discard(channel)              return True diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 45a010f00..867de815a 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -146,7 +146,7 @@ class Superstarify(InfractionScheduler, Cog):          log.debug(f"Changing nickname of {member} to {forced_nick}.")          self.mod_log.ignore(constants.Event.member_update, member.id)          await member.edit(nick=forced_nick, reason=reason) -        self.schedule_task(id_, infraction) +        self.schedule_expiration(infraction)          # Send a DM to the user to notify them of their new infraction.          await utils.notify_infraction( diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index c242d2920..0d20bdb2b 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -17,7 +17,7 @@ from bot.converters import Duration  from bot.pagination import LinePaginator  from bot.utils.checks import without_role_check  from bot.utils.scheduling import Scheduler -from bot.utils.time import humanize_delta, wait_until +from bot.utils.time import humanize_delta  log = logging.getLogger(__name__) @@ -25,12 +25,12 @@ WHITELISTED_CHANNELS = Guild.reminder_whitelist  MAXIMUM_REMINDERS = 5 -class Reminders(Scheduler, Cog): +class Reminders(Cog):      """Provide in-channel reminder functionality."""      def __init__(self, bot: Bot):          self.bot = bot -        super().__init__() +        self.scheduler = Scheduler(self.__class__.__name__)          self.bot.loop.create_task(self.reschedule_reminders()) @@ -56,7 +56,7 @@ class Reminders(Scheduler, Cog):                  late = relativedelta(now, remind_at)                  await self.send_reminder(reminder, late)              else: -                self.schedule_task(reminder["id"], reminder) +                self.schedule_reminder(reminder)      def ensure_valid_reminder(          self, @@ -99,17 +99,18 @@ class Reminders(Scheduler, Cog):          await ctx.send(embed=embed) -    async def _scheduled_task(self, reminder: dict) -> None: +    def schedule_reminder(self, reminder: dict) -> None:          """A coroutine which sends the reminder once the time is reached, and cancels the running task."""          reminder_id = reminder["id"]          reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) -        # Send the reminder message once the desired duration has passed -        await wait_until(reminder_datetime) -        await self.send_reminder(reminder) +        async def _remind() -> None: +            await self.send_reminder(reminder) -        log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") -        await self._delete_reminder(reminder_id) +            log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") +            await self._delete_reminder(reminder_id) + +        self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind())      async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None:          """Delete a reminder from the database, given its ID, and cancel the running task.""" @@ -117,15 +118,15 @@ class Reminders(Scheduler, Cog):          if cancel_task:              # Now we can remove it from the schedule list -            self.cancel_task(reminder_id) +            self.scheduler.cancel(reminder_id)      async def _reschedule_reminder(self, reminder: dict) -> None:          """Reschedule a reminder object."""          log.trace(f"Cancelling old task #{reminder['id']}") -        self.cancel_task(reminder["id"]) +        self.scheduler.cancel(reminder["id"])          log.trace(f"Scheduling new task #{reminder['id']}") -        self.schedule_task(reminder["id"], reminder) +        self.schedule_reminder(reminder)      async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:          """Send the reminder.""" @@ -223,7 +224,7 @@ class Reminders(Scheduler, Cog):              delivery_dt=expiration,          ) -        self.schedule_task(reminder["id"], reminder) +        self.schedule_reminder(reminder)      @remind_group.command(name="list")      async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]: diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index 8b778a093..03f31d78f 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -1,81 +1,126 @@  import asyncio  import contextlib +import inspect  import logging  import typing as t -from abc import abstractmethod +from datetime import datetime  from functools import partial -from bot.utils import CogABCMeta -log = logging.getLogger(__name__) +class Scheduler: +    """ +    Schedule the execution of coroutines and keep track of them. +    When instantiating a Scheduler, a name must be provided. This name is used to distinguish the +    instance's log messages from other instances. Using the name of the class or module containing +    the instance is suggested. -class Scheduler(metaclass=CogABCMeta): -    """Task scheduler.""" +    Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at` +    or `schedule_later`. A unique ID is required to be given in order to keep track of the +    resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing +    the same ID used to schedule it.  The `in` operator is supported for checking if a task with a +    given ID is currently scheduled. -    def __init__(self): -        # Keep track of the child cog's name so the logs are clear. -        self.cog_name = self.__class__.__name__ +    Any exception raised in a scheduled task is logged when the task is done. +    """ -        self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} +    def __init__(self, name: str): +        self.name = name -    @abstractmethod -    async def _scheduled_task(self, task_object: t.Any) -> None: -        """ -        A coroutine which handles the scheduling. +        self._log = logging.getLogger(f"{__name__}.{name}") +        self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} -        This is added to the scheduled tasks, and should wait the task duration, execute the desired -        code, then clean up the task. +    def __contains__(self, task_id: t.Hashable) -> bool: +        """Return True if a task with the given `task_id` is currently scheduled.""" +        return task_id in self._scheduled_tasks -        For example, in Reminders this will wait for the reminder duration, send the reminder, -        then make a site API request to delete the reminder from the database. +    def schedule(self, task_id: t.Hashable, coroutine: t.Coroutine) -> None:          """ +        Schedule the execution of a `coroutine`. -    def schedule_task(self, task_id: t.Hashable, task_data: t.Any) -> None: +        If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This +        prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.          """ -        Schedules a task. +        self._log.trace(f"Scheduling task #{task_id}...") -        `task_data` is passed to the `Scheduler._scheduled_task()` coroutine. -        """ -        log.trace(f"{self.cog_name}: scheduling task #{task_id}...") +        msg = f"Cannot schedule an already started coroutine for #{task_id}" +        assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg          if task_id in self._scheduled_tasks: -            log.debug( -                f"{self.cog_name}: did not schedule task #{task_id}; task was already scheduled." -            ) +            self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") +            coroutine.close()              return -        task = asyncio.create_task(self._scheduled_task(task_data)) +        task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}")          task.add_done_callback(partial(self._task_done_callback, task_id))          self._scheduled_tasks[task_id] = task -        log.debug(f"{self.cog_name}: scheduled task #{task_id} {id(task)}.") +        self._log.debug(f"Scheduled task #{task_id} {id(task)}.") + +    def schedule_at(self, time: datetime, task_id: t.Hashable, coroutine: t.Coroutine) -> None: +        """ +        Schedule `coroutine` to be executed at the given naïve UTC `time`. + +        If `time` is in the past, schedule `coroutine` immediately. + +        If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This +        prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. +        """ +        delay = (time - datetime.utcnow()).total_seconds() +        if delay > 0: +            coroutine = self._await_later(delay, task_id, coroutine) + +        self.schedule(task_id, coroutine) -    def cancel_task(self, task_id: t.Hashable, ignore_missing: bool = False) -> None: +    def schedule_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None:          """ -        Unschedule the task identified by `task_id`. +        Schedule `coroutine` to be executed after the given `delay` number of seconds. -        If `ignore_missing` is True, a warning will not be sent if a task isn't found. +        If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This +        prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.          """ -        log.trace(f"{self.cog_name}: cancelling task #{task_id}...") -        task = self._scheduled_tasks.get(task_id) +        self.schedule(task_id, self._await_later(delay, task_id, coroutine)) -        if not task: -            if not ignore_missing: -                log.warning(f"{self.cog_name}: failed to unschedule {task_id} (no task found).") -            return +    def cancel(self, task_id: t.Hashable) -> None: +        """Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist.""" +        self._log.trace(f"Cancelling task #{task_id}...") -        del self._scheduled_tasks[task_id] -        task.cancel() +        try: +            task = self._scheduled_tasks.pop(task_id) +        except KeyError: +            self._log.warning(f"Failed to unschedule {task_id} (no task found).") +        else: +            task.cancel() -        log.debug(f"{self.cog_name}: unscheduled task #{task_id} {id(task)}.") +            self._log.debug(f"Unscheduled task #{task_id} {id(task)}.")      def cancel_all(self) -> None:          """Unschedule all known tasks.""" -        log.debug(f"{self.cog_name}: unscheduling all tasks") +        self._log.debug("Unscheduling all tasks")          for task_id in self._scheduled_tasks.copy(): -            self.cancel_task(task_id, ignore_missing=True) +            self.cancel(task_id) + +    async def _await_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: +        """Await `coroutine` after the given `delay` number of seconds.""" +        try: +            self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") +            await asyncio.sleep(delay) + +            # Use asyncio.shield to prevent the coroutine from cancelling itself. +            self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") +            await asyncio.shield(coroutine) +        finally: +            # Close it to prevent unawaited coroutine warnings, +            # which would happen if the task was cancelled during the sleep. +            # Only close it if it's not been awaited yet. This check is important because the +            # coroutine may cancel this task, which would also trigger the finally block. +            state = inspect.getcoroutinestate(coroutine) +            if state == "CORO_CREATED": +                self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") +                coroutine.close() +            else: +                self._log.debug(f"Finally block reached for #{task_id}; {state=}")      def _task_done_callback(self, task_id: t.Hashable, done_task: asyncio.Task) -> None:          """ @@ -84,24 +129,24 @@ class Scheduler(metaclass=CogABCMeta):          If `done_task` and the task associated with `task_id` are different, then the latter          will not be deleted. In this case, a new task was likely rescheduled with the same ID.          """ -        log.trace(f"{self.cog_name}: performing done callback for task #{task_id} {id(done_task)}.") +        self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.")          scheduled_task = self._scheduled_tasks.get(task_id)          if scheduled_task and done_task is scheduled_task: -            # A task for the ID exists and its the same as the done task. +            # A task for the ID exists and is the same as the done task.              # Since this is the done callback, the task is already done so no need to cancel it. -            log.trace(f"{self.cog_name}: deleting task #{task_id} {id(done_task)}.") +            self._log.trace(f"Deleting task #{task_id} {id(done_task)}.")              del self._scheduled_tasks[task_id]          elif scheduled_task:              # A new task was likely rescheduled with the same ID. -            log.debug( -                f"{self.cog_name}: the scheduled task #{task_id} {id(scheduled_task)} " +            self._log.debug( +                f"The scheduled task #{task_id} {id(scheduled_task)} "                  f"and the done task {id(done_task)} differ."              )          elif not done_task.cancelled(): -            log.warning( -                f"{self.cog_name}: task #{task_id} not found while handling task {id(done_task)}! " +            self._log.warning( +                f"Task #{task_id} not found while handling task {id(done_task)}! "                  f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)."              ) @@ -109,7 +154,4 @@ class Scheduler(metaclass=CogABCMeta):              exception = done_task.exception()              # Log the exception if one exists.              if exception: -                log.error( -                    f"{self.cog_name}: error in task #{task_id} {id(done_task)}!", -                    exc_info=exception -                ) +                self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception)  |