diff options
-rw-r--r-- | bot/cogs/moderation/management.py | 4 | ||||
-rw-r--r-- | bot/cogs/moderation/scheduler.py | 9 | ||||
-rw-r--r-- | bot/cogs/moderation/superstarify.py | 2 | ||||
-rw-r--r-- | bot/cogs/reminders.py | 16 | ||||
-rw-r--r-- | bot/utils/scheduling.py | 83 |
5 files changed, 70 insertions, 44 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index f74089056..35448f682 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -1,4 +1,3 @@ -import asyncio import logging import textwrap import typing as t @@ -135,8 +134,7 @@ class ModManagement(commands.Cog): # If the infraction was not marked as permanent, schedule a new expiration task if request_data['expires_at']: - loop = asyncio.get_event_loop() - self.infractions_cog.schedule_task(loop, new_infraction['id'], new_infraction) + self.infractions_cog.schedule_task(new_infraction['id'], new_infraction) log_text += f""" Previous expiry: {old_infraction['expires_at'] or "Permanent"} diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 45cf5ec8a..f0b6b2c48 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -1,3 +1,4 @@ +import asyncio import logging import textwrap import typing as t @@ -48,7 +49,7 @@ class InfractionScheduler(Scheduler): ) for infraction in infractions: if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_task(self.bot.loop, infraction["id"], infraction) + self.schedule_task(infraction["id"], infraction) async def reapply_infraction( self, @@ -150,7 +151,7 @@ class InfractionScheduler(Scheduler): await action_coro if expiry: # Schedule the expiration of the infraction. - self.schedule_task(ctx.bot.loop, infraction["id"], infraction) + self.schedule_task(infraction["id"], infraction) except discord.HTTPException as e: # Accordingly display that applying the infraction failed. confirm_msg = f":x: failed to apply" @@ -427,4 +428,6 @@ class InfractionScheduler(Scheduler): expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) await time.wait_until(expiry) - await self.deactivate_infraction(infraction) + # Because deactivate_infraction() explicitly cancels this scheduled task, it is shielded + # to avoid prematurely cancelling itself. + await asyncio.shield(self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index c41874a95..893cb7f13 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -146,7 +146,7 @@ class Superstarify(InfractionScheduler, Cog): log.debug(f"Changing nickname of {member} to {forced_nick}.") self.mod_log.ignore(constants.Event.member_update, member.id) await member.edit(nick=forced_nick, reason=reason) - self.schedule_task(ctx.bot.loop, id_, infraction) + self.schedule_task(id_, infraction) # Send a DM to the user to notify them of their new infraction. await utils.notify_infraction( diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 041791056..24c279357 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -43,7 +43,6 @@ class Reminders(Scheduler, Cog): ) now = datetime.utcnow() - loop = asyncio.get_event_loop() for reminder in response: is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) @@ -57,7 +56,7 @@ class Reminders(Scheduler, Cog): late = relativedelta(now, remind_at) await self.send_reminder(reminder, late) else: - self.schedule_task(loop, reminder["id"], reminder) + self.schedule_task(reminder["id"], reminder) def ensure_valid_reminder( self, @@ -112,9 +111,6 @@ class Reminders(Scheduler, Cog): log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") await self._delete_reminder(reminder_id) - # Now we can begone with it from our schedule list. - self.cancel_task(reminder_id) - async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: """Delete a reminder from the database, given its ID, and cancel the running task.""" await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) @@ -125,10 +121,11 @@ class Reminders(Scheduler, Cog): async def _reschedule_reminder(self, reminder: dict) -> None: """Reschedule a reminder object.""" - loop = asyncio.get_event_loop() - + log.trace(f"Cancelling old task #{reminder['id']}") self.cancel_task(reminder["id"]) - self.schedule_task(loop, reminder["id"], reminder) + + log.trace(f"Scheduling new task #{reminder['id']}") + self.schedule_task(reminder["id"], reminder) async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: """Send the reminder.""" @@ -226,8 +223,7 @@ class Reminders(Scheduler, Cog): delivery_dt=expiration, ) - loop = asyncio.get_event_loop() - self.schedule_task(loop, reminder["id"], reminder) + self.schedule_task(reminder["id"], reminder) @remind_group.command(name="list") async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]: diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index ee6c0a8e6..5760ec2d4 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -1,8 +1,9 @@ import asyncio import contextlib import logging +import typing as t from abc import abstractmethod -from typing import Coroutine, Dict, Union +from functools import partial from bot.utils import CogABCMeta @@ -13,12 +14,13 @@ class Scheduler(metaclass=CogABCMeta): """Task scheduler.""" def __init__(self): + # Keep track of the child cog's name so the logs are clear. + self.cog_name = self.__class__.__name__ - self.cog_name = self.__class__.__name__ # keep track of the child cog's name so the logs are clear. - self.scheduled_tasks: Dict[str, asyncio.Task] = {} + self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} @abstractmethod - async def _scheduled_task(self, task_object: dict) -> None: + async def _scheduled_task(self, task_object: t.Any) -> None: """ A coroutine which handles the scheduling. @@ -29,46 +31,73 @@ class Scheduler(metaclass=CogABCMeta): then make a site API request to delete the reminder from the database. """ - def schedule_task(self, loop: asyncio.AbstractEventLoop, task_id: str, task_data: dict) -> None: + def schedule_task(self, task_id: t.Hashable, task_data: t.Any) -> None: """ Schedules a task. - `task_data` is passed to `Scheduler._scheduled_expiration` + `task_data` is passed to the `Scheduler._scheduled_task()` coroutine. """ - if task_id in self.scheduled_tasks: + log.trace(f"{self.cog_name}: scheduling task #{task_id}...") + + if task_id in self._scheduled_tasks: log.debug( f"{self.cog_name}: did not schedule task #{task_id}; task was already scheduled." ) return - task: asyncio.Task = create_task(loop, self._scheduled_task(task_data)) + task = asyncio.create_task(self._scheduled_task(task_data)) + task.add_done_callback(partial(self._task_done_callback, task_id)) - self.scheduled_tasks[task_id] = task - log.debug(f"{self.cog_name}: scheduled task #{task_id}.") + self._scheduled_tasks[task_id] = task + log.debug(f"{self.cog_name}: scheduled task #{task_id} {id(task)}.") - def cancel_task(self, task_id: str) -> None: - """Un-schedules a task.""" - task = self.scheduled_tasks.get(task_id) + def cancel_task(self, task_id: t.Hashable) -> None: + """Unschedule the task identified by `task_id`.""" + log.trace(f"{self.cog_name}: cancelling task #{task_id}...") + task = self._scheduled_tasks.get(task_id) - if task is None: - log.warning(f"{self.cog_name}: Failed to unschedule {task_id} (no task found).") + if not task: + log.warning(f"{self.cog_name}: failed to unschedule {task_id} (no task found).") return task.cancel() - log.debug(f"{self.cog_name}: unscheduled task #{task_id}.") - del self.scheduled_tasks[task_id] + del self._scheduled_tasks[task_id] + + log.debug(f"{self.cog_name}: unscheduled task #{task_id} {id(task)}.") + def _task_done_callback(self, task_id: t.Hashable, done_task: asyncio.Task) -> None: + """ + Delete the task and raise its exception if one exists. -def create_task(loop: asyncio.AbstractEventLoop, coro_or_future: Union[Coroutine, asyncio.Future]) -> asyncio.Task: - """Creates an asyncio.Task object from a coroutine or future object.""" - task: asyncio.Task = asyncio.ensure_future(coro_or_future, loop=loop) + If `done_task` and the task associated with `task_id` are different, then the latter + will not be deleted. In this case, a new task was likely rescheduled with the same ID. + """ + log.trace(f"{self.cog_name}: performing done callback for task #{task_id} {id(done_task)}.") - # Silently ignore exceptions in a callback (handles the CancelledError nonsense) - task.add_done_callback(_silent_exception) - return task + scheduled_task = self._scheduled_tasks.get(task_id) + if scheduled_task and done_task is scheduled_task: + # A task for the ID exists and its the same as the done task. + # Since this is the done callback, the task is already done so no need to cancel it. + log.trace(f"{self.cog_name}: deleting task #{task_id} {id(done_task)}.") + del self._scheduled_tasks[task_id] + elif scheduled_task: + # A new task was likely rescheduled with the same ID. + log.debug( + f"{self.cog_name}: the scheduled task #{task_id} {id(scheduled_task)} " + f"and the done task {id(done_task)} differ." + ) + elif not done_task.cancelled(): + log.warning( + f"{self.cog_name}: task #{task_id} not found while handling task {id(done_task)}! " + f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." + ) -def _silent_exception(future: asyncio.Future) -> None: - """Suppress future's exception.""" - with contextlib.suppress(Exception): - future.exception() + with contextlib.suppress(asyncio.CancelledError): + exception = done_task.exception() + # Log the exception if one exists. + if exception: + log.error( + f"{self.cog_name}: error in task #{task_id} {id(scheduled_task)}!", + exc_info=exception + ) |