aboutsummaryrefslogtreecommitdiffstats
path: root/bot/decorators.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/decorators.py')
-rw-r--r--bot/decorators.py96
1 files changed, 93 insertions, 3 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index 58f67a15..74976cd6 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,7 +1,10 @@
+import asyncio
+import functools
import logging
import random
import typing
from asyncio import Lock
+from datetime import datetime
from functools import wraps
from weakref import WeakValueDictionary
@@ -9,7 +12,10 @@ from discord import Colour, Embed
from discord.ext import commands
from discord.ext.commands import CheckFailure, Context
-from bot.constants import ERROR_REPLIES
+from bot.bot import bot
+from bot.constants import ERROR_REPLIES, Month
+
+ONE_DAY = 24 * 60 * 60
log = logging.getLogger(__name__)
@@ -20,7 +26,91 @@ class InChannelCheckFailure(CheckFailure):
pass
-def with_role(*role_ids: int) -> bool:
+class InMonthCheckFailure(CheckFailure):
+ """Check failure for when a command is invoked outside of its allowed month."""
+
+ pass
+
+
+def seasonal_task(*allowed_months: Month, sleep_time: typing.Union[float, int] = ONE_DAY) -> typing.Callable:
+ """
+ Perform the decorated method periodically in `allowed_months`.
+
+ This provides a convenience wrapper to avoid code repetition where some task shall
+ perform an operation repeatedly in a constant interval, but only in specific months.
+
+ The decorated function will be called once every `sleep_time` seconds while
+ the current UTC month is in `allowed_months`. Sleep time defaults to 24 hours.
+ """
+ def decorator(task_body: typing.Callable) -> typing.Callable:
+ @functools.wraps(task_body)
+ async def decorated_task(*args, **kwargs) -> None:
+ """
+ Call `task_body` once every `sleep_time` seconds in `allowed_months`.
+
+ Wait for bot to be ready before calling `task_body` for the first time.
+ """
+ await bot.wait_until_ready()
+ log.info(f"Starting seasonal task {task_body.__qualname__} ({allowed_months})")
+
+ while True:
+ current_month = Month(datetime.utcnow().month)
+
+ if current_month in allowed_months:
+ await task_body(*args, **kwargs)
+ else:
+ log.debug(f"Seasonal task {task_body.__qualname__} sleeps in {current_month.name}")
+
+ await asyncio.sleep(sleep_time)
+ return decorated_task
+ return decorator
+
+
+def in_month_listener(*allowed_months: Month) -> typing.Callable:
+ """
+ Shield a listener from being invoked outside of `allowed_months`.
+
+ The check is performed against current UTC month.
+ """
+ def decorator(listener: typing.Callable) -> typing.Callable:
+ @functools.wraps(listener)
+ async def guarded_listener(*args, **kwargs) -> None:
+ """Wrapped listener will abort if not in allowed month."""
+ current_month = Month(datetime.utcnow().month)
+
+ if current_month in allowed_months:
+ # Propagate return value although it should always be None
+ return await listener(*args, **kwargs)
+ else:
+ log.debug(f"Guarded {listener.__qualname__} from invoking in {current_month.name}")
+ return guarded_listener
+ return decorator
+
+
+def in_month(*allowed_months: Month) -> typing.Callable:
+ """
+ Check whether the command was invoked in one of `enabled_months`.
+
+ Uses the current UTC month at the time of running the predicate.
+ """
+ async def predicate(ctx: Context) -> bool:
+ current_month = datetime.utcnow().month
+ can_run = current_month in allowed_months
+
+ human_months = ", ".join(m.name for m in allowed_months)
+ log.debug(
+ f"Command '{ctx.command}' is locked to months {human_months}. "
+ f"Invoking it in month {current_month} is {'allowed' if can_run else 'disallowed'}."
+ )
+ if can_run:
+ return True
+ else:
+ raise InMonthCheckFailure(f"Command can only be used in {human_months}")
+
+ return commands.check(predicate)
+
+
+def with_role(*role_ids: int) -> typing.Callable:
"""Check to see whether the invoking user has any of the roles specified in role_ids."""
async def predicate(ctx: Context) -> bool:
if not ctx.guild: # Return False in a DM
@@ -43,7 +133,7 @@ def with_role(*role_ids: int) -> bool:
return commands.check(predicate)
-def without_role(*role_ids: int) -> bool:
+def without_role(*role_ids: int) -> typing.Callable:
"""Check whether the invoking user does not have all of the roles specified in role_ids."""
async def predicate(ctx: Context) -> bool:
if not ctx.guild: # Return False in a DM