diff options
Diffstat (limited to 'bot/bot.py')
-rw-r--r-- | bot/bot.py | 155 |
1 files changed, 127 insertions, 28 deletions
@@ -1,21 +1,43 @@ +import asyncio +import enum import logging import socket -from traceback import format_exc -from typing import List +from typing import Optional, Union +import async_timeout +import discord from aiohttp import AsyncResolver, ClientSession, TCPConnector -from discord import DiscordException, Embed +from discord import DiscordException, Embed, Guild, User from discord.ext import commands from bot.constants import Channels, Client +from bot.utils.decorators import mock_in_debug log = logging.getLogger(__name__) -__all__ = ('SeasonalBot', 'bot') +__all__ = ("AssetType", "SeasonalBot", "bot") + + +class AssetType(enum.Enum): + """ + Discord media assets. + + The values match exactly the kwarg keys that can be passed to `Guild.edit` or `User.edit`. + """ + + BANNER = "banner" + AVATAR = "avatar" + SERVER_ICON = "icon" class SeasonalBot(commands.Bot): - """Base bot instance.""" + """ + Base bot instance. + + While in debug mode, the asset upload methods (avatar, banner, ...) will not + perform the upload, and will instead only log the passed download urls and pretend + that the upload was successful. See the `mock_in_debug` decorator for further details. + """ def __init__(self, **kwargs): super().__init__(**kwargs) @@ -23,22 +45,106 @@ class SeasonalBot(commands.Bot): connector=TCPConnector(resolver=AsyncResolver(), family=socket.AF_INET) ) - def load_extensions(self, exts: List[str]) -> None: - """Unload all current extensions, then load the given extensions.""" - # Unload all cogs - extensions = list(self.extensions.keys()) - for extension in extensions: - if extension not in ["bot.seasons", "bot.help"]: # We shouldn't unload the manager and help. - self.unload_extension(extension) - - # Load in the list of cogs that was passed in here - for extension in exts: - cog = extension.split(".")[-1] - try: - self.load_extension(extension) - log.info(f'Successfully loaded extension: {cog}') - except Exception as e: - log.error(f'Failed to load extension {cog}: {repr(e)} {format_exc()}') + @property + def member(self) -> Optional[discord.Member]: + """Retrieves the guild member object for the bot.""" + guild = self.get_guild(Client.guild) + if not guild: + return None + return guild.me + + def add_cog(self, cog: commands.Cog) -> None: + """ + Delegate to super to register `cog`. + + This only serves to make the info log, so that extensions don't have to. + """ + super().add_cog(cog) + log.info(f"Cog loaded: {cog.qualified_name}") + + async def on_command_error(self, context: commands.Context, exception: DiscordException) -> None: + """Check command errors for UserInputError and reset the cooldown if thrown.""" + if isinstance(exception, commands.UserInputError): + context.command.reset_cooldown(context) + else: + await super().on_command_error(context, exception) + + async def _fetch_image(self, url: str) -> bytes: + """Retrieve and read image from `url`.""" + log.debug(f"Getting image from: {url}") + async with self.http_session.get(url) as resp: + return await resp.read() + + async def _apply_asset(self, target: Union[Guild, User], asset: AssetType, url: str) -> bool: + """ + Internal method for applying media assets to the guild or the bot. + + This shouldn't be called directly. The purpose of this method is mainly generic + error handling to reduce needless code repetition. + + Return True if upload was successful, False otherwise. + """ + log.info(f"Attempting to set {asset.name}: {url}") + + kwargs = {asset.value: await self._fetch_image(url)} + try: + async with async_timeout.timeout(5): + await target.edit(**kwargs) + + except asyncio.TimeoutError: + log.info("Asset upload timed out") + return False + + except discord.HTTPException as discord_error: + log.exception("Asset upload failed", exc_info=discord_error) + return False + + else: + log.info(f"Asset successfully applied") + return True + + @mock_in_debug(return_value=True) + async def set_banner(self, url: str) -> bool: + """Set the guild's banner to image at `url`.""" + guild = self.get_guild(Client.guild) + if guild is None: + log.info("Failed to get guild instance, aborting asset upload") + return False + + return await self._apply_asset(guild, AssetType.BANNER, url) + + @mock_in_debug(return_value=True) + async def set_icon(self, url: str) -> bool: + """Sets the guild's icon to image at `url`.""" + guild = self.get_guild(Client.guild) + if guild is None: + log.info("Failed to get guild instance, aborting asset upload") + return False + + return await self._apply_asset(guild, AssetType.SERVER_ICON, url) + + @mock_in_debug(return_value=True) + async def set_avatar(self, url: str) -> bool: + """Set the bot's avatar to image at `url`.""" + return await self._apply_asset(self.user, AssetType.AVATAR, url) + + @mock_in_debug(return_value=True) + async def set_nickname(self, new_name: str) -> bool: + """Set the bot nickname in the main guild to `new_name`.""" + member = self.member + if member is None: + log.info("Failed to get bot member instance, aborting asset upload") + return False + + log.info(f"Attempting to set nickname to {new_name}") + try: + await member.edit(nick=new_name) + except discord.HTTPException as discord_error: + log.exception("Setting nickname failed", exc_info=discord_error) + return False + else: + log.info("Nickname set successfully") + return True async def send_log(self, title: str, details: str = None, *, icon: str = None) -> None: """Send an embed message to the devlog channel.""" @@ -56,12 +162,5 @@ class SeasonalBot(commands.Bot): await devlog.send(embed=embed) - async def on_command_error(self, context: commands.Context, exception: DiscordException) -> None: - """Check command errors for UserInputError and reset the cooldown if thrown.""" - if isinstance(exception, commands.UserInputError): - context.command.reset_cooldown(context) - else: - await super().on_command_error(context, exception) - bot = SeasonalBot(command_prefix=Client.prefix) |