diff options
| -rw-r--r-- | bot/cogs/moderation/management.py | 4 | ||||
| -rw-r--r-- | bot/cogs/moderation/scheduler.py | 8 | ||||
| -rw-r--r-- | bot/cogs/moderation/superstarify.py | 2 | ||||
| -rw-r--r-- | bot/cogs/reminders.py | 16 | ||||
| -rw-r--r-- | bot/utils/scheduling.py | 64 | 
5 files changed, 52 insertions, 42 deletions
| diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index f2964cd78..279c8b809 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -1,4 +1,3 @@ -import asyncio  import logging  import textwrap  import typing as t @@ -133,8 +132,7 @@ class ModManagement(commands.Cog):              # If the infraction was not marked as permanent, schedule a new expiration task              if request_data['expires_at']: -                loop = asyncio.get_event_loop() -                self.infractions_cog.schedule_task(loop, new_infraction['id'], new_infraction) +                self.infractions_cog.schedule_task(new_infraction['id'], new_infraction)              log_text += f"""                  Previous expiry: {old_infraction['expires_at'] or "Permanent"} diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 3c5185468..162159af8 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -48,7 +48,7 @@ class InfractionScheduler(Scheduler):          )          for infraction in infractions:              if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: -                self.schedule_task(self.bot.loop, infraction["id"], infraction) +                self.schedule_task(infraction["id"], infraction)      async def reapply_infraction(          self, @@ -150,7 +150,7 @@ class InfractionScheduler(Scheduler):                  await action_coro                  if expiry:                      # Schedule the expiration of the infraction. -                    self.schedule_task(ctx.bot.loop, infraction["id"], infraction) +                    self.schedule_task(infraction["id"], infraction)              except discord.HTTPException as e:                  # Accordingly display that applying the infraction failed.                  confirm_msg = f":x: failed to apply" @@ -427,4 +427,6 @@ class InfractionScheduler(Scheduler):          expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None)          await time.wait_until(expiry) -        await self.deactivate_infraction(infraction) +        # Because deactivate_infraction() explicitly cancels this scheduled task, it runs in +        # a separate task to avoid prematurely cancelling itself. +        self.bot.loop.create_task(self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index c41874a95..893cb7f13 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -146,7 +146,7 @@ class Superstarify(InfractionScheduler, Cog):          log.debug(f"Changing nickname of {member} to {forced_nick}.")          self.mod_log.ignore(constants.Event.member_update, member.id)          await member.edit(nick=forced_nick, reason=reason) -        self.schedule_task(ctx.bot.loop, id_, infraction) +        self.schedule_task(id_, infraction)          # Send a DM to the user to notify them of their new infraction.          await utils.notify_infraction( diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 041791056..24c279357 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -43,7 +43,6 @@ class Reminders(Scheduler, Cog):          )          now = datetime.utcnow() -        loop = asyncio.get_event_loop()          for reminder in response:              is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) @@ -57,7 +56,7 @@ class Reminders(Scheduler, Cog):                  late = relativedelta(now, remind_at)                  await self.send_reminder(reminder, late)              else: -                self.schedule_task(loop, reminder["id"], reminder) +                self.schedule_task(reminder["id"], reminder)      def ensure_valid_reminder(          self, @@ -112,9 +111,6 @@ class Reminders(Scheduler, Cog):          log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).")          await self._delete_reminder(reminder_id) -        # Now we can begone with it from our schedule list. -        self.cancel_task(reminder_id) -      async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None:          """Delete a reminder from the database, given its ID, and cancel the running task."""          await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) @@ -125,10 +121,11 @@ class Reminders(Scheduler, Cog):      async def _reschedule_reminder(self, reminder: dict) -> None:          """Reschedule a reminder object.""" -        loop = asyncio.get_event_loop() - +        log.trace(f"Cancelling old task #{reminder['id']}")          self.cancel_task(reminder["id"]) -        self.schedule_task(loop, reminder["id"], reminder) + +        log.trace(f"Scheduling new task #{reminder['id']}") +        self.schedule_task(reminder["id"], reminder)      async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:          """Send the reminder.""" @@ -226,8 +223,7 @@ class Reminders(Scheduler, Cog):              delivery_dt=expiration,          ) -        loop = asyncio.get_event_loop() -        self.schedule_task(loop, reminder["id"], reminder) +        self.schedule_task(reminder["id"], reminder)      @remind_group.command(name="list")      async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]: diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index ee6c0a8e6..0d66952eb 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -2,7 +2,8 @@ import asyncio  import contextlib  import logging  from abc import abstractmethod -from typing import Coroutine, Dict, Union +from functools import partial +from typing import Dict  from bot.utils import CogABCMeta @@ -13,9 +14,10 @@ class Scheduler(metaclass=CogABCMeta):      """Task scheduler."""      def __init__(self): +        # Keep track of the child cog's name so the logs are clear. +        self.cog_name = self.__class__.__name__ -        self.cog_name = self.__class__.__name__  # keep track of the child cog's name so the logs are clear. -        self.scheduled_tasks: Dict[str, asyncio.Task] = {} +        self._scheduled_tasks: Dict[str, asyncio.Task] = {}      @abstractmethod      async def _scheduled_task(self, task_object: dict) -> None: @@ -29,46 +31,58 @@ class Scheduler(metaclass=CogABCMeta):          then make a site API request to delete the reminder from the database.          """ -    def schedule_task(self, loop: asyncio.AbstractEventLoop, task_id: str, task_data: dict) -> None: +    def schedule_task(self, task_id: str, task_data: dict) -> None:          """          Schedules a task. -        `task_data` is passed to `Scheduler._scheduled_expiration` +        `task_data` is passed to the `Scheduler._scheduled_task()` coroutine.          """ -        if task_id in self.scheduled_tasks: +        log.trace(f"{self.cog_name}: scheduling task #{task_id}...") + +        if task_id in self._scheduled_tasks:              log.debug(                  f"{self.cog_name}: did not schedule task #{task_id}; task was already scheduled."              )              return -        task: asyncio.Task = create_task(loop, self._scheduled_task(task_data)) +        task = asyncio.create_task(self._scheduled_task(task_data)) +        task.add_done_callback(partial(self._task_done_callback, task_id)) -        self.scheduled_tasks[task_id] = task -        log.debug(f"{self.cog_name}: scheduled task #{task_id}.") +        self._scheduled_tasks[task_id] = task +        log.debug(f"{self.cog_name}: scheduled task #{task_id} {id(task)}.")      def cancel_task(self, task_id: str) -> None:          """Un-schedules a task.""" -        task = self.scheduled_tasks.get(task_id) +        log.trace(f"{self.cog_name}: cancelling task #{task_id}...") + +        task = self._scheduled_tasks.get(task_id)          if task is None:              log.warning(f"{self.cog_name}: Failed to unschedule {task_id} (no task found).")              return          task.cancel() -        log.debug(f"{self.cog_name}: unscheduled task #{task_id}.") -        del self.scheduled_tasks[task_id] - - -def create_task(loop: asyncio.AbstractEventLoop, coro_or_future: Union[Coroutine, asyncio.Future]) -> asyncio.Task: -    """Creates an asyncio.Task object from a coroutine or future object.""" -    task: asyncio.Task = asyncio.ensure_future(coro_or_future, loop=loop) - -    # Silently ignore exceptions in a callback (handles the CancelledError nonsense) -    task.add_done_callback(_silent_exception) -    return task +        log.debug(f"{self.cog_name}: unscheduled task #{task_id} {id(task)}.") +        del self._scheduled_tasks[task_id] +    def _task_done_callback(self, task_id: str, task: asyncio.Task) -> None: +        """ +        Unschedule the task and raise its exception if one exists. -def _silent_exception(future: asyncio.Future) -> None: -    """Suppress future's exception.""" -    with contextlib.suppress(Exception): -        future.exception() +        If the task was cancelled, the CancelledError is retrieved and suppressed. In this case, +        the task is already assumed to have been unscheduled. +        """ +        log.trace(f"{self.cog_name}: performing done callback for task #{task_id} {id(task)}") + +        if task.cancelled(): +            with contextlib.suppress(asyncio.CancelledError): +                task.exception() +        else: +            # Check if it exists to avoid logging a warning. +            if task_id in self._scheduled_tasks: +                # Only cancel if the task is not cancelled to avoid a race condition when a new +                # task is scheduled using the same ID. Reminders do this when re-scheduling after +                # editing. +                self.cancel_task(task_id) + +            task.exception() | 
