aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Hassan Abouelela <[email protected]>2021-05-04 05:11:36 +0300
committerGravatar GitHub <[email protected]>2021-05-04 05:11:36 +0300
commit6a3551782b9c8613f1314265b6c092360ebf4131 (patch)
tree1a84e444a27e40a2ffab6d49020d12399818e0bc
parentAdds Missing Voice Version Of Tests (diff)
parentMerge pull request #1556 from ToxicKidz/mod-ping-off-embed-timestamp (diff)
Merge branch 'main' into voicechannel-mute
-rw-r--r--bot/constants.py10
-rw-r--r--bot/exts/backend/error_handler.py22
-rw-r--r--bot/exts/info/code_snippets.py265
-rw-r--r--bot/exts/info/information.py5
-rw-r--r--bot/exts/moderation/infraction/infractions.py18
-rw-r--r--bot/exts/moderation/modpings.py8
-rw-r--r--bot/exts/moderation/stream.py30
-rw-r--r--bot/exts/utils/reminders.py13
-rw-r--r--bot/log.py37
-rw-r--r--config-default.yml9
-rw-r--r--tests/bot/exts/backend/test_error_handler.py550
-rw-r--r--tests/bot/exts/info/test_information.py2
-rw-r--r--tests/helpers.py2
13 files changed, 932 insertions, 39 deletions
diff --git a/bot/constants.py b/bot/constants.py
index 916ae77e6..0c602f19b 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -175,13 +175,14 @@ class YAMLGetter(type):
if cls.subsection is not None:
return _CONFIG_YAML[cls.section][cls.subsection][name]
return _CONFIG_YAML[cls.section][name]
- except KeyError:
+ except KeyError as e:
dotted_path = '.'.join(
(cls.section, cls.subsection, name)
if cls.subsection is not None else (cls.section, name)
)
- log.critical(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.")
- raise
+ # Only an INFO log since this can be caught through `hasattr` or `getattr`.
+ log.info(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.")
+ raise AttributeError(repr(name)) from e
def __getitem__(cls, name):
return cls.__getattr__(name)
@@ -199,6 +200,7 @@ class Bot(metaclass=YAMLGetter):
prefix: str
sentry_dsn: Optional[str]
token: str
+ trace_loggers: Optional[str]
class Redis(metaclass=YAMLGetter):
@@ -279,6 +281,8 @@ class Emojis(metaclass=YAMLGetter):
badge_partner: str
badge_staff: str
badge_verified_bot_developer: str
+ verified_bot: str
+ bot: str
defcon_shutdown: str # noqa: E704
defcon_unshutdown: str # noqa: E704
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index da0e94a7e..d8de177f5 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -1,4 +1,3 @@
-import contextlib
import difflib
import logging
import typing as t
@@ -60,7 +59,7 @@ class ErrorHandler(Cog):
log.trace(f"Command {command} had its error already handled locally; ignoring.")
return
- if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"):
+ if isinstance(e, errors.CommandNotFound) and not getattr(ctx, "invoked_from_error_handler", False):
if await self.try_silence(ctx):
return
# Try to look for a tag with the command's name
@@ -162,9 +161,8 @@ class ErrorHandler(Cog):
f"and the fallback tag failed validation in TagNameConverter."
)
else:
- with contextlib.suppress(ResponseCodeError):
- if await ctx.invoke(tags_get_command, tag_name=tag_name):
- return
+ if await ctx.invoke(tags_get_command, tag_name=tag_name):
+ return
if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
await self.send_command_suggestion(ctx, ctx.invoked_with)
@@ -214,32 +212,30 @@ class ErrorHandler(Cog):
* ArgumentParsingError: send an error message
* Other: send an error message and the help command
"""
- prepared_help_command = self.get_help_command(ctx)
-
if isinstance(e, errors.MissingRequiredArgument):
embed = self._get_error_embed("Missing required argument", e.param.name)
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.missing_required_argument")
elif isinstance(e, errors.TooManyArguments):
embed = self._get_error_embed("Too many arguments", str(e))
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.too_many_arguments")
elif isinstance(e, errors.BadArgument):
embed = self._get_error_embed("Bad argument", str(e))
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.bad_argument")
elif isinstance(e, errors.BadUnionArgument):
embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}")
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.bad_union_argument")
elif isinstance(e, errors.ArgumentParsingError):
embed = self._get_error_embed("Argument parsing error", str(e))
await ctx.send(embed=embed)
- prepared_help_command.close()
+ self.get_help_command(ctx).close()
self.bot.stats.incr("errors.argument_parsing_error")
else:
embed = self._get_error_embed(
@@ -247,7 +243,7 @@ class ErrorHandler(Cog):
"Something about your input seems off. Check the arguments and try again."
)
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.other_user_input_error")
@staticmethod
diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py
new file mode 100644
index 000000000..06885410b
--- /dev/null
+++ b/bot/exts/info/code_snippets.py
@@ -0,0 +1,265 @@
+import logging
+import re
+import textwrap
+from typing import Any
+from urllib.parse import quote_plus
+
+from aiohttp import ClientResponseError
+from discord import Message
+from discord.ext.commands import Cog
+
+from bot.bot import Bot
+from bot.constants import Channels
+from bot.utils.messages import wait_for_deletion
+
+log = logging.getLogger(__name__)
+
+GITHUB_RE = re.compile(
+ r'https://github\.com/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/blob/'
+ r'(?P<path>[^#>]+)(\?[^#>]+)?(#L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)'
+)
+
+GITHUB_GIST_RE = re.compile(
+ r'https://gist\.github\.com/([a-zA-Z0-9-]+)/(?P<gist_id>[a-zA-Z0-9]+)/*'
+ r'(?P<revision>[a-zA-Z0-9]*)/*#file-(?P<file_path>[^#>]+?)(\?[^#>]+)?'
+ r'(-L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)'
+)
+
+GITHUB_HEADERS = {'Accept': 'application/vnd.github.v3.raw'}
+
+GITLAB_RE = re.compile(
+ r'https://gitlab\.com/(?P<repo>[\w.-]+/[\w.-]+)/\-/blob/(?P<path>[^#>]+)'
+ r'(\?[^#>]+)?(#L(?P<start_line>\d+)(-(?P<end_line>\d+))?)'
+)
+
+BITBUCKET_RE = re.compile(
+ r'https://bitbucket\.org/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/src/(?P<ref>[0-9a-zA-Z]+)'
+ r'/(?P<file_path>[^#>]+)(\?[^#>]+)?(#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?)'
+)
+
+
+class CodeSnippets(Cog):
+ """
+ Cog that parses and sends code snippets to Discord.
+
+ Matches each message against a regex and prints the contents of all matched snippets.
+ """
+
+ async def _fetch_response(self, url: str, response_format: str, **kwargs) -> Any:
+ """Makes http requests using aiohttp."""
+ async with self.bot.http_session.get(url, raise_for_status=True, **kwargs) as response:
+ if response_format == 'text':
+ return await response.text()
+ elif response_format == 'json':
+ return await response.json()
+
+ def _find_ref(self, path: str, refs: tuple) -> tuple:
+ """Loops through all branches and tags to find the required ref."""
+ # Base case: there is no slash in the branch name
+ ref, file_path = path.split('/', 1)
+ # In case there are slashes in the branch name, we loop through all branches and tags
+ for possible_ref in refs:
+ if path.startswith(possible_ref['name'] + '/'):
+ ref = possible_ref['name']
+ file_path = path[len(ref) + 1:]
+ break
+ return ref, file_path
+
+ async def _fetch_github_snippet(
+ self,
+ repo: str,
+ path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitHub repo."""
+ # Search the GitHub API for the specified branch
+ branches = await self._fetch_response(
+ f'https://api.github.com/repos/{repo}/branches',
+ 'json',
+ headers=GITHUB_HEADERS
+ )
+ tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=GITHUB_HEADERS)
+ refs = branches + tags
+ ref, file_path = self._find_ref(path, refs)
+
+ file_contents = await self._fetch_response(
+ f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}',
+ 'text',
+ headers=GITHUB_HEADERS,
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ async def _fetch_github_gist_snippet(
+ self,
+ gist_id: str,
+ revision: str,
+ file_path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitHub gist."""
+ gist_json = await self._fetch_response(
+ f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}',
+ 'json',
+ headers=GITHUB_HEADERS,
+ )
+
+ # Check each file in the gist for the specified file
+ for gist_file in gist_json['files']:
+ if file_path == gist_file.lower().replace('.', '-'):
+ file_contents = await self._fetch_response(
+ gist_json['files'][gist_file]['raw_url'],
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line)
+ return ''
+
+ async def _fetch_gitlab_snippet(
+ self,
+ repo: str,
+ path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitLab repo."""
+ enc_repo = quote_plus(repo)
+
+ # Searches the GitLab API for the specified branch
+ branches = await self._fetch_response(
+ f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/branches',
+ 'json'
+ )
+ tags = await self._fetch_response(f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags', 'json')
+ refs = branches + tags
+ ref, file_path = self._find_ref(path, refs)
+ enc_ref = quote_plus(ref)
+ enc_file_path = quote_plus(file_path)
+
+ file_contents = await self._fetch_response(
+ f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}',
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ async def _fetch_bitbucket_snippet(
+ self,
+ repo: str,
+ ref: str,
+ file_path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a BitBucket repo."""
+ file_contents = await self._fetch_response(
+ f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}',
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str:
+ """
+ Given the entire file contents and target lines, creates a code block.
+
+ First, we split the file contents into a list of lines and then keep and join only the required
+ ones together.
+
+ We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent
+ markdown injection.
+
+ Finally, we surround the code with ``` characters.
+ """
+ # Parse start_line and end_line into integers
+ if end_line is None:
+ start_line = end_line = int(start_line)
+ else:
+ start_line = int(start_line)
+ end_line = int(end_line)
+
+ split_file_contents = file_contents.splitlines()
+
+ # Make sure that the specified lines are in range
+ if start_line > end_line:
+ start_line, end_line = end_line, start_line
+ if start_line > len(split_file_contents) or end_line < 1:
+ return ''
+ start_line = max(1, start_line)
+ end_line = min(len(split_file_contents), end_line)
+
+ # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection
+ required = '\n'.join(split_file_contents[start_line - 1:end_line])
+ required = textwrap.dedent(required).rstrip().replace('`', '`\u200b')
+
+ # Extracts the code language and checks whether it's a "valid" language
+ language = file_path.split('/')[-1].split('.')[-1]
+ trimmed_language = language.replace('-', '').replace('+', '').replace('_', '')
+ is_valid_language = trimmed_language.isalnum()
+ if not is_valid_language:
+ language = ''
+
+ # Adds a label showing the file path to the snippet
+ if start_line == end_line:
+ ret = f'`{file_path}` line {start_line}\n'
+ else:
+ ret = f'`{file_path}` lines {start_line} to {end_line}\n'
+
+ if len(required) != 0:
+ return f'{ret}```{language}\n{required}```'
+ # Returns an empty codeblock if the snippet is empty
+ return f'{ret}``` ```'
+
+ def __init__(self, bot: Bot):
+ """Initializes the cog's bot."""
+ self.bot = bot
+
+ self.pattern_handlers = [
+ (GITHUB_RE, self._fetch_github_snippet),
+ (GITHUB_GIST_RE, self._fetch_github_gist_snippet),
+ (GITLAB_RE, self._fetch_gitlab_snippet),
+ (BITBUCKET_RE, self._fetch_bitbucket_snippet)
+ ]
+
+ @Cog.listener()
+ async def on_message(self, message: Message) -> None:
+ """Checks if the message has a snippet link, removes the embed, then sends the snippet contents."""
+ if not message.author.bot:
+ all_snippets = []
+
+ for pattern, handler in self.pattern_handlers:
+ for match in pattern.finditer(message.content):
+ try:
+ snippet = await handler(**match.groupdict())
+ all_snippets.append((match.start(), snippet))
+ except ClientResponseError as error:
+ error_message = error.message # noqa: B306
+ log.log(
+ logging.DEBUG if error.status == 404 else logging.ERROR,
+ f'Failed to fetch code snippet from {match[0]!r}: {error.status} '
+ f'{error_message} for GET {error.request_info.real_url.human_repr()}'
+ )
+
+ # Sorts the list of snippets by their match index and joins them into a single message
+ message_to_send = '\n'.join(map(lambda x: x[1], sorted(all_snippets)))
+
+ if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15:
+ await message.edit(suppress=True)
+ if len(message_to_send) > 1000 and message.channel.id != Channels.bot_commands:
+ # Redirects to #bot-commands if the snippet contents are too long
+ await self.bot.wait_until_guild_available()
+ await message.channel.send(('The snippet you tried to send was too long. Please '
+ f'see <#{Channels.bot_commands}> for the full snippet.'))
+ bot_commands_channel = self.bot.get_channel(Channels.bot_commands)
+ await wait_for_deletion(
+ await bot_commands_channel.send(message_to_send),
+ (message.author.id,)
+ )
+ else:
+ await wait_for_deletion(
+ await message.channel.send(message_to_send),
+ (message.author.id,)
+ )
+
+
+def setup(bot: Bot) -> None:
+ """Load the CodeSnippets cog."""
+ bot.add_cog(CodeSnippets(bot))
diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py
index 5e2c4b417..834fee1b4 100644
--- a/bot/exts/info/information.py
+++ b/bot/exts/info/information.py
@@ -230,6 +230,11 @@ class Information(Cog):
if on_server and user.nick:
name = f"{user.nick} ({name})"
+ if user.public_flags.verified_bot:
+ name += f" {constants.Emojis.verified_bot}"
+ elif user.bot:
+ name += f" {constants.Emojis.bot}"
+
badges = []
for badge, is_set in user.public_flags:
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index d89e80acc..38d1ffc0e 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -54,8 +54,12 @@ class Infractions(InfractionScheduler, commands.Cog):
# region: Permanent infractions
@command()
- async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None:
+ async def warn(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:
"""Warn a user for the given reason."""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False)
if infraction is None:
return
@@ -63,8 +67,12 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_infraction(ctx, infraction, user)
@command()
- async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None:
+ async def kick(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:
"""Kick a user for the given reason."""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
await self.apply_kick(ctx, user, reason)
@command()
@@ -100,7 +108,7 @@ class Infractions(InfractionScheduler, commands.Cog):
@command(aliases=["mute"])
async def tempmute(
self, ctx: Context,
- user: Member,
+ user: FetchedMember,
duration: t.Optional[Expiry] = None,
*,
reason: t.Optional[str] = None
@@ -122,6 +130,10 @@ class Infractions(InfractionScheduler, commands.Cog):
If no duration is given, a one hour duration is used by default.
"""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
if duration is None:
duration = await Duration().convert(ctx, "1h")
await self.apply_mute(ctx, user, reason, expires_at=duration)
diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py
index 2f180e594..1ad5005de 100644
--- a/bot/exts/moderation/modpings.py
+++ b/bot/exts/moderation/modpings.py
@@ -3,11 +3,11 @@ import logging
from async_rediscache import RedisCache
from dateutil.parser import isoparse
-from discord import Member
+from discord import Embed, Member
from discord.ext.commands import Cog, Context, group, has_any_role
from bot.bot import Bot
-from bot.constants import Emojis, Guild, MODERATION_ROLES, Roles
+from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles
from bot.converters import Expiry
from bot.utils.scheduling import Scheduler
@@ -104,7 +104,9 @@ class ModPings(Cog):
self._role_scheduler.cancel(mod.id)
self._role_scheduler.schedule_at(duration, mod.id, self.reapply_role(mod))
- await ctx.send(f"{Emojis.check_mark} Moderators role has been removed until {until_date}.")
+ embed = Embed(timestamp=duration, colour=Colours.bright_green)
+ embed.set_footer(text="Moderators role has been removed until", icon_url=Icons.green_checkmark)
+ await ctx.send(embed=embed)
@modpings_group.command(name='on')
@has_any_role(*MODERATION_ROLES)
diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py
index 1dbb2a46b..fd856a7f4 100644
--- a/bot/exts/moderation/stream.py
+++ b/bot/exts/moderation/stream.py
@@ -70,6 +70,28 @@ class Stream(commands.Cog):
self._revoke_streaming_permission(member)
)
+ async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None:
+ """Suspend a member's stream."""
+ await self.bot.wait_until_guild_available()
+ voice_state = member.voice
+
+ if not voice_state:
+ return
+
+ # If the user is streaming.
+ if voice_state.self_stream:
+ # End user's stream by moving them to AFK voice channel and back.
+ original_vc = voice_state.channel
+ await member.move_to(ctx.guild.afk_channel)
+ await member.move_to(original_vc)
+
+ # Notify.
+ await ctx.send(f"{member.mention}'s stream has been suspended!")
+ log.debug(f"Successfully suspended stream from {member} ({member.id}).")
+ return
+
+ log.debug(f"No stream found to suspend from {member} ({member.id}).")
+
@commands.command(aliases=("streaming",))
@commands.has_any_role(*MODERATION_ROLES)
async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None:
@@ -170,10 +192,12 @@ class Stream(commands.Cog):
await ctx.send(f"{Emojis.check_mark} Revoked the permission to stream from {member.mention}.")
log.debug(f"Successfully revoked streaming permission from {member} ({member.id}).")
- return
- await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!")
- log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!")
+ else:
+ await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!")
+ log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!")
+
+ await self._suspend_stream(ctx, member)
@commands.command(aliases=('lstream',))
@commands.has_any_role(*MODERATION_ROLES)
diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py
index 3113a1149..6c21920a1 100644
--- a/bot/exts/utils/reminders.py
+++ b/bot/exts/utils/reminders.py
@@ -90,15 +90,18 @@ class Reminders(Cog):
delivery_dt: t.Optional[datetime],
) -> None:
"""Send an embed confirming the reminder change was made successfully."""
- embed = discord.Embed()
- embed.colour = discord.Colour.green()
- embed.title = random.choice(POSITIVE_REPLIES)
- embed.description = on_success
+ embed = discord.Embed(
+ description=on_success,
+ colour=discord.Colour.green(),
+ title=random.choice(POSITIVE_REPLIES)
+ )
footer_str = f"ID: {reminder_id}"
+
if delivery_dt:
# Reminder deletion will have a `None` `delivery_dt`
- footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}"
+ footer_str += ', Due'
+ embed.timestamp = delivery_dt
embed.set_footer(text=footer_str)
diff --git a/bot/log.py b/bot/log.py
index e92233a33..4e20c005e 100644
--- a/bot/log.py
+++ b/bot/log.py
@@ -20,7 +20,6 @@ def setup() -> None:
logging.addLevelName(TRACE_LEVEL, "TRACE")
Logger.trace = _monkeypatch_trace
- log_level = TRACE_LEVEL if constants.DEBUG_MODE else logging.INFO
format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"
log_format = logging.Formatter(format_string)
@@ -30,7 +29,6 @@ def setup() -> None:
file_handler.setFormatter(log_format)
root_log = logging.getLogger()
- root_log.setLevel(log_level)
root_log.addHandler(file_handler)
if "COLOREDLOGS_LEVEL_STYLES" not in os.environ:
@@ -44,11 +42,9 @@ def setup() -> None:
if "COLOREDLOGS_LOG_FORMAT" not in os.environ:
coloredlogs.DEFAULT_LOG_FORMAT = format_string
- if "COLOREDLOGS_LOG_LEVEL" not in os.environ:
- coloredlogs.DEFAULT_LOG_LEVEL = log_level
-
- coloredlogs.install(logger=root_log, stream=sys.stdout)
+ coloredlogs.install(level=logging.TRACE, logger=root_log, stream=sys.stdout)
+ root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO)
logging.getLogger("discord").setLevel(logging.WARNING)
logging.getLogger("websockets").setLevel(logging.WARNING)
logging.getLogger("chardet").setLevel(logging.WARNING)
@@ -57,6 +53,8 @@ def setup() -> None:
# Set back to the default of INFO even if asyncio's debug mode is enabled.
logging.getLogger("asyncio").setLevel(logging.INFO)
+ _set_trace_loggers()
+
def setup_sentry() -> None:
"""Set up the Sentry logging integrations."""
@@ -86,3 +84,30 @@ def _monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:
"""
if self.isEnabledFor(TRACE_LEVEL):
self._log(TRACE_LEVEL, msg, args, **kwargs)
+
+
+def _set_trace_loggers() -> None:
+ """
+ Set loggers to the trace level according to the value from the BOT_TRACE_LOGGERS env var.
+
+ When the env var is a list of logger names delimited by a comma,
+ each of the listed loggers will be set to the trace level.
+
+ If this list is prefixed with a "!", all of the loggers except the listed ones will be set to the trace level.
+
+ Otherwise if the env var begins with a "*",
+ the root logger is set to the trace level and other contents are ignored.
+ """
+ level_filter = constants.Bot.trace_loggers
+ if level_filter:
+ if level_filter.startswith("*"):
+ logging.getLogger().setLevel(logging.TRACE)
+
+ elif level_filter.startswith("!"):
+ logging.getLogger().setLevel(logging.TRACE)
+ for logger_name in level_filter.strip("!,").split(","):
+ logging.getLogger(logger_name).setLevel(logging.DEBUG)
+
+ else:
+ for logger_name in level_filter.strip(",").split(","):
+ logging.getLogger(logger_name).setLevel(logging.TRACE)
diff --git a/config-default.yml b/config-default.yml
index 9f4c9b80b..b9f6b40ac 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -1,7 +1,8 @@
bot:
- prefix: "!"
- sentry_dsn: !ENV "BOT_SENTRY_DSN"
- token: !ENV "BOT_TOKEN"
+ prefix: "!"
+ sentry_dsn: !ENV "BOT_SENTRY_DSN"
+ token: !ENV "BOT_TOKEN"
+ trace_loggers: !ENV "BOT_TRACE_LOGGERS"
clean:
# Maximum number of messages to traverse for clean commands
@@ -46,6 +47,8 @@ style:
badge_partner: "<:partner:748666453242413136>"
badge_staff: "<:discord_staff:743882896498098226>"
badge_verified_bot_developer: "<:verified_bot_dev:743882897299210310>"
+ bot: "<:bot:812712599464443914>"
+ verified_bot: "<:verified_bot:811645219220750347>"
defcon_shutdown: "<:defcondisabled:470326273952972810>"
defcon_unshutdown: "<:defconenabled:470326274213150730>"
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
new file mode 100644
index 000000000..bd4fb5942
--- /dev/null
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -0,0 +1,550 @@
+import unittest
+from unittest.mock import AsyncMock, MagicMock, call, patch
+
+from discord.ext.commands import errors
+
+from bot.api import ResponseCodeError
+from bot.errors import InvalidInfractedUser, LockedResourceError
+from bot.exts.backend.error_handler import ErrorHandler, setup
+from bot.exts.info.tags import Tags
+from bot.exts.moderation.silence import Silence
+from bot.utils.checks import InWhitelistCheckFailure
+from tests.helpers import MockBot, MockContext, MockGuild, MockRole
+
+
+class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for error handler functionality."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext(bot=self.bot)
+
+ async def test_error_handler_already_handled(self):
+ """Should not do anything when error is already handled by local error handler."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ error = errors.CommandError()
+ error.handled = "foo"
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_command_not_found_error_not_invoked_by_handler(self):
+ """Should try first (un)silence channel, when fail, try to get tag."""
+ error = errors.CommandNotFound()
+ test_cases = (
+ {
+ "try_silence_return": True,
+ "called_try_get_tag": False
+ },
+ {
+ "try_silence_return": False,
+ "called_try_get_tag": False
+ },
+ {
+ "try_silence_return": False,
+ "called_try_get_tag": True
+ }
+ )
+ cog = ErrorHandler(self.bot)
+ cog.try_silence = AsyncMock()
+ cog.try_get_tag = AsyncMock()
+
+ for case in test_cases:
+ with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]):
+ self.ctx.reset_mock()
+ cog.try_silence.reset_mock(return_value=True)
+ cog.try_get_tag.reset_mock()
+
+ cog.try_silence.return_value = case["try_silence_return"]
+ self.ctx.channel.id = 1234
+
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+
+ if case["try_silence_return"]:
+ cog.try_get_tag.assert_not_awaited()
+ cog.try_silence.assert_awaited_once()
+ else:
+ cog.try_silence.assert_awaited_once()
+ cog.try_get_tag.assert_awaited_once()
+
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_command_not_found_error_invoked_by_handler(self):
+ """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`."""
+ ctx = MockContext(bot=self.bot, invoked_from_error_handler=True)
+
+ cog = ErrorHandler(self.bot)
+ cog.try_silence = AsyncMock()
+ cog.try_get_tag = AsyncMock()
+
+ error = errors.CommandNotFound()
+
+ self.assertIsNone(await cog.on_command_error(ctx, error))
+
+ cog.try_silence.assert_not_awaited()
+ cog.try_get_tag.assert_not_awaited()
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_user_input_error(self):
+ """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ cog.handle_user_input_error = AsyncMock()
+ error = errors.UserInputError()
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)
+
+ async def test_error_handler_check_failure(self):
+ """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ cog.handle_check_failure = AsyncMock()
+ error = errors.CheckFailure()
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)
+
+ async def test_error_handler_command_on_cooldown(self):
+ """Should send error with `ctx.send` when error is `CommandOnCooldown`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ error = errors.CommandOnCooldown(10, 9)
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.ctx.send.assert_awaited_once_with(error)
+
+ async def test_error_handler_command_invoke_error(self):
+ """Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_api_error = AsyncMock()
+ cog.handle_unexpected_error = AsyncMock()
+ test_cases = (
+ {
+ "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))),
+ "expect_mock_call": cog.handle_api_error
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(TypeError)),
+ "expect_mock_call": cog.handle_unexpected_error
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))),
+ "expect_mock_call": "send"
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(InvalidInfractedUser(self.ctx.author))),
+ "expect_mock_call": "send"
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]):
+ self.ctx.send.reset_mock()
+ self.assertIsNone(await cog.on_command_error(*case["args"]))
+ if case["expect_mock_call"] == "send":
+ self.ctx.send.assert_awaited_once()
+ else:
+ case["expect_mock_call"].assert_awaited_once_with(
+ self.ctx, case["args"][1].original
+ )
+
+ async def test_error_handler_conversion_error(self):
+ """Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_api_error = AsyncMock()
+ cog.handle_unexpected_error = AsyncMock()
+ cases = (
+ {
+ "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())),
+ "mock_function_to_call": cog.handle_api_error
+ },
+ {
+ "error": errors.ConversionError(AsyncMock(), TypeError),
+ "mock_function_to_call": cog.handle_unexpected_error
+ }
+ )
+
+ for case in cases:
+ with self.subTest(**case):
+ self.assertIsNone(await cog.on_command_error(self.ctx, case["error"]))
+ case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original)
+
+ async def test_error_handler_two_other_errors(self):
+ """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_unexpected_error = AsyncMock()
+ errs = (
+ errors.MaxConcurrencyReached(1, MagicMock()),
+ errors.ExtensionError(name="foo")
+ )
+
+ for err in errs:
+ with self.subTest(error=err):
+ cog.handle_unexpected_error.reset_mock()
+ self.assertIsNone(await cog.on_command_error(self.ctx, err))
+ cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)
+
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_error_handler_other_errors(self, log_mock):
+ """Should `log.debug` other errors."""
+ cog = ErrorHandler(self.bot)
+ error = errors.DisabledCommand() # Use this just as a other error
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ log_mock.debug.assert_called_once()
+
+
+class TrySilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Test for helper functions that handle `CommandNotFound` error."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.silence = Silence(self.bot)
+ self.bot.get_command.return_value = self.silence.silence
+ self.ctx = MockContext(bot=self.bot)
+ self.cog = ErrorHandler(self.bot)
+
+ async def test_try_silence_context_invoked_from_error_handler(self):
+ """Should set `Context.invoked_from_error_handler` to `True`."""
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_silence(self.ctx)
+ self.assertTrue(hasattr(self.ctx, "invoked_from_error_handler"))
+ self.assertTrue(self.ctx.invoked_from_error_handler)
+
+ async def test_try_silence_get_command(self):
+ """Should call `get_command` with `silence`."""
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_silence(self.ctx)
+ self.bot.get_command.assert_called_once_with("silence")
+
+ async def test_try_silence_no_permissions_to_run(self):
+ """Should return `False` because missing permissions."""
+ self.ctx.invoked_with = "foo"
+ self.bot.get_command.return_value.can_run = AsyncMock(return_value=False)
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+ async def test_try_silence_no_permissions_to_run_command_error(self):
+ """Should return `False` because `CommandError` raised (no permissions)."""
+ self.ctx.invoked_with = "foo"
+ self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError())
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+ async def test_try_silence_silencing(self):
+ """Should run silence command with correct arguments."""
+ self.bot.get_command.return_value.can_run = AsyncMock(return_value=True)
+ test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh")
+
+ for case in test_cases:
+ with self.subTest(message=case):
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = case
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(
+ self.bot.get_command.return_value,
+ duration=min(case.count("h")*2, 15)
+ )
+
+ async def test_try_silence_unsilence(self):
+ """Should call unsilence command."""
+ self.silence.silence.can_run = AsyncMock(return_value=True)
+ test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh")
+
+ for case in test_cases:
+ with self.subTest(message=case):
+ self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence)
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = case
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence)
+
+ async def test_try_silence_no_match(self):
+ """Should return `False` when message don't match."""
+ self.ctx.invoked_with = "foo"
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+
+class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for `try_get_tag` function."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext()
+ self.tag = Tags(self.bot)
+ self.cog = ErrorHandler(self.bot)
+ self.bot.get_command.return_value = self.tag.get_command
+
+ async def test_try_get_tag_get_command(self):
+ """Should call `Bot.get_command` with `tags get` argument."""
+ self.bot.get_command.reset_mock()
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_get_tag(self.ctx)
+ self.bot.get_command.assert_called_once_with("tags get")
+
+ async def test_try_get_tag_invoked_from_error_handler(self):
+ """`self.ctx` should have `invoked_from_error_handler` `True`."""
+ self.ctx.invoked_from_error_handler = False
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_get_tag(self.ctx)
+ self.assertTrue(self.ctx.invoked_from_error_handler)
+
+ async def test_try_get_tag_no_permissions(self):
+ """Test how to handle checks failing."""
+ self.tag.get_command.can_run = AsyncMock(return_value=False)
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+
+ async def test_try_get_tag_command_error(self):
+ """Should call `on_command_error` when `CommandError` raised."""
+ err = errors.CommandError()
+ self.tag.get_command.can_run = AsyncMock(side_effect=err)
+ self.cog.on_command_error = AsyncMock()
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.cog.on_command_error.assert_awaited_once_with(self.ctx, err)
+
+ @patch("bot.exts.backend.error_handler.TagNameConverter")
+ async def test_try_get_tag_convert_success(self, tag_converter):
+ """Converting tag should successful."""
+ self.ctx.invoked_with = "foo"
+ tag_converter.convert = AsyncMock(return_value="foo")
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ tag_converter.convert.assert_awaited_once_with(self.ctx, "foo")
+ self.ctx.invoke.assert_awaited_once()
+
+ @patch("bot.exts.backend.error_handler.TagNameConverter")
+ async def test_try_get_tag_convert_fail(self, tag_converter):
+ """Converting tag should raise `BadArgument`."""
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = "bar"
+ tag_converter.convert = AsyncMock(side_effect=errors.BadArgument())
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.ctx.invoke.assert_not_awaited()
+
+ async def test_try_get_tag_ctx_invoke(self):
+ """Should call `ctx.invoke` with proper args/kwargs."""
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo")
+
+ async def test_dont_call_suggestion_tag_sent(self):
+ """Should never call command suggestion if tag is already sent."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=True)
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_not_awaited()
+
+ @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234])
+ async def test_dont_call_suggestion_if_user_mod(self):
+ """Should not call command suggestion if user is a mod."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=False)
+ self.ctx.author.roles = [MockRole(id=1234)]
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_not_awaited()
+
+ async def test_call_suggestion(self):
+ """Should call command suggestion if user is not a mod."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=False)
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo")
+
+
+class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Individual error categories handler tests."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext(bot=self.bot)
+ self.cog = ErrorHandler(self.bot)
+
+ async def test_handle_input_error_handler_errors(self):
+ """Should handle each error probably."""
+ test_cases = (
+ {
+ "error": errors.MissingRequiredArgument(MagicMock()),
+ "call_prepared": True
+ },
+ {
+ "error": errors.TooManyArguments(),
+ "call_prepared": True
+ },
+ {
+ "error": errors.BadArgument(),
+ "call_prepared": True
+ },
+ {
+ "error": errors.BadUnionArgument(MagicMock(), MagicMock(), MagicMock()),
+ "call_prepared": True
+ },
+ {
+ "error": errors.ArgumentParsingError(),
+ "call_prepared": False
+ },
+ {
+ "error": errors.UserInputError(),
+ "call_prepared": True
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], call_prepared=case["call_prepared"]):
+ self.ctx.reset_mock()
+ self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"]))
+ self.ctx.send.assert_awaited_once()
+ if case["call_prepared"]:
+ self.ctx.send_help.assert_awaited_once()
+ else:
+ self.ctx.send_help.assert_not_awaited()
+
+ async def test_handle_check_failure_errors(self):
+ """Should await `ctx.send` when error is check failure."""
+ test_cases = (
+ {
+ "error": errors.BotMissingPermissions(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.BotMissingRole(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.BotMissingAnyRole(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.NoPrivateMessage(),
+ "call_ctx_send": True
+ },
+ {
+ "error": InWhitelistCheckFailure(1234),
+ "call_ctx_send": True
+ },
+ {
+ "error": ResponseCodeError(MagicMock()),
+ "call_ctx_send": False
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], call_ctx_send=case["call_ctx_send"]):
+ self.ctx.reset_mock()
+ await self.cog.handle_check_failure(self.ctx, case["error"])
+ if case["call_ctx_send"]:
+ self.ctx.send.assert_awaited_once()
+ else:
+ self.ctx.send.assert_not_awaited()
+
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_handle_api_error(self, log_mock):
+ """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code."""
+ test_cases = (
+ {
+ "error": ResponseCodeError(AsyncMock(status=400)),
+ "log_level": "debug"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=404)),
+ "log_level": "debug"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=550)),
+ "log_level": "warning"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=1000)),
+ "log_level": "warning"
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], log_level=case["log_level"]):
+ self.ctx.reset_mock()
+ log_mock.reset_mock()
+ await self.cog.handle_api_error(self.ctx, case["error"])
+ self.ctx.send.assert_awaited_once()
+ if case["log_level"] == "warning":
+ log_mock.warning.assert_called_once()
+ else:
+ log_mock.debug.assert_called_once()
+
+ @patch("bot.exts.backend.error_handler.push_scope")
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_handle_unexpected_error(self, log_mock, push_scope_mock):
+ """Should `ctx.send` this error, error log this and sent to Sentry."""
+ for case in (None, MockGuild()):
+ with self.subTest(guild=case):
+ self.ctx.reset_mock()
+ log_mock.reset_mock()
+ push_scope_mock.reset_mock()
+
+ self.ctx.guild = case
+ await self.cog.handle_unexpected_error(self.ctx, errors.CommandError())
+
+ self.ctx.send.assert_awaited_once()
+ log_mock.error.assert_called_once()
+ push_scope_mock.assert_called_once()
+
+ set_tag_calls = [
+ call("command", self.ctx.command.qualified_name),
+ call("message_id", self.ctx.message.id),
+ call("channel_id", self.ctx.channel.id),
+ ]
+ set_extra_calls = [
+ call("full_message", self.ctx.message.content)
+ ]
+ if case:
+ url = (
+ f"https://discordapp.com/channels/"
+ f"{self.ctx.guild.id}/{self.ctx.channel.id}/{self.ctx.message.id}"
+ )
+ set_extra_calls.append(call("jump_to", url))
+
+ push_scope_mock.set_tag.has_calls(set_tag_calls)
+ push_scope_mock.set_extra.has_calls(set_extra_calls)
+
+
+class OtherErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Other `ErrorHandler` tests."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext()
+
+ async def test_get_help_command_command_specified(self):
+ """Should return coroutine of help command of specified command."""
+ self.ctx.command = "foo"
+ result = ErrorHandler.get_help_command(self.ctx)
+ expected = self.ctx.send_help("foo")
+ self.assertEqual(result.__qualname__, expected.__qualname__)
+ self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals)
+
+ # Await coroutines to avoid warnings
+ await result
+ await expected
+
+ async def test_get_help_command_no_command_specified(self):
+ """Should return coroutine of help command."""
+ self.ctx.command = None
+ result = ErrorHandler.get_help_command(self.ctx)
+ expected = self.ctx.send_help()
+ self.assertEqual(result.__qualname__, expected.__qualname__)
+ self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals)
+
+ # Await coroutines to avoid warnings
+ await result
+ await expected
+
+
+class ErrorHandlerSetupTests(unittest.TestCase):
+ """Tests for `ErrorHandler` `setup` function."""
+
+ def test_setup(self):
+ """Should call `bot.add_cog` with `ErrorHandler`."""
+ bot = MockBot()
+ setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index a996ce477..770660fe3 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -281,6 +281,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
user = helpers.MockMember()
+ user.public_flags = unittest.mock.MagicMock(verified_bot=False)
user.nick = None
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
user.colour = 0
@@ -297,6 +298,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
user = helpers.MockMember()
+ user.public_flags = unittest.mock.MagicMock(verified_bot=False)
user.nick = "Cat lover"
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
user.colour = 0
diff --git a/tests/helpers.py b/tests/helpers.py
index 529664e67..86cc635f8 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -404,6 +404,7 @@ message_instance = discord.Message(state=state, channel=channel, data=message_da
# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`
context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock())
+context_instance.invoked_from_error_handler = None
class MockContext(CustomMockMixin, unittest.mock.MagicMock):
@@ -421,6 +422,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
+ self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False)
attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock())