diff options
| author | 2021-05-05 17:58:56 +0100 | |
|---|---|---|
| committer | 2021-05-05 17:58:56 +0100 | |
| commit | f1e26927636be455ad33ffbea5e5b5d24b3c5505 (patch) | |
| tree | 5e52c7106ee820d4f178c669f8b96fc47e2ed501 | |
| parent | Merge branch 'eval-redirect-output' of https://github.com/ToxicKidz/bot into ... (diff) | |
| parent | Merge pull request #1441 from asleep-cult/master (diff) | |
Merge branch 'main' into eval-redirect-output
| -rw-r--r-- | bot/constants.py | 11 | ||||
| -rw-r--r-- | bot/exts/backend/error_handler.py | 22 | ||||
| -rw-r--r-- | bot/exts/info/code_snippets.py | 265 | ||||
| -rw-r--r-- | bot/exts/info/information.py | 5 | ||||
| -rw-r--r-- | bot/exts/info/reddit.py | 6 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 18 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/superstarify.py | 8 | ||||
| -rw-r--r-- | bot/exts/moderation/modlog.py | 9 | ||||
| -rw-r--r-- | bot/exts/moderation/modpings.py | 138 | ||||
| -rw-r--r-- | bot/exts/moderation/stream.py | 80 | ||||
| -rw-r--r-- | bot/exts/utils/reminders.py | 13 | ||||
| -rw-r--r-- | bot/exts/utils/utils.py | 2 | ||||
| -rw-r--r-- | bot/log.py | 37 | ||||
| -rw-r--r-- | bot/resources/tags/identity.md | 24 | ||||
| -rw-r--r-- | bot/resources/tags/str-join.md | 28 | ||||
| -rw-r--r-- | config-default.yml | 15 | ||||
| -rw-r--r-- | tests/README.md | 2 | ||||
| -rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 550 | ||||
| -rw-r--r-- | tests/bot/exts/info/test_information.py | 2 | ||||
| -rw-r--r-- | tests/helpers.py | 2 | 
20 files changed, 1183 insertions, 54 deletions
| diff --git a/bot/constants.py b/bot/constants.py index 6d14bbb3a..7b2a38079 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -175,13 +175,14 @@ class YAMLGetter(type):              if cls.subsection is not None:                  return _CONFIG_YAML[cls.section][cls.subsection][name]              return _CONFIG_YAML[cls.section][name] -        except KeyError: +        except KeyError as e:              dotted_path = '.'.join(                  (cls.section, cls.subsection, name)                  if cls.subsection is not None else (cls.section, name)              ) -            log.critical(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.") -            raise +            # Only an INFO log since this can be caught through `hasattr` or `getattr`. +            log.info(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.") +            raise AttributeError(repr(name)) from e      def __getitem__(cls, name):          return cls.__getattr__(name) @@ -199,6 +200,7 @@ class Bot(metaclass=YAMLGetter):      prefix: str      sentry_dsn: Optional[str]      token: str +    trace_loggers: Optional[str]  class Redis(metaclass=YAMLGetter): @@ -279,6 +281,8 @@ class Emojis(metaclass=YAMLGetter):      badge_partner: str      badge_staff: str      badge_verified_bot_developer: str +    verified_bot: str +    bot: str      defcon_shutdown: str  # noqa: E704      defcon_unshutdown: str  # noqa: E704 @@ -491,6 +495,7 @@ class Roles(metaclass=YAMLGetter):      domain_leads: int      helpers: int      moderators: int +    mod_team: int      owners: int      project_leads: int diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index da0e94a7e..d8de177f5 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,4 +1,3 @@ -import contextlib  import difflib  import logging  import typing as t @@ -60,7 +59,7 @@ class ErrorHandler(Cog):              log.trace(f"Command {command} had its error already handled locally; ignoring.")              return -        if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): +        if isinstance(e, errors.CommandNotFound) and not getattr(ctx, "invoked_from_error_handler", False):              if await self.try_silence(ctx):                  return              # Try to look for a tag with the command's name @@ -162,9 +161,8 @@ class ErrorHandler(Cog):                  f"and the fallback tag failed validation in TagNameConverter."              )          else: -            with contextlib.suppress(ResponseCodeError): -                if await ctx.invoke(tags_get_command, tag_name=tag_name): -                    return +            if await ctx.invoke(tags_get_command, tag_name=tag_name): +                return          if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):              await self.send_command_suggestion(ctx, ctx.invoked_with) @@ -214,32 +212,30 @@ class ErrorHandler(Cog):          * ArgumentParsingError: send an error message          * Other: send an error message and the help command          """ -        prepared_help_command = self.get_help_command(ctx) -          if isinstance(e, errors.MissingRequiredArgument):              embed = self._get_error_embed("Missing required argument", e.param.name)              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.missing_required_argument")          elif isinstance(e, errors.TooManyArguments):              embed = self._get_error_embed("Too many arguments", str(e))              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.too_many_arguments")          elif isinstance(e, errors.BadArgument):              embed = self._get_error_embed("Bad argument", str(e))              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.bad_argument")          elif isinstance(e, errors.BadUnionArgument):              embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}")              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.bad_union_argument")          elif isinstance(e, errors.ArgumentParsingError):              embed = self._get_error_embed("Argument parsing error", str(e))              await ctx.send(embed=embed) -            prepared_help_command.close() +            self.get_help_command(ctx).close()              self.bot.stats.incr("errors.argument_parsing_error")          else:              embed = self._get_error_embed( @@ -247,7 +243,7 @@ class ErrorHandler(Cog):                  "Something about your input seems off. Check the arguments and try again."              )              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.other_user_input_error")      @staticmethod diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py new file mode 100644 index 000000000..06885410b --- /dev/null +++ b/bot/exts/info/code_snippets.py @@ -0,0 +1,265 @@ +import logging +import re +import textwrap +from typing import Any +from urllib.parse import quote_plus + +from aiohttp import ClientResponseError +from discord import Message +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +GITHUB_RE = re.compile( +    r'https://github\.com/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/blob/' +    r'(?P<path>[^#>]+)(\?[^#>]+)?(#L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)' +) + +GITHUB_GIST_RE = re.compile( +    r'https://gist\.github\.com/([a-zA-Z0-9-]+)/(?P<gist_id>[a-zA-Z0-9]+)/*' +    r'(?P<revision>[a-zA-Z0-9]*)/*#file-(?P<file_path>[^#>]+?)(\?[^#>]+)?' +    r'(-L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)' +) + +GITHUB_HEADERS = {'Accept': 'application/vnd.github.v3.raw'} + +GITLAB_RE = re.compile( +    r'https://gitlab\.com/(?P<repo>[\w.-]+/[\w.-]+)/\-/blob/(?P<path>[^#>]+)' +    r'(\?[^#>]+)?(#L(?P<start_line>\d+)(-(?P<end_line>\d+))?)' +) + +BITBUCKET_RE = re.compile( +    r'https://bitbucket\.org/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/src/(?P<ref>[0-9a-zA-Z]+)' +    r'/(?P<file_path>[^#>]+)(\?[^#>]+)?(#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?)' +) + + +class CodeSnippets(Cog): +    """ +    Cog that parses and sends code snippets to Discord. + +    Matches each message against a regex and prints the contents of all matched snippets. +    """ + +    async def _fetch_response(self, url: str, response_format: str, **kwargs) -> Any: +        """Makes http requests using aiohttp.""" +        async with self.bot.http_session.get(url, raise_for_status=True, **kwargs) as response: +            if response_format == 'text': +                return await response.text() +            elif response_format == 'json': +                return await response.json() + +    def _find_ref(self, path: str, refs: tuple) -> tuple: +        """Loops through all branches and tags to find the required ref.""" +        # Base case: there is no slash in the branch name +        ref, file_path = path.split('/', 1) +        # In case there are slashes in the branch name, we loop through all branches and tags +        for possible_ref in refs: +            if path.startswith(possible_ref['name'] + '/'): +                ref = possible_ref['name'] +                file_path = path[len(ref) + 1:] +                break +        return ref, file_path + +    async def _fetch_github_snippet( +        self, +        repo: str, +        path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitHub repo.""" +        # Search the GitHub API for the specified branch +        branches = await self._fetch_response( +            f'https://api.github.com/repos/{repo}/branches', +            'json', +            headers=GITHUB_HEADERS +        ) +        tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=GITHUB_HEADERS) +        refs = branches + tags +        ref, file_path = self._find_ref(path, refs) + +        file_contents = await self._fetch_response( +            f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}', +            'text', +            headers=GITHUB_HEADERS, +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    async def _fetch_github_gist_snippet( +        self, +        gist_id: str, +        revision: str, +        file_path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitHub gist.""" +        gist_json = await self._fetch_response( +            f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}', +            'json', +            headers=GITHUB_HEADERS, +        ) + +        # Check each file in the gist for the specified file +        for gist_file in gist_json['files']: +            if file_path == gist_file.lower().replace('.', '-'): +                file_contents = await self._fetch_response( +                    gist_json['files'][gist_file]['raw_url'], +                    'text', +                ) +                return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line) +        return '' + +    async def _fetch_gitlab_snippet( +        self, +        repo: str, +        path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitLab repo.""" +        enc_repo = quote_plus(repo) + +        # Searches the GitLab API for the specified branch +        branches = await self._fetch_response( +            f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/branches', +            'json' +        ) +        tags = await self._fetch_response(f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags', 'json') +        refs = branches + tags +        ref, file_path = self._find_ref(path, refs) +        enc_ref = quote_plus(ref) +        enc_file_path = quote_plus(file_path) + +        file_contents = await self._fetch_response( +            f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}', +            'text', +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    async def _fetch_bitbucket_snippet( +        self, +        repo: str, +        ref: str, +        file_path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a BitBucket repo.""" +        file_contents = await self._fetch_response( +            f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}', +            'text', +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str: +        """ +        Given the entire file contents and target lines, creates a code block. + +        First, we split the file contents into a list of lines and then keep and join only the required +        ones together. + +        We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent +        markdown injection. + +        Finally, we surround the code with ``` characters. +        """ +        # Parse start_line and end_line into integers +        if end_line is None: +            start_line = end_line = int(start_line) +        else: +            start_line = int(start_line) +            end_line = int(end_line) + +        split_file_contents = file_contents.splitlines() + +        # Make sure that the specified lines are in range +        if start_line > end_line: +            start_line, end_line = end_line, start_line +        if start_line > len(split_file_contents) or end_line < 1: +            return '' +        start_line = max(1, start_line) +        end_line = min(len(split_file_contents), end_line) + +        # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection +        required = '\n'.join(split_file_contents[start_line - 1:end_line]) +        required = textwrap.dedent(required).rstrip().replace('`', '`\u200b') + +        # Extracts the code language and checks whether it's a "valid" language +        language = file_path.split('/')[-1].split('.')[-1] +        trimmed_language = language.replace('-', '').replace('+', '').replace('_', '') +        is_valid_language = trimmed_language.isalnum() +        if not is_valid_language: +            language = '' + +        # Adds a label showing the file path to the snippet +        if start_line == end_line: +            ret = f'`{file_path}` line {start_line}\n' +        else: +            ret = f'`{file_path}` lines {start_line} to {end_line}\n' + +        if len(required) != 0: +            return f'{ret}```{language}\n{required}```' +        # Returns an empty codeblock if the snippet is empty +        return f'{ret}``` ```' + +    def __init__(self, bot: Bot): +        """Initializes the cog's bot.""" +        self.bot = bot + +        self.pattern_handlers = [ +            (GITHUB_RE, self._fetch_github_snippet), +            (GITHUB_GIST_RE, self._fetch_github_gist_snippet), +            (GITLAB_RE, self._fetch_gitlab_snippet), +            (BITBUCKET_RE, self._fetch_bitbucket_snippet) +        ] + +    @Cog.listener() +    async def on_message(self, message: Message) -> None: +        """Checks if the message has a snippet link, removes the embed, then sends the snippet contents.""" +        if not message.author.bot: +            all_snippets = [] + +            for pattern, handler in self.pattern_handlers: +                for match in pattern.finditer(message.content): +                    try: +                        snippet = await handler(**match.groupdict()) +                        all_snippets.append((match.start(), snippet)) +                    except ClientResponseError as error: +                        error_message = error.message  # noqa: B306 +                        log.log( +                            logging.DEBUG if error.status == 404 else logging.ERROR, +                            f'Failed to fetch code snippet from {match[0]!r}: {error.status} ' +                            f'{error_message} for GET {error.request_info.real_url.human_repr()}' +                        ) + +            # Sorts the list of snippets by their match index and joins them into a single message +            message_to_send = '\n'.join(map(lambda x: x[1], sorted(all_snippets))) + +            if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: +                await message.edit(suppress=True) +                if len(message_to_send) > 1000 and message.channel.id != Channels.bot_commands: +                    # Redirects to #bot-commands if the snippet contents are too long +                    await self.bot.wait_until_guild_available() +                    await message.channel.send(('The snippet you tried to send was too long. Please ' +                                                f'see <#{Channels.bot_commands}> for the full snippet.')) +                    bot_commands_channel = self.bot.get_channel(Channels.bot_commands) +                    await wait_for_deletion( +                        await bot_commands_channel.send(message_to_send), +                        (message.author.id,) +                    ) +                else: +                    await wait_for_deletion( +                        await message.channel.send(message_to_send), +                        (message.author.id,) +                    ) + + +def setup(bot: Bot) -> None: +    """Load the CodeSnippets cog.""" +    bot.add_cog(CodeSnippets(bot)) diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 5e2c4b417..834fee1b4 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -230,6 +230,11 @@ class Information(Cog):          if on_server and user.nick:              name = f"{user.nick} ({name})" +        if user.public_flags.verified_bot: +            name += f" {constants.Emojis.verified_bot}" +        elif user.bot: +            name += f" {constants.Emojis.bot}" +          badges = []          for badge, is_set in user.public_flags: diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py index 6790be762..e1f0c5f9f 100644 --- a/bot/exts/info/reddit.py +++ b/bot/exts/info/reddit.py @@ -4,6 +4,7 @@ import random  import textwrap  from collections import namedtuple  from datetime import datetime, timedelta +from html import unescape  from typing import List  from aiohttp import BasicAuth, ClientError @@ -179,8 +180,7 @@ class Reddit(Cog):          for post in posts:              data = post["data"] -            text = data["selftext"] -            if text: +            if text := unescape(data["selftext"]):                  text = textwrap.shorten(text, width=128, placeholder="...")                  text += "\n"  # Add newline to separate embed info @@ -188,7 +188,7 @@ class Reddit(Cog):              comments = data["num_comments"]              author = data["author"] -            title = textwrap.shorten(data["title"], width=64, placeholder="...") +            title = textwrap.shorten(unescape(data["title"]), width=64, placeholder="...")              # Normal brackets interfere with Markdown.              title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌")              link = self.URL + data["permalink"] diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index d89e80acc..38d1ffc0e 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -54,8 +54,12 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Permanent infractions      @command() -    async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: +    async def warn(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:          """Warn a user for the given reason.""" +        if not isinstance(user, Member): +            await ctx.send(":x: The user doesn't appear to be on the server.") +            return +          infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False)          if infraction is None:              return @@ -63,8 +67,12 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_infraction(ctx, infraction, user)      @command() -    async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: +    async def kick(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:          """Kick a user for the given reason.""" +        if not isinstance(user, Member): +            await ctx.send(":x: The user doesn't appear to be on the server.") +            return +          await self.apply_kick(ctx, user, reason)      @command() @@ -100,7 +108,7 @@ class Infractions(InfractionScheduler, commands.Cog):      @command(aliases=["mute"])      async def tempmute(          self, ctx: Context, -        user: Member, +        user: FetchedMember,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] = None @@ -122,6 +130,10 @@ class Infractions(InfractionScheduler, commands.Cog):          If no duration is given, a one hour duration is used by default.          """ +        if not isinstance(user, Member): +            await ctx.send(":x: The user doesn't appear to be on the server.") +            return +          if duration is None:              duration = await Duration().convert(ctx, "1h")          await self.apply_mute(ctx, user, reason, expires_at=duration) diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 704dddf9c..07e79b9fe 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -11,7 +11,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Expiry +from bot.converters import Duration, Expiry  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler  from bot.utils.messages import format_user @@ -19,6 +19,7 @@ from bot.utils.time import format_infraction  log = logging.getLogger(__name__)  NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" +SUPERSTARIFY_DEFAULT_DURATION = "1h"  with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file:      STAR_NAMES = json.load(stars_file) @@ -109,7 +110,7 @@ class Superstarify(InfractionScheduler, Cog):          self,          ctx: Context,          member: Member, -        duration: Expiry, +        duration: t.Optional[Expiry],          *,          reason: str = '',      ) -> None: @@ -134,6 +135,9 @@ class Superstarify(InfractionScheduler, Cog):          if await _utils.get_active_infraction(ctx, member, "superstar"):              return +        # Set to default duration if none was provided. +        duration = duration or await Duration().convert(ctx, SUPERSTARIFY_DEFAULT_DURATION) +          # Post the infraction to the API          old_nick = member.display_name          infraction_reason = f'Old nickname: {old_nick}. {reason}' diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index 2dae9d268..e92f76c9a 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -14,7 +14,7 @@ from discord.abc import GuildChannel  from discord.ext.commands import Cog, Context  from bot.bot import Bot -from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs +from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs  from bot.utils.messages import format_user  from bot.utils.time import humanize_delta @@ -115,9 +115,9 @@ class ModLog(Cog, name="ModLog"):          if ping_everyone:              if content: -                content = f"@everyone\n{content}" +                content = f"<@&{Roles.moderators}>\n{content}"              else: -                content = "@everyone" +                content = f"<@&{Roles.moderators}>"          # Truncate content to 2000 characters and append an ellipsis.          if content and len(content) > 2000: @@ -127,8 +127,7 @@ class ModLog(Cog, name="ModLog"):          log_message = await channel.send(              content=content,              embed=embed, -            files=files, -            allowed_mentions=discord.AllowedMentions(everyone=True) +            files=files          )          if additional_embeds: diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py new file mode 100644 index 000000000..1ad5005de --- /dev/null +++ b/bot/exts/moderation/modpings.py @@ -0,0 +1,138 @@ +import datetime +import logging + +from async_rediscache import RedisCache +from dateutil.parser import isoparse +from discord import Embed, Member +from discord.ext.commands import Cog, Context, group, has_any_role + +from bot.bot import Bot +from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles +from bot.converters import Expiry +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + + +class ModPings(Cog): +    """Commands for a moderator to turn moderator pings on and off.""" + +    # RedisCache[discord.Member.id, 'Naïve ISO 8601 string'] +    # The cache's keys are mods who have pings off. +    # The cache's values are the times when the role should be re-applied to them, stored in ISO format. +    pings_off_mods = RedisCache() + +    def __init__(self, bot: Bot): +        self.bot = bot +        self._role_scheduler = Scheduler(self.__class__.__name__) + +        self.guild = None +        self.moderators_role = None + +        self.reschedule_task = self.bot.loop.create_task(self.reschedule_roles(), name="mod-pings-reschedule") + +    async def reschedule_roles(self) -> None: +        """Reschedule moderators role re-apply times.""" +        await self.bot.wait_until_guild_available() +        self.guild = self.bot.get_guild(Guild.id) +        self.moderators_role = self.guild.get_role(Roles.moderators) + +        mod_team = self.guild.get_role(Roles.mod_team) +        pings_on = self.moderators_role.members +        pings_off = await self.pings_off_mods.to_dict() + +        log.trace("Applying the moderators role to the mod team where necessary.") +        for mod in mod_team.members: +            if mod in pings_on:  # Make sure that on-duty mods aren't in the cache. +                if mod in pings_off: +                    await self.pings_off_mods.delete(mod.id) +                continue + +            # Keep the role off only for those in the cache. +            if mod.id not in pings_off: +                await self.reapply_role(mod) +            else: +                expiry = isoparse(pings_off[mod.id]).replace(tzinfo=None) +                self._role_scheduler.schedule_at(expiry, mod.id, self.reapply_role(mod)) + +    async def reapply_role(self, mod: Member) -> None: +        """Reapply the moderator's role to the given moderator.""" +        log.trace(f"Re-applying role to mod with ID {mod.id}.") +        await mod.add_roles(self.moderators_role, reason="Pings off period expired.") + +    @group(name='modpings', aliases=('modping',), invoke_without_command=True) +    @has_any_role(*MODERATION_ROLES) +    async def modpings_group(self, ctx: Context) -> None: +        """Allow the removal and re-addition of the pingable moderators role.""" +        await ctx.send_help(ctx.command) + +    @modpings_group.command(name='off') +    @has_any_role(*MODERATION_ROLES) +    async def off_command(self, ctx: Context, duration: Expiry) -> None: +        """ +        Temporarily removes the pingable moderators role for a set amount of time. + +        A unit of time should be appended to the duration. +        Units (∗case-sensitive): +        \u2003`y` - years +        \u2003`m` - months∗ +        \u2003`w` - weeks +        \u2003`d` - days +        \u2003`h` - hours +        \u2003`M` - minutes∗ +        \u2003`s` - seconds + +        Alternatively, an ISO 8601 timestamp can be provided for the duration. + +        The duration cannot be longer than 30 days. +        """ +        duration: datetime.datetime +        delta = duration - datetime.datetime.utcnow() +        if delta > datetime.timedelta(days=30): +            await ctx.send(":x: Cannot remove the role for longer than 30 days.") +            return + +        mod = ctx.author + +        until_date = duration.replace(microsecond=0).isoformat()  # Looks noisy with microseconds. +        await mod.remove_roles(self.moderators_role, reason=f"Turned pings off until {until_date}.") + +        await self.pings_off_mods.set(mod.id, duration.isoformat()) + +        # Allow rescheduling the task without cancelling it separately via the `on` command. +        if mod.id in self._role_scheduler: +            self._role_scheduler.cancel(mod.id) +        self._role_scheduler.schedule_at(duration, mod.id, self.reapply_role(mod)) + +        embed = Embed(timestamp=duration, colour=Colours.bright_green) +        embed.set_footer(text="Moderators role has been removed until", icon_url=Icons.green_checkmark) +        await ctx.send(embed=embed) + +    @modpings_group.command(name='on') +    @has_any_role(*MODERATION_ROLES) +    async def on_command(self, ctx: Context) -> None: +        """Re-apply the pingable moderators role.""" +        mod = ctx.author +        if mod in self.moderators_role.members: +            await ctx.send(":question: You already have the role.") +            return + +        await mod.add_roles(self.moderators_role, reason="Pings off period canceled.") + +        await self.pings_off_mods.delete(mod.id) + +        # We assume the task exists. Lack of it may indicate a bug. +        self._role_scheduler.cancel(mod.id) + +        await ctx.send(f"{Emojis.check_mark} Moderators role has been re-applied.") + +    def cog_unload(self) -> None: +        """Cancel role tasks when the cog unloads.""" +        log.trace("Cog unload: canceling role tasks.") +        self.reschedule_task.cancel() +        self._role_scheduler.cancel_all() + + +def setup(bot: Bot) -> None: +    """Load the ModPings cog.""" +    bot.add_cog(ModPings(bot)) diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 12e195172..fd856a7f4 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -1,5 +1,6 @@  import logging  from datetime import timedelta, timezone +from operator import itemgetter  import arrow  import discord @@ -8,8 +9,9 @@ from async_rediscache import RedisCache  from discord.ext import commands  from bot.bot import Bot -from bot.constants import Colours, Emojis, Guild, Roles, STAFF_ROLES, VideoPermission +from bot.constants import Colours, Emojis, Guild, MODERATION_ROLES, Roles, STAFF_ROLES, VideoPermission  from bot.converters import Expiry +from bot.pagination import LinePaginator  from bot.utils.scheduling import Scheduler  from bot.utils.time import format_infraction_with_duration @@ -68,8 +70,30 @@ class Stream(commands.Cog):                  self._revoke_streaming_permission(member)              ) +    async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None: +        """Suspend a member's stream.""" +        await self.bot.wait_until_guild_available() +        voice_state = member.voice + +        if not voice_state: +            return + +        # If the user is streaming. +        if voice_state.self_stream: +            # End user's stream by moving them to AFK voice channel and back. +            original_vc = voice_state.channel +            await member.move_to(ctx.guild.afk_channel) +            await member.move_to(original_vc) + +            # Notify. +            await ctx.send(f"{member.mention}'s stream has been suspended!") +            log.debug(f"Successfully suspended stream from {member} ({member.id}).") +            return + +        log.debug(f"No stream found to suspend from {member} ({member.id}).") +      @commands.command(aliases=("streaming",)) -    @commands.has_any_role(*STAFF_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None:          """          Temporarily grant streaming permissions to a member for a given duration. @@ -126,7 +150,7 @@ class Stream(commands.Cog):          log.debug(f"Successfully gave {member} ({member.id}) permission to stream until {revoke_time}.")      @commands.command(aliases=("pstream",)) -    @commands.has_any_role(*STAFF_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def permanentstream(self, ctx: commands.Context, member: discord.Member) -> None:          """Permanently grants the given member the permission to stream."""          log.trace(f"Attempting to give permanent streaming permission to {member} ({member.id}).") @@ -153,7 +177,7 @@ class Stream(commands.Cog):          log.debug(f"Successfully gave {member} ({member.id}) permanent streaming permission.")      @commands.command(aliases=("unstream", "rstream")) -    @commands.has_any_role(*STAFF_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def revokestream(self, ctx: commands.Context, member: discord.Member) -> None:          """Revoke the permission to stream from the given member."""          log.trace(f"Attempting to remove streaming permission from {member} ({member.id}).") @@ -168,10 +192,52 @@ class Stream(commands.Cog):              await ctx.send(f"{Emojis.check_mark} Revoked the permission to stream from {member.mention}.")              log.debug(f"Successfully revoked streaming permission from {member} ({member.id}).") -            return -        await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!") -        log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!") +        else: +            await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!") +            log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!") + +        await self._suspend_stream(ctx, member) + +    @commands.command(aliases=('lstream',)) +    @commands.has_any_role(*MODERATION_ROLES) +    async def liststream(self, ctx: commands.Context) -> None: +        """Lists all non-staff users who have permission to stream.""" +        non_staff_members_with_stream = [ +            member +            for member in ctx.guild.get_role(Roles.video).members +            if not any(role.id in STAFF_ROLES for role in member.roles) +        ] + +        # List of tuples (UtcPosixTimestamp, str) +        # So that the list can be sorted on the UtcPosixTimestamp before the message is passed to the paginator. +        streamer_info = [] +        for member in non_staff_members_with_stream: +            if revoke_time := await self.task_cache.get(member.id): +                # Member only has temporary streaming perms +                revoke_delta = Arrow.utcfromtimestamp(revoke_time).humanize() +                message = f"{member.mention} will have stream permissions revoked {revoke_delta}." +            else: +                message = f"{member.mention} has permanent streaming permissions." + +            # If revoke_time is None use max timestamp to force sort to put them at the end +            streamer_info.append( +                (revoke_time or Arrow.max.timestamp(), message) +            ) + +        if streamer_info: +            # Sort based on duration left of streaming perms +            streamer_info.sort(key=itemgetter(0)) + +            # Only output the message in the pagination +            lines = [line[1] for line in streamer_info] +            embed = discord.Embed( +                title=f"Members with streaming permission (`{len(lines)}` total)", +                colour=Colours.soft_green +            ) +            await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) +        else: +            await ctx.send("No members with stream permissions found.")  def setup(bot: Bot) -> None: diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 3113a1149..6c21920a1 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -90,15 +90,18 @@ class Reminders(Cog):          delivery_dt: t.Optional[datetime],      ) -> None:          """Send an embed confirming the reminder change was made successfully.""" -        embed = discord.Embed() -        embed.colour = discord.Colour.green() -        embed.title = random.choice(POSITIVE_REPLIES) -        embed.description = on_success +        embed = discord.Embed( +            description=on_success, +            colour=discord.Colour.green(), +            title=random.choice(POSITIVE_REPLIES) +        )          footer_str = f"ID: {reminder_id}" +          if delivery_dt:              # Reminder deletion will have a `None` `delivery_dt` -            footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" +            footer_str += ', Due' +            embed.timestamp = delivery_dt          embed.set_footer(text=footer_str) diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index 8d9d27c64..4c39a7c2a 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -109,7 +109,7 @@ class Utils(Cog):          # handle if it's an index int          if isinstance(search_value, int):              upper_bound = len(zen_lines) - 1 -            lower_bound = -1 * upper_bound +            lower_bound = -1 * len(zen_lines)              if not (lower_bound <= search_value <= upper_bound):                  raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") diff --git a/bot/log.py b/bot/log.py index e92233a33..4e20c005e 100644 --- a/bot/log.py +++ b/bot/log.py @@ -20,7 +20,6 @@ def setup() -> None:      logging.addLevelName(TRACE_LEVEL, "TRACE")      Logger.trace = _monkeypatch_trace -    log_level = TRACE_LEVEL if constants.DEBUG_MODE else logging.INFO      format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"      log_format = logging.Formatter(format_string) @@ -30,7 +29,6 @@ def setup() -> None:      file_handler.setFormatter(log_format)      root_log = logging.getLogger() -    root_log.setLevel(log_level)      root_log.addHandler(file_handler)      if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: @@ -44,11 +42,9 @@ def setup() -> None:      if "COLOREDLOGS_LOG_FORMAT" not in os.environ:          coloredlogs.DEFAULT_LOG_FORMAT = format_string -    if "COLOREDLOGS_LOG_LEVEL" not in os.environ: -        coloredlogs.DEFAULT_LOG_LEVEL = log_level - -    coloredlogs.install(logger=root_log, stream=sys.stdout) +    coloredlogs.install(level=logging.TRACE, logger=root_log, stream=sys.stdout) +    root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO)      logging.getLogger("discord").setLevel(logging.WARNING)      logging.getLogger("websockets").setLevel(logging.WARNING)      logging.getLogger("chardet").setLevel(logging.WARNING) @@ -57,6 +53,8 @@ def setup() -> None:      # Set back to the default of INFO even if asyncio's debug mode is enabled.      logging.getLogger("asyncio").setLevel(logging.INFO) +    _set_trace_loggers() +  def setup_sentry() -> None:      """Set up the Sentry logging integrations.""" @@ -86,3 +84,30 @@ def _monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:      """      if self.isEnabledFor(TRACE_LEVEL):          self._log(TRACE_LEVEL, msg, args, **kwargs) + + +def _set_trace_loggers() -> None: +    """ +    Set loggers to the trace level according to the value from the BOT_TRACE_LOGGERS env var. + +    When the env var is a list of logger names delimited by a comma, +    each of the listed loggers will be set to the trace level. + +    If this list is prefixed with a "!", all of the loggers except the listed ones will be set to the trace level. + +    Otherwise if the env var begins with a "*", +    the root logger is set to the trace level and other contents are ignored. +    """ +    level_filter = constants.Bot.trace_loggers +    if level_filter: +        if level_filter.startswith("*"): +            logging.getLogger().setLevel(logging.TRACE) + +        elif level_filter.startswith("!"): +            logging.getLogger().setLevel(logging.TRACE) +            for logger_name in level_filter.strip("!,").split(","): +                logging.getLogger(logger_name).setLevel(logging.DEBUG) + +        else: +            for logger_name in level_filter.strip(",").split(","): +                logging.getLogger(logger_name).setLevel(logging.TRACE) diff --git a/bot/resources/tags/identity.md b/bot/resources/tags/identity.md new file mode 100644 index 000000000..fb2010759 --- /dev/null +++ b/bot/resources/tags/identity.md @@ -0,0 +1,24 @@ +**Identity vs. Equality** + +Should I be using `is` or `==`? + +To check if two objects are equal, use the equality operator (`==`). +```py +x = 5 +if x == 5: +    print("x equals 5") +if x == 3: +    print("x equals 3") +# Prints 'x equals 5' +``` +To check if two objects are actually the same thing in memory, use the identity comparison operator (`is`). +```py +list_1 = [1, 2, 3] +list_2 = [1, 2, 3] +if list_1 is [1, 2, 3]: +    print("list_1 is list_2") +reference_to_list_1 = list_1 +if list_1 is reference_to_list_1: +    print("list_1 is reference_to_list_1") +# Prints 'list_1 is reference_to_list_1' +``` diff --git a/bot/resources/tags/str-join.md b/bot/resources/tags/str-join.md new file mode 100644 index 000000000..c835f9313 --- /dev/null +++ b/bot/resources/tags/str-join.md @@ -0,0 +1,28 @@ +**Joining Iterables** + +If you want to display a list (or some other iterable), you can write: +```py +colors = ['red', 'green', 'blue', 'yellow'] +output = "" +separator = ", " +for color in colors: +    output += color + separator +print(output) +# Prints 'red, green, blue, yellow, ' +``` +However, the separator is still added to the last element, and it is relatively slow. + +A better solution is to use `str.join`. +```py +colors = ['red', 'green', 'blue', 'yellow'] +separator = ", " +print(separator.join(colors)) +# Prints 'red, green, blue, yellow' +``` +An important thing to note is that you can only `str.join` strings. For a list of ints, +you must convert each element to a string before joining. +```py +integers = [1, 3, 6, 10, 15] +print(", ".join(str(e) for e in integers)) +# Prints '1, 3, 6, 10, 15' +``` diff --git a/config-default.yml b/config-default.yml index 8c6e18470..46475f845 100644 --- a/config-default.yml +++ b/config-default.yml @@ -1,7 +1,8 @@  bot: -    prefix:      "!" -    sentry_dsn:  !ENV "BOT_SENTRY_DSN" -    token:       !ENV "BOT_TOKEN" +    prefix:         "!" +    sentry_dsn:     !ENV "BOT_SENTRY_DSN" +    token:          !ENV "BOT_TOKEN" +    trace_loggers:  !ENV "BOT_TRACE_LOGGERS"      clean:          # Maximum number of messages to traverse for clean commands @@ -46,6 +47,8 @@ style:          badge_partner: "<:partner:748666453242413136>"          badge_staff: "<:discord_staff:743882896498098226>"          badge_verified_bot_developer: "<:verified_bot_dev:743882897299210310>" +        bot: "<:bot:812712599464443914>" +        verified_bot: "<:verified_bot:811645219220750347>"          defcon_shutdown:    "<:defcondisabled:470326273952972810>"          defcon_unshutdown:  "<:defconenabled:470326274213150730>" @@ -260,7 +263,8 @@ guild:          devops:                             409416496733880320          domain_leads:                       807415650778742785          helpers:            &HELPERS_ROLE   267630620367257601 -        moderators:         &MODS_ROLE      267629731250176001 +        moderators:         &MODS_ROLE      831776746206265384 +        mod_team:           &MOD_TEAM_ROLE  267629731250176001          owners:             &OWNERS_ROLE    267627879762755584          project_leads:                      815701647526330398 @@ -273,13 +277,14 @@ guild:      moderation_roles:          - *ADMINS_ROLE +        - *MOD_TEAM_ROLE          - *MODS_ROLE          - *OWNERS_ROLE      staff_roles:          - *ADMINS_ROLE          - *HELPERS_ROLE -        - *MODS_ROLE +        - *MOD_TEAM_ROLE          - *OWNERS_ROLE      webhooks: diff --git a/tests/README.md b/tests/README.md index 4f62edd68..092324123 100644 --- a/tests/README.md +++ b/tests/README.md @@ -114,7 +114,7 @@ class BotCogTests(unittest.TestCase):  ### Mocking coroutines -By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8. +By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. The [`AsyncMock`](https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock) that has been [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest) is an asynchronous version of `MagicMock` that can be used anywhere a coroutine is expected.  ### Special mocks for some `discord.py` types diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py new file mode 100644 index 000000000..bd4fb5942 --- /dev/null +++ b/tests/bot/exts/backend/test_error_handler.py @@ -0,0 +1,550 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, call, patch + +from discord.ext.commands import errors + +from bot.api import ResponseCodeError +from bot.errors import InvalidInfractedUser, LockedResourceError +from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.info.tags import Tags +from bot.exts.moderation.silence import Silence +from bot.utils.checks import InWhitelistCheckFailure +from tests.helpers import MockBot, MockContext, MockGuild, MockRole + + +class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): +    """Tests for error handler functionality.""" + +    def setUp(self): +        self.bot = MockBot() +        self.ctx = MockContext(bot=self.bot) + +    async def test_error_handler_already_handled(self): +        """Should not do anything when error is already handled by local error handler.""" +        self.ctx.reset_mock() +        cog = ErrorHandler(self.bot) +        error = errors.CommandError() +        error.handled = "foo" +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.ctx.send.assert_not_awaited() + +    async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): +        """Should try first (un)silence channel, when fail, try to get tag.""" +        error = errors.CommandNotFound() +        test_cases = ( +            { +                "try_silence_return": True, +                "called_try_get_tag": False +            }, +            { +                "try_silence_return": False, +                "called_try_get_tag": False +            }, +            { +                "try_silence_return": False, +                "called_try_get_tag": True +            } +        ) +        cog = ErrorHandler(self.bot) +        cog.try_silence = AsyncMock() +        cog.try_get_tag = AsyncMock() + +        for case in test_cases: +            with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): +                self.ctx.reset_mock() +                cog.try_silence.reset_mock(return_value=True) +                cog.try_get_tag.reset_mock() + +                cog.try_silence.return_value = case["try_silence_return"] +                self.ctx.channel.id = 1234 + +                self.assertIsNone(await cog.on_command_error(self.ctx, error)) + +                if case["try_silence_return"]: +                    cog.try_get_tag.assert_not_awaited() +                    cog.try_silence.assert_awaited_once() +                else: +                    cog.try_silence.assert_awaited_once() +                    cog.try_get_tag.assert_awaited_once() + +                self.ctx.send.assert_not_awaited() + +    async def test_error_handler_command_not_found_error_invoked_by_handler(self): +        """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" +        ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) + +        cog = ErrorHandler(self.bot) +        cog.try_silence = AsyncMock() +        cog.try_get_tag = AsyncMock() + +        error = errors.CommandNotFound() + +        self.assertIsNone(await cog.on_command_error(ctx, error)) + +        cog.try_silence.assert_not_awaited() +        cog.try_get_tag.assert_not_awaited() +        self.ctx.send.assert_not_awaited() + +    async def test_error_handler_user_input_error(self): +        """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" +        self.ctx.reset_mock() +        cog = ErrorHandler(self.bot) +        cog.handle_user_input_error = AsyncMock() +        error = errors.UserInputError() +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + +    async def test_error_handler_check_failure(self): +        """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" +        self.ctx.reset_mock() +        cog = ErrorHandler(self.bot) +        cog.handle_check_failure = AsyncMock() +        error = errors.CheckFailure() +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + +    async def test_error_handler_command_on_cooldown(self): +        """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" +        self.ctx.reset_mock() +        cog = ErrorHandler(self.bot) +        error = errors.CommandOnCooldown(10, 9) +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.ctx.send.assert_awaited_once_with(error) + +    async def test_error_handler_command_invoke_error(self): +        """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" +        cog = ErrorHandler(self.bot) +        cog.handle_api_error = AsyncMock() +        cog.handle_unexpected_error = AsyncMock() +        test_cases = ( +            { +                "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), +                "expect_mock_call": cog.handle_api_error +            }, +            { +                "args": (self.ctx, errors.CommandInvokeError(TypeError)), +                "expect_mock_call": cog.handle_unexpected_error +            }, +            { +                "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), +                "expect_mock_call": "send" +            }, +            { +                "args": (self.ctx, errors.CommandInvokeError(InvalidInfractedUser(self.ctx.author))), +                "expect_mock_call": "send" +            } +        ) + +        for case in test_cases: +            with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): +                self.ctx.send.reset_mock() +                self.assertIsNone(await cog.on_command_error(*case["args"])) +                if case["expect_mock_call"] == "send": +                    self.ctx.send.assert_awaited_once() +                else: +                    case["expect_mock_call"].assert_awaited_once_with( +                        self.ctx, case["args"][1].original +                    ) + +    async def test_error_handler_conversion_error(self): +        """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" +        cog = ErrorHandler(self.bot) +        cog.handle_api_error = AsyncMock() +        cog.handle_unexpected_error = AsyncMock() +        cases = ( +            { +                "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), +                "mock_function_to_call": cog.handle_api_error +            }, +            { +                "error": errors.ConversionError(AsyncMock(), TypeError), +                "mock_function_to_call": cog.handle_unexpected_error +            } +        ) + +        for case in cases: +            with self.subTest(**case): +                self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) +                case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) + +    async def test_error_handler_two_other_errors(self): +        """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" +        cog = ErrorHandler(self.bot) +        cog.handle_unexpected_error = AsyncMock() +        errs = ( +            errors.MaxConcurrencyReached(1, MagicMock()), +            errors.ExtensionError(name="foo") +        ) + +        for err in errs: +            with self.subTest(error=err): +                cog.handle_unexpected_error.reset_mock() +                self.assertIsNone(await cog.on_command_error(self.ctx, err)) +                cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + +    @patch("bot.exts.backend.error_handler.log") +    async def test_error_handler_other_errors(self, log_mock): +        """Should `log.debug` other errors.""" +        cog = ErrorHandler(self.bot) +        error = errors.DisabledCommand()  # Use this just as a other error +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        log_mock.debug.assert_called_once() + + +class TrySilenceTests(unittest.IsolatedAsyncioTestCase): +    """Test for helper functions that handle `CommandNotFound` error.""" + +    def setUp(self): +        self.bot = MockBot() +        self.silence = Silence(self.bot) +        self.bot.get_command.return_value = self.silence.silence +        self.ctx = MockContext(bot=self.bot) +        self.cog = ErrorHandler(self.bot) + +    async def test_try_silence_context_invoked_from_error_handler(self): +        """Should set `Context.invoked_from_error_handler` to `True`.""" +        self.ctx.invoked_with = "foo" +        await self.cog.try_silence(self.ctx) +        self.assertTrue(hasattr(self.ctx, "invoked_from_error_handler")) +        self.assertTrue(self.ctx.invoked_from_error_handler) + +    async def test_try_silence_get_command(self): +        """Should call `get_command` with `silence`.""" +        self.ctx.invoked_with = "foo" +        await self.cog.try_silence(self.ctx) +        self.bot.get_command.assert_called_once_with("silence") + +    async def test_try_silence_no_permissions_to_run(self): +        """Should return `False` because missing permissions.""" +        self.ctx.invoked_with = "foo" +        self.bot.get_command.return_value.can_run = AsyncMock(return_value=False) +        self.assertFalse(await self.cog.try_silence(self.ctx)) + +    async def test_try_silence_no_permissions_to_run_command_error(self): +        """Should return `False` because `CommandError` raised (no permissions).""" +        self.ctx.invoked_with = "foo" +        self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError()) +        self.assertFalse(await self.cog.try_silence(self.ctx)) + +    async def test_try_silence_silencing(self): +        """Should run silence command with correct arguments.""" +        self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) +        test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh") + +        for case in test_cases: +            with self.subTest(message=case): +                self.ctx.reset_mock() +                self.ctx.invoked_with = case +                self.assertTrue(await self.cog.try_silence(self.ctx)) +                self.ctx.invoke.assert_awaited_once_with( +                    self.bot.get_command.return_value, +                    duration=min(case.count("h")*2, 15) +                ) + +    async def test_try_silence_unsilence(self): +        """Should call unsilence command.""" +        self.silence.silence.can_run = AsyncMock(return_value=True) +        test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh") + +        for case in test_cases: +            with self.subTest(message=case): +                self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence) +                self.ctx.reset_mock() +                self.ctx.invoked_with = case +                self.assertTrue(await self.cog.try_silence(self.ctx)) +                self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence) + +    async def test_try_silence_no_match(self): +        """Should return `False` when message don't match.""" +        self.ctx.invoked_with = "foo" +        self.assertFalse(await self.cog.try_silence(self.ctx)) + + +class TryGetTagTests(unittest.IsolatedAsyncioTestCase): +    """Tests for `try_get_tag` function.""" + +    def setUp(self): +        self.bot = MockBot() +        self.ctx = MockContext() +        self.tag = Tags(self.bot) +        self.cog = ErrorHandler(self.bot) +        self.bot.get_command.return_value = self.tag.get_command + +    async def test_try_get_tag_get_command(self): +        """Should call `Bot.get_command` with `tags get` argument.""" +        self.bot.get_command.reset_mock() +        self.ctx.invoked_with = "foo" +        await self.cog.try_get_tag(self.ctx) +        self.bot.get_command.assert_called_once_with("tags get") + +    async def test_try_get_tag_invoked_from_error_handler(self): +        """`self.ctx` should have `invoked_from_error_handler` `True`.""" +        self.ctx.invoked_from_error_handler = False +        self.ctx.invoked_with = "foo" +        await self.cog.try_get_tag(self.ctx) +        self.assertTrue(self.ctx.invoked_from_error_handler) + +    async def test_try_get_tag_no_permissions(self): +        """Test how to handle checks failing.""" +        self.tag.get_command.can_run = AsyncMock(return_value=False) +        self.ctx.invoked_with = "foo" +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + +    async def test_try_get_tag_command_error(self): +        """Should call `on_command_error` when `CommandError` raised.""" +        err = errors.CommandError() +        self.tag.get_command.can_run = AsyncMock(side_effect=err) +        self.cog.on_command_error = AsyncMock() +        self.ctx.invoked_with = "foo" +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) +        self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) + +    @patch("bot.exts.backend.error_handler.TagNameConverter") +    async def test_try_get_tag_convert_success(self, tag_converter): +        """Converting tag should successful.""" +        self.ctx.invoked_with = "foo" +        tag_converter.convert = AsyncMock(return_value="foo") +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) +        tag_converter.convert.assert_awaited_once_with(self.ctx, "foo") +        self.ctx.invoke.assert_awaited_once() + +    @patch("bot.exts.backend.error_handler.TagNameConverter") +    async def test_try_get_tag_convert_fail(self, tag_converter): +        """Converting tag should raise `BadArgument`.""" +        self.ctx.reset_mock() +        self.ctx.invoked_with = "bar" +        tag_converter.convert = AsyncMock(side_effect=errors.BadArgument()) +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) +        self.ctx.invoke.assert_not_awaited() + +    async def test_try_get_tag_ctx_invoke(self): +        """Should call `ctx.invoke` with proper args/kwargs.""" +        self.ctx.reset_mock() +        self.ctx.invoked_with = "foo" +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) +        self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo") + +    async def test_dont_call_suggestion_tag_sent(self): +        """Should never call command suggestion if tag is already sent.""" +        self.ctx.invoked_with = "foo" +        self.ctx.invoke = AsyncMock(return_value=True) +        self.cog.send_command_suggestion = AsyncMock() + +        await self.cog.try_get_tag(self.ctx) +        self.cog.send_command_suggestion.assert_not_awaited() + +    @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234]) +    async def test_dont_call_suggestion_if_user_mod(self): +        """Should not call command suggestion if user is a mod.""" +        self.ctx.invoked_with = "foo" +        self.ctx.invoke = AsyncMock(return_value=False) +        self.ctx.author.roles = [MockRole(id=1234)] +        self.cog.send_command_suggestion = AsyncMock() + +        await self.cog.try_get_tag(self.ctx) +        self.cog.send_command_suggestion.assert_not_awaited() + +    async def test_call_suggestion(self): +        """Should call command suggestion if user is not a mod.""" +        self.ctx.invoked_with = "foo" +        self.ctx.invoke = AsyncMock(return_value=False) +        self.cog.send_command_suggestion = AsyncMock() + +        await self.cog.try_get_tag(self.ctx) +        self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo") + + +class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): +    """Individual error categories handler tests.""" + +    def setUp(self): +        self.bot = MockBot() +        self.ctx = MockContext(bot=self.bot) +        self.cog = ErrorHandler(self.bot) + +    async def test_handle_input_error_handler_errors(self): +        """Should handle each error probably.""" +        test_cases = ( +            { +                "error": errors.MissingRequiredArgument(MagicMock()), +                "call_prepared": True +            }, +            { +                "error": errors.TooManyArguments(), +                "call_prepared": True +            }, +            { +                "error": errors.BadArgument(), +                "call_prepared": True +            }, +            { +                "error": errors.BadUnionArgument(MagicMock(), MagicMock(), MagicMock()), +                "call_prepared": True +            }, +            { +                "error": errors.ArgumentParsingError(), +                "call_prepared": False +            }, +            { +                "error": errors.UserInputError(), +                "call_prepared": True +            } +        ) + +        for case in test_cases: +            with self.subTest(error=case["error"], call_prepared=case["call_prepared"]): +                self.ctx.reset_mock() +                self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"])) +                self.ctx.send.assert_awaited_once() +                if case["call_prepared"]: +                    self.ctx.send_help.assert_awaited_once() +                else: +                    self.ctx.send_help.assert_not_awaited() + +    async def test_handle_check_failure_errors(self): +        """Should await `ctx.send` when error is check failure.""" +        test_cases = ( +            { +                "error": errors.BotMissingPermissions(MagicMock()), +                "call_ctx_send": True +            }, +            { +                "error": errors.BotMissingRole(MagicMock()), +                "call_ctx_send": True +            }, +            { +                "error": errors.BotMissingAnyRole(MagicMock()), +                "call_ctx_send": True +            }, +            { +                "error": errors.NoPrivateMessage(), +                "call_ctx_send": True +            }, +            { +                "error": InWhitelistCheckFailure(1234), +                "call_ctx_send": True +            }, +            { +                "error": ResponseCodeError(MagicMock()), +                "call_ctx_send": False +            } +        ) + +        for case in test_cases: +            with self.subTest(error=case["error"], call_ctx_send=case["call_ctx_send"]): +                self.ctx.reset_mock() +                await self.cog.handle_check_failure(self.ctx, case["error"]) +                if case["call_ctx_send"]: +                    self.ctx.send.assert_awaited_once() +                else: +                    self.ctx.send.assert_not_awaited() + +    @patch("bot.exts.backend.error_handler.log") +    async def test_handle_api_error(self, log_mock): +        """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code.""" +        test_cases = ( +            { +                "error": ResponseCodeError(AsyncMock(status=400)), +                "log_level": "debug" +            }, +            { +                "error": ResponseCodeError(AsyncMock(status=404)), +                "log_level": "debug" +            }, +            { +                "error": ResponseCodeError(AsyncMock(status=550)), +                "log_level": "warning" +            }, +            { +                "error": ResponseCodeError(AsyncMock(status=1000)), +                "log_level": "warning" +            } +        ) + +        for case in test_cases: +            with self.subTest(error=case["error"], log_level=case["log_level"]): +                self.ctx.reset_mock() +                log_mock.reset_mock() +                await self.cog.handle_api_error(self.ctx, case["error"]) +                self.ctx.send.assert_awaited_once() +                if case["log_level"] == "warning": +                    log_mock.warning.assert_called_once() +                else: +                    log_mock.debug.assert_called_once() + +    @patch("bot.exts.backend.error_handler.push_scope") +    @patch("bot.exts.backend.error_handler.log") +    async def test_handle_unexpected_error(self, log_mock, push_scope_mock): +        """Should `ctx.send` this error, error log this and sent to Sentry.""" +        for case in (None, MockGuild()): +            with self.subTest(guild=case): +                self.ctx.reset_mock() +                log_mock.reset_mock() +                push_scope_mock.reset_mock() + +                self.ctx.guild = case +                await self.cog.handle_unexpected_error(self.ctx, errors.CommandError()) + +                self.ctx.send.assert_awaited_once() +                log_mock.error.assert_called_once() +                push_scope_mock.assert_called_once() + +                set_tag_calls = [ +                    call("command", self.ctx.command.qualified_name), +                    call("message_id", self.ctx.message.id), +                    call("channel_id", self.ctx.channel.id), +                ] +                set_extra_calls = [ +                    call("full_message", self.ctx.message.content) +                ] +                if case: +                    url = ( +                        f"https://discordapp.com/channels/" +                        f"{self.ctx.guild.id}/{self.ctx.channel.id}/{self.ctx.message.id}" +                    ) +                    set_extra_calls.append(call("jump_to", url)) + +                push_scope_mock.set_tag.has_calls(set_tag_calls) +                push_scope_mock.set_extra.has_calls(set_extra_calls) + + +class OtherErrorHandlerTests(unittest.IsolatedAsyncioTestCase): +    """Other `ErrorHandler` tests.""" + +    def setUp(self): +        self.bot = MockBot() +        self.ctx = MockContext() + +    async def test_get_help_command_command_specified(self): +        """Should return coroutine of help command of specified command.""" +        self.ctx.command = "foo" +        result = ErrorHandler.get_help_command(self.ctx) +        expected = self.ctx.send_help("foo") +        self.assertEqual(result.__qualname__, expected.__qualname__) +        self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals) + +        # Await coroutines to avoid warnings +        await result +        await expected + +    async def test_get_help_command_no_command_specified(self): +        """Should return coroutine of help command.""" +        self.ctx.command = None +        result = ErrorHandler.get_help_command(self.ctx) +        expected = self.ctx.send_help() +        self.assertEqual(result.__qualname__, expected.__qualname__) +        self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals) + +        # Await coroutines to avoid warnings +        await result +        await expected + + +class ErrorHandlerSetupTests(unittest.TestCase): +    """Tests for `ErrorHandler` `setup` function.""" + +    def test_setup(self): +        """Should call `bot.add_cog` with `ErrorHandler`.""" +        bot = MockBot() +        setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index a996ce477..770660fe3 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -281,6 +281,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          """The embed should use the string representation of the user if they don't have a nick."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))          user = helpers.MockMember() +        user.public_flags = unittest.mock.MagicMock(verified_bot=False)          user.nick = None          user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")          user.colour = 0 @@ -297,6 +298,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          """The embed should use the nick if it's available."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))          user = helpers.MockMember() +        user.public_flags = unittest.mock.MagicMock(verified_bot=False)          user.nick = "Cat lover"          user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")          user.colour = 0 diff --git a/tests/helpers.py b/tests/helpers.py index 496363ae3..e3dc5fe5b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -385,6 +385,7 @@ message_instance = discord.Message(state=state, channel=channel, data=message_da  # Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`  context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock()) +context_instance.invoked_from_error_handler = None  class MockContext(CustomMockMixin, unittest.mock.MagicMock): @@ -402,6 +403,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):          self.guild = kwargs.get('guild', MockGuild())          self.author = kwargs.get('author', MockMember())          self.channel = kwargs.get('channel', MockTextChannel()) +        self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False)  attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) | 
