aboutsummaryrefslogtreecommitdiffstats
path: root/bot/decorators.py
diff options
context:
space:
mode:
authorGravatar kwzrd <[email protected]>2020-03-28 14:51:14 +0100
committerGravatar kwzrd <[email protected]>2020-03-28 14:51:14 +0100
commit99b23ff1500aa15b076ee2e1c5e1b5560ba2c366 (patch)
treed710d125d00d3e8fb6ad325813f6ef44497c8bc5 /bot/decorators.py
parentDeseasonify: move pagination module under utils (diff)
Deseasonify: move decorators module under utils
Diffstat (limited to 'bot/decorators.py')
-rw-r--r--bot/decorators.py321
1 files changed, 0 insertions, 321 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
deleted file mode 100644
index f85996b5..00000000
--- a/bot/decorators.py
+++ /dev/null
@@ -1,321 +0,0 @@
-import asyncio
-import functools
-import logging
-import random
-import typing as t
-from asyncio import Lock
-from datetime import datetime
-from functools import wraps
-from weakref import WeakValueDictionary
-
-from discord import Colour, Embed
-from discord.ext import commands
-from discord.ext.commands import CheckFailure, Command, Context
-
-from bot.constants import Client, ERROR_REPLIES, Month
-
-ONE_DAY = 24 * 60 * 60
-
-log = logging.getLogger(__name__)
-
-
-class InChannelCheckFailure(CheckFailure):
- """Check failure when the user runs a command in a non-whitelisted channel."""
-
- pass
-
-
-class InMonthCheckFailure(CheckFailure):
- """Check failure for when a command is invoked outside of its allowed month."""
-
- pass
-
-
-def seasonal_task(*allowed_months: Month, sleep_time: t.Union[float, int] = ONE_DAY) -> t.Callable:
- """
- Perform the decorated method periodically in `allowed_months`.
-
- This provides a convenience wrapper to avoid code repetition where some task shall
- perform an operation repeatedly in a constant interval, but only in specific months.
-
- The decorated function will be called once every `sleep_time` seconds while
- the current UTC month is in `allowed_months`. Sleep time defaults to 24 hours.
-
- The wrapped task is responsible for waiting for the bot to be ready, if necessary.
- """
- def decorator(task_body: t.Callable) -> t.Callable:
- @functools.wraps(task_body)
- async def decorated_task(*args, **kwargs) -> None:
- """Call `task_body` once every `sleep_time` seconds in `allowed_months`."""
- log.info(f"Starting seasonal task {task_body.__qualname__} ({allowed_months})")
-
- while True:
- current_month = Month(datetime.utcnow().month)
-
- if current_month in allowed_months:
- await task_body(*args, **kwargs)
- else:
- log.debug(f"Seasonal task {task_body.__qualname__} sleeps in {current_month.name}")
-
- await asyncio.sleep(sleep_time)
- return decorated_task
- return decorator
-
-
-def in_month_listener(*allowed_months: Month) -> t.Callable:
- """
- Shield a listener from being invoked outside of `allowed_months`.
-
- The check is performed against current UTC month.
- """
- def decorator(listener: t.Callable) -> t.Callable:
- @functools.wraps(listener)
- async def guarded_listener(*args, **kwargs) -> None:
- """Wrapped listener will abort if not in allowed month."""
- current_month = Month(datetime.utcnow().month)
-
- if current_month in allowed_months:
- # Propagate return value although it should always be None
- return await listener(*args, **kwargs)
- else:
- log.debug(f"Guarded {listener.__qualname__} from invoking in {current_month.name}")
- return guarded_listener
- return decorator
-
-
-def in_month_command(*allowed_months: Month) -> t.Callable:
- """
- Check whether the command was invoked in one of `enabled_months`.
-
- Uses the current UTC month at the time of running the predicate.
- """
- async def predicate(ctx: Context) -> bool:
- current_month = datetime.utcnow().month
- can_run = current_month in allowed_months
-
- human_months = ", ".join(m.name for m in allowed_months)
- log.debug(
- f"Command '{ctx.command}' is locked to months {human_months}. "
- f"Invoking it in month {current_month} is {'allowed' if can_run else 'disallowed'}."
- )
- if can_run:
- return True
- else:
- raise InMonthCheckFailure(f"Command can only be used in {human_months}")
-
- return commands.check(predicate)
-
-
-def in_month(*allowed_months: Month) -> t.Callable:
- """
- Universal decorator for season-locking commands and listeners alike.
-
- This only serves to determine whether the decorated callable is a command,
- a listener, or neither. It then delegates to either `in_month_command`,
- or `in_month_listener`, or raises TypeError, respectively.
-
- Please note that in order for this decorator to correctly determine whether
- the decorated callable is a cmd or listener, it **has** to first be turned
- into one. This means that this decorator should always be placed **above**
- the d.py one that registers it as either.
-
- This will decorate groups as well, as those subclass Command. In order to lock
- all subcommands of a group, its `invoke_without_command` param must **not** be
- manually set to True - this causes a circumvention of the group's callback
- and the seasonal check applied to it.
- """
- def decorator(callable_: t.Callable) -> t.Callable:
- # Functions decorated as commands are turned into instances of `Command`
- if isinstance(callable_, Command):
- logging.debug(f"Command {callable_.qualified_name} will be locked to {allowed_months}")
- actual_deco = in_month_command(*allowed_months)
-
- # D.py will assign this attribute when `callable_` is registered as a listener
- elif hasattr(callable_, "__cog_listener__"):
- logging.debug(f"Listener {callable_.__qualname__} will be locked to {allowed_months}")
- actual_deco = in_month_listener(*allowed_months)
-
- # Otherwise we're unsure exactly what has been decorated
- # This happens before the bot starts, so let's just raise
- else:
- raise TypeError(f"Decorated object {callable_} is neither a command nor a listener")
-
- return actual_deco(callable_)
- return decorator
-
-
-def with_role(*role_ids: int) -> t.Callable:
- """Check to see whether the invoking user has any of the roles specified in role_ids."""
- async def predicate(ctx: Context) -> bool:
- if not ctx.guild: # Return False in a DM
- log.debug(
- f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
- "This command is restricted by the with_role decorator. Rejecting request."
- )
- return False
-
- for role in ctx.author.roles:
- if role.id in role_ids:
- log.debug(f"{ctx.author} has the '{role.name}' role, and passes the check.")
- return True
-
- log.debug(
- f"{ctx.author} does not have the required role to use "
- f"the '{ctx.command.name}' command, so the request is rejected."
- )
- return False
- return commands.check(predicate)
-
-
-def without_role(*role_ids: int) -> t.Callable:
- """Check whether the invoking user does not have all of the roles specified in role_ids."""
- async def predicate(ctx: Context) -> bool:
- if not ctx.guild: # Return False in a DM
- log.debug(
- f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
- "This command is restricted by the without_role decorator. Rejecting request."
- )
- return False
-
- author_roles = [role.id for role in ctx.author.roles]
- check = all(role not in author_roles for role in role_ids)
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The result of the without_role check was {check}."
- )
- return check
- return commands.check(predicate)
-
-
-def in_channel_check(*channels: int, bypass_roles: t.Container[int] = None) -> t.Callable[[Context], bool]:
- """
- Checks that the message is in a whitelisted channel or optionally has a bypass role.
-
- If `in_channel_override` is present, check if it contains channels
- and use them in place of the global whitelist.
- """
- def predicate(ctx: Context) -> bool:
- if not ctx.guild:
- log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
- return True
- if ctx.channel.id in channels:
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command "
- f"and the command was used in a whitelisted channel."
- )
- return True
-
- if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles):
- log.debug(
- f"{ctx.author} called the '{ctx.command.name}' command and "
- f"had a role to bypass the in_channel check."
- )
- return True
-
- if hasattr(ctx.command.callback, "in_channel_override"):
- override = ctx.command.callback.in_channel_override
- if override is None:
- log.debug(
- f"{ctx.author} called the '{ctx.command.name}' command "
- f"and the command was whitelisted to bypass the in_channel check."
- )
- return True
- else:
- if ctx.channel.id in override:
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command "
- f"and the command was used in an overridden whitelisted channel."
- )
- return True
-
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The overridden in_channel check failed."
- )
- channels_str = ', '.join(f"<#{c_id}>" for c_id in override)
- raise InChannelCheckFailure(
- f"Sorry, but you may only use this command within {channels_str}."
- )
-
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The in_channel check failed."
- )
-
- channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
- raise InChannelCheckFailure(
- f"Sorry, but you may only use this command within {channels_str}."
- )
-
- return predicate
-
-
-in_channel = commands.check(in_channel_check)
-
-
-def override_in_channel(channels: t.Tuple[int] = None) -> t.Callable:
- """
- Set command callback attribute for detection in `in_channel_check`.
-
- Override global whitelist if channels are specified.
-
- This decorator has to go before (below) below the `command` decorator.
- """
- def inner(func: t.Callable) -> t.Callable:
- func.in_channel_override = channels
- return func
-
- return inner
-
-
-def locked() -> t.Union[t.Callable, None]:
- """
- Allows the user to only run one instance of the decorated command at a time.
-
- Subsequent calls to the command from the same author are ignored until the command has completed invocation.
-
- This decorator has to go before (below) the `command` decorator.
- """
- def wrap(func: t.Callable) -> t.Union[t.Callable, None]:
- func.__locks = WeakValueDictionary()
-
- @wraps(func)
- async def inner(self: t.Callable, ctx: Context, *args, **kwargs) -> t.Union[t.Callable, None]:
- lock = func.__locks.setdefault(ctx.author.id, Lock())
- if lock.locked():
- embed = Embed()
- embed.colour = Colour.red()
-
- log.debug(f"User tried to invoke a locked command.")
- embed.description = (
- "You're already using this command. Please wait until "
- "it is done before you use it again."
- )
- embed.title = random.choice(ERROR_REPLIES)
- await ctx.send(embed=embed)
- return
-
- async with func.__locks.setdefault(ctx.author.id, Lock()):
- return await func(self, ctx, *args, **kwargs)
- return inner
- return wrap
-
-
-def mock_in_debug(return_value: t.Any) -> t.Callable:
- """
- Short-circuit function execution if in debug mode and return `return_value`.
-
- The original function name, and the incoming args and kwargs are DEBUG level logged
- upon each call. This is useful for expensive operations, i.e. media asset uploads
- that are prone to rate-limits but need to be tested extensively.
- """
- def decorator(func: t.Callable) -> t.Callable:
- @functools.wraps(func)
- async def wrapped(*args, **kwargs) -> t.Any:
- """Short-circuit and log if in debug mode."""
- if Client.debug:
- log.debug(f"Function {func.__name__} called with args: {args}, kwargs: {kwargs}")
- return return_value
- return await func(*args, **kwargs)
- return wrapped
- return decorator