diff options
| author | 2020-10-17 19:57:29 -0700 | |
|---|---|---|
| committer | 2020-10-17 20:15:17 -0700 | |
| commit | 9676866990523266d39fc26c4fe6bfa28a8ca9e4 (patch) | |
| tree | c056180633d17dbe16d793c5d19837388c1a03cd | |
| parent | Move logging set up to a separate module (diff) | |
Move bot creation code from __main__.py to bot.py
| -rw-r--r-- | bot/__main__.py | 55 | ||||
| -rw-r--r-- | bot/bot.py | 56 | 
2 files changed, 58 insertions, 53 deletions
| diff --git a/bot/__main__.py b/bot/__main__.py index f3204c18a..9847c1849 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,58 +1,7 @@ -import asyncio - -import discord -from async_rediscache import RedisSession -from discord.ext.commands import when_mentioned_or -  import bot  from bot import constants  from bot.bot import Bot -from bot.utils.extensions import EXTENSIONS - - -# Create the redis session instance. -redis_session = RedisSession( -    address=(constants.Redis.host, constants.Redis.port), -    password=constants.Redis.password, -    minsize=1, -    maxsize=20, -    use_fakeredis=constants.Redis.use_fakeredis, -    global_namespace="bot", -) - -# Connect redis session to ensure it's connected before we try to access Redis -# from somewhere within the bot. We create the event loop in the same way -# discord.py normally does and pass it to the bot's __init__. -loop = asyncio.get_event_loop() -loop.run_until_complete(redis_session.connect()) - - -# Instantiate the bot. -allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] -intents = discord.Intents().all() -intents.presences = False -intents.dm_typing = False -intents.dm_reactions = False -intents.invites = False -intents.webhooks = False -intents.integrations = False -bot.instance = Bot( -    redis_session=redis_session, -    loop=loop, -    command_prefix=when_mentioned_or(constants.Bot.prefix), -    activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), -    case_insensitive=True, -    max_messages=10_000, -    allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), -    intents=intents, -) - -# Load extensions. -extensions = set(EXTENSIONS)  # Create a mutable copy. -if not constants.HelpChannels.enable: -    extensions.remove("bot.exts.help_channels") - -for extension in extensions: -    bot.instance.load_extension(extension) +bot.instance = Bot.create() +bot.instance.load_extensions()  bot.instance.run(constants.Bot.token) diff --git a/bot/bot.py b/bot/bot.py index 892bb3325..36cf7d30a 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -95,6 +95,43 @@ class Bot(commands.Bot):          # Build the FilterList cache          self.loop.create_task(self.cache_filter_list_data()) +    @classmethod +    def create(cls) -> "Bot": +        """Create and return an instance of a Bot.""" +        loop = asyncio.get_event_loop() +        allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] + +        intents = discord.Intents().all() +        intents.presences = False +        intents.dm_typing = False +        intents.dm_reactions = False +        intents.invites = False +        intents.webhooks = False +        intents.integrations = False + +        return cls( +            redis_session=_create_redis_session(loop), +            loop=loop, +            command_prefix=commands.when_mentioned_or(constants.Bot.prefix), +            activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), +            case_insensitive=True, +            max_messages=10_000, +            allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), +            intents=intents, +        ) + +    def load_extensions(self) -> None: +        """Load all enabled extensions.""" +        # Must be done here to avoid a circular import. +        from bot.utils.extensions import EXTENSIONS + +        extensions = set(EXTENSIONS)  # Create a mutable copy. +        if not constants.HelpChannels.enable: +            extensions.remove("bot.exts.help_channels") + +        for extension in extensions: +            self.load_extension(extension) +      def add_cog(self, cog: commands.Cog) -> None:          """Adds a "cog" to the bot and logs the operation."""          super().add_cog(cog) @@ -243,3 +280,22 @@ class Bot(commands.Bot):          for alias in getattr(command, "root_aliases", ()):              self.all_commands.pop(alias, None) + + +def _create_redis_session(loop: asyncio.AbstractEventLoop) -> RedisSession: +    """ +    Create and connect to a redis session. + +    Ensure the connection is established before returning to prevent race conditions. +    `loop` is the event loop on which to connect. The Bot should use this same event loop. +    """ +    redis_session = RedisSession( +        address=(constants.Redis.host, constants.Redis.port), +        password=constants.Redis.password, +        minsize=1, +        maxsize=20, +        use_fakeredis=constants.Redis.use_fakeredis, +        global_namespace="bot", +    ) +    loop.run_until_complete(redis_session.connect()) +    return redis_session | 
