diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/__main__.py | 5 | ||||
| -rw-r--r-- | bot/exts/christmas/advent_of_code/_cog.py | 20 | ||||
| -rw-r--r-- | bot/exts/evergreen/cheatsheet.py | 12 | ||||
| -rw-r--r-- | bot/exts/evergreen/conversationstarters.py | 4 | ||||
| -rw-r--r-- | bot/exts/halloween/hacktoberstats.py | 8 | ||||
| -rw-r--r-- | bot/utils/decorators.py | 119 | 
6 files changed, 92 insertions, 76 deletions
| diff --git a/bot/__main__.py b/bot/__main__.py index e9b14a53..c6e5fa57 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -6,10 +6,9 @@ from sentry_sdk.integrations.redis import RedisIntegration  from bot.bot import bot  from bot.constants import Client, GIT_SHA, STAFF_ROLES, WHITELISTED_CHANNELS -from bot.utils.decorators import in_channel_check +from bot.utils.decorators import whitelist_check  from bot.utils.extensions import walk_extensions -  sentry_logging = LoggingIntegration(      level=logging.DEBUG,      event_level=logging.WARNING @@ -26,7 +25,7 @@ sentry_sdk.init(  log = logging.getLogger(__name__) -bot.add_check(in_channel_check(*WHITELISTED_CHANNELS, bypass_roles=STAFF_ROLES)) +bot.add_check(whitelist_check(channels=WHITELISTED_CHANNELS, roles=STAFF_ROLES))  for ext in walk_extensions():      bot.load_extension(ext) diff --git a/bot/exts/christmas/advent_of_code/_cog.py b/bot/exts/christmas/advent_of_code/_cog.py index c3b87f96..466edd48 100644 --- a/bot/exts/christmas/advent_of_code/_cog.py +++ b/bot/exts/christmas/advent_of_code/_cog.py @@ -11,7 +11,7 @@ from bot.constants import (      AdventOfCode as AocConfig, Channels, Colours, Emojis, Month, Roles, WHITELISTED_CHANNELS,  )  from bot.exts.christmas.advent_of_code import _helpers -from bot.utils.decorators import InChannelCheckFailure, in_month, override_in_channel, with_role +from bot.utils.decorators import InChannelCheckFailure, in_month, whitelist_override, with_role  log = logging.getLogger(__name__) @@ -50,7 +50,7 @@ class AdventOfCode(commands.Cog):          self.status_task.add_done_callback(_helpers.background_task_callback)      @commands.group(name="adventofcode", aliases=("aoc",)) -    @override_in_channel(AOC_WHITELIST) +    @whitelist_override(channels=AOC_WHITELIST)      async def adventofcode_group(self, ctx: commands.Context) -> None:          """All of the Advent of Code commands."""          if not ctx.invoked_subcommand: @@ -61,7 +61,7 @@ class AdventOfCode(commands.Cog):          aliases=("sub", "notifications", "notify", "notifs"),          brief="Notifications for new days"      ) -    @override_in_channel(AOC_WHITELIST) +    @whitelist_override(channels=AOC_WHITELIST)      async def aoc_subscribe(self, ctx: commands.Context) -> None:          """Assign the role for notifications about new days being ready."""          current_year = datetime.now().year @@ -82,7 +82,7 @@ class AdventOfCode(commands.Cog):      @in_month(Month.DECEMBER)      @adventofcode_group.command(name="unsubscribe", aliases=("unsub",), brief="Notifications for new days") -    @override_in_channel(AOC_WHITELIST) +    @whitelist_override(channels=AOC_WHITELIST)      async def aoc_unsubscribe(self, ctx: commands.Context) -> None:          """Remove the role for notifications about new days being ready."""          role = ctx.guild.get_role(AocConfig.role_id) @@ -94,7 +94,7 @@ class AdventOfCode(commands.Cog):              await ctx.send("Hey, you don't even get any notifications about new Advent of Code tasks currently anyway.")      @adventofcode_group.command(name="countdown", aliases=("count", "c"), brief="Return time left until next day") -    @override_in_channel(AOC_WHITELIST) +    @whitelist_override(channels=AOC_WHITELIST)      async def aoc_countdown(self, ctx: commands.Context) -> None:          """Return time left until next day."""          if not _helpers.is_in_advent(): @@ -123,13 +123,13 @@ class AdventOfCode(commands.Cog):          await ctx.send(f"There are {hours} hours and {minutes} minutes left until day {tomorrow.day}.")      @adventofcode_group.command(name="about", aliases=("ab", "info"), brief="Learn about Advent of Code") -    @override_in_channel(AOC_WHITELIST) +    @whitelist_override(channels=AOC_WHITELIST)      async def about_aoc(self, ctx: commands.Context) -> None:          """Respond with an explanation of all things Advent of Code."""          await ctx.send("", embed=self.cached_about_aoc)      @adventofcode_group.command(name="join", aliases=("j",), brief="Learn how to join the leaderboard (via DM)") -    @override_in_channel(AOC_WHITELIST) +    @whitelist_override(channels=AOC_WHITELIST)      async def join_leaderboard(self, ctx: commands.Context) -> None:          """DM the user the information for joining the Python Discord leaderboard."""          current_year = datetime.now().year @@ -178,7 +178,7 @@ class AdventOfCode(commands.Cog):          aliases=("board", "lb"),          brief="Get a snapshot of the PyDis private AoC leaderboard",      ) -    @override_in_channel(AOC_WHITELIST_RESTRICTED) +    @whitelist_override(channels=AOC_WHITELIST_RESTRICTED)      async def aoc_leaderboard(self, ctx: commands.Context) -> None:          """Get the current top scorers of the Python Discord Leaderboard."""          async with ctx.typing(): @@ -203,7 +203,7 @@ class AdventOfCode(commands.Cog):          aliases=("globalboard", "gb"),          brief="Get a link to the global leaderboard",      ) -    @override_in_channel(AOC_WHITELIST_RESTRICTED) +    @whitelist_override(channels=AOC_WHITELIST_RESTRICTED)      async def aoc_global_leaderboard(self, ctx: commands.Context) -> None:          """Get a link to the global Advent of Code leaderboard."""          url = self.global_leaderboard_url @@ -219,7 +219,7 @@ class AdventOfCode(commands.Cog):          aliases=("dailystats", "ds"),          brief="Get daily statistics for the Python Discord leaderboard"      ) -    @override_in_channel(AOC_WHITELIST_RESTRICTED) +    @whitelist_override(channels=AOC_WHITELIST_RESTRICTED)      async def private_leaderboard_daily_stats(self, ctx: commands.Context) -> None:          """Send an embed with daily completion statistics for the Python Discord leaderboard."""          try: diff --git a/bot/exts/evergreen/cheatsheet.py b/bot/exts/evergreen/cheatsheet.py index a64ddd69..3fe709d5 100644 --- a/bot/exts/evergreen/cheatsheet.py +++ b/bot/exts/evergreen/cheatsheet.py @@ -8,8 +8,8 @@ from discord.ext import commands  from discord.ext.commands import BucketType, Context  from bot import constants -from bot.constants import Categories, Channels, Colours, ERROR_REPLIES, Roles, WHITELISTED_CHANNELS -from bot.utils.decorators import with_role +from bot.constants import Categories, Channels, Colours, ERROR_REPLIES +from bot.utils.decorators import whitelist_override  ERROR_MESSAGE = f"""  Unknown cheat sheet. Please try to reformulate your query. @@ -75,7 +75,7 @@ class CheatSheet(commands.Cog):          aliases=("cht.sh", "cheatsheet", "cheat-sheet", "cht"),      )      @commands.cooldown(1, 10, BucketType.user) -    @with_role(Roles.everyone_role) +    @whitelist_override(categories=[Categories.help_in_use])      async def cheat_sheet(self, ctx: Context, *search_terms: str) -> None:          """          Search cheat.sh. @@ -84,12 +84,6 @@ class CheatSheet(commands.Cog):          Usage:          --> .cht read json          """ -        if not ( -                ctx.channel.category.id == Categories.help_in_use -                or ctx.channel.id in WHITELISTED_CHANNELS -        ): -            return -          async with ctx.typing():              search_string = quote_plus(" ".join(search_terms)) diff --git a/bot/exts/evergreen/conversationstarters.py b/bot/exts/evergreen/conversationstarters.py index 576b8d76..e7058961 100644 --- a/bot/exts/evergreen/conversationstarters.py +++ b/bot/exts/evergreen/conversationstarters.py @@ -5,7 +5,7 @@ from discord import Color, Embed  from discord.ext import commands  from bot.constants import WHITELISTED_CHANNELS -from bot.utils.decorators import override_in_channel +from bot.utils.decorators import whitelist_override  from bot.utils.randomization import RandomCycle  SUGGESTION_FORM = 'https://forms.gle/zw6kkJqv8U43Nfjg9' @@ -38,7 +38,7 @@ class ConvoStarters(commands.Cog):          self.bot = bot      @commands.command() -    @override_in_channel(ALL_ALLOWED_CHANNELS) +    @whitelist_override(channels=ALL_ALLOWED_CHANNELS)      async def topic(self, ctx: commands.Context) -> None:          """          Responds with a random topic to start a conversation. diff --git a/bot/exts/halloween/hacktoberstats.py b/bot/exts/halloween/hacktoberstats.py index a1c55922..d9fc0e8a 100644 --- a/bot/exts/halloween/hacktoberstats.py +++ b/bot/exts/halloween/hacktoberstats.py @@ -11,7 +11,7 @@ from async_rediscache import RedisCache  from discord.ext import commands  from bot.constants import Channels, Month, NEGATIVE_REPLIES, Tokens, WHITELISTED_CHANNELS -from bot.utils.decorators import in_month, override_in_channel +from bot.utils.decorators import in_month, whitelist_override  log = logging.getLogger(__name__) @@ -44,7 +44,7 @@ class HacktoberStats(commands.Cog):      @in_month(Month.SEPTEMBER, Month.OCTOBER, Month.NOVEMBER)      @commands.group(name="hacktoberstats", aliases=("hackstats",), invoke_without_command=True) -    @override_in_channel(HACKTOBER_WHITELIST) +    @whitelist_override(channels=HACKTOBER_WHITELIST)      async def hacktoberstats_group(self, ctx: commands.Context, github_username: str = None) -> None:          """          Display an embed for a user's Hacktoberfest contributions. @@ -72,7 +72,7 @@ class HacktoberStats(commands.Cog):      @in_month(Month.SEPTEMBER, Month.OCTOBER, Month.NOVEMBER)      @hacktoberstats_group.command(name="link") -    @override_in_channel(HACKTOBER_WHITELIST) +    @whitelist_override(channels=HACKTOBER_WHITELIST)      async def link_user(self, ctx: commands.Context, github_username: str = None) -> None:          """          Link the invoking user's Github github_username to their Discord ID. @@ -96,7 +96,7 @@ class HacktoberStats(commands.Cog):      @in_month(Month.SEPTEMBER, Month.OCTOBER, Month.NOVEMBER)      @hacktoberstats_group.command(name="unlink") -    @override_in_channel(HACKTOBER_WHITELIST) +    @whitelist_override(channels=HACKTOBER_WHITELIST)      async def unlink_user(self, ctx: commands.Context) -> None:          """Remove the invoking user's account link from the log."""          author_id, author_mention = self._author_mention_from_context(ctx) diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 9cdaad3f..c12a15ff 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -13,6 +13,7 @@ from discord.ext.commands import CheckFailure, Command, Context  from bot.constants import ERROR_REPLIES, Month  from bot.utils import human_months, resolve_current_month +from bot.utils.checks import in_whitelist_check  ONE_DAY = 24 * 60 * 60 @@ -186,82 +187,104 @@ def without_role(*role_ids: int) -> t.Callable:      return commands.check(predicate) -def in_channel_check(*channels: int, bypass_roles: t.Container[int] = None) -> t.Callable[[Context], bool]: +def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context], bool]:      """ -    Checks that the message is in a whitelisted channel or optionally has a bypass role. +    Checks if a message is sent in a whitelisted context. -    If `in_channel_override` is present, check if it contains channels -    and use them in place of the global whitelist. +    All arguments from `in_whitelist_check` are supported, with the exception of "fail_silently". +    If `whitelist_override` is present, it is added to the global whitelist.      """      def predicate(ctx: Context) -> bool: +        # Skip DM invocations          if not ctx.guild:              log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")              return True -        if ctx.channel.id in channels: -            log.debug( -                f"{ctx.author} tried to call the '{ctx.command.name}' command " -                f"and the command was used in a whitelisted channel." -            ) -            return True -        if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles): -            log.debug( -                f"{ctx.author} called the '{ctx.command.name}' command and " -                f"had a role to bypass the in_channel check." -            ) -            return True +        kwargs = default_kwargs.copy() -        if hasattr(ctx.command.callback, "in_channel_override"): -            override = ctx.command.callback.in_channel_override -            if override is None: +        # Update kwargs based on override +        if hasattr(ctx.command.callback, "override"): +            # Remove default kwargs if reset is True +            if ctx.command.callback.override_reset: +                kwargs = {}                  log.debug( -                    f"{ctx.author} called the '{ctx.command.name}' command " -                    f"and the command was whitelisted to bypass the in_channel check." +                    f"{ctx.author} called the '{ctx.command.name}' command and " +                    f"overrode default checks."                  ) -                return True -            else: -                if ctx.channel.id in override: -                    log.debug( -                        f"{ctx.author} tried to call the '{ctx.command.name}' command " -                        f"and the command was used in an overridden whitelisted channel." -                    ) -                    return True -                log.debug( -                    f"{ctx.author} tried to call the '{ctx.command.name}' command. " -                    f"The overridden in_channel check failed." -                ) -                channels_str = ', '.join(f"<#{c_id}>" for c_id in override) -                raise InChannelCheckFailure( -                    f"Sorry, but you may only use this command within {channels_str}." -                ) +            # Merge overwrites and defaults +            for arg in ctx.command.callback.override: +                default_value = kwargs.get(arg) +                new_value = ctx.command.callback.override[arg] + +                # Skip values that don't need merging, or can't be merged +                if default_value is None or isinstance(arg, int): +                    kwargs[arg] = new_value + +                # Merge containers +                elif isinstance(default_value, t.Container): +                    if isinstance(new_value, t.Container): +                        kwargs[arg] = (*default_value, *new_value) +                    else: +                        kwargs[arg] = new_value + +            log.debug( +                f"Updated default check arguments for '{ctx.command.name}' " +                f"invoked by {ctx.author}." +            ) + +        log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") +        result = in_whitelist_check(ctx, fail_silently=True, **kwargs) + +        # Return if check passed +        if result: +            log.debug( +                f"{ctx.author} tried to call the '{ctx.command.name}' command " +                f"and the command was used in an overridden context." +            ) +            return result          log.debug(              f"{ctx.author} tried to call the '{ctx.command.name}' command. " -            f"The in_channel check failed." +            f"The whitelist check failed."          ) -        channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) -        raise InChannelCheckFailure( -            f"Sorry, but you may only use this command within {channels_str}." -        ) +        # Raise error if the check did not pass +        channels = set(kwargs.get("channels") or {}) +        categories = kwargs.get("categories") -    return predicate +        # Add all whitelisted category channels +        if categories: +            for category_id in categories: +                category = ctx.guild.get_channel(category_id) +                if category is None: +                    continue +                [channels.add(channel.id) for channel in category.text_channels] -in_channel = commands.check(in_channel_check) +        if channels: +            channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) +            message = f"Sorry, but you may only use this command within {channels_str}." +        else: +            message = "Sorry, but you may not use this command." + +        raise InChannelCheckFailure(message) + +    return predicate -def override_in_channel(channels: t.Tuple[int] = None) -> t.Callable: +def whitelist_override(bypass_defaults: bool = False, **kwargs: t.Container[int]) -> t.Callable:      """ -    Set command callback attribute for detection in `in_channel_check`. +    Override global whitelist context, with the kwargs specified. -    Override global whitelist if channels are specified. +    All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`. +    Set `bypass_defaults` to True if you want to completely bypass global checks.      This decorator has to go before (below) below the `command` decorator.      """      def inner(func: t.Callable) -> t.Callable: -        func.in_channel_override = channels +        func.override = kwargs +        func.override_reset = bypass_defaults          return func      return inner | 
