aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Kingsley McDonald <[email protected]>2018-10-07 13:33:42 +0100
committerGravatar Kingsley McDonald <[email protected]>2018-10-07 13:33:42 +0100
commitd079b3d34ba1fb045e63d332b70bc91940246492 (patch)
tree8e91735ad9b4de23bda71fa7cb10771cc619282b
parentMerge branch 'remind-command' into 'master' (diff)
common scheduling methods have been moved to a separate abstract class.
-rw-r--r--bot/cogs/moderation.py50
-rw-r--r--bot/cogs/reminders.py62
-rw-r--r--bot/utils/scheduling.py55
3 files changed, 78 insertions, 89 deletions
diff --git a/bot/cogs/moderation.py b/bot/cogs/moderation.py
index 4a0e4c0f4..72efee9a5 100644
--- a/bot/cogs/moderation.py
+++ b/bot/cogs/moderation.py
@@ -1,7 +1,6 @@
import asyncio
import logging
import textwrap
-from typing import Dict
from aiohttp import ClientError
from discord import Colour, Embed, Guild, Member, Object, User
@@ -13,7 +12,7 @@ from bot.constants import Colours, Event, Icons, Keys, Roles, URLs
from bot.converters import InfractionSearchQuery
from bot.decorators import with_role
from bot.pagination import LinePaginator
-from bot.utils.scheduling import create_task
+from bot.utils.scheduling import Scheduler
from bot.utils.time import parse_rfc1123, wait_until
log = logging.getLogger(__name__)
@@ -21,7 +20,7 @@ log = logging.getLogger(__name__)
MODERATION_ROLES = Roles.owner, Roles.admin, Roles.moderator
-class Moderation:
+class Moderation(Scheduler):
"""
Rowboat replacement moderation tools.
"""
@@ -29,8 +28,8 @@ class Moderation:
def __init__(self, bot: Bot):
self.bot = bot
self.headers = {"X-API-KEY": Keys.site_api}
- self.expiration_tasks: Dict[str, asyncio.Task] = {}
self._muted_role = Object(constants.Roles.muted)
+ super().__init__()
@property
def mod_log(self) -> ModLog:
@@ -47,7 +46,7 @@ class Moderation:
loop = asyncio.get_event_loop()
for infraction_object in infraction_list:
if infraction_object["expires_at"] is not None:
- self.schedule_expiration(loop, infraction_object)
+ self.schedule_task(loop, infraction_object["id"], infraction_object)
# region: Permanent infractions
@@ -291,7 +290,7 @@ class Moderation:
infraction_expiration = infraction_object["expires_at"]
loop = asyncio.get_event_loop()
- self.schedule_expiration(loop, infraction_object)
+ self.schedule_task(loop, infraction_object["id"], infraction_object)
if reason is None:
result_message = f":ok_hand: muted {user.mention} until {infraction_expiration}."
@@ -356,7 +355,7 @@ class Moderation:
infraction_expiration = infraction_object["expires_at"]
loop = asyncio.get_event_loop()
- self.schedule_expiration(loop, infraction_object)
+ self.schedule_task(loop, infraction_object["id"], infraction_object)
if reason is None:
result_message = f":ok_hand: banned {user.mention} until {infraction_expiration}."
@@ -536,9 +535,9 @@ class Moderation:
infraction_object = response_object["infraction"]
# Re-schedule
- self.cancel_expiration(infraction_id)
+ self.cancel_task(infraction_id)
loop = asyncio.get_event_loop()
- self.schedule_expiration(loop, infraction_object)
+ self.schedule_task(loop, infraction_object["id"], infraction_object)
if duration is None:
await ctx.send(f":ok_hand: Updated infraction: marked as permanent.")
@@ -744,36 +743,7 @@ class Moderation:
max_size=1000
)
- def schedule_expiration(self, loop: asyncio.AbstractEventLoop, infraction_object: dict):
- """
- Schedules a task to expire a temporary infraction.
- :param loop: the asyncio event loop
- :param infraction_object: the infraction object to expire at the end of the task
- """
-
- infraction_id = infraction_object["id"]
- if infraction_id in self.expiration_tasks:
- return
-
- task: asyncio.Task = create_task(loop, self._scheduled_expiration(infraction_object))
-
- self.expiration_tasks[infraction_id] = task
-
- def cancel_expiration(self, infraction_id: str):
- """
- Un-schedules a task set to expire a temporary infraction.
- :param infraction_id: the ID of the infraction in question
- """
-
- task = self.expiration_tasks.get(infraction_id)
- if task is None:
- log.warning(f"Failed to unschedule {infraction_id}: no task found.")
- return
- task.cancel()
- log.debug(f"Unscheduled {infraction_id}.")
- del self.expiration_tasks[infraction_id]
-
- async def _scheduled_expiration(self, infraction_object):
+ async def _scheduled_task(self, infraction_object: dict):
"""
A co-routine which marks an infraction as expired after the delay from the time of scheduling
to the time of expiration. At the time of expiration, the infraction is marked as inactive on the website,
@@ -790,7 +760,7 @@ class Moderation:
log.debug(f"Marking infraction {infraction_id} as inactive (expired).")
await self._deactivate_infraction(infraction_object)
- self.cancel_expiration(infraction_object["id"])
+ self.cancel_task(infraction_object["id"])
async def _deactivate_infraction(self, infraction_object):
"""
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py
index 98d7942b3..f6ed111dc 100644
--- a/bot/cogs/reminders.py
+++ b/bot/cogs/reminders.py
@@ -14,7 +14,7 @@ from bot.constants import (
POSITIVE_REPLIES, Roles, URLs
)
from bot.pagination import LinePaginator
-from bot.utils.scheduling import create_task
+from bot.utils.scheduling import Scheduler
from bot.utils.time import humanize_delta, parse_rfc1123, wait_until
log = logging.getLogger(__name__)
@@ -24,16 +24,12 @@ WHITELISTED_CHANNELS = (Channels.bot,)
MAXIMUM_REMINDERS = 5
-# The scheduling parts of this cog are pretty much directly copied
-# from the moderation cog. I'll be working on making it more
-# webscale:tm: as soon as possible, because this is a mess :D
-class Reminders:
+class Reminders(Scheduler):
def __init__(self, bot: Bot):
self.bot = bot
-
self.headers = {"X-API-Key": Keys.site_api}
- self.reminder_tasks = {}
+ super().__init__()
async def on_ready(self):
# Get all the current reminders for re-scheduling
@@ -57,7 +53,7 @@ class Reminders:
await self.send_reminder(reminder, late)
else:
- self.schedule_reminder(loop, reminder)
+ self.schedule_task(loop, reminder["id"], reminder)
@staticmethod
async def _send_confirmation(ctx: Context, response: dict, on_success: str):
@@ -87,24 +83,7 @@ class Reminders:
await ctx.send(embed=embed)
return failed
- def schedule_reminder(self, loop: asyncio.AbstractEventLoop, reminder):
- """
- Schedule a reminder from the bot at the requested time.
-
- :param loop: the asyncio event loop
- :param reminder: the data of the reminder.
- """
-
- # Avoid duplicate schedules, just in case.
- reminder_id = reminder["id"]
- if reminder_id in self.reminder_tasks:
- return
-
- # Make a scheduled task and add it to the list
- task: asyncio.Task = create_task(loop, self._scheduled_reminder(reminder))
- self.reminder_tasks[reminder_id] = task
-
- async def _scheduled_reminder(self, reminder):
+ async def _scheduled_task(self, reminder: dict):
"""
A coroutine which sends the reminder once the time is reached.
@@ -120,27 +99,10 @@ class Reminders:
await self.send_reminder(reminder)
log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).")
- await self._delete_reminder(reminder)
+ await self._delete_reminder(reminder_id)
# Now we can begone with it from our schedule list.
- self.cancel_reminder(reminder_id)
-
- def cancel_reminder(self, reminder_id: str):
- """
- Un-schedules a task to send a reminder.
-
- :param reminder_id: the ID of the reminder in question
- """
-
- task = self.reminder_tasks.get(reminder_id)
-
- if task is None:
- log.warning(f"Failed to unschedule {reminder_id}: no task found.")
- return
-
- task.cancel()
- log.debug(f"Unscheduled {reminder_id}.")
- del self.reminder_tasks[reminder_id]
+ self.cancel_task(reminder_id)
async def _delete_reminder(self, reminder_id: str):
"""
@@ -163,7 +125,7 @@ class Reminders:
)
# Now we can remove it from the schedule list
- self.cancel_reminder(reminder_id)
+ self.cancel_task(reminder_id)
async def _reschedule_reminder(self, reminder):
"""
@@ -174,8 +136,8 @@ class Reminders:
loop = asyncio.get_event_loop()
- self.cancel_reminder(reminder["id"])
- self.schedule_reminder(loop, reminder)
+ self.cancel_task(reminder["id"])
+ self.schedule_task(loop, reminder["id"], reminder)
async def send_reminder(self, reminder, late: relativedelta = None):
"""
@@ -291,7 +253,9 @@ class Reminders:
# If it worked, schedule the reminder.
if not failed:
loop = asyncio.get_event_loop()
- self.schedule_reminder(loop=loop, reminder=response_data["reminder"])
+ reminder = response_data["reminder"]
+
+ self.schedule_task(loop, reminder["id"], reminder)
@remind_group.command(name="list")
async def list_reminders(self, ctx: Context):
diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py
index f9b844046..ded6401b0 100644
--- a/bot/utils/scheduling.py
+++ b/bot/utils/scheduling.py
@@ -1,5 +1,60 @@
import asyncio
import contextlib
+import logging
+from abc import ABC, abstractmethod
+from typing import Dict
+
+log = logging.getLogger(__name__)
+
+
+class Scheduler(ABC):
+
+ def __init__(self):
+
+ self.cog_name = self.__class__.__name__ # keep track of the child cog's name so the logs are clear.
+ self.scheduled_tasks: Dict[str, asyncio.Task] = {}
+
+ @abstractmethod
+ async def _scheduled_task(self, task_object: dict):
+ """
+ A coroutine which handles the scheduling. This is added to the scheduled tasks,
+ and should wait the task duration, execute the desired code, and clean up the task.
+ For example, in Reminders this will wait for the reminder duration, send the reminder,
+ then make a site API request to delete the reminder from the database.
+
+ :param task_object:
+ """
+
+ def schedule_task(self, loop: asyncio.AbstractEventLoop, task_id: str, task_data: dict):
+ """
+ Schedules a task.
+ :param loop: the asyncio event loop
+ :param task_id: the ID of the task.
+ :param task_data: the data of the task, passed to `Scheduler._scheduled_expiration`.
+ """
+
+ if task_id in self.scheduled_tasks:
+ return
+
+ task: asyncio.Task = create_task(loop, self._scheduled_task(task_data))
+
+ self.scheduled_tasks[task_id] = task
+
+ def cancel_task(self, task_id: str):
+ """
+ Un-schedules a task.
+ :param task_id: the ID of the infraction in question
+ """
+
+ task = self.scheduled_tasks.get(task_id)
+
+ if task is None:
+ log.warning(f"{self.cog_name}: Failed to unschedule {task_id} (no task found).")
+ return
+
+ task.cancel()
+ log.debug(f"{self.cog_name}: Unscheduled {task_id}.")
+ del self.scheduled_tasks[task_id]
def create_task(loop: asyncio.AbstractEventLoop, coro_or_future):