aboutsummaryrefslogtreecommitdiffstats
path: root/bot/bot.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/bot.py')
-rw-r--r--bot/bot.py155
1 files changed, 127 insertions, 28 deletions
diff --git a/bot/bot.py b/bot/bot.py
index 8b389b6a..87575fde 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -1,21 +1,43 @@
+import asyncio
+import enum
import logging
import socket
-from traceback import format_exc
-from typing import List
+from typing import Optional, Union
+import async_timeout
+import discord
from aiohttp import AsyncResolver, ClientSession, TCPConnector
-from discord import DiscordException, Embed
+from discord import DiscordException, Embed, Guild, User
from discord.ext import commands
from bot.constants import Channels, Client
+from bot.utils.decorators import mock_in_debug
log = logging.getLogger(__name__)
-__all__ = ('SeasonalBot', 'bot')
+__all__ = ("AssetType", "SeasonalBot", "bot")
+
+
+class AssetType(enum.Enum):
+ """
+ Discord media assets.
+
+ The values match exactly the kwarg keys that can be passed to `Guild.edit` or `User.edit`.
+ """
+
+ BANNER = "banner"
+ AVATAR = "avatar"
+ SERVER_ICON = "icon"
class SeasonalBot(commands.Bot):
- """Base bot instance."""
+ """
+ Base bot instance.
+
+ While in debug mode, the asset upload methods (avatar, banner, ...) will not
+ perform the upload, and will instead only log the passed download urls and pretend
+ that the upload was successful. See the `mock_in_debug` decorator for further details.
+ """
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -23,22 +45,106 @@ class SeasonalBot(commands.Bot):
connector=TCPConnector(resolver=AsyncResolver(), family=socket.AF_INET)
)
- def load_extensions(self, exts: List[str]) -> None:
- """Unload all current extensions, then load the given extensions."""
- # Unload all cogs
- extensions = list(self.extensions.keys())
- for extension in extensions:
- if extension not in ["bot.seasons", "bot.help"]: # We shouldn't unload the manager and help.
- self.unload_extension(extension)
-
- # Load in the list of cogs that was passed in here
- for extension in exts:
- cog = extension.split(".")[-1]
- try:
- self.load_extension(extension)
- log.info(f'Successfully loaded extension: {cog}')
- except Exception as e:
- log.error(f'Failed to load extension {cog}: {repr(e)} {format_exc()}')
+ @property
+ def member(self) -> Optional[discord.Member]:
+ """Retrieves the guild member object for the bot."""
+ guild = self.get_guild(Client.guild)
+ if not guild:
+ return None
+ return guild.me
+
+ def add_cog(self, cog: commands.Cog) -> None:
+ """
+ Delegate to super to register `cog`.
+
+ This only serves to make the info log, so that extensions don't have to.
+ """
+ super().add_cog(cog)
+ log.info(f"Cog loaded: {cog.qualified_name}")
+
+ async def on_command_error(self, context: commands.Context, exception: DiscordException) -> None:
+ """Check command errors for UserInputError and reset the cooldown if thrown."""
+ if isinstance(exception, commands.UserInputError):
+ context.command.reset_cooldown(context)
+ else:
+ await super().on_command_error(context, exception)
+
+ async def _fetch_image(self, url: str) -> bytes:
+ """Retrieve and read image from `url`."""
+ log.debug(f"Getting image from: {url}")
+ async with self.http_session.get(url) as resp:
+ return await resp.read()
+
+ async def _apply_asset(self, target: Union[Guild, User], asset: AssetType, url: str) -> bool:
+ """
+ Internal method for applying media assets to the guild or the bot.
+
+ This shouldn't be called directly. The purpose of this method is mainly generic
+ error handling to reduce needless code repetition.
+
+ Return True if upload was successful, False otherwise.
+ """
+ log.info(f"Attempting to set {asset.name}: {url}")
+
+ kwargs = {asset.value: await self._fetch_image(url)}
+ try:
+ async with async_timeout.timeout(5):
+ await target.edit(**kwargs)
+
+ except asyncio.TimeoutError:
+ log.info("Asset upload timed out")
+ return False
+
+ except discord.HTTPException as discord_error:
+ log.exception("Asset upload failed", exc_info=discord_error)
+ return False
+
+ else:
+ log.info(f"Asset successfully applied")
+ return True
+
+ @mock_in_debug(return_value=True)
+ async def set_banner(self, url: str) -> bool:
+ """Set the guild's banner to image at `url`."""
+ guild = self.get_guild(Client.guild)
+ if guild is None:
+ log.info("Failed to get guild instance, aborting asset upload")
+ return False
+
+ return await self._apply_asset(guild, AssetType.BANNER, url)
+
+ @mock_in_debug(return_value=True)
+ async def set_icon(self, url: str) -> bool:
+ """Sets the guild's icon to image at `url`."""
+ guild = self.get_guild(Client.guild)
+ if guild is None:
+ log.info("Failed to get guild instance, aborting asset upload")
+ return False
+
+ return await self._apply_asset(guild, AssetType.SERVER_ICON, url)
+
+ @mock_in_debug(return_value=True)
+ async def set_avatar(self, url: str) -> bool:
+ """Set the bot's avatar to image at `url`."""
+ return await self._apply_asset(self.user, AssetType.AVATAR, url)
+
+ @mock_in_debug(return_value=True)
+ async def set_nickname(self, new_name: str) -> bool:
+ """Set the bot nickname in the main guild to `new_name`."""
+ member = self.member
+ if member is None:
+ log.info("Failed to get bot member instance, aborting asset upload")
+ return False
+
+ log.info(f"Attempting to set nickname to {new_name}")
+ try:
+ await member.edit(nick=new_name)
+ except discord.HTTPException as discord_error:
+ log.exception("Setting nickname failed", exc_info=discord_error)
+ return False
+ else:
+ log.info("Nickname set successfully")
+ return True
async def send_log(self, title: str, details: str = None, *, icon: str = None) -> None:
"""Send an embed message to the devlog channel."""
@@ -56,12 +162,5 @@ class SeasonalBot(commands.Bot):
await devlog.send(embed=embed)
- async def on_command_error(self, context: commands.Context, exception: DiscordException) -> None:
- """Check command errors for UserInputError and reset the cooldown if thrown."""
- if isinstance(exception, commands.UserInputError):
- context.command.reset_cooldown(context)
- else:
- await super().on_command_error(context, exception)
-
bot = SeasonalBot(command_prefix=Client.prefix)