diff options
Diffstat (limited to 'bot/decorators.py')
-rw-r--r-- | bot/decorators.py | 97 |
1 files changed, 94 insertions, 3 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index 58f67a15..874c811b 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,7 +1,10 @@ +import asyncio +import functools import logging import random import typing from asyncio import Lock +from datetime import datetime from functools import wraps from weakref import WeakValueDictionary @@ -9,7 +12,9 @@ from discord import Colour, Embed from discord.ext import commands from discord.ext.commands import CheckFailure, Context -from bot.constants import ERROR_REPLIES +from bot.constants import ERROR_REPLIES, Month + +ONE_DAY = 24 * 60 * 60 log = logging.getLogger(__name__) @@ -20,7 +25,93 @@ class InChannelCheckFailure(CheckFailure): pass -def with_role(*role_ids: int) -> bool: +class InMonthCheckFailure(CheckFailure): + """Check failure for when a command is invoked outside of its allowed month.""" + + pass + + +def seasonal_task(*allowed_months: Month, sleep_time: float = ONE_DAY) -> typing.Callable: + """ + Perform the decorated method periodically in `allowed_months`. + + This provides a convenience wrapper to avoid code repetition where some task shall + perform an operation repeatedly in a constant interval, but only in specific months. + + The decorated function will be called once every `sleep_time` seconds while + the current UTC month is in `allowed_months`. Sleep time defaults to 24 hours. + """ + def decorator(task_body: typing.Callable) -> typing.Callable: + @functools.wraps(task_body) + async def decorated_task(self: commands.Cog, *args, **kwargs) -> None: + """ + Call `task_body` once every `sleep_time` seconds in `allowed_months`. + + We assume `self` to be a Cog subclass instance carrying a `bot` attr. + As some tasks may rely on the client's cache to be ready, we delegate + to the bot to wait until it's ready. + """ + await self.bot.wait_until_ready() + log.info(f"Starting seasonal task {task_body.__qualname__} ({allowed_months})") + + while True: + current_month = Month(datetime.utcnow().month) + + if current_month in allowed_months: + await task_body(self, *args, **kwargs) + else: + log.debug(f"Seasonal task {task_body.__qualname__} sleeps in {current_month.name}") + + await asyncio.sleep(sleep_time) + return decorated_task + return decorator + + +def in_month_listener(*allowed_months: Month) -> typing.Callable: + """ + Shield a listener from being invoked outside of `allowed_months`. + + The check is performed against current UTC month. + """ + def decorator(listener: typing.Callable) -> typing.Callable: + @functools.wraps(listener) + async def guarded_listener(*args, **kwargs) -> None: + """Wrapped listener will abort if not in allowed month.""" + current_month = Month(datetime.utcnow().month) + + if current_month in allowed_months: + # Propagate return value although it should always be None + return await listener(*args, **kwargs) + else: + log.debug(f"Guarded {listener.__qualname__} from invoking in {current_month.name}") + return guarded_listener + return decorator + + +def in_month(*allowed_months: Month) -> typing.Callable: + """ + Check whether the command was invoked in one of `enabled_months`. + + Uses the current UTC month at the time of running the predicate. + """ + async def predicate(ctx: Context) -> bool: + current_month = datetime.utcnow().month + can_run = current_month in allowed_months + + human_months = ", ".join(m.name for m in allowed_months) + log.debug( + f"Command '{ctx.command}' is locked to months {human_months}. " + f"Invoking it in month {current_month} is {'allowed' if can_run else 'disallowed'}." + ) + if can_run: + return True + else: + raise InMonthCheckFailure(f"Command can only be used in {human_months}") + + return commands.check(predicate) + + +def with_role(*role_ids: int) -> typing.Callable: """Check to see whether the invoking user has any of the roles specified in role_ids.""" async def predicate(ctx: Context) -> bool: if not ctx.guild: # Return False in a DM @@ -43,7 +134,7 @@ def with_role(*role_ids: int) -> bool: return commands.check(predicate) -def without_role(*role_ids: int) -> bool: +def without_role(*role_ids: int) -> typing.Callable: """Check whether the invoking user does not have all of the roles specified in role_ids.""" async def predicate(ctx: Context) -> bool: if not ctx.guild: # Return False in a DM |