diff options
Diffstat (limited to 'bot')
-rw-r--r-- | bot/__main__.py | 3 | ||||
-rw-r--r-- | bot/constants.py | 12 | ||||
-rw-r--r-- | bot/exts/__init__.py | 23 | ||||
-rw-r--r-- | bot/exts/evergreen/8bitify.py | 2 | ||||
-rw-r--r-- | bot/exts/evergreen/bookmark.py | 3 | ||||
-rw-r--r-- | bot/exts/evergreen/snakes/__init__.py | 2 | ||||
-rw-r--r-- | bot/exts/evergreen/snakes/_converter.py (renamed from bot/exts/evergreen/snakes/converter.py) | 2 | ||||
-rw-r--r-- | bot/exts/evergreen/snakes/_snakes_cog.py (renamed from bot/exts/evergreen/snakes/snakes_cog.py) | 4 | ||||
-rw-r--r-- | bot/exts/evergreen/snakes/_utils.py (renamed from bot/exts/evergreen/snakes/utils.py) | 0 | ||||
-rw-r--r-- | bot/exts/evergreen/wikipedia.py | 29 | ||||
-rw-r--r-- | bot/exts/halloween/hacktober-issue-finder.py | 12 | ||||
-rw-r--r-- | bot/exts/halloween/hacktoberstats.py | 39 | ||||
-rw-r--r-- | bot/exts/halloween/spookysound.py | 48 | ||||
-rw-r--r-- | bot/exts/halloween/timeleft.py | 32 | ||||
-rw-r--r-- | bot/exts/utils/__init__.py | 0 | ||||
-rw-r--r-- | bot/exts/utils/extensions.py | 265 | ||||
-rw-r--r-- | bot/utils/checks.py | 164 | ||||
-rw-r--r-- | bot/utils/converters.py | 16 | ||||
-rw-r--r-- | bot/utils/extensions.py | 34 |
19 files changed, 569 insertions, 121 deletions
diff --git a/bot/__main__.py b/bot/__main__.py index 0ffd6143..cd2d43a9 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -5,8 +5,9 @@ from sentry_sdk.integrations.logging import LoggingIntegration from bot.bot import bot from bot.constants import Client, STAFF_ROLES, WHITELISTED_CHANNELS -from bot.exts import walk_extensions from bot.utils.decorators import in_channel_check +from bot.utils.extensions import walk_extensions + sentry_logging = LoggingIntegration( level=logging.DEBUG, diff --git a/bot/constants.py b/bot/constants.py index 7c8f72cb..7ec8ac27 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -11,7 +11,6 @@ __all__ = ( "Client", "Colours", "Emojis", - "Hacktoberfest", "Icons", "Lovefest", "Month", @@ -75,7 +74,7 @@ class Channels(NamedTuple): python_discussion = 267624335836053506 show_your_projects = int(environ.get("CHANNEL_SHOW_YOUR_PROJECTS", 303934982764625920)) show_your_projects_discussion = 360148304664723466 - hacktoberfest_2019 = 628184417646411776 + hacktoberfest_2020 = 760857070781071431 class Client(NamedTuple): @@ -84,6 +83,7 @@ class Client(NamedTuple): token = environ.get("SEASONALBOT_TOKEN") sentry_dsn = environ.get("SEASONALBOT_SENTRY_DSN") debug = environ.get("SEASONALBOT_DEBUG", "").lower() == "true" + github_bot_repo = "https://github.com/python-discord/seasonalbot" # Override seasonal locks: 1 (January) to 12 (December) month_override = int(environ["MONTH_OVERRIDE"]) if "MONTH_OVERRIDE" in environ else None @@ -122,9 +122,10 @@ class Emojis: pull_request_closed = "<:PRClosed:629695470519713818>" merge = "<:PRMerged:629695470570176522>" - -class Hacktoberfest(NamedTuple): - voice_id = 514420006474219521 + status_online = "<:status_online:470326272351010816>" + status_idle = "<:status_idle:470326266625785866>" + status_dnd = "<:status_dnd:470326272082313216>" + status_offline = "<:status_offline:470326266537705472>" class Icons: @@ -177,6 +178,7 @@ class Roles(NamedTuple): verified = 352427296948486144 helpers = 267630620367257601 rockstars = 458226413825294336 + core_developers = 587606783669829632 class Tokens(NamedTuple): diff --git a/bot/exts/__init__.py b/bot/exts/__init__.py index 25deb9af..13f484ac 100644 --- a/bot/exts/__init__.py +++ b/bot/exts/__init__.py @@ -1,9 +1,8 @@ import logging import pkgutil -from pathlib import Path from typing import Iterator -__all__ = ("get_package_names", "walk_extensions") +__all__ = ("get_package_names",) log = logging.getLogger(__name__) @@ -13,23 +12,3 @@ def get_package_names() -> Iterator[str]: for package in pkgutil.iter_modules(__path__): if package.ispkg: yield package.name - - -def walk_extensions() -> Iterator[str]: - """ - Iterate dot-separated paths to all extensions. - - The strings are formatted in a way such that the bot's `load_extension` - method can take them. Use this to load all available extensions. - - This intentionally doesn't make use of pkgutil's `walk_packages`, as we only - want to build paths to extensions - not recursively all modules. For some - extensions, the `setup` function is in the package's __init__ file, while - modules nested under the package are only helpers. Constructing the paths - ourselves serves our purpose better. - """ - base_path = Path(__path__[0]) - - for package in get_package_names(): - for extension in pkgutil.iter_modules([base_path.joinpath(package)]): - yield f"bot.exts.{package}.{extension.name}" diff --git a/bot/exts/evergreen/8bitify.py b/bot/exts/evergreen/8bitify.py index 60062fc1..c048d9bf 100644 --- a/bot/exts/evergreen/8bitify.py +++ b/bot/exts/evergreen/8bitify.py @@ -14,7 +14,7 @@ class EightBitify(commands.Cog): @staticmethod def pixelate(image: Image) -> Image: """Takes an image and pixelates it.""" - return image.resize((32, 32)).resize((1024, 1024)) + return image.resize((32, 32), resample=Image.NEAREST).resize((1024, 1024), resample=Image.NEAREST) @staticmethod def quantize(image: Image) -> Image: diff --git a/bot/exts/evergreen/bookmark.py b/bot/exts/evergreen/bookmark.py index 73908702..5fa05d2e 100644 --- a/bot/exts/evergreen/bookmark.py +++ b/bot/exts/evergreen/bookmark.py @@ -5,6 +5,7 @@ import discord from discord.ext import commands from bot.constants import Colours, ERROR_REPLIES, Emojis, Icons +from bot.utils.converters import WrappedMessageConverter log = logging.getLogger(__name__) @@ -19,7 +20,7 @@ class Bookmark(commands.Cog): async def bookmark( self, ctx: commands.Context, - target_message: discord.Message, + target_message: WrappedMessageConverter, *, title: str = "Bookmark" ) -> None: diff --git a/bot/exts/evergreen/snakes/__init__.py b/bot/exts/evergreen/snakes/__init__.py index 2eae2751..bc42f0c2 100644 --- a/bot/exts/evergreen/snakes/__init__.py +++ b/bot/exts/evergreen/snakes/__init__.py @@ -2,7 +2,7 @@ import logging from discord.ext import commands -from bot.exts.evergreen.snakes.snakes_cog import Snakes +from bot.exts.evergreen.snakes._snakes_cog import Snakes log = logging.getLogger(__name__) diff --git a/bot/exts/evergreen/snakes/converter.py b/bot/exts/evergreen/snakes/_converter.py index 55609b8e..eee248cf 100644 --- a/bot/exts/evergreen/snakes/converter.py +++ b/bot/exts/evergreen/snakes/_converter.py @@ -7,7 +7,7 @@ import discord from discord.ext.commands import Context, Converter from fuzzywuzzy import fuzz -from bot.exts.evergreen.snakes.utils import SNAKE_RESOURCES +from bot.exts.evergreen.snakes._utils import SNAKE_RESOURCES from bot.utils import disambiguate log = logging.getLogger(__name__) diff --git a/bot/exts/evergreen/snakes/snakes_cog.py b/bot/exts/evergreen/snakes/_snakes_cog.py index 9bbad9fe..a846274b 100644 --- a/bot/exts/evergreen/snakes/snakes_cog.py +++ b/bot/exts/evergreen/snakes/_snakes_cog.py @@ -18,8 +18,8 @@ from discord import Colour, Embed, File, Member, Message, Reaction from discord.ext.commands import BadArgument, Bot, Cog, CommandError, Context, bot_has_permissions, group from bot.constants import ERROR_REPLIES, Tokens -from bot.exts.evergreen.snakes import utils -from bot.exts.evergreen.snakes.converter import Snake +from bot.exts.evergreen.snakes import _utils as utils +from bot.exts.evergreen.snakes._converter import Snake from bot.utils.decorators import locked log = logging.getLogger(__name__) diff --git a/bot/exts/evergreen/snakes/utils.py b/bot/exts/evergreen/snakes/_utils.py index 7d6caf04..7d6caf04 100644 --- a/bot/exts/evergreen/snakes/utils.py +++ b/bot/exts/evergreen/snakes/_utils.py diff --git a/bot/exts/evergreen/wikipedia.py b/bot/exts/evergreen/wikipedia.py index c1fff873..be36e2c4 100644 --- a/bot/exts/evergreen/wikipedia.py +++ b/bot/exts/evergreen/wikipedia.py @@ -1,8 +1,9 @@ import asyncio import datetime import logging -from typing import List +from typing import List, Optional +from aiohttp import client_exceptions from discord import Color, Embed, Message from discord.ext import commands @@ -14,7 +15,7 @@ SEARCH_API = "https://en.wikipedia.org/w/api.php?action=query&list=search&srsear WIKIPEDIA_URL = "https://en.wikipedia.org/wiki/{title}" -class WikipediaCog(commands.Cog): +class WikipediaSearch(commands.Cog): """Get info from wikipedia.""" def __init__(self, bot: commands.Bot): @@ -26,20 +27,22 @@ class WikipediaCog(commands.Cog): """Formating wikipedia link with index and title.""" return f'`{index}` [{title}]({WIKIPEDIA_URL.format(title=title.replace(" ", "_"))})' - async def search_wikipedia(self, search_term: str) -> List[str]: + async def search_wikipedia(self, search_term: str) -> Optional[List[str]]: """Search wikipedia and return the first 10 pages found.""" - async with self.http_session.get(SEARCH_API.format(search_term=search_term)) as response: - data = await response.json() - pages = [] + async with self.http_session.get(SEARCH_API.format(search_term=search_term)) as response: + try: + data = await response.json() - search_results = data["query"]["search"] + search_results = data["query"]["search"] - # Ignore pages with "may refer to" - for search_result in search_results: - log.info("trying to append titles") - if "may refer to" not in search_result["snippet"]: - pages.append(search_result["title"]) + # Ignore pages with "may refer to" + for search_result in search_results: + log.info("trying to append titles") + if "may refer to" not in search_result["snippet"]: + pages.append(search_result["title"]) + except client_exceptions.ContentTypeError: + pages = None log.info("Finished appending titles") return pages @@ -108,4 +111,4 @@ class WikipediaCog(commands.Cog): def setup(bot: commands.Bot) -> None: """Wikipedia Cog load.""" - bot.add_cog(WikipediaCog(bot)) + bot.add_cog(WikipediaSearch(bot)) diff --git a/bot/exts/halloween/hacktober-issue-finder.py b/bot/exts/halloween/hacktober-issue-finder.py index b5ad1c4f..78acf391 100644 --- a/bot/exts/halloween/hacktober-issue-finder.py +++ b/bot/exts/halloween/hacktober-issue-finder.py @@ -7,13 +7,19 @@ import aiohttp import discord from discord.ext import commands -from bot.constants import Month +from bot.constants import Month, Tokens from bot.utils.decorators import in_month log = logging.getLogger(__name__) URL = "https://api.github.com/search/issues?per_page=100&q=is:issue+label:hacktoberfest+language:python+state:open" -HEADERS = {"Accept": "application / vnd.github.v3 + json"} + +REQUEST_HEADERS = { + "User-Agent": "Python Discord Hacktoberbot", + "Accept": "application / vnd.github.v3 + json" +} +if GITHUB_TOKEN := Tokens.github: + REQUEST_HEADERS["Authorization"] = f"token {GITHUB_TOKEN}" class HacktoberIssues(commands.Cog): @@ -66,7 +72,7 @@ class HacktoberIssues(commands.Cog): url += f"&page={page}" log.debug(f"making api request to url: {url}") - async with session.get(url, headers=HEADERS) as response: + async with session.get(url, headers=REQUEST_HEADERS) as response: if response.status != 200: log.error(f"expected 200 status (got {response.status}) from the GitHub api.") await ctx.send(f"ERROR: expected 200 status (got {response.status}) from the GitHub api.") diff --git a/bot/exts/halloween/hacktoberstats.py b/bot/exts/halloween/hacktoberstats.py index db5e37f2..ed1755e3 100644 --- a/bot/exts/halloween/hacktoberstats.py +++ b/bot/exts/halloween/hacktoberstats.py @@ -10,7 +10,7 @@ import aiohttp import discord from discord.ext import commands -from bot.constants import Channels, Month, WHITELISTED_CHANNELS +from bot.constants import Channels, Month, Tokens, WHITELISTED_CHANNELS from bot.utils.decorators import in_month, override_in_channel from bot.utils.persist import make_persistent @@ -18,7 +18,16 @@ log = logging.getLogger(__name__) CURRENT_YEAR = datetime.now().year # Used to construct GH API query PRS_FOR_SHIRT = 4 # Minimum number of PRs before a shirt is awarded -HACKTOBER_WHITELIST = WHITELISTED_CHANNELS + (Channels.hacktoberfest_2019,) +HACKTOBER_WHITELIST = WHITELISTED_CHANNELS + (Channels.hacktoberfest_2020,) + +REQUEST_HEADERS = {"User-Agent": "Python Discord Hacktoberbot"} +if GITHUB_TOKEN := Tokens.github: + REQUEST_HEADERS["Authorization"] = f"token {GITHUB_TOKEN}" + +GITHUB_NONEXISTENT_USER_MESSAGE = ( + "The listed users cannot be searched either because the users do not exist " + "or you do not have permission to view the users." +) class HacktoberStats(commands.Cog): @@ -29,7 +38,7 @@ class HacktoberStats(commands.Cog): self.link_json = make_persistent(Path("bot", "resources", "halloween", "github_links.json")) self.linked_accounts = self.load_linked_users() - @in_month(Month.OCTOBER) + @in_month(Month.SEPTEMBER, Month.OCTOBER, Month.NOVEMBER) @commands.group(name="hacktoberstats", aliases=("hackstats",), invoke_without_command=True) @override_in_channel(HACKTOBER_WHITELIST) async def hacktoberstats_group(self, ctx: commands.Context, github_username: str = None) -> None: @@ -57,7 +66,7 @@ class HacktoberStats(commands.Cog): await self.get_stats(ctx, github_username) - @in_month(Month.OCTOBER) + @in_month(Month.SEPTEMBER, Month.OCTOBER, Month.NOVEMBER) @hacktoberstats_group.command(name="link") @override_in_channel(HACKTOBER_WHITELIST) async def link_user(self, ctx: commands.Context, github_username: str = None) -> None: @@ -92,7 +101,7 @@ class HacktoberStats(commands.Cog): logging.info(f"{author_id} tried to link a GitHub account but didn't provide a username") await ctx.send(f"{author_mention}, a GitHub username is required to link your account") - @in_month(Month.OCTOBER) + @in_month(Month.SEPTEMBER, Month.OCTOBER, Month.NOVEMBER) @hacktoberstats_group.command(name="unlink") @override_in_channel(HACKTOBER_WHITELIST) async def unlink_user(self, ctx: commands.Context) -> None: @@ -175,11 +184,11 @@ class HacktoberStats(commands.Cog): n = pr_stats['n_prs'] if n >= PRS_FOR_SHIRT: - shirtstr = f"**{github_username} has earned a tshirt!**" + shirtstr = f"**{github_username} has earned a T-shirt or a tree!**" elif n == PRS_FOR_SHIRT - 1: - shirtstr = f"**{github_username} is 1 PR away from a tshirt!**" + shirtstr = f"**{github_username} is 1 PR away from a T-shirt or a tree!**" else: - shirtstr = f"**{github_username} is {PRS_FOR_SHIRT - n} PRs away from a tshirt!**" + shirtstr = f"**{github_username} is {PRS_FOR_SHIRT - n} PRs away from a T-shirt or a tree!**" stats_embed = discord.Embed( title=f"{github_username}'s Hacktoberfest", @@ -196,7 +205,7 @@ class HacktoberStats(commands.Cog): stats_embed.set_author( name="Hacktoberfest", url="https://hacktoberfest.digitalocean.com", - icon_url="https://hacktoberfest.digitalocean.com/pretty_logo.png" + icon_url="https://avatars1.githubusercontent.com/u/35706162?s=200&v=4" ) stats_embed.add_field( name="Top 5 Repositories:", @@ -242,16 +251,22 @@ class HacktoberStats(commands.Cog): f"&per_page={per_page}" ) - headers = {"user-agent": "Discord Python Hacktoberbot"} async with aiohttp.ClientSession() as session: - async with session.get(query_url, headers=headers) as resp: + async with session.get(query_url, headers=REQUEST_HEADERS) as resp: jsonresp = await resp.json() if "message" in jsonresp.keys(): # One of the parameters is invalid, short circuit for now api_message = jsonresp["errors"][0]["message"] - logging.error(f"GitHub API request for '{github_username}' failed with message: {api_message}") + + # Ignore logging non-existent users or users we do not have permission to see + if api_message == GITHUB_NONEXISTENT_USER_MESSAGE: + logging.debug(f"No GitHub user found named '{github_username}'") + else: + logging.error(f"GitHub API request for '{github_username}' failed with message: {api_message}") + return + else: if jsonresp["total_count"] == 0: # Short circuit if there aren't any PRs diff --git a/bot/exts/halloween/spookysound.py b/bot/exts/halloween/spookysound.py deleted file mode 100644 index 569a9153..00000000 --- a/bot/exts/halloween/spookysound.py +++ /dev/null @@ -1,48 +0,0 @@ -import logging -import random -from pathlib import Path - -import discord -from discord.ext import commands - -from bot.bot import SeasonalBot -from bot.constants import Hacktoberfest - -log = logging.getLogger(__name__) - - -class SpookySound(commands.Cog): - """A cog that plays a spooky sound in a voice channel on command.""" - - def __init__(self, bot: SeasonalBot): - self.bot = bot - self.sound_files = list(Path("bot/resources/halloween/spookysounds").glob("*.mp3")) - self.channel = None - - @commands.cooldown(rate=1, per=1) - @commands.command(brief="Play a spooky sound, restricted to once per 2 mins") - async def spookysound(self, ctx: commands.Context) -> None: - """ - Connect to the Hacktoberbot voice channel, play a random spooky sound, then disconnect. - - Cannot be used more than once in 2 minutes. - """ - if not self.channel: - await self.bot.wait_until_guild_available() - self.channel = self.bot.get_channel(Hacktoberfest.voice_id) - - await ctx.send("Initiating spooky sound...") - file_path = random.choice(self.sound_files) - src = discord.FFmpegPCMAudio(str(file_path.resolve())) - voice = await self.channel.connect() - voice.play(src, after=lambda e: self.bot.loop.create_task(self.disconnect(voice))) - - @staticmethod - async def disconnect(voice: discord.VoiceClient) -> None: - """Helper method to disconnect a given voice client.""" - await voice.disconnect() - - -def setup(bot: SeasonalBot) -> None: - """Spooky sound Cog load.""" - bot.add_cog(SpookySound(bot)) diff --git a/bot/exts/halloween/timeleft.py b/bot/exts/halloween/timeleft.py index 295acc89..47adb09b 100644 --- a/bot/exts/halloween/timeleft.py +++ b/bot/exts/halloween/timeleft.py @@ -13,20 +13,23 @@ class TimeLeft(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot - @staticmethod - def in_october() -> bool: - """Return True if the current month is October.""" - return datetime.utcnow().month == 10 + def in_hacktober(self) -> bool: + """Return True if the current time is within Hacktoberfest.""" + _, end, start = self.load_date() + + now = datetime.utcnow() + + return start <= now <= end @staticmethod - def load_date() -> Tuple[int, datetime, datetime]: + def load_date() -> Tuple[datetime, datetime, datetime]: """Return of a tuple of the current time and the end and start times of the next October.""" now = datetime.utcnow() year = now.year if now.month > 10: year += 1 - end = datetime(year, 11, 1, 11, 59, 59) - start = datetime(year, 10, 1) + end = datetime(year, 11, 1, 12) # November 1st 12:00 (UTC-12:00) + start = datetime(year, 9, 30, 10) # September 30th 10:00 (UTC+14:00) return now, end, start @commands.command() @@ -35,16 +38,23 @@ class TimeLeft(commands.Cog): Calculates the time left until the end of Hacktober. Whilst in October, displays the days, hours and minutes left. - Only displays the days left until the beginning and end whilst in a different month + Only displays the days left until the beginning and end whilst in a different month. + + This factors in that Hacktoberfest starts when it is October anywhere in the world + and ends with the same rules. It treats the start as UTC+14:00 and the end as + UTC-12. """ now, end, start = self.load_date() diff = end - now days, seconds = diff.days, diff.seconds - if self.in_october(): + if self.in_hacktober(): minutes = seconds // 60 hours, minutes = divmod(minutes, 60) - await ctx.send(f"There is currently only {days} days, {hours} hours and {minutes}" - "minutes left until the end of Hacktober.") + + await ctx.send( + f"There are {days} days, {hours} hours and {minutes}" + f" minutes left until the end of Hacktober." + ) else: start_diff = start - now start_days = start_diff.days diff --git a/bot/exts/utils/__init__.py b/bot/exts/utils/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/bot/exts/utils/__init__.py diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py new file mode 100644 index 00000000..102a0416 --- /dev/null +++ b/bot/exts/utils/extensions.py @@ -0,0 +1,265 @@ +import functools +import logging +import typing as t +from enum import Enum + +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Context, group + +from bot import exts +from bot.bot import SeasonalBot as Bot +from bot.constants import Client, Emojis, MODERATION_ROLES, Roles +from bot.utils.checks import with_role_check +from bot.utils.extensions import EXTENSIONS, unqualify +from bot.utils.pagination import LinePaginator + +log = logging.getLogger(__name__) + + +UNLOAD_BLACKLIST = {f"{exts.__name__}.utils.extensions"} +BASE_PATH_LEN = len(exts.__name__.split(".")) + + +class Action(Enum): + """Represents an action to perform on an extension.""" + + # Need to be partial otherwise they are considered to be function definitions. + LOAD = functools.partial(Bot.load_extension) + UNLOAD = functools.partial(Bot.unload_extension) + RELOAD = functools.partial(Bot.reload_extension) + + +class Extension(commands.Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if argument in EXTENSIONS: + return argument + elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + return qualified_arg + + matches = [] + for ext in EXTENSIONS: + if argument == unqualify(ext): + matches.append(ext) + + if len(matches) > 1: + matches.sort() + names = "\n".join(matches) + raise commands.BadArgument( + f":x: `{argument}` is an ambiguous extension name. " + f"Please use one of the following fully-qualified names.```\n{names}```" + ) + elif matches: + return matches[0] + else: + raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") + + +class Extensions(commands.Cog): + """Extension management commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) + async def extensions_group(self, ctx: Context) -> None: + """Load, unload, reload, and list loaded extensions.""" + await ctx.send_help(ctx.command) + + @extensions_group.command(name="load", aliases=("l",)) + async def load_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Load extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "*" in extensions or "**" in extensions: + extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + + msg = self.batch_manage(Action.LOAD, *extensions) + await ctx.send(msg) + + @extensions_group.command(name="unload", aliases=("ul",)) + async def unload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Unload currently loaded extensions given their fully qualified or unqualified names. + + If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + + if blacklisted: + msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" + else: + if "*" in extensions or "**" in extensions: + extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + + msg = self.batch_manage(Action.UNLOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="reload", aliases=("r",), root_aliases=("reload",)) + async def reload_command(self, ctx: Context, *extensions: Extension) -> None: + r""" + Reload extensions given their fully qualified or unqualified names. + + If an extension fails to be reloaded, it will be rolled-back to the prior working state. + + If '\*' is given as the name, all currently loaded extensions will be reloaded. + If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. + """ # noqa: W605 + if not extensions: + await ctx.send_help(ctx.command) + return + + if "**" in extensions: + extensions = EXTENSIONS + elif "*" in extensions: + extensions = set(self.bot.extensions.keys()) | set(extensions) + extensions.remove("*") + + msg = self.batch_manage(Action.RELOAD, *extensions) + + await ctx.send(msg) + + @extensions_group.command(name="list", aliases=("all",)) + async def list_command(self, ctx: Context) -> None: + """ + Get a list of all extensions, including their loaded status. + + Grey indicates that the extension is unloaded. + Green indicates that the extension is currently loaded. + """ + embed = Embed(colour=Colour.blurple()) + embed.set_author( + name="Extensions List", + url=Client.github_bot_repo, + icon_url=str(self.bot.user.avatar_url) + ) + + lines = [] + categories = self.group_extension_statuses() + for category, extensions in sorted(categories.items()): + # Treat each category as a single line by concatenating everything. + # This ensures the paginator will not cut off a page in the middle of a category. + category = category.replace("_", " ").title() + extensions = "\n".join(sorted(extensions)) + lines.append(f"**{category}**\n{extensions}\n") + + log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") + await LinePaginator.paginate(lines, ctx, embed, max_size=1200, empty=False) + + def group_extension_statuses(self) -> t.Mapping[str, str]: + """Return a mapping of extension names and statuses to their categories.""" + categories = {} + + for ext in EXTENSIONS: + if ext in self.bot.extensions: + status = Emojis.status_online + else: + status = Emojis.status_offline + + path = ext.split(".") + if len(path) > BASE_PATH_LEN + 1: + category = " - ".join(path[BASE_PATH_LEN:-1]) + else: + category = "uncategorised" + + categories.setdefault(category, []).append(f"{status} {path[-1]}") + + return categories + + def batch_manage(self, action: Action, *extensions: str) -> str: + """ + Apply an action to multiple extensions and return a message with the results. + + If only one extension is given, it is deferred to `manage()`. + """ + if len(extensions) == 1: + msg, _ = self.manage(action, extensions[0]) + return msg + + verb = action.name.lower() + failures = {} + + for extension in extensions: + _, error = self.manage(action, extension) + if error: + failures[extension] = error + + emoji = ":x:" if failures else ":ok_hand:" + msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." + + if failures: + failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) + msg += f"\nFailures:```{failures}```" + + log.debug(f"Batch {verb}ed extensions.") + + return msg + + def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: + """Apply an action to an extension and return the status message and any error message.""" + verb = action.name.lower() + error_msg = None + + try: + action.value(self.bot, ext) + except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): + if action is Action.RELOAD: + # When reloading, just load the extension if it was not loaded. + return self.manage(Action.LOAD, ext) + + msg = f":x: Extension `{ext}` is already {verb}ed." + log.debug(msg[4:]) + except Exception as e: + if hasattr(e, "original"): + e = e.original + + log.exception(f"Extension '{ext}' failed to {verb}.") + + error_msg = f"{e.__class__.__name__}: {e}" + msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" + else: + msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." + log.debug(msg[10:]) + + return msg, error_msg + + # This cannot be static (must have a __func__ attribute). + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators and core developers to invoke the commands in this cog.""" + return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) + + # This cannot be static (must have a __func__ attribute). + async def cog_command_error(self, ctx: Context, error: Exception) -> None: + """Handle BadArgument errors locally to prevent the help command from showing.""" + if isinstance(error, commands.BadArgument): + await ctx.send(str(error)) + error.handled = True + + +def setup(bot: Bot) -> None: + """Load the Extensions cog.""" + bot.add_cog(Extensions(bot)) diff --git a/bot/utils/checks.py b/bot/utils/checks.py new file mode 100644 index 00000000..3031a271 --- /dev/null +++ b/bot/utils/checks.py @@ -0,0 +1,164 @@ +import datetime +import logging +from typing import Callable, Container, Iterable, Optional + +from discord.ext.commands import ( + BucketType, + CheckFailure, + Cog, + Command, + CommandOnCooldown, + Context, + Cooldown, + CooldownMapping, +) + +from bot import constants + +log = logging.getLogger(__name__) + + +class InWhitelistCheckFailure(CheckFailure): + """Raised when the `in_whitelist` check fails.""" + + def __init__(self, redirect_channel: Optional[int]) -> None: + self.redirect_channel = redirect_channel + + if redirect_channel: + redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" + else: + redirect_message = "" + + error_message = f"You are not allowed to use that command{redirect_message}." + + super().__init__(error_message) + + +def in_whitelist_check( + ctx: Context, + channels: Container[int] = (), + categories: Container[int] = (), + roles: Container[int] = (), + redirect: Optional[int] = constants.Channels.seasonalbot_commands, + fail_silently: bool = False, +) -> bool: + """ + Check if a command was issued in a whitelisted context. + + The whitelists that can be provided are: + + - `channels`: a container with channel ids for whitelisted channels + - `categories`: a container with category ids for whitelisted categories + - `roles`: a container with with role ids for whitelisted roles + + If the command was invoked in a context that was not whitelisted, the member is either + redirected to the `redirect` channel that was passed (default: #bot-commands) or simply + told that they're not allowed to use this particular command (if `None` was passed). + """ + if redirect and redirect not in channels: + # It does not make sense for the channel whitelist to not contain the redirection + # channel (if applicable). That's why we add the redirection channel to the `channels` + # container if it's not already in it. As we allow any container type to be passed, + # we first create a tuple in order to safely add the redirection channel. + # + # Note: It's possible for the redirect channel to be in a whitelisted category, but + # there's no easy way to check that and as a channel can easily be moved in and out of + # categories, it's probably not wise to rely on its category in any case. + channels = tuple(channels) + (redirect,) + + if channels and ctx.channel.id in channels: + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") + return True + + # Only check the category id if we have a category whitelist and the channel has a `category_id` + if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") + return True + + # Only check the roles whitelist if we have one and ensure the author's roles attribute returns + # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). + if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") + return True + + log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") + + # Some commands are secret, and should produce no feedback at all. + if not fail_silently: + raise InWhitelistCheckFailure(redirect) + return False + + +def with_role_check(ctx: Context, *role_ids: int) -> bool: + """Returns True if the user has any one of the roles in role_ids.""" + if not ctx.guild: # Return False in a DM + log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " + "This command is restricted by the with_role decorator. Rejecting request.") + return False + + for role in ctx.author.roles: + if role.id in role_ids: + log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.") + return True + + log.trace(f"{ctx.author} does not have the required role to use " + f"the '{ctx.command.name}' command, so the request is rejected.") + return False + + +def without_role_check(ctx: Context, *role_ids: int) -> bool: + """Returns True if the user does not have any of the roles in role_ids.""" + if not ctx.guild: # Return False in a DM + log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " + "This command is restricted by the without_role decorator. Rejecting request.") + return False + + author_roles = [role.id for role in ctx.author.roles] + check = all(role not in author_roles for role in role_ids) + log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. " + f"The result of the without_role check was {check}.") + return check + + +def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *, + bypass_roles: Iterable[int]) -> Callable: + """ + Applies a cooldown to a command, but allows members with certain roles to be ignored. + + NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future. + """ + # Make it a set so lookup is hash based. + bypass = set(bypass_roles) + + # This handles the actual cooldown logic. + buckets = CooldownMapping(Cooldown(rate, per, type)) + + # Will be called after the command has been parse but before it has been invoked, ensures that + # the cooldown won't be updated if the user screws up their input to the command. + async def predicate(cog: Cog, ctx: Context) -> None: + nonlocal bypass, buckets + + if any(role.id in bypass for role in ctx.author.roles): + return + + # Cooldown logic, taken from discord.py internals. + current = ctx.message.created_at.replace(tzinfo=datetime.timezone.utc).timestamp() + bucket = buckets.get_bucket(ctx.message) + retry_after = bucket.update_rate_limit(current) + if retry_after: + raise CommandOnCooldown(bucket, retry_after) + + def wrapper(command: Command) -> Command: + # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it + # so I just made it raise an error when the decorator is applied before the actual command object exists. + # + # If the `before_invoke` detail is ever a problem then I can quickly just swap over. + if not isinstance(command, Command): + raise TypeError('Decorator `cooldown_with_role_bypass` must be applied after the command decorator. ' + 'This means it has to be above the command decorator in the code.') + + command._before_invoke = predicate + + return command + + return wrapper diff --git a/bot/utils/converters.py b/bot/utils/converters.py new file mode 100644 index 00000000..228714c9 --- /dev/null +++ b/bot/utils/converters.py @@ -0,0 +1,16 @@ +import discord +from discord.ext.commands.converter import MessageConverter + + +class WrappedMessageConverter(MessageConverter): + """A converter that handles embed-suppressed links like <http://example.com>.""" + + async def convert(self, ctx: discord.ext.commands.Context, argument: str) -> discord.Message: + """Wrap the commands.MessageConverter to handle <> delimited message links.""" + # It's possible to wrap a message in [<>] as well, and it's supported because its easy + if argument.startswith("[") and argument.endswith("]"): + argument = argument[1:-1] + if argument.startswith("<") and argument.endswith(">"): + argument = argument[1:-1] + + return await super().convert(ctx, argument) diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py new file mode 100644 index 00000000..50350ea8 --- /dev/null +++ b/bot/utils/extensions.py @@ -0,0 +1,34 @@ +import importlib +import inspect +import pkgutil +from typing import Iterator, NoReturn + +from bot import exts + + +def unqualify(name: str) -> str: + """Return an unqualified name given a qualified module/package `name`.""" + return name.rsplit(".", maxsplit=1)[-1] + + +def walk_extensions() -> Iterator[str]: + """Yield extension names from the bot.exts subpackage.""" + + def on_error(name: str) -> NoReturn: + raise ImportError(name=name) # pragma: no cover + + for module in pkgutil.walk_packages(exts.__path__, f"{exts.__name__}.", onerror=on_error): + if unqualify(module.name).startswith("_"): + # Ignore module/package names starting with an underscore. + continue + + if module.ispkg: + imported = importlib.import_module(module.name) + if not inspect.isfunction(getattr(imported, "setup", None)): + # If it lacks a setup function, it's not an extension. + continue + + yield module.name + + +EXTENSIONS = frozenset(walk_extensions()) |