1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
|
import asyncio
import socket
import types
import warnings
from contextlib import suppress
from typing import Optional
import aiohttp
import discord
from discord.ext import commands
from botcore.async_stats import AsyncStatsClient
from botcore.site_api import APIClient
from botcore.utils import scheduling
from botcore.utils._extensions import walk_extensions
from botcore.utils.logging import get_logger
try:
from async_rediscache import RedisSession
REDIS_AVAILABLE = True
except ImportError:
RedisSession = None
REDIS_AVAILABLE = False
log = get_logger()
class StartupError(Exception):
"""Exception class for startup errors."""
def __init__(self, base: Exception):
super().__init__()
self.exception = base
class BotBase(commands.Bot):
"""A sub-class that implements many common features that Python Discord bots use."""
def __init__(
self,
*args,
guild_id: int,
allowed_roles: list,
http_session: aiohttp.ClientSession,
redis_session: Optional[RedisSession] = None,
api_client: Optional[APIClient] = None,
statsd_url: Optional[str] = None,
**kwargs,
):
"""
Initialise the base bot instance.
Args:
guild_id: The ID of the guild use for :func:`wait_until_guild_available`.
allowed_roles: A list of role IDs that the bot is allowed to mention.
http_session (aiohttp.ClientSession): The session to use for the bot.
redis_session: The `async_rediscache.RedisSession`_ to use for the bot.
api_client: The :obj:`botcore.site_api.APIClient` instance to use for the bot.
statsd_url: The URL of the statsd server to use for the bot. If not given,
a dummy statsd client will be created.
.. _async_rediscache.RedisSession: https://github.com/SebastiaanZ/async-rediscache#creating-a-redissession
"""
super().__init__(
*args,
allowed_roles=allowed_roles,
**kwargs,
)
self.guild_id = guild_id
self.http_session = http_session
self.api_client = api_client
self.statsd_url = statsd_url
if redis_session and not REDIS_AVAILABLE:
warnings.warn("redis_session kwarg passed, but async-rediscache not installed!")
elif redis_session:
self.redis_session = redis_session
self._resolver: Optional[aiohttp.AsyncResolver] = None
self._connector: Optional[aiohttp.TCPConnector] = None
self._statsd_timerhandle: Optional[asyncio.TimerHandle] = None
self._guild_available: Optional[asyncio.Event] = None
self.stats: Optional[AsyncStatsClient] = None
self.all_extensions: Optional[frozenset[str]] = None
def _connect_statsd(
self,
statsd_url: str,
loop: asyncio.AbstractEventLoop,
retry_after: int = 2,
attempt: int = 1
) -> None:
"""Callback used to retry a connection to statsd if it should fail."""
if attempt >= 8:
log.error(
"Reached 8 attempts trying to reconnect AsyncStatsClient to %s. "
"Aborting and leaving the dummy statsd client in place.",
statsd_url,
)
return
try:
self.stats = AsyncStatsClient(loop, statsd_url, 8125, prefix="bot")
except socket.gaierror:
log.warning(f"Statsd client failed to connect (Attempt(s): {attempt})")
# Use a fallback strategy for retrying, up to 8 times.
self._statsd_timerhandle = loop.call_later(
retry_after,
self._connect_statsd,
statsd_url,
retry_after * 2,
attempt + 1
)
async def load_extensions(self, module: types.ModuleType) -> None:
"""
Load all the extensions within the given module and save them to ``self.all_extensions``.
This should be ran in a task on the event loop to avoid deadlocks caused by ``wait_for`` calls.
"""
await self.wait_until_guild_available()
self.all_extensions = walk_extensions(module)
for extension in self.all_extensions:
scheduling.create_task(self.load_extension(extension))
def _add_root_aliases(self, command: commands.Command) -> None:
"""Recursively add root aliases for ``command`` and any of its subcommands."""
if isinstance(command, commands.Group):
for subcommand in command.commands:
self._add_root_aliases(subcommand)
for alias in getattr(command, "root_aliases", ()):
if alias in self.all_commands:
raise commands.CommandRegistrationError(alias, alias_conflict=True)
self.all_commands[alias] = command
def _remove_root_aliases(self, command: commands.Command) -> None:
"""Recursively remove root aliases for ``command`` and any of its subcommands."""
if isinstance(command, commands.Group):
for subcommand in command.commands:
self._remove_root_aliases(subcommand)
for alias in getattr(command, "root_aliases", ()):
self.all_commands.pop(alias, None)
async def add_cog(self, cog: commands.Cog) -> None:
"""Add the given ``cog`` to the bot and log the operation."""
await super().add_cog(cog)
log.info(f"Cog loaded: {cog.qualified_name}")
def add_command(self, command: commands.Command) -> None:
"""Add ``command`` as normal and then add its root aliases to the bot."""
super().add_command(command)
self._add_root_aliases(command)
def remove_command(self, name: str) -> Optional[commands.Command]:
"""
Remove a command/alias as normal and then remove its root aliases from the bot.
Individual root aliases cannot be removed by this function.
To remove them, either remove the entire command or manually edit `bot.all_commands`.
"""
command = super().remove_command(name)
if command is None:
# Even if it's a root alias, there's no way to get the Bot instance to remove the alias.
return None
self._remove_root_aliases(command)
return command
def clear(self) -> None:
"""Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one."""
raise NotImplementedError("Re-using a Bot object after closing it is not supported.")
async def on_guild_unavailable(self, guild: discord.Guild) -> None:
"""Clear the internal guild available event when self.guild_id becomes unavailable."""
if guild.id != self.guild_id:
return
self._guild_available.clear()
async def on_guild_available(self, guild: discord.Guild) -> None:
"""
Set the internal guild available event when self.guild_id becomes available.
If the cache appears to still be empty (no members, no channels, or no roles), the event
will not be set and `guild_available_but_cache_empty` event will be emitted.
"""
if guild.id != self.guild_id:
return
if not guild.roles or not guild.members or not guild.channels:
msg = "Guild available event was dispatched but the cache appears to still be empty!"
self.log_to_dev_log(msg)
return
self._guild_available.set()
async def log_to_dev_log(self, message: str) -> None:
"""Log the given message to #dev-log."""
...
async def wait_until_guild_available(self) -> None:
"""
Wait until the guild that matches the ``guild_id`` given at init is available (and the cache is ready).
The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE
gateway event before giving up and thus not populating the cache for unavailable guilds.
"""
await self._guild_available.wait()
async def setup_hook(self) -> None:
"""
An async init to startup generic services.
Connects to statsd, and calls
:func:`AsyncStatsClient.create_socket <botcore.async_stats.AsyncStatsClient.create_socket>`
and :func:`ping_services`.
"""
loop = asyncio.get_running_loop()
self._guild_available = asyncio.Event()
self._resolver = aiohttp.AsyncResolver()
self._connector = aiohttp.TCPConnector(
resolver=self._resolver,
family=socket.AF_INET,
)
self.http.connector = self._connector
if getattr(self, "redis_session", False) and self.redis_session.closed:
# If the RedisSession was somehow closed, we try to reconnect it
# here. Normally, this shouldn't happen.
await self.redis_session.connect()
# Create dummy stats client first, in case `statsd_url` is unreachable within `_connect_statsd()`
self.stats = AsyncStatsClient(loop, "127.0.0.1")
self._connect_statsd(self.statsd_url, loop)
await self.stats.create_socket()
try:
await self.ping_services()
except Exception as e:
raise StartupError(e)
async def ping_services() -> None:
"""Ping all required services on setup to ensure they are up before starting."""
...
async def close(self) -> None:
"""Close the Discord connection, and the aiohttp session, connector, statsd client, and resolver."""
# Done before super().close() to allow tasks finish before the HTTP session closes.
for ext in list(self.extensions):
with suppress(Exception):
await self.unload_extension(ext)
for cog in list(self.cogs):
with suppress(Exception):
await self.remove_cog(cog)
# Now actually do full close of bot
await super().close()
if self.api_client:
await self.api_client.close()
if self.http_session:
await self.http_session.close()
if self._connector:
await self._connector.close()
if self._resolver:
await self._resolver.close()
if self.stats._transport:
self.stats._transport.close()
if getattr(self, "redis_session", False):
await self.redis_session.close()
if self._statsd_timerhandle:
self._statsd_timerhandle.cancel()
|