aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--bot/utils/__init__.py15
-rw-r--r--bot/utils/decorators.py22
2 files changed, 19 insertions, 18 deletions
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index cb5c9fbe..69bff4de 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -2,14 +2,29 @@ import asyncio
import contextlib
import re
import string
+from datetime import datetime
from typing import List
import discord
from discord.ext.commands import BadArgument, Context
+from bot.constants import Client, Month
from bot.utils.pagination import LinePaginator
+def resolve_current_month() -> Month:
+ """
+ Determine current month w.r.t. `Client.month_override` env var.
+
+ If the env variable was set, current month always resolves to the configured value.
+ Otherwise, the current UTC month is given.
+ """
+ if Client.month_override is not None:
+ return Month(Client.month_override)
+ else:
+ return Month(datetime.utcnow().month)
+
+
async def disambiguate(
ctx: Context, entries: List[str], *, timeout: float = 30,
entries_per_page: int = 20, empty: bool = False, embed: discord.Embed = None
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index a37601be..1e0a1715 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -4,7 +4,6 @@ import logging
import random
import typing as t
from asyncio import Lock
-from datetime import datetime
from functools import wraps
from weakref import WeakValueDictionary
@@ -13,6 +12,7 @@ from discord.ext import commands
from discord.ext.commands import CheckFailure, Command, Context
from bot.constants import Client, ERROR_REPLIES, Month
+from bot.utils import resolve_current_month
ONE_DAY = 24 * 60 * 60
@@ -31,20 +31,6 @@ class InMonthCheckFailure(CheckFailure):
pass
-def _resolve_current_month() -> Month:
- """
- Helper for local decorators to determine the correct Month value.
-
- This interfaces with the `MONTH_OVERRIDE` env var. If tha variable was set,
- current month always resolves to this value. Otherwise, the current utc month
- is given.
- """
- if Client.month_override is not None:
- return Month(Client.month_override)
- else:
- return Month(datetime.utcnow().month)
-
-
def seasonal_task(*allowed_months: Month, sleep_time: t.Union[float, int] = ONE_DAY) -> t.Callable:
"""
Perform the decorated method periodically in `allowed_months`.
@@ -64,7 +50,7 @@ def seasonal_task(*allowed_months: Month, sleep_time: t.Union[float, int] = ONE_
log.info(f"Starting seasonal task {task_body.__qualname__} ({allowed_months})")
while True:
- current_month = _resolve_current_month()
+ current_month = resolve_current_month()
if current_month in allowed_months:
await task_body(*args, **kwargs)
@@ -86,7 +72,7 @@ def in_month_listener(*allowed_months: Month) -> t.Callable:
@functools.wraps(listener)
async def guarded_listener(*args, **kwargs) -> None:
"""Wrapped listener will abort if not in allowed month."""
- current_month = _resolve_current_month()
+ current_month = resolve_current_month()
if current_month in allowed_months:
# Propagate return value although it should always be None
@@ -104,7 +90,7 @@ def in_month_command(*allowed_months: Month) -> t.Callable:
Uses the current UTC month at the time of running the predicate.
"""
async def predicate(ctx: Context) -> bool:
- current_month = _resolve_current_month()
+ current_month = resolve_current_month()
can_run = current_month in allowed_months
human_months = ", ".join(m.name for m in allowed_months)