diff options
| author | 2018-11-13 19:04:05 +0000 | |
|---|---|---|
| committer | 2018-11-13 19:04:05 +0000 | |
| commit | a3a2b555cacc6b10938a5469bb9e56be55c63a23 (patch) | |
| tree | 7ad7d0e2b94af82e7750f7e140f62bc6b7a8427f | |
| parent | Merge branch 'superstarify' into 'master' (diff) | |
| parent | common scheduling methods have been moved to a separate abstract class. (diff) | |
Merge branch 'scheduling-cleanup' into 'master'
Move common scheduling methods to an abstract class.
See merge request python-discord/projects/bot!73
| -rw-r--r-- | bot/cogs/moderation.py | 50 | ||||
| -rw-r--r-- | bot/cogs/reminders.py | 62 | ||||
| -rw-r--r-- | bot/utils/scheduling.py | 55 |
3 files changed, 78 insertions, 89 deletions
diff --git a/bot/cogs/moderation.py b/bot/cogs/moderation.py index 588962e29..9165fe654 100644 --- a/bot/cogs/moderation.py +++ b/bot/cogs/moderation.py @@ -1,7 +1,6 @@ import asyncio import logging import textwrap -from typing import Dict from aiohttp import ClientError from discord import Colour, Embed, Guild, Member, Object, User @@ -13,7 +12,7 @@ from bot.constants import Colours, Event, Icons, Keys, Roles, URLs from bot.converters import InfractionSearchQuery from bot.decorators import with_role from bot.pagination import LinePaginator -from bot.utils.scheduling import create_task +from bot.utils.scheduling import Scheduler from bot.utils.time import parse_rfc1123, wait_until log = logging.getLogger(__name__) @@ -21,7 +20,7 @@ log = logging.getLogger(__name__) MODERATION_ROLES = Roles.owner, Roles.admin, Roles.moderator -class Moderation: +class Moderation(Scheduler): """ Rowboat replacement moderation tools. """ @@ -29,8 +28,8 @@ class Moderation: def __init__(self, bot: Bot): self.bot = bot self.headers = {"X-API-KEY": Keys.site_api} - self.expiration_tasks: Dict[str, asyncio.Task] = {} self._muted_role = Object(constants.Roles.muted) + super().__init__() @property def mod_log(self) -> ModLog: @@ -47,7 +46,7 @@ class Moderation: loop = asyncio.get_event_loop() for infraction_object in infraction_list: if infraction_object["expires_at"] is not None: - self.schedule_expiration(loop, infraction_object) + self.schedule_task(loop, infraction_object["id"], infraction_object) # region: Permanent infractions @@ -291,7 +290,7 @@ class Moderation: infraction_expiration = infraction_object["expires_at"] loop = asyncio.get_event_loop() - self.schedule_expiration(loop, infraction_object) + self.schedule_task(loop, infraction_object["id"], infraction_object) if reason is None: result_message = f":ok_hand: muted {user.mention} until {infraction_expiration}." @@ -356,7 +355,7 @@ class Moderation: infraction_expiration = infraction_object["expires_at"] loop = asyncio.get_event_loop() - self.schedule_expiration(loop, infraction_object) + self.schedule_task(loop, infraction_object["id"], infraction_object) if reason is None: result_message = f":ok_hand: banned {user.mention} until {infraction_expiration}." @@ -540,9 +539,9 @@ class Moderation: infraction_object = response_object["infraction"] # Re-schedule - self.cancel_expiration(infraction_id) + self.cancel_task(infraction_id) loop = asyncio.get_event_loop() - self.schedule_expiration(loop, infraction_object) + self.schedule_task(loop, infraction_object["id"], infraction_object) if duration is None: await ctx.send(f":ok_hand: Updated infraction: marked as permanent.") @@ -748,36 +747,7 @@ class Moderation: max_size=1000 ) - def schedule_expiration(self, loop: asyncio.AbstractEventLoop, infraction_object: dict): - """ - Schedules a task to expire a temporary infraction. - :param loop: the asyncio event loop - :param infraction_object: the infraction object to expire at the end of the task - """ - - infraction_id = infraction_object["id"] - if infraction_id in self.expiration_tasks: - return - - task: asyncio.Task = create_task(loop, self._scheduled_expiration(infraction_object)) - - self.expiration_tasks[infraction_id] = task - - def cancel_expiration(self, infraction_id: str): - """ - Un-schedules a task set to expire a temporary infraction. - :param infraction_id: the ID of the infraction in question - """ - - task = self.expiration_tasks.get(infraction_id) - if task is None: - log.warning(f"Failed to unschedule {infraction_id}: no task found.") - return - task.cancel() - log.debug(f"Unscheduled {infraction_id}.") - del self.expiration_tasks[infraction_id] - - async def _scheduled_expiration(self, infraction_object): + async def _scheduled_task(self, infraction_object: dict): """ A co-routine which marks an infraction as expired after the delay from the time of scheduling to the time of expiration. At the time of expiration, the infraction is marked as inactive on the website, @@ -794,7 +764,7 @@ class Moderation: log.debug(f"Marking infraction {infraction_id} as inactive (expired).") await self._deactivate_infraction(infraction_object) - self.cancel_expiration(infraction_object["id"]) + self.cancel_task(infraction_object["id"]) async def _deactivate_infraction(self, infraction_object): """ diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 98d7942b3..f6ed111dc 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -14,7 +14,7 @@ from bot.constants import ( POSITIVE_REPLIES, Roles, URLs ) from bot.pagination import LinePaginator -from bot.utils.scheduling import create_task +from bot.utils.scheduling import Scheduler from bot.utils.time import humanize_delta, parse_rfc1123, wait_until log = logging.getLogger(__name__) @@ -24,16 +24,12 @@ WHITELISTED_CHANNELS = (Channels.bot,) MAXIMUM_REMINDERS = 5 -# The scheduling parts of this cog are pretty much directly copied -# from the moderation cog. I'll be working on making it more -# webscale:tm: as soon as possible, because this is a mess :D -class Reminders: +class Reminders(Scheduler): def __init__(self, bot: Bot): self.bot = bot - self.headers = {"X-API-Key": Keys.site_api} - self.reminder_tasks = {} + super().__init__() async def on_ready(self): # Get all the current reminders for re-scheduling @@ -57,7 +53,7 @@ class Reminders: await self.send_reminder(reminder, late) else: - self.schedule_reminder(loop, reminder) + self.schedule_task(loop, reminder["id"], reminder) @staticmethod async def _send_confirmation(ctx: Context, response: dict, on_success: str): @@ -87,24 +83,7 @@ class Reminders: await ctx.send(embed=embed) return failed - def schedule_reminder(self, loop: asyncio.AbstractEventLoop, reminder): - """ - Schedule a reminder from the bot at the requested time. - - :param loop: the asyncio event loop - :param reminder: the data of the reminder. - """ - - # Avoid duplicate schedules, just in case. - reminder_id = reminder["id"] - if reminder_id in self.reminder_tasks: - return - - # Make a scheduled task and add it to the list - task: asyncio.Task = create_task(loop, self._scheduled_reminder(reminder)) - self.reminder_tasks[reminder_id] = task - - async def _scheduled_reminder(self, reminder): + async def _scheduled_task(self, reminder: dict): """ A coroutine which sends the reminder once the time is reached. @@ -120,27 +99,10 @@ class Reminders: await self.send_reminder(reminder) log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") - await self._delete_reminder(reminder) + await self._delete_reminder(reminder_id) # Now we can begone with it from our schedule list. - self.cancel_reminder(reminder_id) - - def cancel_reminder(self, reminder_id: str): - """ - Un-schedules a task to send a reminder. - - :param reminder_id: the ID of the reminder in question - """ - - task = self.reminder_tasks.get(reminder_id) - - if task is None: - log.warning(f"Failed to unschedule {reminder_id}: no task found.") - return - - task.cancel() - log.debug(f"Unscheduled {reminder_id}.") - del self.reminder_tasks[reminder_id] + self.cancel_task(reminder_id) async def _delete_reminder(self, reminder_id: str): """ @@ -163,7 +125,7 @@ class Reminders: ) # Now we can remove it from the schedule list - self.cancel_reminder(reminder_id) + self.cancel_task(reminder_id) async def _reschedule_reminder(self, reminder): """ @@ -174,8 +136,8 @@ class Reminders: loop = asyncio.get_event_loop() - self.cancel_reminder(reminder["id"]) - self.schedule_reminder(loop, reminder) + self.cancel_task(reminder["id"]) + self.schedule_task(loop, reminder["id"], reminder) async def send_reminder(self, reminder, late: relativedelta = None): """ @@ -291,7 +253,9 @@ class Reminders: # If it worked, schedule the reminder. if not failed: loop = asyncio.get_event_loop() - self.schedule_reminder(loop=loop, reminder=response_data["reminder"]) + reminder = response_data["reminder"] + + self.schedule_task(loop, reminder["id"], reminder) @remind_group.command(name="list") async def list_reminders(self, ctx: Context): diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index f9b844046..ded6401b0 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -1,5 +1,60 @@ import asyncio import contextlib +import logging +from abc import ABC, abstractmethod +from typing import Dict + +log = logging.getLogger(__name__) + + +class Scheduler(ABC): + + def __init__(self): + + self.cog_name = self.__class__.__name__ # keep track of the child cog's name so the logs are clear. + self.scheduled_tasks: Dict[str, asyncio.Task] = {} + + @abstractmethod + async def _scheduled_task(self, task_object: dict): + """ + A coroutine which handles the scheduling. This is added to the scheduled tasks, + and should wait the task duration, execute the desired code, and clean up the task. + For example, in Reminders this will wait for the reminder duration, send the reminder, + then make a site API request to delete the reminder from the database. + + :param task_object: + """ + + def schedule_task(self, loop: asyncio.AbstractEventLoop, task_id: str, task_data: dict): + """ + Schedules a task. + :param loop: the asyncio event loop + :param task_id: the ID of the task. + :param task_data: the data of the task, passed to `Scheduler._scheduled_expiration`. + """ + + if task_id in self.scheduled_tasks: + return + + task: asyncio.Task = create_task(loop, self._scheduled_task(task_data)) + + self.scheduled_tasks[task_id] = task + + def cancel_task(self, task_id: str): + """ + Un-schedules a task. + :param task_id: the ID of the infraction in question + """ + + task = self.scheduled_tasks.get(task_id) + + if task is None: + log.warning(f"{self.cog_name}: Failed to unschedule {task_id} (no task found).") + return + + task.cancel() + log.debug(f"{self.cog_name}: Unscheduled {task_id}.") + del self.scheduled_tasks[task_id] def create_task(loop: asyncio.AbstractEventLoop, coro_or_future): |