diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/constants.py | 1 | ||||
| -rw-r--r-- | bot/decorators.py | 47 | ||||
| -rw-r--r-- | bot/seasons/christmas/adventofcode.py | 2 | ||||
| -rw-r--r-- | bot/seasons/evergreen/issues.py | 2 | ||||
| -rw-r--r-- | bot/seasons/halloween/hacktoberstats.py | 9 | 
5 files changed, 47 insertions, 14 deletions
| diff --git a/bot/constants.py b/bot/constants.py index dbf35754..0d4321c8 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -52,6 +52,7 @@ class Channels(NamedTuple):      python_discussion = 267624335836053506      show_your_projects = int(environ.get("CHANNEL_SHOW_YOUR_PROJECTS", 303934982764625920))      show_your_projects_discussion = 360148304664723466 +    hacktoberfest_2019 = 628184417646411776  class Client(NamedTuple): diff --git a/bot/decorators.py b/bot/decorators.py index dbaad4a2..2c042b56 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -64,12 +64,16 @@ def without_role(*role_ids: int) -> bool:  def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None) -> typing.Callable[[Context], bool]: -    """Checks that the message is in a whitelisted channel or optionally has a bypass role.""" +    """ +    Checks that the message is in a whitelisted channel or optionally has a bypass role. + +    If `in_channel_override` is present, check if it contains channels +    and use them in place of the global whitelist. +    """      def predicate(ctx: Context) -> bool:          if not ctx.guild:              log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")              return True -          if ctx.channel.id in channels:              log.debug(                  f"{ctx.author} tried to call the '{ctx.command.name}' command " @@ -78,11 +82,29 @@ def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None)              return True          if hasattr(ctx.command.callback, "in_channel_override"): -            log.debug( -                f"{ctx.author} called the '{ctx.command.name}' command " -                f"and the command was whitelisted to bypass the in_channel check." -            ) -            return True +            override = ctx.command.callback.in_channel_override +            if override is None: +                log.debug( +                    f"{ctx.author} called the '{ctx.command.name}' command " +                    f"and the command was whitelisted to bypass the in_channel check." +                ) +                return True +            else: +                if ctx.channel.id in override: +                    log.debug( +                        f"{ctx.author} tried to call the '{ctx.command.name}' command " +                        f"and the command was used in an overridden whitelisted channel." +                    ) +                    return True + +                log.debug( +                    f"{ctx.author} tried to call the '{ctx.command.name}' command. " +                    f"The overridden in_channel check failed." +                ) +                channels_str = ', '.join(f"<#{c_id}>" for c_id in override) +                raise InChannelCheckFailure( +                    f"Sorry, but you may only use this command within {channels_str}." +                )          if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles):              log.debug( @@ -107,14 +129,19 @@ def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None)  in_channel = commands.check(in_channel_check) -def override_in_channel(func: typing.Callable) -> typing.Callable: +def override_in_channel(channels: typing.Tuple[int] = None) -> typing.Callable:      """      Set command callback attribute for detection in `in_channel_check`. +    Override global whitelist if channels are specified. +      This decorator has to go before (below) below the `command` decorator.      """ -    func.in_channel_override = True -    return func +    def inner(func: typing.Callable) -> typing.Callable: +        func.in_channel_override = channels +        return func + +    return inner  def locked() -> typing.Union[typing.Callable, None]: diff --git a/bot/seasons/christmas/adventofcode.py b/bot/seasons/christmas/adventofcode.py index 6609387e..513c1020 100644 --- a/bot/seasons/christmas/adventofcode.py +++ b/bot/seasons/christmas/adventofcode.py @@ -126,7 +126,7 @@ class AdventOfCode(commands.Cog):          self.status_task = asyncio.ensure_future(self.bot.loop.create_task(status_coro))      @commands.group(name="adventofcode", aliases=("aoc",), invoke_without_command=True) -    @override_in_channel +    @override_in_channel()      async def adventofcode_group(self, ctx: commands.Context) -> None:          """All of the Advent of Code commands."""          await ctx.send_help(ctx.command) diff --git a/bot/seasons/evergreen/issues.py b/bot/seasons/evergreen/issues.py index 0ba74d9c..438ab475 100644 --- a/bot/seasons/evergreen/issues.py +++ b/bot/seasons/evergreen/issues.py @@ -16,7 +16,7 @@ class Issues(commands.Cog):          self.bot = bot      @commands.command(aliases=("issues",)) -    @override_in_channel +    @override_in_channel()      async def issue(          self, ctx: commands.Context, number: int, repository: str = "seasonalbot", user: str = "python-discord"      ) -> None: diff --git a/bot/seasons/halloween/hacktoberstats.py b/bot/seasons/halloween/hacktoberstats.py index 20797037..035eafbc 100644 --- a/bot/seasons/halloween/hacktoberstats.py +++ b/bot/seasons/halloween/hacktoberstats.py @@ -10,12 +10,16 @@ import aiohttp  import discord  from discord.ext import commands +from bot.constants import Channels, WHITELISTED_CHANNELS +from bot.decorators import override_in_channel  from bot.utils.persist import make_persistent +  log = logging.getLogger(__name__)  CURRENT_YEAR = datetime.now().year  # Used to construct GH API query  PRS_FOR_SHIRT = 4  # Minimum number of PRs before a shirt is awarded +HACKTOBER_WHITELIST = WHITELISTED_CHANNELS + (Channels.hacktoberfest_2019,)  class HacktoberStats(commands.Cog): @@ -27,6 +31,7 @@ class HacktoberStats(commands.Cog):          self.linked_accounts = self.load_linked_users()      @commands.group(name="hacktoberstats", aliases=("hackstats",), invoke_without_command=True) +    @override_in_channel(HACKTOBER_WHITELIST)      async def hacktoberstats_group(self, ctx: commands.Context, github_username: str = None) -> None:          """          Display an embed for a user's Hacktoberfest contributions. @@ -220,7 +225,7 @@ class HacktoberStats(commands.Cog):          not_label = "invalid"          action_type = "pr"          is_query = f"public+author:{github_username}" -        date_range = f"{CURRENT_YEAR}-10-01..{CURRENT_YEAR}-10-31" +        date_range = f"{CURRENT_YEAR}-10-01T00:00:00%2B14:00..{CURRENT_YEAR}-10-31T00:00:00-11:00"          per_page = "300"          query_url = (              f"{base_url}" @@ -231,7 +236,7 @@ class HacktoberStats(commands.Cog):              f"&per_page={per_page}"          ) -        headers = {"user-agent": "Discord Python Hactoberbot"} +        headers = {"user-agent": "Discord Python Hacktoberbot"}          async with aiohttp.ClientSession() as session:              async with session.get(query_url, headers=headers) as resp:                  jsonresp = await resp.json() | 
