aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils
diff options
context:
space:
mode:
Diffstat (limited to 'bot/utils')
-rw-r--r--bot/utils/__init__.py43
-rw-r--r--bot/utils/checks.py30
-rw-r--r--bot/utils/converters.py83
-rw-r--r--bot/utils/decorators.py2
-rw-r--r--bot/utils/extensions.py4
-rw-r--r--bot/utils/halloween/spookifications.py10
-rw-r--r--bot/utils/pagination.py20
7 files changed, 141 insertions, 51 deletions
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index 35ef0a7b..bef12d25 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -3,7 +3,7 @@ import contextlib
import re
import string
from datetime import datetime
-from typing import Iterable, List
+from typing import Iterable, List, Optional
import discord
from discord.ext.commands import BadArgument, Context
@@ -31,8 +31,13 @@ def resolve_current_month() -> Month:
async def disambiguate(
- ctx: Context, entries: List[str], *, timeout: float = 30,
- entries_per_page: int = 20, empty: bool = False, embed: discord.Embed = None
+ ctx: Context,
+ entries: List[str],
+ *,
+ timeout: float = 30,
+ entries_per_page: int = 20,
+ empty: bool = False,
+ embed: Optional[discord.Embed] = None
) -> str:
"""
Has the user choose between multiple entries in case one could not be chosen automatically.
@@ -43,25 +48,29 @@ async def disambiguate(
or if the user makes an invalid choice.
"""
if len(entries) == 0:
- raise BadArgument('No matches found.')
+ raise BadArgument("No matches found.")
if len(entries) == 1:
return entries[0]
- choices = (f'{index}: {entry}' for index, entry in enumerate(entries, start=1))
+ choices = (f"{index}: {entry}" for index, entry in enumerate(entries, start=1))
def check(message: discord.Message) -> bool:
- return (message.content.isdigit()
- and message.author == ctx.author
- and message.channel == ctx.channel)
+ return (
+ message.content.isdecimal()
+ and message.author == ctx.author
+ and message.channel == ctx.channel
+ )
try:
if embed is None:
embed = discord.Embed()
- coro1 = ctx.bot.wait_for('message', check=check, timeout=timeout)
- coro2 = LinePaginator.paginate(choices, ctx, embed=embed, max_lines=entries_per_page,
- empty=empty, max_size=6000, timeout=9000)
+ coro1 = ctx.bot.wait_for("message", check=check, timeout=timeout)
+ coro2 = LinePaginator.paginate(
+ choices, ctx, embed=embed, max_lines=entries_per_page,
+ empty=empty, max_size=6000, timeout=9000
+ )
# wait_for timeout will go to except instead of the wait_for thing as I expected
futures = [asyncio.ensure_future(coro1), asyncio.ensure_future(coro2)]
@@ -74,7 +83,7 @@ async def disambiguate(
if result is None:
for coro in pending:
coro.cancel()
- raise BadArgument('Canceled.')
+ raise BadArgument("Canceled.")
# Pagination was not initiated, only one page
if result.author == ctx.bot.user:
@@ -85,19 +94,19 @@ async def disambiguate(
for coro in pending:
coro.cancel()
except asyncio.TimeoutError:
- raise BadArgument('Timed out.')
+ raise BadArgument("Timed out.")
- # Guaranteed to not error because of isdigit() in check
+ # Guaranteed to not error because of isdecimal() in check
index = int(result.content)
try:
return entries[index - 1]
except IndexError:
- raise BadArgument('Invalid choice.')
+ raise BadArgument("Invalid choice.")
def replace_many(
- sentence: str, replacements: dict, *, ignore_case: bool = False, match_case: bool = False
+ sentence: str, replacements: dict, *, ignore_case: bool = False, match_case: bool = False
) -> str:
"""
Replaces multiple substrings in a string given a mapping of strings.
@@ -139,7 +148,7 @@ def replace_many(
return replacement
# Clean punctuation from word so string methods work
- cleaned_word = word.translate(str.maketrans('', '', string.punctuation))
+ cleaned_word = word.translate(str.maketrans("", "", string.punctuation))
if cleaned_word.isupper():
return replacement.upper()
elif cleaned_word[0].isupper():
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
index 9dd4dde0..c06b6870 100644
--- a/bot/utils/checks.py
+++ b/bot/utils/checks.py
@@ -92,8 +92,10 @@ def in_whitelist_check(
def with_role_check(ctx: Context, *role_ids: int) -> bool:
"""Returns True if the user has any one of the roles in role_ids."""
if not ctx.guild: # Return False in a DM
- log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
- "This command is restricted by the with_role decorator. Rejecting request.")
+ log.trace(
+ f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
+ "This command is restricted by the with_role decorator. Rejecting request."
+ )
return False
for role in ctx.author.roles:
@@ -101,22 +103,28 @@ def with_role_check(ctx: Context, *role_ids: int) -> bool:
log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.")
return True
- log.trace(f"{ctx.author} does not have the required role to use "
- f"the '{ctx.command.name}' command, so the request is rejected.")
+ log.trace(
+ f"{ctx.author} does not have the required role to use "
+ f"the '{ctx.command.name}' command, so the request is rejected."
+ )
return False
def without_role_check(ctx: Context, *role_ids: int) -> bool:
"""Returns True if the user does not have any of the roles in role_ids."""
if not ctx.guild: # Return False in a DM
- log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
- "This command is restricted by the without_role decorator. Rejecting request.")
+ log.trace(
+ f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
+ "This command is restricted by the without_role decorator. Rejecting request."
+ )
return False
author_roles = [role.id for role in ctx.author.roles]
check = all(role not in author_roles for role in role_ids)
- log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The result of the without_role check was {check}.")
+ log.trace(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The result of the without_role check was {check}."
+ )
return check
@@ -154,8 +162,10 @@ def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketTy
#
# If the `before_invoke` detail is ever a problem then I can quickly just swap over.
if not isinstance(command, Command):
- raise TypeError('Decorator `cooldown_with_role_bypass` must be applied after the command decorator. '
- 'This means it has to be above the command decorator in the code.')
+ raise TypeError(
+ "Decorator `cooldown_with_role_bypass` must be applied after the command decorator. "
+ "This means it has to be above the command decorator in the code."
+ )
command._before_invoke = predicate
diff --git a/bot/utils/converters.py b/bot/utils/converters.py
index 27804170..fe2c980c 100644
--- a/bot/utils/converters.py
+++ b/bot/utils/converters.py
@@ -1,12 +1,14 @@
+from datetime import datetime
+from typing import Tuple, Union
+
import discord
-from discord.ext.commands import BadArgument, Context
-from discord.ext.commands.converter import Converter, MessageConverter
+from discord.ext import commands
-class WrappedMessageConverter(MessageConverter):
+class WrappedMessageConverter(commands.MessageConverter):
"""A converter that handles embed-suppressed links like <http://example.com>."""
- async def convert(self, ctx: discord.ext.commands.Context, argument: str) -> discord.Message:
+ async def convert(self, ctx: commands.Context, argument: str) -> discord.Message:
"""Wrap the commands.MessageConverter to handle <> delimited message links."""
# It's possible to wrap a message in [<>] as well, and it's supported because its easy
if argument.startswith("[") and argument.endswith("]"):
@@ -17,11 +19,78 @@ class WrappedMessageConverter(MessageConverter):
return await super().convert(ctx, argument)
-class Subreddit(Converter):
+class CoordinateConverter(commands.Converter):
+ """Converter for Coordinates."""
+
+ @staticmethod
+ async def convert(ctx: commands.Context, coordinate: str) -> Tuple[int, int]:
+ """Take in a coordinate string and turn it into an (x, y) tuple."""
+ if len(coordinate) not in (2, 3):
+ raise commands.BadArgument("Invalid co-ordinate provided.")
+
+ coordinate = coordinate.lower()
+ if coordinate[0].isalpha():
+ digit = coordinate[1:]
+ letter = coordinate[0]
+ else:
+ digit = coordinate[:-1]
+ letter = coordinate[-1]
+
+ if not digit.isdecimal():
+ raise commands.BadArgument
+
+ x = ord(letter) - ord("a")
+ y = int(digit) - 1
+
+ if (not 0 <= x <= 9) or (not 0 <= y <= 9):
+ raise commands.BadArgument
+ return x, y
+
+
+SourceType = Union[commands.Command, commands.Cog]
+
+
+class SourceConverter(commands.Converter):
+ """Convert an argument into a command or cog."""
+
+ @staticmethod
+ async def convert(ctx: commands.Context, argument: str) -> SourceType:
+ """Convert argument into source object."""
+ cog = ctx.bot.get_cog(argument)
+ if cog:
+ return cog
+
+ cmd = ctx.bot.get_command(argument)
+ if cmd:
+ return cmd
+
+ raise commands.BadArgument(
+ f"Unable to convert `{argument}` to valid command or Cog."
+ )
+
+
+class DateConverter(commands.Converter):
+ """Parse SOL or earth date (in format YYYY-MM-DD) into `int` or `datetime`. When invalid input, raise error."""
+
+ @staticmethod
+ async def convert(ctx: commands.Context, argument: str) -> Union[int, datetime]:
+ """Parse date (SOL or earth) into `datetime` or `int`. When invalid value, raise error."""
+ if argument.isdecimal():
+ return int(argument)
+ try:
+ date = datetime.strptime(argument, "%Y-%m-%d")
+ except ValueError:
+ raise commands.BadArgument(
+ f"Can't convert `{argument}` to `datetime` in format `YYYY-MM-DD` or `int` in SOL."
+ )
+ return date
+
+
+class Subreddit(commands.Converter):
"""Forces a string to begin with "r/" and checks if it's a valid subreddit."""
@staticmethod
- async def convert(ctx: Context, sub: str) -> str:
+ async def convert(ctx: commands.Context, sub: str) -> str:
"""
Force sub to begin with "r/" and check if it's a valid subreddit.
@@ -39,7 +108,7 @@ class Subreddit(Converter):
json = await resp.json()
if not json["data"]["children"]:
- raise BadArgument(
+ raise commands.BadArgument(
f"The subreddit `{sub}` either doesn't exist, or it has no posts."
)
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index 60066dc4..c0783144 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -269,7 +269,7 @@ def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context],
channels.update(channel.id for channel in category.text_channels)
if channels:
- channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
+ channels_str = ", ".join(f"<#{c_id}>" for c_id in channels)
message = f"Sorry, but you may only use this command within {channels_str}."
else:
message = "Sorry, but you may not use this command."
diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py
index 459588a1..cd491c4b 100644
--- a/bot/utils/extensions.py
+++ b/bot/utils/extensions.py
@@ -35,8 +35,8 @@ def walk_extensions() -> Iterator[str]:
async def invoke_help_command(ctx: Context) -> None:
"""Invoke the help command or default help command if help extensions is not loaded."""
- if 'bot.exts.evergreen.help' in ctx.bot.extensions:
- help_command = ctx.bot.get_command('help')
+ if "bot.exts.evergreen.help" in ctx.bot.extensions:
+ help_command = ctx.bot.get_command("help")
await ctx.invoke(help_command, ctx.command.qualified_name)
return
await ctx.send_help(ctx.command)
diff --git a/bot/utils/halloween/spookifications.py b/bot/utils/halloween/spookifications.py
index 11f69850..f69dd6fd 100644
--- a/bot/utils/halloween/spookifications.py
+++ b/bot/utils/halloween/spookifications.py
@@ -13,16 +13,16 @@ def inversion(im: Image) -> Image:
Returns an inverted image when supplied with an Image object.
"""
- im = im.convert('RGB')
+ im = im.convert("RGB")
inv = ImageOps.invert(im)
return inv
def pentagram(im: Image) -> Image:
"""Adds pentagram to the image."""
- im = im.convert('RGB')
+ im = im.convert("RGB")
wt, ht = im.size
- penta = Image.open('bot/resources/halloween/bloody-pentagram.png')
+ penta = Image.open("bot/resources/halloween/bloody-pentagram.png")
penta = penta.resize((wt, ht))
im.paste(penta, (0, 0), penta)
return im
@@ -35,9 +35,9 @@ def bat(im: Image) -> Image:
The bat silhoutte is of a size at least one-fifths that of the original image and may be rotated
up to 90 degrees anti-clockwise.
"""
- im = im.convert('RGB')
+ im = im.convert("RGB")
wt, ht = im.size
- bat = Image.open('bot/resources/halloween/bat-clipart.png')
+ bat = Image.open("bot/resources/halloween/bat-clipart.png")
bat_size = randint(wt//10, wt//7)
rot = randint(0, 90)
bat = bat.resize((bat_size, bat_size))
diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py
index 917275c0..742281d7 100644
--- a/bot/utils/pagination.py
+++ b/bot/utils/pagination.py
@@ -27,7 +27,7 @@ class EmptyPaginatorEmbed(Exception):
class LinePaginator(Paginator):
"""A class that aids in paginating code blocks for Discord messages."""
- def __init__(self, prefix: str = '```', suffix: str = '```', max_size: int = 2000, max_lines: int = None):
+ def __init__(self, prefix: str = "```", suffix: str = "```", max_size: int = 2000, max_lines: int = None):
"""
Overrides the Paginator.__init__ from inside discord.ext.commands.
@@ -45,7 +45,7 @@ class LinePaginator(Paginator):
self._count = len(prefix) + 1 # prefix + newline
self._pages = []
- def add_line(self, line: str = '', *, empty: bool = False) -> None:
+ def add_line(self, line: str = "", *, empty: bool = False) -> None:
"""
Adds a line to the current page.
@@ -57,7 +57,7 @@ class LinePaginator(Paginator):
If `empty` is True, an empty line will be placed after the a given `line`.
"""
if len(line) > self.max_size - len(self.prefix) - 2:
- raise RuntimeError('Line exceeds maximum page size %s' % (self.max_size - len(self.prefix) - 2))
+ raise RuntimeError("Line exceeds maximum page size %s" % (self.max_size - len(self.prefix) - 2))
if self.max_lines is not None:
if self._linecount >= self.max_lines:
@@ -72,7 +72,7 @@ class LinePaginator(Paginator):
self._current_page.append(line)
if empty:
- self._current_page.append('')
+ self._current_page.append("")
self._count += 1
@classmethod
@@ -80,7 +80,7 @@ class LinePaginator(Paginator):
prefix: str = "", suffix: str = "", max_lines: Optional[int] = None,
max_size: int = 500, empty: bool = True, restrict_to_user: User = None,
timeout: int = 300, footer_text: str = None, url: str = None,
- exception_on_empty_embed: bool = False):
+ exception_on_empty_embed: bool = False) -> None:
"""
Use a paginator and set of reactions to provide pagination over a set of lines.
@@ -158,7 +158,8 @@ class LinePaginator(Paginator):
log.trace(f"Setting embed url to '{url}'")
log.debug("There's less than two pages, so we won't paginate - sending single page on its own")
- return await ctx.send(embed=embed)
+ await ctx.send(embed=embed)
+ return
else:
if footer_text:
embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})")
@@ -283,7 +284,7 @@ class ImagePaginator(Paginator):
self.images = []
self._pages = []
- def add_line(self, line: str = '', *, empty: bool = False) -> None:
+ def add_line(self, line: str = "", *, empty: bool = False) -> None:
"""
Adds a line to each page, usually just 1 line in this context.
@@ -303,7 +304,7 @@ class ImagePaginator(Paginator):
@classmethod
async def paginate(cls, pages: List[Tuple[str, str]], ctx: Context, embed: Embed,
prefix: str = "", suffix: str = "", timeout: int = 300,
- exception_on_empty_embed: bool = False):
+ exception_on_empty_embed: bool = False) -> None:
"""
Use a paginator and set of reactions to provide pagination over a set of title/image pairs.
@@ -353,7 +354,8 @@ class ImagePaginator(Paginator):
embed.set_image(url=image)
if len(paginator.pages) <= 1:
- return await ctx.send(embed=embed)
+ await ctx.send(embed=embed)
+ return
embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}")
message = await ctx.send(embed=embed)