diff options
-rw-r--r-- | bot/__main__.py | 5 | ||||
-rw-r--r-- | bot/bot.py | 9 | ||||
-rw-r--r-- | bot/constants.py | 62 | ||||
-rw-r--r-- | bot/decorators.py | 93 | ||||
-rw-r--r-- | bot/seasons/christmas/adventofcode.py | 2 | ||||
-rw-r--r-- | bot/seasons/easter/egg_hunt/cog.py | 3 | ||||
-rw-r--r-- | bot/seasons/easter/egg_hunt/constants.py | 3 | ||||
-rw-r--r-- | bot/seasons/evergreen/error_handler.py | 15 | ||||
-rw-r--r-- | bot/seasons/evergreen/issues.py | 2 | ||||
-rw-r--r-- | bot/seasons/season.py | 3 |
10 files changed, 166 insertions, 31 deletions
diff --git a/bot/__main__.py b/bot/__main__.py index a3b68ec1..9dc0b173 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,8 +1,11 @@ import logging -from bot.constants import Client, bot +from bot.bot import bot +from bot.constants import Client, STAFF_ROLES, WHITELISTED_CHANNELS +from bot.decorators import in_channel_check log = logging.getLogger(__name__) +bot.add_check(in_channel_check(*WHITELISTED_CHANNELS, bypass_roles=STAFF_ROLES)) bot.load_extension("bot.seasons") bot.run(Client.token) @@ -7,11 +7,11 @@ from aiohttp import AsyncResolver, ClientSession, TCPConnector from discord import Embed from discord.ext import commands -from bot import constants +from bot.constants import Channels, Client log = logging.getLogger(__name__) -__all__ = ('SeasonalBot',) +__all__ = ('SeasonalBot', 'bot') class SeasonalBot(commands.Bot): @@ -42,7 +42,7 @@ class SeasonalBot(commands.Bot): async def send_log(self, title: str, details: str = None, *, icon: str = None): """Send an embed message to the devlog channel.""" - devlog = self.get_channel(constants.Channels.devlog) + devlog = self.get_channel(Channels.devlog) if not devlog: log.warning("Log failed to send. Devlog channel not found.") @@ -62,3 +62,6 @@ class SeasonalBot(commands.Bot): context.command.reset_cooldown(context) else: await super().on_command_error(context, exception) + + +bot = SeasonalBot(command_prefix=Client.prefix) diff --git a/bot/constants.py b/bot/constants.py index 8902d918..dbf35754 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -2,11 +2,10 @@ import logging from os import environ from typing import NamedTuple -from bot.bot import SeasonalBot - __all__ = ( - "AdventOfCode", "Channels", "Client", "Colours", "Emojis", "Hacktoberfest", "Roles", - "Tokens", "ERROR_REPLIES", "bot" + "AdventOfCode", "Channels", "Client", "Colours", "Emojis", "Hacktoberfest", "Roles", "Tokens", + "WHITELISTED_CHANNELS", "STAFF_ROLES", "MODERATION_ROLES", + "POSITIVE_REPLIES", "NEGATIVE_REPLIES", "ERROR_REPLIES", ) log = logging.getLogger(__name__) @@ -118,6 +117,58 @@ class Tokens(NamedTuple): youtube = environ.get("YOUTUBE_API_KEY") +# Default role combinations +MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner +STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner + +# Whitelisted channels +WHITELISTED_CHANNELS = ( + Channels.bot, Channels.seasonalbot_commands, + Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2, + Channels.devtest, +) + +# Bot replies +NEGATIVE_REPLIES = [ + "Noooooo!!", + "Nope.", + "I'm sorry Dave, I'm afraid I can't do that.", + "I don't think so.", + "Not gonna happen.", + "Out of the question.", + "Huh? No.", + "Nah.", + "Naw.", + "Not likely.", + "No way, José.", + "Not in a million years.", + "Fat chance.", + "Certainly not.", + "NEGATORY.", + "Nuh-uh.", + "Not in my house!", +] + +POSITIVE_REPLIES = [ + "Yep.", + "Absolutely!", + "Can do!", + "Affirmative!", + "Yeah okay.", + "Sure.", + "Sure thing!", + "You're the boss!", + "Okay.", + "No problem.", + "I got you.", + "Alright.", + "You got it!", + "ROGER THAT", + "Of course!", + "Aye aye, cap'n!", + "I'll allow it.", +] + ERROR_REPLIES = [ "Please don't do that.", "You have to stop.", @@ -130,6 +181,3 @@ ERROR_REPLIES = [ "Noooooo!!", "I can't believe you've done this", ] - - -bot = SeasonalBot(command_prefix=Client.prefix) diff --git a/bot/decorators.py b/bot/decorators.py index dfe80e5c..02cf4b8a 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,24 +1,33 @@ import logging import random +import typing from asyncio import Lock from functools import wraps from weakref import WeakValueDictionary from discord import Colour, Embed from discord.ext import commands -from discord.ext.commands import Context +from discord.ext.commands import CheckFailure, Context from bot.constants import ERROR_REPLIES log = logging.getLogger(__name__) +class InChannelCheckFailure(CheckFailure): + """Check failure when the user runs a command in a non-whitelisted channel.""" + + pass + + def with_role(*role_ids: int): """Check to see whether the invoking user has any of the roles specified in role_ids.""" async def predicate(ctx: Context): if not ctx.guild: # Return False in a DM - log.debug(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " - "This command is restricted by the with_role decorator. Rejecting request.") + log.debug( + f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " + "This command is restricted by the with_role decorator. Rejecting request." + ) return False for role in ctx.author.roles: @@ -26,8 +35,10 @@ def with_role(*role_ids: int): log.debug(f"{ctx.author} has the '{role.name}' role, and passes the check.") return True - log.debug(f"{ctx.author} does not have the required role to use " - f"the '{ctx.command.name}' command, so the request is rejected.") + log.debug( + f"{ctx.author} does not have the required role to use " + f"the '{ctx.command.name}' command, so the request is rejected." + ) return False return commands.check(predicate) @@ -36,26 +47,74 @@ def without_role(*role_ids: int): """Check whether the invoking user does not have all of the roles specified in role_ids.""" async def predicate(ctx: Context): if not ctx.guild: # Return False in a DM - log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " - "This command is restricted by the without_role decorator. Rejecting request.") + log.debug( + f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " + "This command is restricted by the without_role decorator. Rejecting request." + ) return False author_roles = [role.id for role in ctx.author.roles] check = all(role not in author_roles for role in role_ids) - log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The result of the without_role check was {check}.") + log.debug( + f"{ctx.author} tried to call the '{ctx.command.name}' command. " + f"The result of the without_role check was {check}." + ) return check return commands.check(predicate) -def in_channel(channel_id): - """Check that the command invocation is in the channel specified by channel_id.""" - async def predicate(ctx: Context): - check = ctx.channel.id == channel_id - log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The result of the in_channel check was {check}.") - return check - return commands.check(predicate) +def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None) -> typing.Callable[[Context], bool]: + """Checks that the message is in a whitelisted channel or optionally has a bypass role.""" + def predicate(ctx: Context) -> bool: + if not ctx.guild: + log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.") + return True + + if ctx.channel.id in channels: + log.debug( + f"{ctx.author} tried to call the '{ctx.command.name}' command " + f"and the command was used in a whitelisted channel." + ) + return True + + if hasattr(ctx.command.callback, "in_channel_override"): + log.debug( + f"{ctx.author} called the '{ctx.command.name}' command " + f"and the command was whitelisted to bypass the in_channel check." + ) + return True + + if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles): + log.debug( + f"{ctx.author} called the '{ctx.command.name}' command and " + f"had a role to bypass the in_channel check." + ) + return True + + log.debug( + f"{ctx.author} tried to call the '{ctx.command.name}' command. " + f"The in_channel check failed." + ) + + channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) + raise InChannelCheckFailure( + f"Sorry, but you may only use this command within {channels_str}." + ) + + return predicate + + +in_channel = commands.check(in_channel_check) + + +def override_in_channel(func: typing.Callable) -> typing.Callable: + """ + Set command callback attribute for detection in `in_channel_check`. + + This decorator has to go before (below) below the `command` decorator. + """ + func.in_channel_override = True + return func def locked(): diff --git a/bot/seasons/christmas/adventofcode.py b/bot/seasons/christmas/adventofcode.py index 08b07e83..a9e72805 100644 --- a/bot/seasons/christmas/adventofcode.py +++ b/bot/seasons/christmas/adventofcode.py @@ -14,6 +14,7 @@ from discord.ext import commands from pytz import timezone from bot.constants import AdventOfCode as AocConfig, Channels, Colours, Emojis, Tokens +from bot.decorators import override_in_channel log = logging.getLogger(__name__) @@ -125,6 +126,7 @@ class AdventOfCode(commands.Cog): self.status_task = asyncio.ensure_future(self.bot.loop.create_task(status_coro)) @commands.group(name="adventofcode", aliases=("aoc",), invoke_without_command=True) + @override_in_channel async def adventofcode_group(self, ctx: commands.Context): """All of the Advent of Code commands.""" await ctx.send_help(ctx.command) diff --git a/bot/seasons/easter/egg_hunt/cog.py b/bot/seasons/easter/egg_hunt/cog.py index 30fd3284..a4ad27df 100644 --- a/bot/seasons/easter/egg_hunt/cog.py +++ b/bot/seasons/easter/egg_hunt/cog.py @@ -9,7 +9,8 @@ from pathlib import Path import discord from discord.ext import commands -from bot.constants import Channels, Client, Roles as MainRoles, bot +from bot.bot import bot +from bot.constants import Channels, Client, Roles as MainRoles from bot.decorators import with_role from .constants import Colours, EggHuntSettings, Emoji, Roles diff --git a/bot/seasons/easter/egg_hunt/constants.py b/bot/seasons/easter/egg_hunt/constants.py index c7d9818b..02f6e9f2 100644 --- a/bot/seasons/easter/egg_hunt/constants.py +++ b/bot/seasons/easter/egg_hunt/constants.py @@ -2,7 +2,8 @@ import os from discord import Colour -from bot.constants import Channels, Client, bot +from bot.bot import bot +from bot.constants import Channels, Client GUILD = bot.get_guild(Client.guild) diff --git a/bot/seasons/evergreen/error_handler.py b/bot/seasons/evergreen/error_handler.py index f4457f8f..6690cf89 100644 --- a/bot/seasons/evergreen/error_handler.py +++ b/bot/seasons/evergreen/error_handler.py @@ -1,10 +1,15 @@ import logging
import math
+import random
import sys
import traceback
+from discord import Colour, Embed
from discord.ext import commands
+from bot.constants import NEGATIVE_REPLIES
+from bot.decorators import InChannelCheckFailure
+
log = logging.getLogger(__name__)
@@ -34,6 +39,16 @@ class CommandErrorHandler(commands.Cog): error = getattr(error, 'original', error)
+ if isinstance(error, InChannelCheckFailure):
+ logging.debug(
+ f"{ctx.author} the command '{ctx.command}', but they did not have "
+ f"permissions to run commands in the channel {ctx.channel}!"
+ )
+ embed = Embed(colour=Colour.red())
+ embed.title = random.choice(NEGATIVE_REPLIES)
+ embed.description = str(error)
+ return await ctx.send(embed=embed)
+
if isinstance(error, commands.CommandNotFound):
return logging.debug(
f"{ctx.author} called '{ctx.message.content}' but no command was found."
diff --git a/bot/seasons/evergreen/issues.py b/bot/seasons/evergreen/issues.py index 2a31a2e1..f19a1129 100644 --- a/bot/seasons/evergreen/issues.py +++ b/bot/seasons/evergreen/issues.py @@ -4,6 +4,7 @@ import discord from discord.ext import commands from bot.constants import Colours +from bot.decorators import override_in_channel log = logging.getLogger(__name__) @@ -15,6 +16,7 @@ class Issues(commands.Cog): self.bot = bot @commands.command(aliases=("issues",)) + @override_in_channel async def issue(self, ctx, number: int, repository: str = "seasonalbot", user: str = "python-discord"): """Command to retrieve issues from a GitHub repository.""" api_url = f"https://api.github.com/repos/{user}/{repository}/issues/{number}" diff --git a/bot/seasons/season.py b/bot/seasons/season.py index 3b623040..c88ef2a7 100644 --- a/bot/seasons/season.py +++ b/bot/seasons/season.py @@ -12,7 +12,8 @@ import async_timeout import discord from discord.ext import commands -from bot.constants import Channels, Client, Roles, bot +from bot.bot import bot +from bot.constants import Channels, Client, Roles from bot.decorators import with_role log = logging.getLogger(__name__) |