aboutsummaryrefslogtreecommitdiffstats
path: root/bot/bot.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/bot.py')
-rw-r--r--bot/bot.py148
1 files changed, 128 insertions, 20 deletions
diff --git a/bot/bot.py b/bot/bot.py
index 8b389b6a..47d63de9 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -1,13 +1,17 @@
+import asyncio
+import contextlib
import logging
import socket
-from traceback import format_exc
-from typing import List
+from typing import Optional
+import async_timeout
+import discord
from aiohttp import AsyncResolver, ClientSession, TCPConnector
from discord import DiscordException, Embed
from discord.ext import commands
from bot.constants import Channels, Client
+from bot.utils.decorators import mock_in_debug
log = logging.getLogger(__name__)
@@ -15,7 +19,13 @@ __all__ = ('SeasonalBot', 'bot')
class SeasonalBot(commands.Bot):
- """Base bot instance."""
+ """
+ Base bot instance.
+
+ While in debug mode, the asset upload methods (avatar, banner, ...) will not
+ perform the upload, and will instead only log the passed download urls and pretend
+ that the upload was successful. See the `mock_in_debug` decorator for further details.
+ """
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -23,23 +33,6 @@ class SeasonalBot(commands.Bot):
connector=TCPConnector(resolver=AsyncResolver(), family=socket.AF_INET)
)
- def load_extensions(self, exts: List[str]) -> None:
- """Unload all current extensions, then load the given extensions."""
- # Unload all cogs
- extensions = list(self.extensions.keys())
- for extension in extensions:
- if extension not in ["bot.seasons", "bot.help"]: # We shouldn't unload the manager and help.
- self.unload_extension(extension)
-
- # Load in the list of cogs that was passed in here
- for extension in exts:
- cog = extension.split(".")[-1]
- try:
- self.load_extension(extension)
- log.info(f'Successfully loaded extension: {cog}')
- except Exception as e:
- log.error(f'Failed to load extension {cog}: {repr(e)} {format_exc()}')
-
async def send_log(self, title: str, details: str = None, *, icon: str = None) -> None:
"""Send an embed message to the devlog channel."""
devlog = self.get_channel(Channels.devlog)
@@ -63,5 +56,120 @@ class SeasonalBot(commands.Bot):
else:
await super().on_command_error(context, exception)
+ @property
+ def member(self) -> Optional[discord.Member]:
+ """Retrieves the guild member object for the bot."""
+ guild = bot.get_guild(Client.guild)
+ if not guild:
+ return None
+ return guild.me
+
+ @mock_in_debug(return_value=True)
+ async def set_avatar(self, url: str) -> bool:
+ """Sets the bot's avatar based on a URL."""
+ # Track old avatar hash for later comparison
+ old_avatar = bot.user.avatar
+
+ image = await self._fetch_image(url)
+ with contextlib.suppress(discord.HTTPException, asyncio.TimeoutError):
+ async with async_timeout.timeout(5):
+ await bot.user.edit(avatar=image)
+
+ if bot.user.avatar != old_avatar:
+ log.debug(f"Avatar changed to {url}")
+ return True
+
+ log.warning(f"Changing avatar failed: {url}")
+ return False
+
+ @mock_in_debug(return_value=True)
+ async def set_banner(self, url: str) -> bool:
+ """Sets the guild's banner based on the provided `url`."""
+ guild = bot.get_guild(Client.guild)
+ old_banner = guild.banner
+
+ image = await self._fetch_image(url)
+ with contextlib.suppress(discord.HTTPException, asyncio.TimeoutError):
+ async with async_timeout.timeout(5):
+ await guild.edit(banner=image)
+
+ new_banner = bot.get_guild(Client.guild).banner
+ if new_banner != old_banner:
+ log.debug(f"Banner changed to {url}")
+ return True
+
+ log.warning(f"Changing banner failed: {url}")
+ return False
+
+ @mock_in_debug(return_value=True)
+ async def set_icon(self, url: str) -> bool:
+ """Sets the guild's icon based on a URL."""
+ guild = bot.get_guild(Client.guild)
+ # Track old icon hash for later comparison
+ old_icon = guild.icon
+
+ image = await self._fetch_image(url)
+ with contextlib.suppress(discord.HTTPException, asyncio.TimeoutError):
+ async with async_timeout.timeout(5):
+ await guild.edit(icon=image)
+
+ new_icon = bot.get_guild(Client.guild).icon
+ if new_icon != old_icon:
+ log.debug(f"Icon changed to {url}")
+ return True
+
+ log.warning(f"Changing icon failed: {url}")
+ return False
+
+ async def _fetch_image(self, url: str) -> bytes:
+ """Retrieve an image based on a URL."""
+ log.debug(f"Getting image from: {url}")
+ async with self.http_session.get(url) as resp:
+ return await resp.read()
+
+ @mock_in_debug(return_value=True)
+ async def set_username(self, new_name: str, nick_only: bool = False) -> Optional[bool]:
+ """
+ Set the bot username and/or nickname to given new name.
+
+ Returns True/False based on success, or None if nickname fallback also failed.
+ """
+ old_username = self.user.name
+
+ if nick_only:
+ return await self.set_nickname(new_name)
+
+ if old_username == new_name:
+ # since the username is correct, make sure nickname is removed
+ return await self.set_nickname()
+
+ log.debug(f"Changing username to {new_name}")
+ with contextlib.suppress(discord.HTTPException):
+ await bot.user.edit(username=new_name, nick=None)
+
+ if not new_name == self.member.display_name:
+ # name didn't change, try to changing nickname as fallback
+ if await self.set_nickname(new_name):
+ log.warning(f"Changing username failed, changed nickname instead.")
+ return False
+ log.warning(f"Changing username and nickname failed.")
+ return None
+
+ return True
+
+ @mock_in_debug(return_value=True)
+ async def set_nickname(self, new_name: str = None) -> bool:
+ """Set the bot nickname in the main guild."""
+ old_display_name = self.member.display_name
+
+ if old_display_name == new_name:
+ return False
+
+ log.debug(f"Changing nickname to {new_name}")
+ with contextlib.suppress(discord.HTTPException):
+ await self.member.edit(nick=new_name)
+
+ return not old_display_name == self.member.display_name
+
bot = SeasonalBot(command_prefix=Client.prefix)