diff options
| author | 2019-10-25 23:49:06 +0200 | |
|---|---|---|
| committer | 2019-10-25 23:49:06 +0200 | |
| commit | 509b95e7812ae13d3444b0738ac7a3e1d57ee2b5 (patch) | |
| tree | bba40cdbd90ec7402360eb66f18c963db30739f8 | |
| parent | Use Scheduler instead of a custom async loop (diff) | |
| parent | Merge pull request #527 from kraktus/compact_free (diff) | |
Merge branch 'master' into #364-offensive-msg-autodeletion
| -rw-r--r-- | bot/cogs/filtering.py | 41 | ||||
| -rw-r--r-- | bot/cogs/free.py | 25 | ||||
| -rw-r--r-- | bot/cogs/moderation/infractions.py | 4 | ||||
| -rw-r--r-- | bot/cogs/moderation/management.py | 10 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 229 | ||||
| -rw-r--r-- | bot/cogs/reminders.py | 13 | ||||
| -rw-r--r-- | bot/cogs/snekbox.py | 13 | ||||
| -rw-r--r-- | bot/constants.py | 13 | ||||
| -rw-r--r-- | bot/utils/checks.py | 6 | ||||
| -rw-r--r-- | config-default.yml | 11 | ||||
| -rw-r--r-- | tests/bot/utils/test_checks.py | 8 | 
11 files changed, 182 insertions, 191 deletions
| diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 35c14f101..5bd72a584 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -11,7 +11,7 @@ from discord.ext.commands import Bot, Cog  from bot.cogs.moderation import ModLog  from bot.constants import ( -    Channels, Colours, DEBUG_MODE, +    Channels, Colours,      Filter, Icons, URLs  )  from bot.utils.scheduling import Scheduler @@ -152,10 +152,6 @@ class Filtering(Cog, Scheduler):              and not msg.author.bot                          # Author not a bot          ) -        # If we're running the bot locally, ignore role whitelist and only listen to #dev-test -        if DEBUG_MODE: -            filter_message = not msg.author.bot and msg.channel.id == Channels.devtest -          # If none of the above, we can start filtering.          if filter_message:              for filter_name, _filter in self.filters.items(): @@ -170,11 +166,11 @@ class Filtering(Cog, Scheduler):                      # Does the filter only need the message content or the full message?                      if _filter["content_only"]: -                        triggered = await _filter["function"](msg.content) +                        match = await _filter["function"](msg.content)                      else: -                        triggered = await _filter["function"](msg) +                        match = await _filter["function"](msg) -                    if triggered: +                    if match:                          # If this is a filter (not a watchlist), we should delete the message.                          if _filter["type"] == "filter":                              try: @@ -215,12 +211,23 @@ class Filtering(Cog, Scheduler):                          else:                              channel_str = f"in {msg.channel.mention}" +                        # Word and match stats for watch_words and watch_tokens +                        if filter_name in ("watch_words", "watch_tokens"): +                            surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] +                            message_content = ( +                                f"**Match:** '{match[0]}'\n" +                                f"**Location:** '...{surroundings}...'\n" +                                f"\n**Original Message:**\n{msg.content}" +                            ) +                        else:  # Use content of discord Message +                            message_content = msg.content +                          message = (                              f"The {filter_name} {_filter['type']} was triggered "                              f"by **{msg.author}** "                              f"(`{msg.author.id}`) {channel_str} with [the "                              f"following message]({msg.jump_url}):\n\n" -                            f"{msg.content}" +                            f"{message_content}"                          )                          log.debug(message) @@ -230,7 +237,7 @@ class Filtering(Cog, Scheduler):                          if filter_name == "filter_invites":                              additional_embeds = [] -                            for invite, data in triggered.items(): +                            for invite, data in match.items():                                  embed = discord.Embed(description=(                                      f"**Members:**\n{data['members']}\n"                                      f"**Active:**\n{data['active']}" @@ -261,31 +268,33 @@ class Filtering(Cog, Scheduler):                          break  # We don't want multiple filters to trigger      @staticmethod -    async def _has_watchlist_words(text: str) -> bool: +    async def _has_watchlist_words(text: str) -> Union[bool, re.Match]:          """          Returns True if the text contains one of the regular expressions from the word_watchlist in our filter config.          Only matches words with boundaries before and after the expression.          """          for regex_pattern in WORD_WATCHLIST_PATTERNS: -            if regex_pattern.search(text): -                return True +            match = regex_pattern.search(text) +            if match: +                return match  # match objects always have a boolean value of True          return False      @staticmethod -    async def _has_watchlist_tokens(text: str) -> bool: +    async def _has_watchlist_tokens(text: str) -> Union[bool, re.Match]:          """          Returns True if the text contains one of the regular expressions from the token_watchlist in our filter config.          This will match the expression even if it does not have boundaries before and after.          """          for regex_pattern in TOKEN_WATCHLIST_PATTERNS: -            if regex_pattern.search(text): +            match = regex_pattern.search(text) +            if match:                  # Make sure it's not a URL                  if not URL_RE.search(text): -                    return True +                    return match  # match objects always have a boolean value of True          return False diff --git a/bot/cogs/free.py b/bot/cogs/free.py index 269c5c1b9..82285656b 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -72,30 +72,27 @@ class Free(Cog):          # Display all potentially inactive channels          # in descending order of inactivity          if free_channels: -            embed.description += "**The following channel{0} look{1} free:**\n\n**".format( -                's' if len(free_channels) > 1 else '', -                '' if len(free_channels) > 1 else 's' -            ) -              # Sort channels in descending order by seconds              # Get position in list, inactivity, and channel object              # For each channel, add to embed.description              sorted_channels = sorted(free_channels, key=itemgetter(0), reverse=True) -            for i, (inactive, channel) in enumerate(sorted_channels, 1): + +            for (inactive, channel) in sorted_channels[:3]:                  minutes, seconds = divmod(inactive, 60)                  if minutes > 59:                      hours, minutes = divmod(minutes, 60) -                    embed.description += f"{i}. {channel.mention} inactive for {hours}h{minutes}m{seconds}s\n\n" +                    embed.description += f"{channel.mention} **{hours}h {minutes}m {seconds}s** inactive\n"                  else: -                    embed.description += f"{i}. {channel.mention} inactive for {minutes}m{seconds}s\n\n" +                    embed.description += f"{channel.mention} **{minutes}m {seconds}s** inactive\n" -            embed.description += ("**\nThese channels aren't guaranteed to be free, " -                                  "so use your best judgement and check for yourself.") +            embed.set_footer(text="Please confirm these channels are free before posting")          else: -            embed.description = ("**Doesn't look like any channels are available right now. " -                                 "You're welcome to check for yourself to be sure. " -                                 "If all channels are truly busy, please be patient " -                                 "as one will likely be available soon.**") +            embed.description = ( +                "Doesn't look like any channels are available right now. " +                "You're welcome to check for yourself to be sure. " +                "If all channels are truly busy, please be patient " +                "as one will likely be available soon." +            )          await ctx.send(embed=embed) diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index f2ae7b95d..997ffe524 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -12,7 +12,7 @@ from discord.ext.commands import Context, command  from bot import constants  from bot.api import ResponseCodeError -from bot.constants import Colours, Event +from bot.constants import Colours, Event, STAFF_CHANNELS  from bot.decorators import respect_role_hierarchy  from bot.utils import time  from bot.utils.checks import with_role_check @@ -465,6 +465,8 @@ class Infractions(Scheduler, commands.Cog):          if infraction["actor"] == self.bot.user.id:              end_msg = f" (reason: {infraction['reason']})" +        elif ctx.channel.id not in STAFF_CHANNELS: +            end_msg = ''          else:              infractions = await self.bot.api_client.get(                  "bot/infractions", diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 491f6d400..44a508436 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -11,7 +11,7 @@ from bot import constants  from bot.converters import InfractionSearchQuery  from bot.pagination import LinePaginator  from bot.utils import time -from bot.utils.checks import with_role_check +from bot.utils.checks import in_channel_check, with_role_check  from . import utils  from .infractions import Infractions  from .modlog import ModLog @@ -256,8 +256,12 @@ class ModManagement(commands.Cog):      # This cannot be static (must have a __func__ attribute).      def cog_check(self, ctx: Context) -> bool: -        """Only allow moderators to invoke the commands in this cog.""" -        return with_role_check(ctx, *constants.MODERATION_ROLES) +        """Only allow moderators from moderator channels to invoke the commands in this cog.""" +        checks = [ +            with_role_check(ctx, *constants.MODERATION_ROLES), +            in_channel_check(ctx, *constants.MODERATION_CHANNELS) +        ] +        return all(checks)      # This cannot be static (must have a __func__ attribute).      async def cog_command_error(self, ctx: Context, error: Exception) -> None: diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0f575cece..7749d237f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,13 +2,14 @@ import asyncio  import logging  import random  import textwrap -from datetime import datetime, timedelta +from datetime import datetime  from typing import List -from discord import Colour, Embed, Message, TextChannel +from discord import Colour, Embed, TextChannel  from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.tasks import loop -from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES +from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES, Webhooks  from bot.converters import Subreddit  from bot.decorators import with_role  from bot.pagination import LinePaginator @@ -26,15 +27,25 @@ class Reddit(Cog):      def __init__(self, bot: Bot):          self.bot = bot -        self.reddit_channel = None +        self.webhook = None  # set in on_ready +        bot.loop.create_task(self.init_reddit_ready()) -        self.prev_lengths = {} -        self.last_ids = {} +        self.auto_poster_loop.start() -        self.new_posts_task = None -        self.top_weekly_posts_task = None +    def cog_unload(self) -> None: +        """Stops the loops when the cog is unloaded.""" +        self.auto_poster_loop.cancel() -        self.bot.loop.create_task(self.init_reddit_polling()) +    async def init_reddit_ready(self) -> None: +        """Sets the reddit webhook when the cog is loaded.""" +        await self.bot.wait_until_ready() +        if not self.webhook: +            self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) + +    @property +    def channel(self) -> TextChannel: +        """Get the #reddit channel object from the bot's cache.""" +        return self.bot.get_channel(Channels.reddit)      async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:          """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -63,23 +74,22 @@ class Reddit(Cog):          log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}")          return list()  # Failed to get appropriate response within allowed number of retries. -    async def send_top_posts( -        self, channel: TextChannel, subreddit: Subreddit, content: str = None, time: str = "all" -    ) -> Message: -        """Create an embed for the top posts, then send it in a given TextChannel.""" -        # Create the new spicy embed. -        embed = Embed() -        embed.description = "" - -        # Get the posts -        async with channel.typing(): -            posts = await self.fetch_posts( -                route=f"{subreddit}/top", -                amount=5, -                params={ -                    "t": time -                } -            ) +    async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: +        """ +        Get the top amount of posts for a given subreddit within a specified timeframe. + +        A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top +        weekly posts. + +        The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. +        """ +        embed = Embed(description="") + +        posts = await self.fetch_posts( +            route=f"{subreddit}/top", +            amount=amount, +            params={"t": time} +        )          if not posts:              embed.title = random.choice(ERROR_REPLIES) @@ -89,9 +99,7 @@ class Reddit(Cog):                  "If this problem persists, please let us know."              ) -            return await channel.send( -                embed=embed -            ) +            return embed          for post in posts:              data = post["data"] @@ -115,103 +123,51 @@ class Reddit(Cog):              )          embed.colour = Colour.blurple() +        return embed -        return await channel.send( -            content=content, -            embed=embed -        ) - -    async def poll_new_posts(self) -> None: -        """Periodically search for new subreddit posts.""" -        while True: -            await asyncio.sleep(RedditConfig.request_delay) - -            for subreddit in RedditConfig.subreddits: -                # Make a HEAD request to the subreddit -                head_response = await self.bot.http_session.head( -                    url=f"{self.URL}/{subreddit}/new.rss", -                    headers=self.HEADERS -                ) - -                content_length = head_response.headers["content-length"] - -                # If the content is the same size as before, assume there's no new posts. -                if content_length == self.prev_lengths.get(subreddit, None): -                    continue - -                self.prev_lengths[subreddit] = content_length - -                # Now we can actually fetch the new data -                posts = await self.fetch_posts(f"{subreddit}/new") -                new_posts = [] +    @loop() +    async def auto_poster_loop(self) -> None: +        """Post the top 5 posts daily, and the top 5 posts weekly.""" +        # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter +        now = datetime.utcnow() +        midnight_tomorrow = now.replace(day=now.day + 1, hour=0, minute=0, second=0) +        seconds_until = (midnight_tomorrow - now).total_seconds() -                # Only show new posts if we've checked before. -                if subreddit in self.last_ids: -                    for post in posts: -                        data = post["data"] +        await asyncio.sleep(seconds_until) -                        # Convert the ID to an integer for easy comparison. -                        int_id = int(data["id"], 36) - -                        # If we've already seen this post, finish checking -                        if int_id <= self.last_ids[subreddit]: -                            break - -                        embed_data = { -                            "title": textwrap.shorten(data["title"], width=64, placeholder="..."), -                            "text": textwrap.shorten(data["selftext"], width=128, placeholder="..."), -                            "url": self.URL + data["permalink"], -                            "author": data["author"] -                        } - -                        new_posts.append(embed_data) - -                self.last_ids[subreddit] = int(posts[0]["data"]["id"], 36) - -                # Send all of the new posts as spicy embeds -                for data in new_posts: -                    embed = Embed() - -                    embed.title = data["title"] -                    embed.url = data["url"] -                    embed.description = data["text"] -                    embed.set_footer(text=f"Posted by u/{data['author']} in {subreddit}") -                    embed.colour = Colour.blurple() - -                    await self.reddit_channel.send(embed=embed) +        await self.bot.wait_until_ready() +        if not self.webhook: +            await self.bot.fetch_webhook(Webhooks.reddit) -                log.trace(f"Sent {len(new_posts)} new {subreddit} posts to channel {self.reddit_channel.id}.") +        if datetime.utcnow().weekday() == 0: +            await self.top_weekly_posts() +            # if it's a monday send the top weekly posts -    async def poll_top_weekly_posts(self) -> None: -        """Post a summary of the top posts every week.""" -        while True: -            now = datetime.utcnow() +        for subreddit in RedditConfig.subreddits: +            top_posts = await self.get_top_posts(subreddit=subreddit, time="day") +            await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts) -            # Calculate the amount of seconds until midnight next monday. -            monday = now + timedelta(days=7 - now.weekday()) -            monday = monday.replace(hour=0, minute=0, second=0) -            until_monday = (monday - now).total_seconds() +    async def top_weekly_posts(self) -> None: +        """Post a summary of the top posts.""" +        for subreddit in RedditConfig.subreddits: +            # Send and pin the new weekly posts. +            top_posts = await self.get_top_posts(subreddit=subreddit, time="week") -            await asyncio.sleep(until_monday) +            message = await self.webhook.send(wait=True, username=f"{subreddit} Top Weekly Posts", embed=top_posts) -            for subreddit in RedditConfig.subreddits: -                # Send and pin the new weekly posts. -                message = await self.send_top_posts( -                    channel=self.reddit_channel, -                    subreddit=subreddit, -                    content=f"This week's top {subreddit} posts have arrived!", -                    time="week" -                ) +            if subreddit.lower() == "r/python": +                if not self.channel: +                    log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") +                    return -                if subreddit.lower() == "r/python": -                    # Remove the oldest pins so that only 5 remain at most. -                    pins = await self.reddit_channel.pins() +                # Remove the oldest pins so that only 12 remain at most. +                pins = await self.channel.pins() -                    while len(pins) >= 5: -                        await pins[-1].unpin() -                        del pins[-1] +                while len(pins) >= 12: +                    await pins[-1].unpin() +                    del pins[-1] -                    await message.pin() +                await message.pin()      @group(name="reddit", invoke_without_command=True)      async def reddit_group(self, ctx: Context) -> None: @@ -221,32 +177,26 @@ class Reddit(Cog):      @reddit_group.command(name="top")      async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:          """Send the top posts of all time from a given subreddit.""" -        await self.send_top_posts( -            channel=ctx.channel, -            subreddit=subreddit, -            content=f"Here are the top {subreddit} posts of all time!", -            time="all" -        ) +        async with ctx.typing(): +            embed = await self.get_top_posts(subreddit=subreddit, time="all") + +        await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed)      @reddit_group.command(name="daily")      async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:          """Send the top posts of today from a given subreddit.""" -        await self.send_top_posts( -            channel=ctx.channel, -            subreddit=subreddit, -            content=f"Here are today's top {subreddit} posts!", -            time="day" -        ) +        async with ctx.typing(): +            embed = await self.get_top_posts(subreddit=subreddit, time="day") + +        await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed)      @reddit_group.command(name="weekly")      async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:          """Send the top posts of this week from a given subreddit.""" -        await self.send_top_posts( -            channel=ctx.channel, -            subreddit=subreddit, -            content=f"Here are this week's top {subreddit} posts!", -            time="week" -        ) +        async with ctx.typing(): +            embed = await self.get_top_posts(subreddit=subreddit, time="week") + +        await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed)      @with_role(*STAFF_ROLES)      @reddit_group.command(name="subreddits", aliases=("subs",)) @@ -264,19 +214,6 @@ class Reddit(Cog):              max_lines=15          ) -    async def init_reddit_polling(self) -> None: -        """Initiate reddit post event loop.""" -        await self.bot.wait_until_ready() -        self.reddit_channel = await self.bot.fetch_channel(Channels.reddit) - -        if self.reddit_channel is not None: -            if self.new_posts_task is None: -                self.new_posts_task = self.bot.loop.create_task(self.poll_new_posts()) -            if self.top_weekly_posts_task is None: -                self.top_weekly_posts_task = self.bot.loop.create_task(self.poll_top_weekly_posts()) -        else: -            log.warning("Couldn't locate a channel for subreddit relaying.") -  def setup(bot: Bot) -> None:      """Reddit cog load.""" diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index b54622306..81990704b 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -2,7 +2,7 @@ import asyncio  import logging  import random  import textwrap -from datetime import datetime +from datetime import datetime, timedelta  from operator import itemgetter  from typing import Optional @@ -104,7 +104,10 @@ class Reminders(Scheduler, Cog):              name="It has arrived!"          ) -        embed.description = f"Here's your reminder: `{reminder['content']}`" +        embed.description = f"Here's your reminder: `{reminder['content']}`." + +        if reminder.get("jump_url"):  # keep backward compatibility +            embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})"          if late:              embed.colour = Colour.red() @@ -167,14 +170,18 @@ class Reminders(Scheduler, Cog):              json={                  'author': ctx.author.id,                  'channel_id': ctx.message.channel.id, +                'jump_url': ctx.message.jump_url,                  'content': content,                  'expiration': expiration.isoformat()              }          ) +        now = datetime.utcnow() - timedelta(seconds=1) +          # Confirm to the user that it worked.          await self._send_confirmation( -            ctx, on_success="Your reminder has been created successfully!" +            ctx, +            on_success=f"Your reminder will arrive in {humanize_delta(relativedelta(expiration, now))}!"          )          loop = asyncio.get_event_loop() diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index c0390cb1e..362968bd0 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -115,6 +115,16 @@ class Snekbox(Cog):          return msg, error +    @staticmethod +    def get_status_emoji(results: dict) -> str: +        """Return an emoji corresponding to the status code or lack of output in result.""" +        if not results["stdout"].strip():  # No output +            return ":warning:" +        elif results["returncode"] == 0:  # No error +            return ":white_check_mark:" +        else:  # Exception +            return ":x:" +      async def format_output(self, output: str) -> Tuple[str, Optional[str]]:          """          Format the output and return a tuple of the formatted output and a URL to the full output. @@ -201,7 +211,8 @@ class Snekbox(Cog):                  else:                      output, paste_link = await self.format_output(results["stdout"]) -                msg = f"{ctx.author.mention} {msg}.\n\n```py\n{output}\n```" +                icon = self.get_status_emoji(results) +                msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```"                  if paste_link:                      msg = f"{msg}\nFull output: {paste_link}" diff --git a/bot/constants.py b/bot/constants.py index 6106d911c..fd95712e7 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -329,6 +329,7 @@ class Channels(metaclass=YAMLGetter):      subsection = "channels"      admins: int +    admin_spam: int      announcements: int      big_brother_logs: int      bot: int @@ -347,11 +348,14 @@ class Channels(metaclass=YAMLGetter):      helpers: int      message_log: int      meta: int +    mod_spam: int +    mods: int      mod_alerts: int      modlog: int      off_topic_0: int      off_topic_1: int      off_topic_2: int +    organisation: int      python: int      reddit: int      talent_pool: int @@ -366,6 +370,7 @@ class Webhooks(metaclass=YAMLGetter):      talent_pool: int      big_brother: int +    reddit: int  class Roles(metaclass=YAMLGetter): @@ -393,6 +398,7 @@ class Guild(metaclass=YAMLGetter):      id: int      ignored: List[int] +    staff_channels: List[int]  class Keys(metaclass=YAMLGetter): @@ -440,7 +446,6 @@ class URLs(metaclass=YAMLGetter):  class Reddit(metaclass=YAMLGetter):      section = "reddit" -    request_delay: int      subreddits: list @@ -508,6 +513,12 @@ PROJECT_ROOT = os.path.abspath(os.path.join(BOT_DIR, os.pardir))  MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner  STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner +# Roles combinations +STAFF_CHANNELS = Guild.staff_channels + +# Default Channel combinations +MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam +  # Bot replies  NEGATIVE_REPLIES = [ diff --git a/bot/utils/checks.py b/bot/utils/checks.py index ad892e512..db56c347c 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -38,9 +38,9 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool:      return check -def in_channel_check(ctx: Context, channel_id: int) -> bool: -    """Checks if the command was executed inside of the specified channel.""" -    check = ctx.channel.id == channel_id +def in_channel_check(ctx: Context, *channel_ids: int) -> bool: +    """Checks if the command was executed inside the list of specified channels.""" +    check = ctx.channel.id in channel_ids      log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "                f"The result of the in_channel check was {check}.")      return check diff --git a/config-default.yml b/config-default.yml index fc702e991..98638a3e1 100644 --- a/config-default.yml +++ b/config-default.yml @@ -90,11 +90,12 @@ guild:      channels:          admins:            &ADMINS        365960823622991872 +        admin_spam:        &ADMIN_SPAM    563594791770914816          announcements:                    354619224620138496          big_brother_logs:  &BBLOGS        468507907357409333          bot:                              267659945086812160          checkpoint_test:                  422077681434099723 -        defcon:                           464469101889454091 +        defcon:            &DEFCON        464469101889454091          devlog:            &DEVLOG        622895325144940554          devtest:           &DEVTEST       414574275865870337          help_0:                           303906576991780866 @@ -105,14 +106,17 @@ guild:          help_5:                           454941769734422538          help_6:                           587375753306570782          help_7:                           587375768556797982 -        helpers:                          385474242440986624 +        helpers:           &HELPERS       385474242440986624          message_log:       &MESSAGE_LOG   467752170159079424          meta:                             429409067623251969 +        mod_spam:          &MOD_SPAM      620607373828030464 +        mods:              &MODS          305126844661760000          mod_alerts:                       473092532147060736          modlog:            &MODLOG        282638479504965634          off_topic_0:                      291284109232308226          off_topic_1:                      463035241142026251          off_topic_2:                      463035268514185226 +        organisation:      &ORGANISATION  551789653284356126          python:                           267624335836053506          reddit:                           458224812528238616          staff_lounge:      &STAFF_LOUNGE  464905259261755392 @@ -121,6 +125,7 @@ guild:          user_event_a:      &USER_EVENT_A  592000283102674944          verification:                     352442727016693763 +    staff_channels: [*ADMINS, *ADMIN_SPAM, *MOD_SPAM, *MODS, *HELPERS, *ORGANISATION, *DEFCON]      ignored: [*ADMINS, *MESSAGE_LOG, *MODLOG]      roles: @@ -142,6 +147,7 @@ guild:      webhooks:          talent_pool:                        569145364800602132          big_brother:                        569133704568373283 +        reddit:                             635408384794951680  filter: @@ -346,7 +352,6 @@ anti_malware:  reddit: -    request_delay: 60      subreddits:          - 'r/Python' diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 22dc93073..19b758336 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -41,3 +41,11 @@ class ChecksTests(unittest.TestCase):          role_id = 42          self.ctx.author.roles.append(MockRole(role_id=role_id))          self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) + +    def test_in_channel_check_for_correct_channel(self): +        self.ctx.channel.id = 42 +        self.assertTrue(checks.in_channel_check(self.ctx, *[42])) + +    def test_in_channel_check_for_incorrect_channel(self): +        self.ctx.channel.id = 42 + 10 +        self.assertFalse(checks.in_channel_check(self.ctx, *[42])) | 
