diff options
| author | 2020-07-13 21:11:09 +0200 | |
|---|---|---|
| committer | 2020-07-13 21:11:09 +0200 | |
| commit | 02f77991647ad896120d31aa16a75f90ad098449 (patch) | |
| tree | 1595d42a05805065eb268bdf52f27547307d7f02 | |
| parent | Store last DM user in RedisCache. (diff) | |
| parent | Merge pull request #1039 from python-discord/1038_allow_role_mentions_in_spec... (diff) | |
Merge branch 'master' into dm_relay
| -rw-r--r-- | bot/cogs/filtering.py | 16 | ||||
| -rw-r--r-- | bot/cogs/help_channels.py | 74 | ||||
| -rw-r--r-- | bot/cogs/moderation/management.py | 4 | ||||
| -rw-r--r-- | bot/cogs/moderation/modlog.py | 7 | ||||
| -rw-r--r-- | bot/cogs/moderation/scheduler.py | 23 | ||||
| -rw-r--r-- | bot/cogs/moderation/silence.py | 32 | ||||
| -rw-r--r-- | bot/cogs/moderation/superstarify.py | 2 | ||||
| -rw-r--r-- | bot/cogs/reminders.py | 29 | ||||
| -rw-r--r-- | bot/cogs/sync/syncers.py | 7 | ||||
| -rw-r--r-- | bot/utils/scheduling.py | 146 | ||||
| -rw-r--r-- | config-default.yml | 2 |
11 files changed, 169 insertions, 173 deletions
diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 76ea68660..099606b82 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -19,7 +19,6 @@ from bot.constants import ( ) from bot.utils.redis_cache import RedisCache from bot.utils.scheduling import Scheduler -from bot.utils.time import wait_until log = logging.getLogger(__name__) @@ -60,7 +59,7 @@ def expand_spoilers(text: str) -> str: OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) -class Filtering(Cog, Scheduler): +class Filtering(Cog): """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent @@ -68,8 +67,7 @@ class Filtering(Cog, Scheduler): def __init__(self, bot: Bot): self.bot = bot - super().__init__() - + self.scheduler = Scheduler(self.__class__.__name__) self.name_lock = asyncio.Lock() staff_mistake_str = "If you believe this was a mistake, please let staff know!" @@ -268,7 +266,7 @@ class Filtering(Cog, Scheduler): } await self.bot.api_client.post('bot/offensive-messages', json=data) - self.schedule_task(msg.id, data) + self.schedule_msg_delete(data) log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") if is_private: @@ -457,12 +455,10 @@ class Filtering(Cog, Scheduler): except discord.errors.Forbidden: await channel.send(f"{filtered_member.mention} {reason}") - async def _scheduled_task(self, msg: dict) -> None: + def schedule_msg_delete(self, msg: dict) -> None: """Delete an offensive message once its deletion date is reached.""" delete_at = dateutil.parser.isoparse(msg['delete_date']).replace(tzinfo=None) - - await wait_until(delete_at) - await self.delete_offensive_msg(msg) + self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) async def reschedule_offensive_msg_deletion(self) -> None: """Get all the pending message deletion from the API and reschedule them.""" @@ -477,7 +473,7 @@ class Filtering(Cog, Scheduler): if delete_at < now: await self.delete_offensive_msg(msg) else: - self.schedule_task(msg['id'], msg) + self.schedule_msg_delete(msg) async def delete_offensive_msg(self, msg: Mapping[str, str]) -> None: """Delete an offensive message, and then delete it from the db.""" diff --git a/bot/cogs/help_channels.py b/bot/cogs/help_channels.py index 187adfe51..4d0c534b0 100644 --- a/bot/cogs/help_channels.py +++ b/bot/cogs/help_channels.py @@ -1,5 +1,4 @@ import asyncio -import inspect import json import logging import random @@ -57,14 +56,7 @@ through our guide for [asking a good question]({ASKING_GUIDE_URL}). CoroutineFunc = t.Callable[..., t.Coroutine] -class TaskData(t.NamedTuple): - """Data for a scheduled task.""" - - wait_time: int - callback: t.Awaitable - - -class HelpChannels(Scheduler, commands.Cog): +class HelpChannels(commands.Cog): """ Manage the help channel system of the guild. @@ -114,9 +106,8 @@ class HelpChannels(Scheduler, commands.Cog): claim_times = RedisCache() def __init__(self, bot: Bot): - super().__init__() - self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) # Categories self.available_category: discord.CategoryChannel = None @@ -145,7 +136,7 @@ class HelpChannels(Scheduler, commands.Cog): for task in self.queue_tasks: task.cancel() - self.cancel_all() + self.scheduler.cancel_all() def create_channel_queue(self) -> asyncio.Queue: """ @@ -229,10 +220,11 @@ class HelpChannels(Scheduler, commands.Cog): await self.remove_cooldown_role(ctx.author) # Ignore missing task when cooldown has passed but the channel still isn't dormant. - self.cancel_task(ctx.author.id, ignore_missing=True) + if ctx.author.id in self.scheduler: + self.scheduler.cancel(ctx.author.id) await self.move_to_dormant(ctx.channel, "command") - self.cancel_task(ctx.channel.id) + self.scheduler.cancel(ctx.channel.id) else: log.debug(f"{ctx.author} invoked command 'dormant' outside an in-use help channel") @@ -474,16 +466,15 @@ class HelpChannels(Scheduler, commands.Cog): else: # Cancel the existing task, if any. if has_task: - self.cancel_task(channel.id) - - data = TaskData(idle_seconds - time_elapsed, self.move_idle_channel(channel)) + self.scheduler.cancel(channel.id) + delay = idle_seconds - time_elapsed log.info( f"#{channel} ({channel.id}) is still active; " - f"scheduling it to be moved after {data.wait_time} seconds." + f"scheduling it to be moved after {delay} seconds." ) - self.schedule_task(channel.id, data) + self.scheduler.schedule_later(delay, channel.id, self.move_idle_channel(channel)) async def move_to_bottom_position(self, channel: discord.TextChannel, category_id: int, **options) -> None: """ @@ -588,8 +579,7 @@ class HelpChannels(Scheduler, commands.Cog): timeout = constants.HelpChannels.idle_minutes * 60 log.trace(f"Scheduling #{channel} ({channel.id}) to become dormant in {timeout} sec.") - data = TaskData(timeout, self.move_idle_channel(channel)) - self.schedule_task(channel.id, data) + self.scheduler.schedule_later(timeout, channel.id, self.move_idle_channel(channel)) self.report_stats() async def notify(self) -> None: @@ -624,11 +614,13 @@ class HelpChannels(Scheduler, commands.Cog): channel = self.bot.get_channel(constants.HelpChannels.notify_channel) mentions = " ".join(f"<@&{role}>" for role in constants.HelpChannels.notify_roles) + allowed_roles = [discord.Object(id_) for id_ in constants.HelpChannels.notify_roles] message = await channel.send( f"{mentions} A new available help channel is needed but there " f"are no more dormant ones. Consider freeing up some in-use channels manually by " - f"using the `{constants.Bot.prefix}dormant` command within the channels." + f"using the `{constants.Bot.prefix}dormant` command within the channels.", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) ) self.bot.stats.incr("help.out_of_channel_alerts") @@ -722,10 +714,10 @@ class HelpChannels(Scheduler, commands.Cog): log.info(f"Claimant of #{msg.channel} ({msg.author}) deleted message, channel is empty now. Rescheduling task.") # Cancel existing dormant task before scheduling new. - self.cancel_task(msg.channel.id) + self.scheduler.cancel(msg.channel.id) - task = TaskData(constants.HelpChannels.deleted_idle_minutes * 60, self.move_idle_channel(msg.channel)) - self.schedule_task(msg.channel.id, task) + delay = constants.HelpChannels.deleted_idle_minutes * 60 + self.scheduler.schedule_later(delay, msg.channel.id, self.move_idle_channel(msg.channel)) async def is_empty(self, channel: discord.TextChannel) -> bool: """Return True if the most recent message in `channel` is the bot's `AVAILABLE_MSG`.""" @@ -752,8 +744,8 @@ class HelpChannels(Scheduler, commands.Cog): await self.remove_cooldown_role(member) else: # The member is still on a cooldown; re-schedule it for the remaining time. - remaining = cooldown - in_use_time.seconds - await self.schedule_cooldown_expiration(member, remaining) + delay = cooldown - in_use_time.seconds + self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) async def add_cooldown_role(self, member: discord.Member) -> None: """Add the help cooldown role to `member`.""" @@ -804,16 +796,11 @@ class HelpChannels(Scheduler, commands.Cog): # Cancel the existing task, if any. # Would mean the user somehow bypassed the lack of permissions (e.g. user is guild owner). - self.cancel_task(member.id, ignore_missing=True) + if member.id in self.scheduler: + self.scheduler.cancel(member.id) - await self.schedule_cooldown_expiration(member, constants.HelpChannels.claim_minutes * 60) - - async def schedule_cooldown_expiration(self, member: discord.Member, seconds: int) -> None: - """Schedule the cooldown role for `member` to be removed after a duration of `seconds`.""" - log.trace(f"Scheduling removal of {member}'s ({member.id}) cooldown.") - - callback = self.remove_cooldown_role(member) - self.schedule_task(member.id, TaskData(seconds, callback)) + delay = constants.HelpChannels.claim_minutes * 60 + self.scheduler.schedule_later(delay, member.id, self.remove_cooldown_role(member)) async def send_available_message(self, channel: discord.TextChannel) -> None: """Send the available message by editing a dormant message or sending a new message.""" @@ -855,21 +842,6 @@ class HelpChannels(Scheduler, commands.Cog): return channel - async def _scheduled_task(self, data: TaskData) -> None: - """Await the `data.callback` coroutine after waiting for `data.wait_time` seconds.""" - try: - log.trace(f"Waiting {data.wait_time} seconds before awaiting callback.") - await asyncio.sleep(data.wait_time) - - # Use asyncio.shield to prevent callback from cancelling itself. - # The parent task (_scheduled_task) will still get cancelled. - log.trace("Done waiting; now awaiting the callback.") - await asyncio.shield(data.callback) - finally: - if inspect.iscoroutine(data.callback): - log.trace("Explicitly closing coroutine.") - data.callback.close() - def validate_config() -> None: """Raise a ValueError if the cog's config is invalid.""" diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 617d957ed..672bb0e9c 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -135,11 +135,11 @@ class ModManagement(commands.Cog): if 'expires_at' in request_data: # A scheduled task should only exist if the old infraction wasn't permanent if old_infraction['expires_at']: - self.infractions_cog.cancel_task(new_infraction['id']) + self.infractions_cog.scheduler.cancel(new_infraction['id']) # If the infraction was not marked as permanent, schedule a new expiration task if request_data['expires_at']: - self.infractions_cog.schedule_task(new_infraction['id'], new_infraction) + self.infractions_cog.schedule_expiration(new_infraction) log_text += f""" Previous expiry: {old_infraction['expires_at'] or "Permanent"} diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index ffbb87bbe..0a63f57b8 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -121,7 +121,12 @@ class ModLog(Cog, name="ModLog"): content = "@everyone" channel = self.bot.get_channel(channel_id) - log_message = await channel.send(content=content, embed=embed, files=files) + log_message = await channel.send( + content=content, + embed=embed, + files=files, + allowed_mentions=discord.AllowedMentions(everyone=True) + ) if additional_embeds: if additional_embeds_msg: diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index d75a72ddb..601e238c9 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -1,4 +1,3 @@ -import asyncio import logging import textwrap import typing as t @@ -23,13 +22,13 @@ from .utils import UserSnowflake log = logging.getLogger(__name__) -class InfractionScheduler(Scheduler): +class InfractionScheduler: """Handles the application, pardoning, and expiration of infractions.""" def __init__(self, bot: Bot, supported_infractions: t.Container[str]): - super().__init__() - self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + self.bot.loop.create_task(self.reschedule_infractions(supported_infractions)) @property @@ -49,7 +48,7 @@ class InfractionScheduler(Scheduler): ) for infraction in infractions: if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_task(infraction["id"], infraction) + self.schedule_expiration(infraction) async def reapply_infraction( self, @@ -155,7 +154,7 @@ class InfractionScheduler(Scheduler): await action_coro if expiry: # Schedule the expiration of the infraction. - self.schedule_task(infraction["id"], infraction) + self.schedule_expiration(infraction) except discord.HTTPException as e: # Accordingly display that applying the infraction failed. confirm_msg = ":x: failed to apply" @@ -278,7 +277,7 @@ class InfractionScheduler(Scheduler): # Cancel pending expiration task. if infraction["expires_at"] is not None: - self.cancel_task(infraction["id"]) + self.scheduler.cancel(infraction["id"]) # Accordingly display whether the user was successfully notified via DM. dm_emoji = "" @@ -415,7 +414,7 @@ class InfractionScheduler(Scheduler): # Cancel the expiration task. if infraction["expires_at"] is not None: - self.cancel_task(infraction["id"]) + self.scheduler.cancel(infraction["id"]) # Send a log message to the mod log. if send_log: @@ -449,7 +448,7 @@ class InfractionScheduler(Scheduler): """ raise NotImplementedError - async def _scheduled_task(self, infraction: utils.Infraction) -> None: + def schedule_expiration(self, infraction: utils.Infraction) -> None: """ Marks an infraction expired after the delay from time of scheduling to time of expiration. @@ -457,8 +456,4 @@ class InfractionScheduler(Scheduler): expiration task is cancelled. """ expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) - await time.wait_until(expiry) - - # Because deactivate_infraction() explicitly cancels this scheduled task, it is shielded - # to avoid prematurely cancelling itself. - await asyncio.shield(self.deactivate_infraction(infraction)) + self.scheduler.schedule_at(expiry, infraction["id"], self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index c8ab6443b..ae4fb7b64 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -1,7 +1,7 @@ import asyncio import logging from contextlib import suppress -from typing import NamedTuple, Optional +from typing import Optional from discord import TextChannel from discord.ext import commands, tasks @@ -16,13 +16,6 @@ from bot.utils.scheduling import Scheduler log = logging.getLogger(__name__) -class TaskData(NamedTuple): - """Data for a scheduled task.""" - - delay: int - ctx: Context - - class SilenceNotifier(tasks.Loop): """Loop notifier for posting notices to `alert_channel` containing added channels.""" @@ -61,25 +54,17 @@ class SilenceNotifier(tasks.Loop): await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") -class Silence(Scheduler, commands.Cog): +class Silence(commands.Cog): """Commands for stopping channel messages for `verified` role in a channel.""" def __init__(self, bot: Bot): - super().__init__() self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) self.muted_channels = set() + self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars()) self._get_instance_vars_event = asyncio.Event() - async def _scheduled_task(self, task: TaskData) -> None: - """Calls `self.unsilence` on expired silenced channel to unsilence it.""" - await asyncio.sleep(task.delay) - log.info("Unsilencing channel after set delay.") - - # Because `self.unsilence` explicitly cancels this scheduled task, it is shielded - # to avoid prematurely cancelling itself - await asyncio.shield(task.ctx.invoke(self.unsilence)) - async def _get_instance_vars(self) -> None: """Get instance variables after they're available to get from the guild.""" await self.bot.wait_until_guild_available() @@ -109,12 +94,7 @@ class Silence(Scheduler, commands.Cog): await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") - task_data = TaskData( - delay=duration*60, - ctx=ctx - ) - - self.schedule_task(ctx.channel.id, task_data) + self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) @commands.command(aliases=("unhush",)) async def unsilence(self, ctx: Context) -> None: @@ -164,7 +144,7 @@ class Silence(Scheduler, commands.Cog): if current_overwrite.send_messages is False: await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) log.info(f"Unsilenced channel #{channel} ({channel.id}).") - self.cancel_task(channel.id) + self.scheduler.cancel(channel.id) self.notifier.remove_channel(channel) self.muted_channels.discard(channel) return True diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 45a010f00..867de815a 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -146,7 +146,7 @@ class Superstarify(InfractionScheduler, Cog): log.debug(f"Changing nickname of {member} to {forced_nick}.") self.mod_log.ignore(constants.Event.member_update, member.id) await member.edit(nick=forced_nick, reason=reason) - self.schedule_task(id_, infraction) + self.schedule_expiration(infraction) # Send a DM to the user to notify them of their new infraction. await utils.notify_infraction( diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index c242d2920..0d20bdb2b 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -17,7 +17,7 @@ from bot.converters import Duration from bot.pagination import LinePaginator from bot.utils.checks import without_role_check from bot.utils.scheduling import Scheduler -from bot.utils.time import humanize_delta, wait_until +from bot.utils.time import humanize_delta log = logging.getLogger(__name__) @@ -25,12 +25,12 @@ WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 -class Reminders(Scheduler, Cog): +class Reminders(Cog): """Provide in-channel reminder functionality.""" def __init__(self, bot: Bot): self.bot = bot - super().__init__() + self.scheduler = Scheduler(self.__class__.__name__) self.bot.loop.create_task(self.reschedule_reminders()) @@ -56,7 +56,7 @@ class Reminders(Scheduler, Cog): late = relativedelta(now, remind_at) await self.send_reminder(reminder, late) else: - self.schedule_task(reminder["id"], reminder) + self.schedule_reminder(reminder) def ensure_valid_reminder( self, @@ -99,17 +99,18 @@ class Reminders(Scheduler, Cog): await ctx.send(embed=embed) - async def _scheduled_task(self, reminder: dict) -> None: + def schedule_reminder(self, reminder: dict) -> None: """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" reminder_id = reminder["id"] reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) - # Send the reminder message once the desired duration has passed - await wait_until(reminder_datetime) - await self.send_reminder(reminder) + async def _remind() -> None: + await self.send_reminder(reminder) - log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") - await self._delete_reminder(reminder_id) + log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") + await self._delete_reminder(reminder_id) + + self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: """Delete a reminder from the database, given its ID, and cancel the running task.""" @@ -117,15 +118,15 @@ class Reminders(Scheduler, Cog): if cancel_task: # Now we can remove it from the schedule list - self.cancel_task(reminder_id) + self.scheduler.cancel(reminder_id) async def _reschedule_reminder(self, reminder: dict) -> None: """Reschedule a reminder object.""" log.trace(f"Cancelling old task #{reminder['id']}") - self.cancel_task(reminder["id"]) + self.scheduler.cancel(reminder["id"]) log.trace(f"Scheduling new task #{reminder['id']}") - self.schedule_task(reminder["id"], reminder) + self.schedule_reminder(reminder) async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: """Send the reminder.""" @@ -223,7 +224,7 @@ class Reminders(Scheduler, Cog): delivery_dt=expiration, ) - self.schedule_task(reminder["id"], reminder) + self.schedule_reminder(reminder) @remind_group.command(name="list") async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]: diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 536455668..f7ba811bc 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -5,6 +5,7 @@ import typing as t from collections import namedtuple from functools import partial +import discord from discord import Guild, HTTPException, Member, Message, Reaction, User from discord.ext.commands import Context @@ -68,7 +69,11 @@ class Syncer(abc.ABC): ) return None - message = await channel.send(f"{self._CORE_DEV_MENTION}{msg_content}") + allowed_roles = [discord.Object(constants.Roles.core_developers)] + message = await channel.send( + f"{self._CORE_DEV_MENTION}{msg_content}", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) else: await message.edit(content=msg_content) diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index 8b778a093..03f31d78f 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -1,81 +1,126 @@ import asyncio import contextlib +import inspect import logging import typing as t -from abc import abstractmethod +from datetime import datetime from functools import partial -from bot.utils import CogABCMeta -log = logging.getLogger(__name__) +class Scheduler: + """ + Schedule the execution of coroutines and keep track of them. + When instantiating a Scheduler, a name must be provided. This name is used to distinguish the + instance's log messages from other instances. Using the name of the class or module containing + the instance is suggested. -class Scheduler(metaclass=CogABCMeta): - """Task scheduler.""" + Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at` + or `schedule_later`. A unique ID is required to be given in order to keep track of the + resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing + the same ID used to schedule it. The `in` operator is supported for checking if a task with a + given ID is currently scheduled. - def __init__(self): - # Keep track of the child cog's name so the logs are clear. - self.cog_name = self.__class__.__name__ + Any exception raised in a scheduled task is logged when the task is done. + """ - self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} + def __init__(self, name: str): + self.name = name - @abstractmethod - async def _scheduled_task(self, task_object: t.Any) -> None: - """ - A coroutine which handles the scheduling. + self._log = logging.getLogger(f"{__name__}.{name}") + self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} - This is added to the scheduled tasks, and should wait the task duration, execute the desired - code, then clean up the task. + def __contains__(self, task_id: t.Hashable) -> bool: + """Return True if a task with the given `task_id` is currently scheduled.""" + return task_id in self._scheduled_tasks - For example, in Reminders this will wait for the reminder duration, send the reminder, - then make a site API request to delete the reminder from the database. + def schedule(self, task_id: t.Hashable, coroutine: t.Coroutine) -> None: """ + Schedule the execution of a `coroutine`. - def schedule_task(self, task_id: t.Hashable, task_data: t.Any) -> None: + If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. """ - Schedules a task. + self._log.trace(f"Scheduling task #{task_id}...") - `task_data` is passed to the `Scheduler._scheduled_task()` coroutine. - """ - log.trace(f"{self.cog_name}: scheduling task #{task_id}...") + msg = f"Cannot schedule an already started coroutine for #{task_id}" + assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg if task_id in self._scheduled_tasks: - log.debug( - f"{self.cog_name}: did not schedule task #{task_id}; task was already scheduled." - ) + self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") + coroutine.close() return - task = asyncio.create_task(self._scheduled_task(task_data)) + task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") task.add_done_callback(partial(self._task_done_callback, task_id)) self._scheduled_tasks[task_id] = task - log.debug(f"{self.cog_name}: scheduled task #{task_id} {id(task)}.") + self._log.debug(f"Scheduled task #{task_id} {id(task)}.") + + def schedule_at(self, time: datetime, task_id: t.Hashable, coroutine: t.Coroutine) -> None: + """ + Schedule `coroutine` to be executed at the given naïve UTC `time`. + + If `time` is in the past, schedule `coroutine` immediately. + + If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + """ + delay = (time - datetime.utcnow()).total_seconds() + if delay > 0: + coroutine = self._await_later(delay, task_id, coroutine) + + self.schedule(task_id, coroutine) - def cancel_task(self, task_id: t.Hashable, ignore_missing: bool = False) -> None: + def schedule_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: """ - Unschedule the task identified by `task_id`. + Schedule `coroutine` to be executed after the given `delay` number of seconds. - If `ignore_missing` is True, a warning will not be sent if a task isn't found. + If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. """ - log.trace(f"{self.cog_name}: cancelling task #{task_id}...") - task = self._scheduled_tasks.get(task_id) + self.schedule(task_id, self._await_later(delay, task_id, coroutine)) - if not task: - if not ignore_missing: - log.warning(f"{self.cog_name}: failed to unschedule {task_id} (no task found).") - return + def cancel(self, task_id: t.Hashable) -> None: + """Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist.""" + self._log.trace(f"Cancelling task #{task_id}...") - del self._scheduled_tasks[task_id] - task.cancel() + try: + task = self._scheduled_tasks.pop(task_id) + except KeyError: + self._log.warning(f"Failed to unschedule {task_id} (no task found).") + else: + task.cancel() - log.debug(f"{self.cog_name}: unscheduled task #{task_id} {id(task)}.") + self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") def cancel_all(self) -> None: """Unschedule all known tasks.""" - log.debug(f"{self.cog_name}: unscheduling all tasks") + self._log.debug("Unscheduling all tasks") for task_id in self._scheduled_tasks.copy(): - self.cancel_task(task_id, ignore_missing=True) + self.cancel(task_id) + + async def _await_later(self, delay: t.Union[int, float], task_id: t.Hashable, coroutine: t.Coroutine) -> None: + """Await `coroutine` after the given `delay` number of seconds.""" + try: + self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") + await asyncio.sleep(delay) + + # Use asyncio.shield to prevent the coroutine from cancelling itself. + self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") + await asyncio.shield(coroutine) + finally: + # Close it to prevent unawaited coroutine warnings, + # which would happen if the task was cancelled during the sleep. + # Only close it if it's not been awaited yet. This check is important because the + # coroutine may cancel this task, which would also trigger the finally block. + state = inspect.getcoroutinestate(coroutine) + if state == "CORO_CREATED": + self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") + coroutine.close() + else: + self._log.debug(f"Finally block reached for #{task_id}; {state=}") def _task_done_callback(self, task_id: t.Hashable, done_task: asyncio.Task) -> None: """ @@ -84,24 +129,24 @@ class Scheduler(metaclass=CogABCMeta): If `done_task` and the task associated with `task_id` are different, then the latter will not be deleted. In this case, a new task was likely rescheduled with the same ID. """ - log.trace(f"{self.cog_name}: performing done callback for task #{task_id} {id(done_task)}.") + self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") scheduled_task = self._scheduled_tasks.get(task_id) if scheduled_task and done_task is scheduled_task: - # A task for the ID exists and its the same as the done task. + # A task for the ID exists and is the same as the done task. # Since this is the done callback, the task is already done so no need to cancel it. - log.trace(f"{self.cog_name}: deleting task #{task_id} {id(done_task)}.") + self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") del self._scheduled_tasks[task_id] elif scheduled_task: # A new task was likely rescheduled with the same ID. - log.debug( - f"{self.cog_name}: the scheduled task #{task_id} {id(scheduled_task)} " + self._log.debug( + f"The scheduled task #{task_id} {id(scheduled_task)} " f"and the done task {id(done_task)} differ." ) elif not done_task.cancelled(): - log.warning( - f"{self.cog_name}: task #{task_id} not found while handling task {id(done_task)}! " + self._log.warning( + f"Task #{task_id} not found while handling task {id(done_task)}! " f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." ) @@ -109,7 +154,4 @@ class Scheduler(metaclass=CogABCMeta): exception = done_task.exception() # Log the exception if one exists. if exception: - log.error( - f"{self.cog_name}: error in task #{task_id} {id(done_task)}!", - exc_info=exception - ) + self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) diff --git a/config-default.yml b/config-default.yml index c09902a5d..d12b9be27 100644 --- a/config-default.yml +++ b/config-default.yml @@ -268,7 +268,7 @@ filter: notify_user_domains: false # Filter configuration - ping_everyone: true # Ping @everyone when we send a mod-alert? + ping_everyone: true offensive_msg_delete_days: 7 # How many days before deleting an offensive message? guild_invite_whitelist: |