diff options
| author | 2021-04-15 19:21:14 +0530 | |
|---|---|---|
| committer | 2021-04-15 19:21:14 +0530 | |
| commit | 45fd8b69826e5c39e87177a3113c486ccc51fcdf (patch) | |
| tree | 0b8352a3ad1601ba188bb0eb1ecd0bca5978064e /bot/utils | |
| parent | Update emojis (diff) | |
| parent | Update branch (diff) | |
Update branch
Diffstat (limited to 'bot/utils')
| -rw-r--r-- | bot/utils/decorators.py | 127 | ||||
| -rw-r--r-- | bot/utils/extensions.py | 10 | ||||
| -rw-r--r-- | bot/utils/time.py | 84 | 
3 files changed, 172 insertions, 49 deletions
| diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 9cdaad3f..60066dc4 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -11,8 +11,9 @@ from discord import Colour, Embed  from discord.ext import commands  from discord.ext.commands import CheckFailure, Command, Context -from bot.constants import ERROR_REPLIES, Month +from bot.constants import Channels, ERROR_REPLIES, Month, WHITELISTED_CHANNELS  from bot.utils import human_months, resolve_current_month +from bot.utils.checks import in_whitelist_check  ONE_DAY = 24 * 60 * 60 @@ -186,82 +187,110 @@ def without_role(*role_ids: int) -> t.Callable:      return commands.check(predicate) -def in_channel_check(*channels: int, bypass_roles: t.Container[int] = None) -> t.Callable[[Context], bool]: +def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context], bool]:      """ -    Checks that the message is in a whitelisted channel or optionally has a bypass role. +    Checks if a message is sent in a whitelisted context. -    If `in_channel_override` is present, check if it contains channels -    and use them in place of the global whitelist. +    All arguments from `in_whitelist_check` are supported, with the exception of "fail_silently". +    If `whitelist_override` is present, it is added to the global whitelist.      """      def predicate(ctx: Context) -> bool: +        # Skip DM invocations          if not ctx.guild:              log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")              return True -        if ctx.channel.id in channels: -            log.debug( -                f"{ctx.author} tried to call the '{ctx.command.name}' command " -                f"and the command was used in a whitelisted channel." -            ) -            return True -        if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles): -            log.debug( -                f"{ctx.author} called the '{ctx.command.name}' command and " -                f"had a role to bypass the in_channel check." -            ) -            return True +        kwargs = default_kwargs.copy() -        if hasattr(ctx.command.callback, "in_channel_override"): -            override = ctx.command.callback.in_channel_override -            if override is None: +        # Update kwargs based on override +        if hasattr(ctx.command.callback, "override"): +            # Remove default kwargs if reset is True +            if ctx.command.callback.override_reset: +                kwargs = {}                  log.debug( -                    f"{ctx.author} called the '{ctx.command.name}' command " -                    f"and the command was whitelisted to bypass the in_channel check." +                    f"{ctx.author} called the '{ctx.command.name}' command and " +                    f"overrode default checks."                  ) -                return True -            else: -                if ctx.channel.id in override: -                    log.debug( -                        f"{ctx.author} tried to call the '{ctx.command.name}' command " -                        f"and the command was used in an overridden whitelisted channel." -                    ) -                    return True -                log.debug( -                    f"{ctx.author} tried to call the '{ctx.command.name}' command. " -                    f"The overridden in_channel check failed." -                ) -                channels_str = ', '.join(f"<#{c_id}>" for c_id in override) -                raise InChannelCheckFailure( -                    f"Sorry, but you may only use this command within {channels_str}." -                ) +            # Merge overwrites and defaults +            for arg in ctx.command.callback.override: +                default_value = kwargs.get(arg) +                new_value = ctx.command.callback.override[arg] + +                # Skip values that don't need merging, or can't be merged +                if default_value is None or isinstance(arg, int): +                    kwargs[arg] = new_value + +                # Merge containers +                elif isinstance(default_value, t.Container): +                    if isinstance(new_value, t.Container): +                        kwargs[arg] = (*default_value, *new_value) +                    else: +                        kwargs[arg] = new_value + +            log.debug( +                f"Updated default check arguments for '{ctx.command.name}' " +                f"invoked by {ctx.author}." +            ) + +        log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") +        result = in_whitelist_check(ctx, fail_silently=True, **kwargs) + +        # Return if check passed +        if result: +            log.debug( +                f"{ctx.author} tried to call the '{ctx.command.name}' command " +                f"and the command was used in an overridden context." +            ) +            return result          log.debug(              f"{ctx.author} tried to call the '{ctx.command.name}' command. " -            f"The in_channel check failed." +            f"The whitelist check failed."          ) -        channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) -        raise InChannelCheckFailure( -            f"Sorry, but you may only use this command within {channels_str}." -        ) +        # Raise error if the check did not pass +        channels = set(kwargs.get("channels") or {}) +        categories = kwargs.get("categories") -    return predicate +        # Only output override channels + community_bot_commands +        if channels: +            default_whitelist_channels = set(WHITELISTED_CHANNELS) +            default_whitelist_channels.discard(Channels.community_bot_commands) +            channels.difference_update(default_whitelist_channels) +        # Add all whitelisted category channels +        if categories: +            for category_id in categories: +                category = ctx.guild.get_channel(category_id) +                if category is None: +                    continue -in_channel = commands.check(in_channel_check) +                channels.update(channel.id for channel in category.text_channels) + +        if channels: +            channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) +            message = f"Sorry, but you may only use this command within {channels_str}." +        else: +            message = "Sorry, but you may not use this command." + +        raise InChannelCheckFailure(message) + +    return predicate -def override_in_channel(channels: t.Tuple[int] = None) -> t.Callable: +def whitelist_override(bypass_defaults: bool = False, **kwargs: t.Container[int]) -> t.Callable:      """ -    Set command callback attribute for detection in `in_channel_check`. +    Override global whitelist context, with the kwargs specified. -    Override global whitelist if channels are specified. +    All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`. +    Set `bypass_defaults` to True if you want to completely bypass global checks.      This decorator has to go before (below) below the `command` decorator.      """      def inner(func: t.Callable) -> t.Callable: -        func.in_channel_override = channels +        func.override = kwargs +        func.override_reset = bypass_defaults          return func      return inner diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py index 50350ea8..459588a1 100644 --- a/bot/utils/extensions.py +++ b/bot/utils/extensions.py @@ -3,6 +3,8 @@ import inspect  import pkgutil  from typing import Iterator, NoReturn +from discord.ext.commands import Context +  from bot import exts @@ -31,4 +33,12 @@ def walk_extensions() -> Iterator[str]:          yield module.name +async def invoke_help_command(ctx: Context) -> None: +    """Invoke the help command or default help command if help extensions is not loaded.""" +    if 'bot.exts.evergreen.help' in ctx.bot.extensions: +        help_command = ctx.bot.get_command('help') +        await ctx.invoke(help_command, ctx.command.qualified_name) +        return +    await ctx.send_help(ctx.command) +  EXTENSIONS = frozenset(walk_extensions()) diff --git a/bot/utils/time.py b/bot/utils/time.py new file mode 100644 index 00000000..fbf2fd21 --- /dev/null +++ b/bot/utils/time.py @@ -0,0 +1,84 @@ +import datetime + +from dateutil.relativedelta import relativedelta + + +# All these functions are from https://github.com/python-discord/bot/blob/main/bot/utils/time.py +def _stringify_time_unit(value: int, unit: str) -> str: +    """ +    Returns a string to represent a value and time unit, ensuring that it uses the right plural form of the unit. + +    >>> _stringify_time_unit(1, "seconds") +    "1 second" +    >>> _stringify_time_unit(24, "hours") +    "24 hours" +    >>> _stringify_time_unit(0, "minutes") +    "less than a minute" +    """ +    if unit == "seconds" and value == 0: +        return "0 seconds" +    elif value == 1: +        return f"{value} {unit[:-1]}" +    elif value == 0: +        return f"less than a {unit[:-1]}" +    else: +        return f"{value} {unit}" + + +def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: int = 6) -> str: +    """ +    Returns a human-readable version of the relativedelta. + +    precision specifies the smallest unit of time to include (e.g. "seconds", "minutes"). +    max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). +    """ +    if max_units <= 0: +        raise ValueError("max_units must be positive") + +    units = ( +        ("years", delta.years), +        ("months", delta.months), +        ("days", delta.days), +        ("hours", delta.hours), +        ("minutes", delta.minutes), +        ("seconds", delta.seconds), +    ) + +    # Add the time units that are >0, but stop at accuracy or max_units. +    time_strings = [] +    unit_count = 0 +    for unit, value in units: +        if value: +            time_strings.append(_stringify_time_unit(value, unit)) +            unit_count += 1 + +        if unit == precision or unit_count >= max_units: +            break + +    # Add the 'and' between the last two units, if necessary +    if len(time_strings) > 1: +        time_strings[-1] = f"{time_strings[-2]} and {time_strings[-1]}" +        del time_strings[-2] + +    # If nothing has been found, just make the value 0 precision, e.g. `0 days`. +    if not time_strings: +        humanized = _stringify_time_unit(0, precision) +    else: +        humanized = ", ".join(time_strings) + +    return humanized + + +def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max_units: int = 6) -> str: +    """ +    Takes a datetime and returns a human-readable string that describes how long ago that datetime was. + +    precision specifies the smallest unit of time to include (e.g. "seconds", "minutes"). +    max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). +    """ +    now = datetime.datetime.utcnow() +    delta = abs(relativedelta(now, past_datetime)) + +    humanized = humanize_delta(delta, precision, max_units) + +    return f"{humanized} ago" | 
