aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils
diff options
context:
space:
mode:
Diffstat (limited to 'bot/utils')
-rw-r--r--bot/utils/decorators.py127
-rw-r--r--bot/utils/extensions.py10
-rw-r--r--bot/utils/time.py84
3 files changed, 172 insertions, 49 deletions
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index 9cdaad3f..60066dc4 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -11,8 +11,9 @@ from discord import Colour, Embed
from discord.ext import commands
from discord.ext.commands import CheckFailure, Command, Context
-from bot.constants import ERROR_REPLIES, Month
+from bot.constants import Channels, ERROR_REPLIES, Month, WHITELISTED_CHANNELS
from bot.utils import human_months, resolve_current_month
+from bot.utils.checks import in_whitelist_check
ONE_DAY = 24 * 60 * 60
@@ -186,82 +187,110 @@ def without_role(*role_ids: int) -> t.Callable:
return commands.check(predicate)
-def in_channel_check(*channels: int, bypass_roles: t.Container[int] = None) -> t.Callable[[Context], bool]:
+def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context], bool]:
"""
- Checks that the message is in a whitelisted channel or optionally has a bypass role.
+ Checks if a message is sent in a whitelisted context.
- If `in_channel_override` is present, check if it contains channels
- and use them in place of the global whitelist.
+ All arguments from `in_whitelist_check` are supported, with the exception of "fail_silently".
+ If `whitelist_override` is present, it is added to the global whitelist.
"""
def predicate(ctx: Context) -> bool:
+ # Skip DM invocations
if not ctx.guild:
log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
return True
- if ctx.channel.id in channels:
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command "
- f"and the command was used in a whitelisted channel."
- )
- return True
- if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles):
- log.debug(
- f"{ctx.author} called the '{ctx.command.name}' command and "
- f"had a role to bypass the in_channel check."
- )
- return True
+ kwargs = default_kwargs.copy()
- if hasattr(ctx.command.callback, "in_channel_override"):
- override = ctx.command.callback.in_channel_override
- if override is None:
+ # Update kwargs based on override
+ if hasattr(ctx.command.callback, "override"):
+ # Remove default kwargs if reset is True
+ if ctx.command.callback.override_reset:
+ kwargs = {}
log.debug(
- f"{ctx.author} called the '{ctx.command.name}' command "
- f"and the command was whitelisted to bypass the in_channel check."
+ f"{ctx.author} called the '{ctx.command.name}' command and "
+ f"overrode default checks."
)
- return True
- else:
- if ctx.channel.id in override:
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command "
- f"and the command was used in an overridden whitelisted channel."
- )
- return True
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The overridden in_channel check failed."
- )
- channels_str = ', '.join(f"<#{c_id}>" for c_id in override)
- raise InChannelCheckFailure(
- f"Sorry, but you may only use this command within {channels_str}."
- )
+ # Merge overwrites and defaults
+ for arg in ctx.command.callback.override:
+ default_value = kwargs.get(arg)
+ new_value = ctx.command.callback.override[arg]
+
+ # Skip values that don't need merging, or can't be merged
+ if default_value is None or isinstance(arg, int):
+ kwargs[arg] = new_value
+
+ # Merge containers
+ elif isinstance(default_value, t.Container):
+ if isinstance(new_value, t.Container):
+ kwargs[arg] = (*default_value, *new_value)
+ else:
+ kwargs[arg] = new_value
+
+ log.debug(
+ f"Updated default check arguments for '{ctx.command.name}' "
+ f"invoked by {ctx.author}."
+ )
+
+ log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.")
+ result = in_whitelist_check(ctx, fail_silently=True, **kwargs)
+
+ # Return if check passed
+ if result:
+ log.debug(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command "
+ f"and the command was used in an overridden context."
+ )
+ return result
log.debug(
f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The in_channel check failed."
+ f"The whitelist check failed."
)
- channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
- raise InChannelCheckFailure(
- f"Sorry, but you may only use this command within {channels_str}."
- )
+ # Raise error if the check did not pass
+ channels = set(kwargs.get("channels") or {})
+ categories = kwargs.get("categories")
- return predicate
+ # Only output override channels + community_bot_commands
+ if channels:
+ default_whitelist_channels = set(WHITELISTED_CHANNELS)
+ default_whitelist_channels.discard(Channels.community_bot_commands)
+ channels.difference_update(default_whitelist_channels)
+ # Add all whitelisted category channels
+ if categories:
+ for category_id in categories:
+ category = ctx.guild.get_channel(category_id)
+ if category is None:
+ continue
-in_channel = commands.check(in_channel_check)
+ channels.update(channel.id for channel in category.text_channels)
+
+ if channels:
+ channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
+ message = f"Sorry, but you may only use this command within {channels_str}."
+ else:
+ message = "Sorry, but you may not use this command."
+
+ raise InChannelCheckFailure(message)
+
+ return predicate
-def override_in_channel(channels: t.Tuple[int] = None) -> t.Callable:
+def whitelist_override(bypass_defaults: bool = False, **kwargs: t.Container[int]) -> t.Callable:
"""
- Set command callback attribute for detection in `in_channel_check`.
+ Override global whitelist context, with the kwargs specified.
- Override global whitelist if channels are specified.
+ All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`.
+ Set `bypass_defaults` to True if you want to completely bypass global checks.
This decorator has to go before (below) below the `command` decorator.
"""
def inner(func: t.Callable) -> t.Callable:
- func.in_channel_override = channels
+ func.override = kwargs
+ func.override_reset = bypass_defaults
return func
return inner
diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py
index 50350ea8..459588a1 100644
--- a/bot/utils/extensions.py
+++ b/bot/utils/extensions.py
@@ -3,6 +3,8 @@ import inspect
import pkgutil
from typing import Iterator, NoReturn
+from discord.ext.commands import Context
+
from bot import exts
@@ -31,4 +33,12 @@ def walk_extensions() -> Iterator[str]:
yield module.name
+async def invoke_help_command(ctx: Context) -> None:
+ """Invoke the help command or default help command if help extensions is not loaded."""
+ if 'bot.exts.evergreen.help' in ctx.bot.extensions:
+ help_command = ctx.bot.get_command('help')
+ await ctx.invoke(help_command, ctx.command.qualified_name)
+ return
+ await ctx.send_help(ctx.command)
+
EXTENSIONS = frozenset(walk_extensions())
diff --git a/bot/utils/time.py b/bot/utils/time.py
new file mode 100644
index 00000000..fbf2fd21
--- /dev/null
+++ b/bot/utils/time.py
@@ -0,0 +1,84 @@
+import datetime
+
+from dateutil.relativedelta import relativedelta
+
+
+# All these functions are from https://github.com/python-discord/bot/blob/main/bot/utils/time.py
+def _stringify_time_unit(value: int, unit: str) -> str:
+ """
+ Returns a string to represent a value and time unit, ensuring that it uses the right plural form of the unit.
+
+ >>> _stringify_time_unit(1, "seconds")
+ "1 second"
+ >>> _stringify_time_unit(24, "hours")
+ "24 hours"
+ >>> _stringify_time_unit(0, "minutes")
+ "less than a minute"
+ """
+ if unit == "seconds" and value == 0:
+ return "0 seconds"
+ elif value == 1:
+ return f"{value} {unit[:-1]}"
+ elif value == 0:
+ return f"less than a {unit[:-1]}"
+ else:
+ return f"{value} {unit}"
+
+
+def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: int = 6) -> str:
+ """
+ Returns a human-readable version of the relativedelta.
+
+ precision specifies the smallest unit of time to include (e.g. "seconds", "minutes").
+ max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).
+ """
+ if max_units <= 0:
+ raise ValueError("max_units must be positive")
+
+ units = (
+ ("years", delta.years),
+ ("months", delta.months),
+ ("days", delta.days),
+ ("hours", delta.hours),
+ ("minutes", delta.minutes),
+ ("seconds", delta.seconds),
+ )
+
+ # Add the time units that are >0, but stop at accuracy or max_units.
+ time_strings = []
+ unit_count = 0
+ for unit, value in units:
+ if value:
+ time_strings.append(_stringify_time_unit(value, unit))
+ unit_count += 1
+
+ if unit == precision or unit_count >= max_units:
+ break
+
+ # Add the 'and' between the last two units, if necessary
+ if len(time_strings) > 1:
+ time_strings[-1] = f"{time_strings[-2]} and {time_strings[-1]}"
+ del time_strings[-2]
+
+ # If nothing has been found, just make the value 0 precision, e.g. `0 days`.
+ if not time_strings:
+ humanized = _stringify_time_unit(0, precision)
+ else:
+ humanized = ", ".join(time_strings)
+
+ return humanized
+
+
+def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max_units: int = 6) -> str:
+ """
+ Takes a datetime and returns a human-readable string that describes how long ago that datetime was.
+
+ precision specifies the smallest unit of time to include (e.g. "seconds", "minutes").
+ max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).
+ """
+ now = datetime.datetime.utcnow()
+ delta = abs(relativedelta(now, past_datetime))
+
+ humanized = humanize_delta(delta, precision, max_units)
+
+ return f"{humanized} ago"