diff options
-rw-r--r-- | bot/constants.py | 10 | ||||
-rw-r--r-- | bot/exts/backend/error_handler.py | 22 | ||||
-rw-r--r-- | bot/exts/info/code_snippets.py | 265 | ||||
-rw-r--r-- | bot/exts/info/information.py | 5 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 18 | ||||
-rw-r--r-- | bot/exts/moderation/modpings.py | 8 | ||||
-rw-r--r-- | bot/exts/moderation/stream.py | 30 | ||||
-rw-r--r-- | bot/exts/utils/reminders.py | 13 | ||||
-rw-r--r-- | bot/log.py | 37 | ||||
-rw-r--r-- | config-default.yml | 9 | ||||
-rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 550 | ||||
-rw-r--r-- | tests/bot/exts/info/test_information.py | 2 | ||||
-rw-r--r-- | tests/helpers.py | 2 |
13 files changed, 932 insertions, 39 deletions
diff --git a/bot/constants.py b/bot/constants.py index 916ae77e6..0c602f19b 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -175,13 +175,14 @@ class YAMLGetter(type): if cls.subsection is not None: return _CONFIG_YAML[cls.section][cls.subsection][name] return _CONFIG_YAML[cls.section][name] - except KeyError: + except KeyError as e: dotted_path = '.'.join( (cls.section, cls.subsection, name) if cls.subsection is not None else (cls.section, name) ) - log.critical(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.") - raise + # Only an INFO log since this can be caught through `hasattr` or `getattr`. + log.info(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.") + raise AttributeError(repr(name)) from e def __getitem__(cls, name): return cls.__getattr__(name) @@ -199,6 +200,7 @@ class Bot(metaclass=YAMLGetter): prefix: str sentry_dsn: Optional[str] token: str + trace_loggers: Optional[str] class Redis(metaclass=YAMLGetter): @@ -279,6 +281,8 @@ class Emojis(metaclass=YAMLGetter): badge_partner: str badge_staff: str badge_verified_bot_developer: str + verified_bot: str + bot: str defcon_shutdown: str # noqa: E704 defcon_unshutdown: str # noqa: E704 diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index da0e94a7e..d8de177f5 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,4 +1,3 @@ -import contextlib import difflib import logging import typing as t @@ -60,7 +59,7 @@ class ErrorHandler(Cog): log.trace(f"Command {command} had its error already handled locally; ignoring.") return - if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + if isinstance(e, errors.CommandNotFound) and not getattr(ctx, "invoked_from_error_handler", False): if await self.try_silence(ctx): return # Try to look for a tag with the command's name @@ -162,9 +161,8 @@ class ErrorHandler(Cog): f"and the fallback tag failed validation in TagNameConverter." ) else: - with contextlib.suppress(ResponseCodeError): - if await ctx.invoke(tags_get_command, tag_name=tag_name): - return + if await ctx.invoke(tags_get_command, tag_name=tag_name): + return if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): await self.send_command_suggestion(ctx, ctx.invoked_with) @@ -214,32 +212,30 @@ class ErrorHandler(Cog): * ArgumentParsingError: send an error message * Other: send an error message and the help command """ - prepared_help_command = self.get_help_command(ctx) - if isinstance(e, errors.MissingRequiredArgument): embed = self._get_error_embed("Missing required argument", e.param.name) await ctx.send(embed=embed) - await prepared_help_command + await self.get_help_command(ctx) self.bot.stats.incr("errors.missing_required_argument") elif isinstance(e, errors.TooManyArguments): embed = self._get_error_embed("Too many arguments", str(e)) await ctx.send(embed=embed) - await prepared_help_command + await self.get_help_command(ctx) self.bot.stats.incr("errors.too_many_arguments") elif isinstance(e, errors.BadArgument): embed = self._get_error_embed("Bad argument", str(e)) await ctx.send(embed=embed) - await prepared_help_command + await self.get_help_command(ctx) self.bot.stats.incr("errors.bad_argument") elif isinstance(e, errors.BadUnionArgument): embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") await ctx.send(embed=embed) - await prepared_help_command + await self.get_help_command(ctx) self.bot.stats.incr("errors.bad_union_argument") elif isinstance(e, errors.ArgumentParsingError): embed = self._get_error_embed("Argument parsing error", str(e)) await ctx.send(embed=embed) - prepared_help_command.close() + self.get_help_command(ctx).close() self.bot.stats.incr("errors.argument_parsing_error") else: embed = self._get_error_embed( @@ -247,7 +243,7 @@ class ErrorHandler(Cog): "Something about your input seems off. Check the arguments and try again." ) await ctx.send(embed=embed) - await prepared_help_command + await self.get_help_command(ctx) self.bot.stats.incr("errors.other_user_input_error") @staticmethod diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py new file mode 100644 index 000000000..06885410b --- /dev/null +++ b/bot/exts/info/code_snippets.py @@ -0,0 +1,265 @@ +import logging +import re +import textwrap +from typing import Any +from urllib.parse import quote_plus + +from aiohttp import ClientResponseError +from discord import Message +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +GITHUB_RE = re.compile( + r'https://github\.com/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/blob/' + r'(?P<path>[^#>]+)(\?[^#>]+)?(#L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)' +) + +GITHUB_GIST_RE = re.compile( + r'https://gist\.github\.com/([a-zA-Z0-9-]+)/(?P<gist_id>[a-zA-Z0-9]+)/*' + r'(?P<revision>[a-zA-Z0-9]*)/*#file-(?P<file_path>[^#>]+?)(\?[^#>]+)?' + r'(-L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)' +) + +GITHUB_HEADERS = {'Accept': 'application/vnd.github.v3.raw'} + +GITLAB_RE = re.compile( + r'https://gitlab\.com/(?P<repo>[\w.-]+/[\w.-]+)/\-/blob/(?P<path>[^#>]+)' + r'(\?[^#>]+)?(#L(?P<start_line>\d+)(-(?P<end_line>\d+))?)' +) + +BITBUCKET_RE = re.compile( + r'https://bitbucket\.org/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/src/(?P<ref>[0-9a-zA-Z]+)' + r'/(?P<file_path>[^#>]+)(\?[^#>]+)?(#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?)' +) + + +class CodeSnippets(Cog): + """ + Cog that parses and sends code snippets to Discord. + + Matches each message against a regex and prints the contents of all matched snippets. + """ + + async def _fetch_response(self, url: str, response_format: str, **kwargs) -> Any: + """Makes http requests using aiohttp.""" + async with self.bot.http_session.get(url, raise_for_status=True, **kwargs) as response: + if response_format == 'text': + return await response.text() + elif response_format == 'json': + return await response.json() + + def _find_ref(self, path: str, refs: tuple) -> tuple: + """Loops through all branches and tags to find the required ref.""" + # Base case: there is no slash in the branch name + ref, file_path = path.split('/', 1) + # In case there are slashes in the branch name, we loop through all branches and tags + for possible_ref in refs: + if path.startswith(possible_ref['name'] + '/'): + ref = possible_ref['name'] + file_path = path[len(ref) + 1:] + break + return ref, file_path + + async def _fetch_github_snippet( + self, + repo: str, + path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitHub repo.""" + # Search the GitHub API for the specified branch + branches = await self._fetch_response( + f'https://api.github.com/repos/{repo}/branches', + 'json', + headers=GITHUB_HEADERS + ) + tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=GITHUB_HEADERS) + refs = branches + tags + ref, file_path = self._find_ref(path, refs) + + file_contents = await self._fetch_response( + f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}', + 'text', + headers=GITHUB_HEADERS, + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + async def _fetch_github_gist_snippet( + self, + gist_id: str, + revision: str, + file_path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitHub gist.""" + gist_json = await self._fetch_response( + f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}', + 'json', + headers=GITHUB_HEADERS, + ) + + # Check each file in the gist for the specified file + for gist_file in gist_json['files']: + if file_path == gist_file.lower().replace('.', '-'): + file_contents = await self._fetch_response( + gist_json['files'][gist_file]['raw_url'], + 'text', + ) + return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line) + return '' + + async def _fetch_gitlab_snippet( + self, + repo: str, + path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitLab repo.""" + enc_repo = quote_plus(repo) + + # Searches the GitLab API for the specified branch + branches = await self._fetch_response( + f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/branches', + 'json' + ) + tags = await self._fetch_response(f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags', 'json') + refs = branches + tags + ref, file_path = self._find_ref(path, refs) + enc_ref = quote_plus(ref) + enc_file_path = quote_plus(file_path) + + file_contents = await self._fetch_response( + f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}', + 'text', + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + async def _fetch_bitbucket_snippet( + self, + repo: str, + ref: str, + file_path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a BitBucket repo.""" + file_contents = await self._fetch_response( + f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}', + 'text', + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str: + """ + Given the entire file contents and target lines, creates a code block. + + First, we split the file contents into a list of lines and then keep and join only the required + ones together. + + We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent + markdown injection. + + Finally, we surround the code with ``` characters. + """ + # Parse start_line and end_line into integers + if end_line is None: + start_line = end_line = int(start_line) + else: + start_line = int(start_line) + end_line = int(end_line) + + split_file_contents = file_contents.splitlines() + + # Make sure that the specified lines are in range + if start_line > end_line: + start_line, end_line = end_line, start_line + if start_line > len(split_file_contents) or end_line < 1: + return '' + start_line = max(1, start_line) + end_line = min(len(split_file_contents), end_line) + + # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection + required = '\n'.join(split_file_contents[start_line - 1:end_line]) + required = textwrap.dedent(required).rstrip().replace('`', '`\u200b') + + # Extracts the code language and checks whether it's a "valid" language + language = file_path.split('/')[-1].split('.')[-1] + trimmed_language = language.replace('-', '').replace('+', '').replace('_', '') + is_valid_language = trimmed_language.isalnum() + if not is_valid_language: + language = '' + + # Adds a label showing the file path to the snippet + if start_line == end_line: + ret = f'`{file_path}` line {start_line}\n' + else: + ret = f'`{file_path}` lines {start_line} to {end_line}\n' + + if len(required) != 0: + return f'{ret}```{language}\n{required}```' + # Returns an empty codeblock if the snippet is empty + return f'{ret}``` ```' + + def __init__(self, bot: Bot): + """Initializes the cog's bot.""" + self.bot = bot + + self.pattern_handlers = [ + (GITHUB_RE, self._fetch_github_snippet), + (GITHUB_GIST_RE, self._fetch_github_gist_snippet), + (GITLAB_RE, self._fetch_gitlab_snippet), + (BITBUCKET_RE, self._fetch_bitbucket_snippet) + ] + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Checks if the message has a snippet link, removes the embed, then sends the snippet contents.""" + if not message.author.bot: + all_snippets = [] + + for pattern, handler in self.pattern_handlers: + for match in pattern.finditer(message.content): + try: + snippet = await handler(**match.groupdict()) + all_snippets.append((match.start(), snippet)) + except ClientResponseError as error: + error_message = error.message # noqa: B306 + log.log( + logging.DEBUG if error.status == 404 else logging.ERROR, + f'Failed to fetch code snippet from {match[0]!r}: {error.status} ' + f'{error_message} for GET {error.request_info.real_url.human_repr()}' + ) + + # Sorts the list of snippets by their match index and joins them into a single message + message_to_send = '\n'.join(map(lambda x: x[1], sorted(all_snippets))) + + if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: + await message.edit(suppress=True) + if len(message_to_send) > 1000 and message.channel.id != Channels.bot_commands: + # Redirects to #bot-commands if the snippet contents are too long + await self.bot.wait_until_guild_available() + await message.channel.send(('The snippet you tried to send was too long. Please ' + f'see <#{Channels.bot_commands}> for the full snippet.')) + bot_commands_channel = self.bot.get_channel(Channels.bot_commands) + await wait_for_deletion( + await bot_commands_channel.send(message_to_send), + (message.author.id,) + ) + else: + await wait_for_deletion( + await message.channel.send(message_to_send), + (message.author.id,) + ) + + +def setup(bot: Bot) -> None: + """Load the CodeSnippets cog.""" + bot.add_cog(CodeSnippets(bot)) diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 5e2c4b417..834fee1b4 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -230,6 +230,11 @@ class Information(Cog): if on_server and user.nick: name = f"{user.nick} ({name})" + if user.public_flags.verified_bot: + name += f" {constants.Emojis.verified_bot}" + elif user.bot: + name += f" {constants.Emojis.bot}" + badges = [] for badge, is_set in user.public_flags: diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index d89e80acc..38d1ffc0e 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -54,8 +54,12 @@ class Infractions(InfractionScheduler, commands.Cog): # region: Permanent infractions @command() - async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + async def warn(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: """Warn a user for the given reason.""" + if not isinstance(user, Member): + await ctx.send(":x: The user doesn't appear to be on the server.") + return + infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False) if infraction is None: return @@ -63,8 +67,12 @@ class Infractions(InfractionScheduler, commands.Cog): await self.apply_infraction(ctx, infraction, user) @command() - async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: + async def kick(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: """Kick a user for the given reason.""" + if not isinstance(user, Member): + await ctx.send(":x: The user doesn't appear to be on the server.") + return + await self.apply_kick(ctx, user, reason) @command() @@ -100,7 +108,7 @@ class Infractions(InfractionScheduler, commands.Cog): @command(aliases=["mute"]) async def tempmute( self, ctx: Context, - user: Member, + user: FetchedMember, duration: t.Optional[Expiry] = None, *, reason: t.Optional[str] = None @@ -122,6 +130,10 @@ class Infractions(InfractionScheduler, commands.Cog): If no duration is given, a one hour duration is used by default. """ + if not isinstance(user, Member): + await ctx.send(":x: The user doesn't appear to be on the server.") + return + if duration is None: duration = await Duration().convert(ctx, "1h") await self.apply_mute(ctx, user, reason, expires_at=duration) diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index 2f180e594..1ad5005de 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -3,11 +3,11 @@ import logging from async_rediscache import RedisCache from dateutil.parser import isoparse -from discord import Member +from discord import Embed, Member from discord.ext.commands import Cog, Context, group, has_any_role from bot.bot import Bot -from bot.constants import Emojis, Guild, MODERATION_ROLES, Roles +from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles from bot.converters import Expiry from bot.utils.scheduling import Scheduler @@ -104,7 +104,9 @@ class ModPings(Cog): self._role_scheduler.cancel(mod.id) self._role_scheduler.schedule_at(duration, mod.id, self.reapply_role(mod)) - await ctx.send(f"{Emojis.check_mark} Moderators role has been removed until {until_date}.") + embed = Embed(timestamp=duration, colour=Colours.bright_green) + embed.set_footer(text="Moderators role has been removed until", icon_url=Icons.green_checkmark) + await ctx.send(embed=embed) @modpings_group.command(name='on') @has_any_role(*MODERATION_ROLES) diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 1dbb2a46b..fd856a7f4 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -70,6 +70,28 @@ class Stream(commands.Cog): self._revoke_streaming_permission(member) ) + async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None: + """Suspend a member's stream.""" + await self.bot.wait_until_guild_available() + voice_state = member.voice + + if not voice_state: + return + + # If the user is streaming. + if voice_state.self_stream: + # End user's stream by moving them to AFK voice channel and back. + original_vc = voice_state.channel + await member.move_to(ctx.guild.afk_channel) + await member.move_to(original_vc) + + # Notify. + await ctx.send(f"{member.mention}'s stream has been suspended!") + log.debug(f"Successfully suspended stream from {member} ({member.id}).") + return + + log.debug(f"No stream found to suspend from {member} ({member.id}).") + @commands.command(aliases=("streaming",)) @commands.has_any_role(*MODERATION_ROLES) async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None: @@ -170,10 +192,12 @@ class Stream(commands.Cog): await ctx.send(f"{Emojis.check_mark} Revoked the permission to stream from {member.mention}.") log.debug(f"Successfully revoked streaming permission from {member} ({member.id}).") - return - await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!") - log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!") + else: + await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!") + log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!") + + await self._suspend_stream(ctx, member) @commands.command(aliases=('lstream',)) @commands.has_any_role(*MODERATION_ROLES) diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 3113a1149..6c21920a1 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -90,15 +90,18 @@ class Reminders(Cog): delivery_dt: t.Optional[datetime], ) -> None: """Send an embed confirming the reminder change was made successfully.""" - embed = discord.Embed() - embed.colour = discord.Colour.green() - embed.title = random.choice(POSITIVE_REPLIES) - embed.description = on_success + embed = discord.Embed( + description=on_success, + colour=discord.Colour.green(), + title=random.choice(POSITIVE_REPLIES) + ) footer_str = f"ID: {reminder_id}" + if delivery_dt: # Reminder deletion will have a `None` `delivery_dt` - footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" + footer_str += ', Due' + embed.timestamp = delivery_dt embed.set_footer(text=footer_str) diff --git a/bot/log.py b/bot/log.py index e92233a33..4e20c005e 100644 --- a/bot/log.py +++ b/bot/log.py @@ -20,7 +20,6 @@ def setup() -> None: logging.addLevelName(TRACE_LEVEL, "TRACE") Logger.trace = _monkeypatch_trace - log_level = TRACE_LEVEL if constants.DEBUG_MODE else logging.INFO format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s" log_format = logging.Formatter(format_string) @@ -30,7 +29,6 @@ def setup() -> None: file_handler.setFormatter(log_format) root_log = logging.getLogger() - root_log.setLevel(log_level) root_log.addHandler(file_handler) if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: @@ -44,11 +42,9 @@ def setup() -> None: if "COLOREDLOGS_LOG_FORMAT" not in os.environ: coloredlogs.DEFAULT_LOG_FORMAT = format_string - if "COLOREDLOGS_LOG_LEVEL" not in os.environ: - coloredlogs.DEFAULT_LOG_LEVEL = log_level - - coloredlogs.install(logger=root_log, stream=sys.stdout) + coloredlogs.install(level=logging.TRACE, logger=root_log, stream=sys.stdout) + root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO) logging.getLogger("discord").setLevel(logging.WARNING) logging.getLogger("websockets").setLevel(logging.WARNING) logging.getLogger("chardet").setLevel(logging.WARNING) @@ -57,6 +53,8 @@ def setup() -> None: # Set back to the default of INFO even if asyncio's debug mode is enabled. logging.getLogger("asyncio").setLevel(logging.INFO) + _set_trace_loggers() + def setup_sentry() -> None: """Set up the Sentry logging integrations.""" @@ -86,3 +84,30 @@ def _monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: """ if self.isEnabledFor(TRACE_LEVEL): self._log(TRACE_LEVEL, msg, args, **kwargs) + + +def _set_trace_loggers() -> None: + """ + Set loggers to the trace level according to the value from the BOT_TRACE_LOGGERS env var. + + When the env var is a list of logger names delimited by a comma, + each of the listed loggers will be set to the trace level. + + If this list is prefixed with a "!", all of the loggers except the listed ones will be set to the trace level. + + Otherwise if the env var begins with a "*", + the root logger is set to the trace level and other contents are ignored. + """ + level_filter = constants.Bot.trace_loggers + if level_filter: + if level_filter.startswith("*"): + logging.getLogger().setLevel(logging.TRACE) + + elif level_filter.startswith("!"): + logging.getLogger().setLevel(logging.TRACE) + for logger_name in level_filter.strip("!,").split(","): + logging.getLogger(logger_name).setLevel(logging.DEBUG) + + else: + for logger_name in level_filter.strip(",").split(","): + logging.getLogger(logger_name).setLevel(logging.TRACE) diff --git a/config-default.yml b/config-default.yml index 9f4c9b80b..b9f6b40ac 100644 --- a/config-default.yml +++ b/config-default.yml @@ -1,7 +1,8 @@ bot: - prefix: "!" - sentry_dsn: !ENV "BOT_SENTRY_DSN" - token: !ENV "BOT_TOKEN" + prefix: "!" + sentry_dsn: !ENV "BOT_SENTRY_DSN" + token: !ENV "BOT_TOKEN" + trace_loggers: !ENV "BOT_TRACE_LOGGERS" clean: # Maximum number of messages to traverse for clean commands @@ -46,6 +47,8 @@ style: badge_partner: "<:partner:748666453242413136>" badge_staff: "<:discord_staff:743882896498098226>" badge_verified_bot_developer: "<:verified_bot_dev:743882897299210310>" + bot: "<:bot:812712599464443914>" + verified_bot: "<:verified_bot:811645219220750347>" defcon_shutdown: "<:defcondisabled:470326273952972810>" defcon_unshutdown: "<:defconenabled:470326274213150730>" diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py new file mode 100644 index 000000000..bd4fb5942 --- /dev/null +++ b/tests/bot/exts/backend/test_error_handler.py @@ -0,0 +1,550 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, call, patch + +from discord.ext.commands import errors + +from bot.api import ResponseCodeError +from bot.errors import InvalidInfractedUser, LockedResourceError +from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.info.tags import Tags +from bot.exts.moderation.silence import Silence +from bot.utils.checks import InWhitelistCheckFailure +from tests.helpers import MockBot, MockContext, MockGuild, MockRole + + +class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): + """Tests for error handler functionality.""" + + def setUp(self): + self.bot = MockBot() + self.ctx = MockContext(bot=self.bot) + + async def test_error_handler_already_handled(self): + """Should not do anything when error is already handled by local error handler.""" + self.ctx.reset_mock() + cog = ErrorHandler(self.bot) + error = errors.CommandError() + error.handled = "foo" + self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.ctx.send.assert_not_awaited() + + async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): + """Should try first (un)silence channel, when fail, try to get tag.""" + error = errors.CommandNotFound() + test_cases = ( + { + "try_silence_return": True, + "called_try_get_tag": False + }, + { + "try_silence_return": False, + "called_try_get_tag": False + }, + { + "try_silence_return": False, + "called_try_get_tag": True + } + ) + cog = ErrorHandler(self.bot) + cog.try_silence = AsyncMock() + cog.try_get_tag = AsyncMock() + + for case in test_cases: + with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): + self.ctx.reset_mock() + cog.try_silence.reset_mock(return_value=True) + cog.try_get_tag.reset_mock() + + cog.try_silence.return_value = case["try_silence_return"] + self.ctx.channel.id = 1234 + + self.assertIsNone(await cog.on_command_error(self.ctx, error)) + + if case["try_silence_return"]: + cog.try_get_tag.assert_not_awaited() + cog.try_silence.assert_awaited_once() + else: + cog.try_silence.assert_awaited_once() + cog.try_get_tag.assert_awaited_once() + + self.ctx.send.assert_not_awaited() + + async def test_error_handler_command_not_found_error_invoked_by_handler(self): + """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" + ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) + + cog = ErrorHandler(self.bot) + cog.try_silence = AsyncMock() + cog.try_get_tag = AsyncMock() + + error = errors.CommandNotFound() + + self.assertIsNone(await cog.on_command_error(ctx, error)) + + cog.try_silence.assert_not_awaited() + cog.try_get_tag.assert_not_awaited() + self.ctx.send.assert_not_awaited() + + async def test_error_handler_user_input_error(self): + """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" + self.ctx.reset_mock() + cog = ErrorHandler(self.bot) + cog.handle_user_input_error = AsyncMock() + error = errors.UserInputError() + self.assertIsNone(await cog.on_command_error(self.ctx, error)) + cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + + async def test_error_handler_check_failure(self): + """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" + self.ctx.reset_mock() + cog = ErrorHandler(self.bot) + cog.handle_check_failure = AsyncMock() + error = errors.CheckFailure() + self.assertIsNone(await cog.on_command_error(self.ctx, error)) + cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + + async def test_error_handler_command_on_cooldown(self): + """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" + self.ctx.reset_mock() + cog = ErrorHandler(self.bot) + error = errors.CommandOnCooldown(10, 9) + self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.ctx.send.assert_awaited_once_with(error) + + async def test_error_handler_command_invoke_error(self): + """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" + cog = ErrorHandler(self.bot) + cog.handle_api_error = AsyncMock() + cog.handle_unexpected_error = AsyncMock() + test_cases = ( + { + "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), + "expect_mock_call": cog.handle_api_error + }, + { + "args": (self.ctx, errors.CommandInvokeError(TypeError)), + "expect_mock_call": cog.handle_unexpected_error + }, + { + "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), + "expect_mock_call": "send" + }, + { + "args": (self.ctx, errors.CommandInvokeError(InvalidInfractedUser(self.ctx.author))), + "expect_mock_call": "send" + } + ) + + for case in test_cases: + with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): + self.ctx.send.reset_mock() + self.assertIsNone(await cog.on_command_error(*case["args"])) + if case["expect_mock_call"] == "send": + self.ctx.send.assert_awaited_once() + else: + case["expect_mock_call"].assert_awaited_once_with( + self.ctx, case["args"][1].original + ) + + async def test_error_handler_conversion_error(self): + """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" + cog = ErrorHandler(self.bot) + cog.handle_api_error = AsyncMock() + cog.handle_unexpected_error = AsyncMock() + cases = ( + { + "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), + "mock_function_to_call": cog.handle_api_error + }, + { + "error": errors.ConversionError(AsyncMock(), TypeError), + "mock_function_to_call": cog.handle_unexpected_error + } + ) + + for case in cases: + with self.subTest(**case): + self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) + case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) + + async def test_error_handler_two_other_errors(self): + """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" + cog = ErrorHandler(self.bot) + cog.handle_unexpected_error = AsyncMock() + errs = ( + errors.MaxConcurrencyReached(1, MagicMock()), + errors.ExtensionError(name="foo") + ) + + for err in errs: + with self.subTest(error=err): + cog.handle_unexpected_error.reset_mock() + self.assertIsNone(await cog.on_command_error(self.ctx, err)) + cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + + @patch("bot.exts.backend.error_handler.log") + async def test_error_handler_other_errors(self, log_mock): + """Should `log.debug` other errors.""" + cog = ErrorHandler(self.bot) + error = errors.DisabledCommand() # Use this just as a other error + self.assertIsNone(await cog.on_command_error(self.ctx, error)) + log_mock.debug.assert_called_once() + + +class TrySilenceTests(unittest.IsolatedAsyncioTestCase): + """Test for helper functions that handle `CommandNotFound` error.""" + + def setUp(self): + self.bot = MockBot() + self.silence = Silence(self.bot) + self.bot.get_command.return_value = self.silence.silence + self.ctx = MockContext(bot=self.bot) + self.cog = ErrorHandler(self.bot) + + async def test_try_silence_context_invoked_from_error_handler(self): + """Should set `Context.invoked_from_error_handler` to `True`.""" + self.ctx.invoked_with = "foo" + await self.cog.try_silence(self.ctx) + self.assertTrue(hasattr(self.ctx, "invoked_from_error_handler")) + self.assertTrue(self.ctx.invoked_from_error_handler) + + async def test_try_silence_get_command(self): + """Should call `get_command` with `silence`.""" + self.ctx.invoked_with = "foo" + await self.cog.try_silence(self.ctx) + self.bot.get_command.assert_called_once_with("silence") + + async def test_try_silence_no_permissions_to_run(self): + """Should return `False` because missing permissions.""" + self.ctx.invoked_with = "foo" + self.bot.get_command.return_value.can_run = AsyncMock(return_value=False) + self.assertFalse(await self.cog.try_silence(self.ctx)) + + async def test_try_silence_no_permissions_to_run_command_error(self): + """Should return `False` because `CommandError` raised (no permissions).""" + self.ctx.invoked_with = "foo" + self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError()) + self.assertFalse(await self.cog.try_silence(self.ctx)) + + async def test_try_silence_silencing(self): + """Should run silence command with correct arguments.""" + self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) + test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh") + + for case in test_cases: + with self.subTest(message=case): + self.ctx.reset_mock() + self.ctx.invoked_with = case + self.assertTrue(await self.cog.try_silence(self.ctx)) + self.ctx.invoke.assert_awaited_once_with( + self.bot.get_command.return_value, + duration=min(case.count("h")*2, 15) + ) + + async def test_try_silence_unsilence(self): + """Should call unsilence command.""" + self.silence.silence.can_run = AsyncMock(return_value=True) + test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh") + + for case in test_cases: + with self.subTest(message=case): + self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence) + self.ctx.reset_mock() + self.ctx.invoked_with = case + self.assertTrue(await self.cog.try_silence(self.ctx)) + self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence) + + async def test_try_silence_no_match(self): + """Should return `False` when message don't match.""" + self.ctx.invoked_with = "foo" + self.assertFalse(await self.cog.try_silence(self.ctx)) + + +class TryGetTagTests(unittest.IsolatedAsyncioTestCase): + """Tests for `try_get_tag` function.""" + + def setUp(self): + self.bot = MockBot() + self.ctx = MockContext() + self.tag = Tags(self.bot) + self.cog = ErrorHandler(self.bot) + self.bot.get_command.return_value = self.tag.get_command + + async def test_try_get_tag_get_command(self): + """Should call `Bot.get_command` with `tags get` argument.""" + self.bot.get_command.reset_mock() + self.ctx.invoked_with = "foo" + await self.cog.try_get_tag(self.ctx) + self.bot.get_command.assert_called_once_with("tags get") + + async def test_try_get_tag_invoked_from_error_handler(self): + """`self.ctx` should have `invoked_from_error_handler` `True`.""" + self.ctx.invoked_from_error_handler = False + self.ctx.invoked_with = "foo" + await self.cog.try_get_tag(self.ctx) + self.assertTrue(self.ctx.invoked_from_error_handler) + + async def test_try_get_tag_no_permissions(self): + """Test how to handle checks failing.""" + self.tag.get_command.can_run = AsyncMock(return_value=False) + self.ctx.invoked_with = "foo" + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + + async def test_try_get_tag_command_error(self): + """Should call `on_command_error` when `CommandError` raised.""" + err = errors.CommandError() + self.tag.get_command.can_run = AsyncMock(side_effect=err) + self.cog.on_command_error = AsyncMock() + self.ctx.invoked_with = "foo" + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) + + @patch("bot.exts.backend.error_handler.TagNameConverter") + async def test_try_get_tag_convert_success(self, tag_converter): + """Converting tag should successful.""" + self.ctx.invoked_with = "foo" + tag_converter.convert = AsyncMock(return_value="foo") + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + tag_converter.convert.assert_awaited_once_with(self.ctx, "foo") + self.ctx.invoke.assert_awaited_once() + + @patch("bot.exts.backend.error_handler.TagNameConverter") + async def test_try_get_tag_convert_fail(self, tag_converter): + """Converting tag should raise `BadArgument`.""" + self.ctx.reset_mock() + self.ctx.invoked_with = "bar" + tag_converter.convert = AsyncMock(side_effect=errors.BadArgument()) + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + self.ctx.invoke.assert_not_awaited() + + async def test_try_get_tag_ctx_invoke(self): + """Should call `ctx.invoke` with proper args/kwargs.""" + self.ctx.reset_mock() + self.ctx.invoked_with = "foo" + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo") + + async def test_dont_call_suggestion_tag_sent(self): + """Should never call command suggestion if tag is already sent.""" + self.ctx.invoked_with = "foo" + self.ctx.invoke = AsyncMock(return_value=True) + self.cog.send_command_suggestion = AsyncMock() + + await self.cog.try_get_tag(self.ctx) + self.cog.send_command_suggestion.assert_not_awaited() + + @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234]) + async def test_dont_call_suggestion_if_user_mod(self): + """Should not call command suggestion if user is a mod.""" + self.ctx.invoked_with = "foo" + self.ctx.invoke = AsyncMock(return_value=False) + self.ctx.author.roles = [MockRole(id=1234)] + self.cog.send_command_suggestion = AsyncMock() + + await self.cog.try_get_tag(self.ctx) + self.cog.send_command_suggestion.assert_not_awaited() + + async def test_call_suggestion(self): + """Should call command suggestion if user is not a mod.""" + self.ctx.invoked_with = "foo" + self.ctx.invoke = AsyncMock(return_value=False) + self.cog.send_command_suggestion = AsyncMock() + + await self.cog.try_get_tag(self.ctx) + self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo") + + +class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): + """Individual error categories handler tests.""" + + def setUp(self): + self.bot = MockBot() + self.ctx = MockContext(bot=self.bot) + self.cog = ErrorHandler(self.bot) + + async def test_handle_input_error_handler_errors(self): + """Should handle each error probably.""" + test_cases = ( + { + "error": errors.MissingRequiredArgument(MagicMock()), + "call_prepared": True + }, + { + "error": errors.TooManyArguments(), + "call_prepared": True + }, + { + "error": errors.BadArgument(), + "call_prepared": True + }, + { + "error": errors.BadUnionArgument(MagicMock(), MagicMock(), MagicMock()), + "call_prepared": True + }, + { + "error": errors.ArgumentParsingError(), + "call_prepared": False + }, + { + "error": errors.UserInputError(), + "call_prepared": True + } + ) + + for case in test_cases: + with self.subTest(error=case["error"], call_prepared=case["call_prepared"]): + self.ctx.reset_mock() + self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"])) + self.ctx.send.assert_awaited_once() + if case["call_prepared"]: + self.ctx.send_help.assert_awaited_once() + else: + self.ctx.send_help.assert_not_awaited() + + async def test_handle_check_failure_errors(self): + """Should await `ctx.send` when error is check failure.""" + test_cases = ( + { + "error": errors.BotMissingPermissions(MagicMock()), + "call_ctx_send": True + }, + { + "error": errors.BotMissingRole(MagicMock()), + "call_ctx_send": True + }, + { + "error": errors.BotMissingAnyRole(MagicMock()), + "call_ctx_send": True + }, + { + "error": errors.NoPrivateMessage(), + "call_ctx_send": True + }, + { + "error": InWhitelistCheckFailure(1234), + "call_ctx_send": True + }, + { + "error": ResponseCodeError(MagicMock()), + "call_ctx_send": False + } + ) + + for case in test_cases: + with self.subTest(error=case["error"], call_ctx_send=case["call_ctx_send"]): + self.ctx.reset_mock() + await self.cog.handle_check_failure(self.ctx, case["error"]) + if case["call_ctx_send"]: + self.ctx.send.assert_awaited_once() + else: + self.ctx.send.assert_not_awaited() + + @patch("bot.exts.backend.error_handler.log") + async def test_handle_api_error(self, log_mock): + """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code.""" + test_cases = ( + { + "error": ResponseCodeError(AsyncMock(status=400)), + "log_level": "debug" + }, + { + "error": ResponseCodeError(AsyncMock(status=404)), + "log_level": "debug" + }, + { + "error": ResponseCodeError(AsyncMock(status=550)), + "log_level": "warning" + }, + { + "error": ResponseCodeError(AsyncMock(status=1000)), + "log_level": "warning" + } + ) + + for case in test_cases: + with self.subTest(error=case["error"], log_level=case["log_level"]): + self.ctx.reset_mock() + log_mock.reset_mock() + await self.cog.handle_api_error(self.ctx, case["error"]) + self.ctx.send.assert_awaited_once() + if case["log_level"] == "warning": + log_mock.warning.assert_called_once() + else: + log_mock.debug.assert_called_once() + + @patch("bot.exts.backend.error_handler.push_scope") + @patch("bot.exts.backend.error_handler.log") + async def test_handle_unexpected_error(self, log_mock, push_scope_mock): + """Should `ctx.send` this error, error log this and sent to Sentry.""" + for case in (None, MockGuild()): + with self.subTest(guild=case): + self.ctx.reset_mock() + log_mock.reset_mock() + push_scope_mock.reset_mock() + + self.ctx.guild = case + await self.cog.handle_unexpected_error(self.ctx, errors.CommandError()) + + self.ctx.send.assert_awaited_once() + log_mock.error.assert_called_once() + push_scope_mock.assert_called_once() + + set_tag_calls = [ + call("command", self.ctx.command.qualified_name), + call("message_id", self.ctx.message.id), + call("channel_id", self.ctx.channel.id), + ] + set_extra_calls = [ + call("full_message", self.ctx.message.content) + ] + if case: + url = ( + f"https://discordapp.com/channels/" + f"{self.ctx.guild.id}/{self.ctx.channel.id}/{self.ctx.message.id}" + ) + set_extra_calls.append(call("jump_to", url)) + + push_scope_mock.set_tag.has_calls(set_tag_calls) + push_scope_mock.set_extra.has_calls(set_extra_calls) + + +class OtherErrorHandlerTests(unittest.IsolatedAsyncioTestCase): + """Other `ErrorHandler` tests.""" + + def setUp(self): + self.bot = MockBot() + self.ctx = MockContext() + + async def test_get_help_command_command_specified(self): + """Should return coroutine of help command of specified command.""" + self.ctx.command = "foo" + result = ErrorHandler.get_help_command(self.ctx) + expected = self.ctx.send_help("foo") + self.assertEqual(result.__qualname__, expected.__qualname__) + self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals) + + # Await coroutines to avoid warnings + await result + await expected + + async def test_get_help_command_no_command_specified(self): + """Should return coroutine of help command.""" + self.ctx.command = None + result = ErrorHandler.get_help_command(self.ctx) + expected = self.ctx.send_help() + self.assertEqual(result.__qualname__, expected.__qualname__) + self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals) + + # Await coroutines to avoid warnings + await result + await expected + + +class ErrorHandlerSetupTests(unittest.TestCase): + """Tests for `ErrorHandler` `setup` function.""" + + def test_setup(self): + """Should call `bot.add_cog` with `ErrorHandler`.""" + bot = MockBot() + setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index a996ce477..770660fe3 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -281,6 +281,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): """The embed should use the string representation of the user if they don't have a nick.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) user = helpers.MockMember() + user.public_flags = unittest.mock.MagicMock(verified_bot=False) user.nick = None user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") user.colour = 0 @@ -297,6 +298,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): """The embed should use the nick if it's available.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) user = helpers.MockMember() + user.public_flags = unittest.mock.MagicMock(verified_bot=False) user.nick = "Cat lover" user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") user.colour = 0 diff --git a/tests/helpers.py b/tests/helpers.py index 529664e67..86cc635f8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -404,6 +404,7 @@ message_instance = discord.Message(state=state, channel=channel, data=message_da # Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context` context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock()) +context_instance.invoked_from_error_handler = None class MockContext(CustomMockMixin, unittest.mock.MagicMock): @@ -421,6 +422,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.guild = kwargs.get('guild', MockGuild()) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) + self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) |