diff options
Diffstat (limited to 'bot/exts/backend')
| -rw-r--r-- | bot/exts/backend/__init__.py | 0 | ||||
| -rw-r--r-- | bot/exts/backend/alias.py | 87 | ||||
| -rw-r--r-- | bot/exts/backend/config_verifier.py | 40 | ||||
| -rw-r--r-- | bot/exts/backend/error_handler.py | 287 | ||||
| -rw-r--r-- | bot/exts/backend/logging.py | 42 | ||||
| -rw-r--r-- | bot/exts/backend/sync/__init__.py | 8 | ||||
| -rw-r--r-- | bot/exts/backend/sync/_cog.py | 180 | ||||
| -rw-r--r-- | bot/exts/backend/sync/_syncers.py | 347 |
8 files changed, 991 insertions, 0 deletions
diff --git a/bot/exts/backend/__init__.py b/bot/exts/backend/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/bot/exts/backend/__init__.py diff --git a/bot/exts/backend/alias.py b/bot/exts/backend/alias.py new file mode 100644 index 000000000..c6ba8d6f3 --- /dev/null +++ b/bot/exts/backend/alias.py @@ -0,0 +1,87 @@ +import inspect +import logging + +from discord import Colour, Embed +from discord.ext.commands import ( + Cog, Command, Context, + clean_content, command, group, +) + +from bot.bot import Bot +from bot.converters import TagNameConverter +from bot.pagination import LinePaginator + +log = logging.getLogger(__name__) + + +class Alias (Cog): + """Aliases for commonly used commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + async def invoke(self, ctx: Context, cmd_name: str, *args, **kwargs) -> None: + """Invokes a command with args and kwargs.""" + log.debug(f"{cmd_name} was invoked through an alias") + cmd = self.bot.get_command(cmd_name) + if not cmd: + return log.info(f'Did not find command "{cmd_name}" to invoke.') + elif not await cmd.can_run(ctx): + return log.info( + f'{str(ctx.author)} tried to run the command "{cmd_name}" but lacks permission.' + ) + + await ctx.invoke(cmd, *args, **kwargs) + + @command(name='aliases') + async def aliases_command(self, ctx: Context) -> None: + """Show configured aliases on the bot.""" + embed = Embed( + title='Configured aliases', + colour=Colour.blue() + ) + await LinePaginator.paginate( + ( + f"• `{ctx.prefix}{value.name}` " + f"=> `{ctx.prefix}{name[:-len('_alias')].replace('_', ' ')}`" + for name, value in inspect.getmembers(self) + if isinstance(value, Command) and name.endswith('_alias') + ), + ctx, embed, empty=False, max_lines=20 + ) + + @command(name="exception", hidden=True) + async def tags_get_traceback_alias(self, ctx: Context) -> None: + """Alias for invoking <prefix>tags get traceback.""" + await self.invoke(ctx, "tags get", tag_name="traceback") + + @group(name="get", + aliases=("show", "g"), + hidden=True, + invoke_without_command=True) + async def get_group_alias(self, ctx: Context) -> None: + """Group for reverse aliases for commands like `tags get`, allowing for `get tags` or `get docs`.""" + pass + + @get_group_alias.command(name="tags", aliases=("tag", "t"), hidden=True) + async def tags_get_alias( + self, ctx: Context, *, tag_name: TagNameConverter = None + ) -> None: + """ + Alias for invoking <prefix>tags get [tag_name]. + + tag_name: str - tag to be viewed. + """ + await self.invoke(ctx, "tags get", tag_name=tag_name) + + @get_group_alias.command(name="docs", aliases=("doc", "d"), hidden=True) + async def docs_get_alias( + self, ctx: Context, symbol: clean_content = None + ) -> None: + """Alias for invoking <prefix>docs get [symbol].""" + await self.invoke(ctx, "docs get", symbol) + + +def setup(bot: Bot) -> None: + """Load the Alias cog.""" + bot.add_cog(Alias(bot)) diff --git a/bot/exts/backend/config_verifier.py b/bot/exts/backend/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/exts/backend/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): + """Verify config on startup.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + + async def verify_channels(self) -> None: + """ + Verify channels. + + If any channels in config aren't present in server, log them in a warning. + """ + await self.bot.wait_until_guild_available() + server = self.bot.get_guild(constants.Guild.id) + + server_channel_ids = {channel.id for channel in server.channels} + invalid_channels = [ + channel_name for channel_name, channel_id in constants.Channels + if channel_id not in server_channel_ids + ] + + if invalid_channels: + log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: + """Load the ConfigVerifier cog.""" + bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py new file mode 100644 index 000000000..f9d4de638 --- /dev/null +++ b/bot/exts/backend/error_handler.py @@ -0,0 +1,287 @@ +import contextlib +import logging +import typing as t + +from discord import Embed +from discord.ext.commands import Cog, Context, errors +from sentry_sdk import push_scope + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Colours +from bot.converters import TagNameConverter +from bot.utils.checks import InWhitelistCheckFailure + +log = logging.getLogger(__name__) + + +class ErrorHandler(Cog): + """Handles errors emitted from commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + def _get_error_embed(self, title: str, body: str) -> Embed: + """Return an embed that contains the exception.""" + return Embed( + title=title, + colour=Colours.soft_red, + description=body + ) + + @Cog.listener() + async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: + """ + Provide generic command error handling. + + Error handling is deferred to any local error handler, if present. This is done by + checking for the presence of a `handled` attribute on the error. + + Error handling emits a single error message in the invoking context `ctx` and a log message, + prioritised as follows: + + 1. If the name fails to match a command: + * If it matches shh+ or unshh+, the channel is silenced or unsilenced respectively. + Otherwise if it matches a tag, the tag is invoked + * If CommandNotFound is raised when invoking the tag (determined by the presence of the + `invoked_from_error_handler` attribute), this error is treated as being unexpected + and therefore sends an error message + * Commands in the verification channel are ignored + 2. UserInputError: see `handle_user_input_error` + 3. CheckFailure: see `handle_check_failure` + 4. CommandOnCooldown: send an error message in the invoking context + 5. ResponseCodeError: see `handle_api_error` + 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` + """ + command = ctx.command + + if hasattr(e, "handled"): + log.trace(f"Command {command} had its error already handled locally; ignoring.") + return + + if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + if await self.try_silence(ctx): + return + if ctx.channel.id != Channels.verification: + # Try to look for a tag with the command's name + await self.try_get_tag(ctx) + return # Exit early to avoid logging. + elif isinstance(e, errors.UserInputError): + await self.handle_user_input_error(ctx, e) + elif isinstance(e, errors.CheckFailure): + await self.handle_check_failure(ctx, e) + elif isinstance(e, errors.CommandOnCooldown): + await ctx.send(e) + elif isinstance(e, errors.CommandInvokeError): + if isinstance(e.original, ResponseCodeError): + await self.handle_api_error(ctx, e.original) + else: + await self.handle_unexpected_error(ctx, e.original) + return # Exit early to avoid logging. + elif not isinstance(e, errors.DisabledCommand): + # ConversionError, MaxConcurrencyReached, ExtensionError + await self.handle_unexpected_error(ctx, e) + return # Exit early to avoid logging. + + log.debug( + f"Command {command} invoked by {ctx.message.author} with error " + f"{e.__class__.__name__}: {e}" + ) + + @staticmethod + def get_help_command(ctx: Context) -> t.Coroutine: + """Return a prepared `help` command invocation coroutine.""" + if ctx.command: + return ctx.send_help(ctx.command) + + return ctx.send_help() + + async def try_silence(self, ctx: Context) -> bool: + """ + Attempt to invoke the silence or unsilence command if invoke with matches a pattern. + + Respecting the checks if: + * invoked with `shh+` silence channel for amount of h's*2 with max of 15. + * invoked with `unshh+` unsilence channel + Return bool depending on success of command. + """ + command = ctx.invoked_with.lower() + silence_command = self.bot.get_command("silence") + ctx.invoked_from_error_handler = True + try: + if not await silence_command.can_run(ctx): + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + except errors.CommandError: + log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") + return False + if command.startswith("shh"): + await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) + return True + elif command.startswith("unshh"): + await ctx.invoke(self.bot.get_command("unsilence")) + return True + return False + + async def try_get_tag(self, ctx: Context) -> None: + """ + Attempt to display a tag by interpreting the command name as a tag name. + + The invocation of tags get respects its checks. Any CommandErrors raised will be handled + by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to + the context to prevent infinite recursion in the case of a CommandNotFound exception. + """ + tags_get_command = self.bot.get_command("tags get") + ctx.invoked_from_error_handler = True + + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except errors.CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) + return + + try: + tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) + except errors.BadArgument: + log.debug( + f"{ctx.author} tried to use an invalid command " + f"and the fallback tag failed validation in TagNameConverter." + ) + else: + with contextlib.suppress(ResponseCodeError): + await ctx.invoke(tags_get_command, tag_name=tag_name) + # Return to not raise the exception + return + + async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: + """ + Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. + + * MissingRequiredArgument: send an error message with arg name and the help command + * TooManyArguments: send an error message and the help command + * BadArgument: send an error message and the help command + * BadUnionArgument: send an error message including the error produced by the last converter + * ArgumentParsingError: send an error message + * Other: send an error message and the help command + """ + prepared_help_command = self.get_help_command(ctx) + + if isinstance(e, errors.MissingRequiredArgument): + embed = self._get_error_embed("Missing required argument", e.param.name) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.missing_required_argument") + elif isinstance(e, errors.TooManyArguments): + embed = self._get_error_embed("Too many arguments", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.too_many_arguments") + elif isinstance(e, errors.BadArgument): + embed = self._get_error_embed("Bad argument", str(e)) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.bad_argument") + elif isinstance(e, errors.BadUnionArgument): + embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}") + await ctx.send(embed=embed) + self.bot.stats.incr("errors.bad_union_argument") + elif isinstance(e, errors.ArgumentParsingError): + embed = self._get_error_embed("Argument parsing error", str(e)) + await ctx.send(embed=embed) + self.bot.stats.incr("errors.argument_parsing_error") + else: + embed = self._get_error_embed( + "Input error", + "Something about your input seems off. Check the arguments and try again." + ) + await ctx.send(embed=embed) + await prepared_help_command + self.bot.stats.incr("errors.other_user_input_error") + + @staticmethod + async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: + """ + Send an error message in `ctx` for certain types of CheckFailure. + + The following types are handled: + + * BotMissingPermissions + * BotMissingRole + * BotMissingAnyRole + * NoPrivateMessage + * InWhitelistCheckFailure + """ + bot_missing_errors = ( + errors.BotMissingPermissions, + errors.BotMissingRole, + errors.BotMissingAnyRole + ) + + if isinstance(e, bot_missing_errors): + ctx.bot.stats.incr("errors.bot_permission_error") + await ctx.send( + "Sorry, it looks like I don't have the permissions or roles I need to do that." + ) + elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): + ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") + await ctx.send(e) + + @staticmethod + async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: + """Send an error message in `ctx` for ResponseCodeError and log it.""" + if e.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + log.debug(f"API responded with 404 for command {ctx.command}") + ctx.bot.stats.incr("errors.api_error_404") + elif e.status == 400: + content = await e.response.json() + log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + await ctx.send("According to the API, your request is malformed.") + ctx.bot.stats.incr("errors.api_error_400") + elif 500 <= e.status < 600: + await ctx.send("Sorry, there seems to be an internal issue with the API.") + log.warning(f"API responded with {e.status} for command {ctx.command}") + ctx.bot.stats.incr("errors.api_internal_server_error") + else: + await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") + log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") + ctx.bot.stats.incr(f"errors.api_error_{e.status}") + + @staticmethod + async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: + """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" + await ctx.send( + f"Sorry, an unexpected error occurred. Please let us know!\n\n" + f"```{e.__class__.__name__}: {e}```" + ) + + ctx.bot.stats.incr("errors.unexpected") + + with push_scope() as scope: + scope.user = { + "id": ctx.author.id, + "username": str(ctx.author) + } + + scope.set_tag("command", ctx.command.qualified_name) + scope.set_tag("message_id", ctx.message.id) + scope.set_tag("channel_id", ctx.channel.id) + + scope.set_extra("full_message", ctx.message.content) + + if ctx.guild is not None: + scope.set_extra( + "jump_to", + f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" + ) + + log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) + + +def setup(bot: Bot) -> None: + """Load the ErrorHandler cog.""" + bot.add_cog(ErrorHandler(bot)) diff --git a/bot/exts/backend/logging.py b/bot/exts/backend/logging.py new file mode 100644 index 000000000..94fa2b139 --- /dev/null +++ b/bot/exts/backend/logging.py @@ -0,0 +1,42 @@ +import logging + +from discord import Embed +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels, DEBUG_MODE + + +log = logging.getLogger(__name__) + + +class Logging(Cog): + """Debug logging module.""" + + def __init__(self, bot: Bot): + self.bot = bot + + self.bot.loop.create_task(self.startup_greeting()) + + async def startup_greeting(self) -> None: + """Announce our presence to the configured devlog channel.""" + await self.bot.wait_until_guild_available() + log.info("Bot connected!") + + embed = Embed(description="Connected!") + embed.set_author( + name="Python Bot", + url="https://github.com/python-discord/bot", + icon_url=( + "https://raw.githubusercontent.com/" + "python-discord/branding/master/logos/logo_circle/logo_circle_large.png" + ) + ) + + if not DEBUG_MODE: + await self.bot.get_channel(Channels.dev_log).send(embed=embed) + + +def setup(bot: Bot) -> None: + """Load the Logging cog.""" + bot.add_cog(Logging(bot)) diff --git a/bot/exts/backend/sync/__init__.py b/bot/exts/backend/sync/__init__.py new file mode 100644 index 000000000..829098f79 --- /dev/null +++ b/bot/exts/backend/sync/__init__.py @@ -0,0 +1,8 @@ +from bot.bot import Bot + + +def setup(bot: Bot) -> None: + """Load the Sync cog.""" + # Defer import to reduce side effects from importing the sync package. + from bot.exts.backend.sync._cog import Sync + bot.add_cog(Sync(bot)) diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py new file mode 100644 index 000000000..6e85e2b7d --- /dev/null +++ b/bot/exts/backend/sync/_cog.py @@ -0,0 +1,180 @@ +import logging +from typing import Any, Dict + +from discord import Member, Role, User +from discord.ext import commands +from discord.ext.commands import Cog, Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.exts.backend.sync import _syncers + +log = logging.getLogger(__name__) + + +class Sync(Cog): + """Captures relevant events and sends them to the site.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.role_syncer = _syncers.RoleSyncer(self.bot) + self.user_syncer = _syncers.UserSyncer(self.bot) + + self.bot.loop.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Syncs the roles/users of the guild with the database.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: + """Send a PATCH request to partially update a user in the database.""" + try: + await self.bot.api_client.patch(f"bot/users/{user_id}", json=json) + except ResponseCodeError as e: + if e.response.status != 404: + raise + if not ignore_404: + log.warning("Unable to update user, got 404. Assuming race condition from join event.") + + @Cog.listener() + async def on_guild_role_create(self, role: Role) -> None: + """Adds newly create role to the database table over the API.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.post( + 'bot/roles', + json={ + 'colour': role.colour.value, + 'id': role.id, + 'name': role.name, + 'permissions': role.permissions.value, + 'position': role.position, + } + ) + + @Cog.listener() + async def on_guild_role_delete(self, role: Role) -> None: + """Deletes role from the database when it's deleted from the guild.""" + if role.guild.id != constants.Guild.id: + return + + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + @Cog.listener() + async def on_guild_role_update(self, before: Role, after: Role) -> None: + """Syncs role with the database if any of the stored attributes were updated.""" + if after.guild.id != constants.Guild.id: + return + + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: + await self.bot.api_client.put( + f'bot/roles/{after.id}', + json={ + 'colour': after.colour.value, + 'id': after.id, + 'name': after.name, + 'permissions': after.permissions.value, + 'position': after.position, + } + ) + + @Cog.listener() + async def on_member_join(self, member: Member) -> None: + """ + Adds a new user or updates existing user to the database when a member joins the guild. + + If the joining member is a user that is already known to the database (i.e., a user that + previously left), it will update the user's information. If the user is not yet known by + the database, the user is added. + """ + if member.guild.id != constants.Guild.id: + return + + packed = { + 'discriminator': int(member.discriminator), + 'id': member.id, + 'in_guild': True, + 'name': member.name, + 'roles': sorted(role.id for role in member.roles) + } + + got_error = False + + try: + # First try an update of the user to set the `in_guild` field and other + # fields that may have changed since the last time we've seen them. + await self.bot.api_client.put(f'bot/users/{member.id}', json=packed) + + except ResponseCodeError as e: + # If we didn't get 404, something else broke - propagate it up. + if e.response.status != 404: + raise + + got_error = True # yikes + + if got_error: + # If we got `404`, the user is new. Create them. + await self.bot.api_client.post('bot/users', json=packed) + + @Cog.listener() + async def on_member_remove(self, member: Member) -> None: + """Set the in_guild field to False when a member leaves the guild.""" + if member.guild.id != constants.Guild.id: + return + + await self.patch_user(member.id, json={"in_guild": False}) + + @Cog.listener() + async def on_member_update(self, before: Member, after: Member) -> None: + """Update the roles of the member in the database if a change is detected.""" + if after.guild.id != constants.Guild.id: + return + + if before.roles != after.roles: + updated_information = {"roles": sorted(role.id for role in after.roles)} + await self.patch_user(after.id, json=updated_information) + + @Cog.listener() + async def on_user_update(self, before: User, after: User) -> None: + """Update the user information in the database if a relevant change is detected.""" + attrs = ("name", "discriminator") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): + updated_information = { + "name": after.name, + "discriminator": int(after.discriminator), + } + # A 404 likely means the user is in another guild. + await self.patch_user(after.id, json=updated_information, ignore_404=True) + + @commands.group(name='sync') + @commands.has_permissions(administrator=True) + async def sync_group(self, ctx: Context) -> None: + """Run synchronizations between the bot and site manually.""" + + @sync_group.command(name='roles') + @commands.has_permissions(administrator=True) + async def sync_roles_command(self, ctx: Context) -> None: + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) + + @sync_group.command(name='users') + @commands.has_permissions(administrator=True) + async def sync_users_command(self, ctx: Context) -> None: + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py new file mode 100644 index 000000000..f7ba811bc --- /dev/null +++ b/bot/exts/backend/sync/_syncers.py @@ -0,0 +1,347 @@ +import abc +import asyncio +import logging +import typing as t +from collections import namedtuple +from functools import partial + +import discord +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot + +log = logging.getLogger(__name__) + +# These objects are declared as namedtuples because tuples are hashable, +# something that we make use of when diffing site roles against guild roles. +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) + + +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' + ) + + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + allowed_roles = [discord.Object(constants.Roles.core_developers)] + message = await channel.send( + f"{self._CORE_DEV_MENTION}{msg_content}", + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) + ) + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS + ) + + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except asyncio.TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.info(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) |