diff options
Diffstat (limited to 'bot/cogs/moderation/scheduler.py')
| -rw-r--r-- | bot/cogs/moderation/scheduler.py | 30 |
1 files changed, 12 insertions, 18 deletions
diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index b03d89537..601e238c9 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -1,4 +1,3 @@ -import asyncio import logging import textwrap import typing as t @@ -23,13 +22,13 @@ from .utils import UserSnowflake log = logging.getLogger(__name__) -class InfractionScheduler(Scheduler): +class InfractionScheduler: """Handles the application, pardoning, and expiration of infractions.""" def __init__(self, bot: Bot, supported_infractions: t.Container[str]): - super().__init__() - self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) @property @@ -49,7 +48,7 @@ class InfractionScheduler(Scheduler): ) for infraction in infractions: if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_task(infraction["id"], infraction) + self.schedule_expiration(infraction) async def reapply_infraction( self, @@ -127,18 +126,17 @@ class InfractionScheduler(Scheduler): dm_result = ":incoming_envelope: " dm_log_text = "\nDM: Sent" + end_msg = "" if infraction["actor"] == self.bot.user.id: log.trace( f"Infraction #{id_} actor is bot; including the reason in the confirmation message." ) - - end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" + if reason: + end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" elif ctx.channel.id not in STAFF_CHANNELS: log.trace( f"Infraction #{id_} context is not in a staff channel; omitting infraction count." ) - - end_msg = "" else: log.trace(f"Fetching total infraction count for {user}.") @@ -156,7 +154,7 @@ class InfractionScheduler(Scheduler): await action_coro if expiry: # Schedule the expiration of the infraction. - self.schedule_task(infraction["id"], infraction) + self.schedule_expiration(infraction) except discord.HTTPException as e: # Accordingly display that applying the infraction failed. confirm_msg = ":x: failed to apply" @@ -279,7 +277,7 @@ class InfractionScheduler(Scheduler): # Cancel pending expiration task. if infraction["expires_at"] is not None: - self.cancel_task(infraction["id"]) + self.scheduler.cancel(infraction["id"]) # Accordingly display whether the user was successfully notified via DM. dm_emoji = "" @@ -416,7 +414,7 @@ class InfractionScheduler(Scheduler): # Cancel the expiration task. if infraction["expires_at"] is not None: - self.cancel_task(infraction["id"]) + self.scheduler.cancel(infraction["id"]) # Send a log message to the mod log. if send_log: @@ -450,7 +448,7 @@ class InfractionScheduler(Scheduler): """ raise NotImplementedError - async def _scheduled_task(self, infraction: utils.Infraction) -> None: + def schedule_expiration(self, infraction: utils.Infraction) -> None: """ Marks an infraction expired after the delay from time of scheduling to time of expiration. @@ -458,8 +456,4 @@ class InfractionScheduler(Scheduler): expiration task is cancelled. """ expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - await time.wait_until(expiry) - - # Because deactivate_infraction() explicitly cancels this scheduled task, it is shielded - # to avoid prematurely cancelling itself. - await asyncio.shield(self.deactivate_infraction(infraction)) + self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) |