diff options
| author | 2022-01-01 23:46:09 +0000 | |
|---|---|---|
| committer | 2022-01-01 23:46:09 +0000 | |
| commit | 0c7443d2ce8b44706d98837830586feaa6016985 (patch) | |
| tree | ee77a61d57b4d46f0e3ebdfadab760896986db2d | |
| parent | Merge pull request #997 from Sn4u/fix-995 (diff) | |
| parent | Merge branch 'main' into dm-check (diff) | |
Merge pull request #953 from python-discord/dm-check
Add Better Support For Whitelisting DM Commands
| -rw-r--r-- | bot/exts/fun/anagram.py | 1 | ||||
| -rw-r--r-- | bot/exts/fun/battleship.py | 1 | ||||
| -rw-r--r-- | bot/exts/fun/connect_four.py | 3 | ||||
| -rw-r--r-- | bot/exts/fun/tic_tac_toe.py | 3 | ||||
| -rw-r--r-- | bot/utils/decorators.py | 27 | 
5 files changed, 18 insertions, 17 deletions
| diff --git a/bot/exts/fun/anagram.py b/bot/exts/fun/anagram.py index 9aee5f18..79280fa9 100644 --- a/bot/exts/fun/anagram.py +++ b/bot/exts/fun/anagram.py @@ -49,7 +49,6 @@ class Anagram(commands.Cog):          self.games: dict[int, AnagramGame] = {}      @commands.command(name="anagram", aliases=("anag", "gram", "ag")) -    @commands.guild_only()      async def anagram_command(self, ctx: commands.Context) -> None:          """          Given shuffled letters, rearrange them into anagrams. diff --git a/bot/exts/fun/battleship.py b/bot/exts/fun/battleship.py index f4351954..beff196f 100644 --- a/bot/exts/fun/battleship.py +++ b/bot/exts/fun/battleship.py @@ -369,7 +369,6 @@ class Battleship(commands.Cog):          return any(player in (game.p1.user, game.p2.user) for game in self.games)      @commands.group(invoke_without_command=True) -    @commands.guild_only()      async def battleship(self, ctx: commands.Context) -> None:          """          Play a game of Battleship with someone else! diff --git a/bot/exts/fun/connect_four.py b/bot/exts/fun/connect_four.py index 647bb2b7..f53695d5 100644 --- a/bot/exts/fun/connect_four.py +++ b/bot/exts/fun/connect_four.py @@ -6,7 +6,6 @@ from typing import Optional, Union  import discord  import emojis  from discord.ext import commands -from discord.ext.commands import guild_only  from bot.bot import Bot  from bot.constants import Emojis @@ -361,7 +360,6 @@ class ConnectFour(commands.Cog):                  self.games.remove(game)              raise -    @guild_only()      @commands.group(          invoke_without_command=True,          aliases=("4inarow", "connect4", "connectfour", "c4"), @@ -426,7 +424,6 @@ class ConnectFour(commands.Cog):          await self._play_game(ctx, user, board_size, str(emoji1), str(emoji2)) -    @guild_only()      @connect_four.command(aliases=("bot", "computer", "cpu"))      async def ai(          self, diff --git a/bot/exts/fun/tic_tac_toe.py b/bot/exts/fun/tic_tac_toe.py index 946b6f7b..5dd38a81 100644 --- a/bot/exts/fun/tic_tac_toe.py +++ b/bot/exts/fun/tic_tac_toe.py @@ -3,7 +3,7 @@ import random  from typing import Callable, Optional, Union  import discord -from discord.ext.commands import Cog, Context, check, group, guild_only +from discord.ext.commands import Cog, Context, check, group  from bot.bot import Bot  from bot.constants import Emojis @@ -253,7 +253,6 @@ class TicTacToe(Cog):      def __init__(self):          self.games: list[Game] = [] -    @guild_only()      @is_channel_free()      @is_requester_free()      @group(name="tictactoe", aliases=("ttt", "tic"), invoke_without_command=True) diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 132aaa87..7a3b14ad 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -196,15 +196,14 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo      If `whitelist_override` is present, it is added to the global whitelist.      """      def predicate(ctx: Context) -> bool: -        # Skip DM invocations -        if not ctx.guild: -            log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.") -            return True -          kwargs = default_kwargs.copy() +        allow_dms = False          # Update kwargs based on override          if hasattr(ctx.command.callback, "override"): +            # Handle DM invocations +            allow_dms = ctx.command.callback.override_dm +              # Remove default kwargs if reset is True              if ctx.command.callback.override_reset:                  kwargs = {} @@ -234,8 +233,12 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo                  f"invoked by {ctx.author}."              ) -        log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") -        result = in_whitelist_check(ctx, fail_silently=True, **kwargs) +        if ctx.guild is None: +            log.debug(f"{ctx.author} tried using the '{ctx.command.name}' command from a DM.") +            result = allow_dms +        else: +            log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") +            result = in_whitelist_check(ctx, fail_silently=True, **kwargs)          # Return if check passed          if result: @@ -260,8 +263,8 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo              default_whitelist_channels.discard(Channels.community_bot_commands)              channels.difference_update(default_whitelist_channels) -        # Add all whitelisted category channels -        if categories: +        # Add all whitelisted category channels, but skip if we're in DMs +        if categories and ctx.guild is not None:              for category_id in categories:                  category = ctx.guild.get_channel(category_id)                  if category is None: @@ -280,18 +283,22 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo      return predicate -def whitelist_override(bypass_defaults: bool = False, **kwargs: Container[int]) -> Callable: +def whitelist_override(bypass_defaults: bool = False, allow_dm: bool = False, **kwargs: Container[int]) -> Callable:      """      Override global whitelist context, with the kwargs specified.      All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`.      Set `bypass_defaults` to True if you want to completely bypass global checks. +    Set `allow_dm` to True if you want to allow the command to be invoked from within direct messages. +    Note that you have to be careful with any references to the guild. +      This decorator has to go before (below) below the `command` decorator.      """      def inner(func: Callable) -> Callable:          func.override = kwargs          func.override_reset = bypass_defaults +        func.override_dm = allow_dm          return func      return inner | 
