aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils
diff options
context:
space:
mode:
Diffstat (limited to 'bot/utils')
-rw-r--r--bot/utils/checks.py164
-rw-r--r--bot/utils/converters.py16
-rw-r--r--bot/utils/decorators.py22
-rw-r--r--bot/utils/exceptions.py6
-rw-r--r--bot/utils/extensions.py34
-rw-r--r--bot/utils/persist.py69
6 files changed, 215 insertions, 96 deletions
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
new file mode 100644
index 00000000..9dd4dde0
--- /dev/null
+++ b/bot/utils/checks.py
@@ -0,0 +1,164 @@
+import datetime
+import logging
+from typing import Callable, Container, Iterable, Optional
+
+from discord.ext.commands import (
+ BucketType,
+ CheckFailure,
+ Cog,
+ Command,
+ CommandOnCooldown,
+ Context,
+ Cooldown,
+ CooldownMapping,
+)
+
+from bot import constants
+
+log = logging.getLogger(__name__)
+
+
+class InWhitelistCheckFailure(CheckFailure):
+ """Raised when the `in_whitelist` check fails."""
+
+ def __init__(self, redirect_channel: Optional[int]) -> None:
+ self.redirect_channel = redirect_channel
+
+ if redirect_channel:
+ redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
+ else:
+ redirect_message = ""
+
+ error_message = f"You are not allowed to use that command{redirect_message}."
+
+ super().__init__(error_message)
+
+
+def in_whitelist_check(
+ ctx: Context,
+ channels: Container[int] = (),
+ categories: Container[int] = (),
+ roles: Container[int] = (),
+ redirect: Optional[int] = constants.Channels.community_bot_commands,
+ fail_silently: bool = False,
+) -> bool:
+ """
+ Check if a command was issued in a whitelisted context.
+
+ The whitelists that can be provided are:
+
+ - `channels`: a container with channel ids for whitelisted channels
+ - `categories`: a container with category ids for whitelisted categories
+ - `roles`: a container with with role ids for whitelisted roles
+
+ If the command was invoked in a context that was not whitelisted, the member is either
+ redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
+ told that they're not allowed to use this particular command (if `None` was passed).
+ """
+ if redirect and redirect not in channels:
+ # It does not make sense for the channel whitelist to not contain the redirection
+ # channel (if applicable). That's why we add the redirection channel to the `channels`
+ # container if it's not already in it. As we allow any container type to be passed,
+ # we first create a tuple in order to safely add the redirection channel.
+ #
+ # Note: It's possible for the redirect channel to be in a whitelisted category, but
+ # there's no easy way to check that and as a channel can easily be moved in and out of
+ # categories, it's probably not wise to rely on its category in any case.
+ channels = tuple(channels) + (redirect,)
+
+ if channels and ctx.channel.id in channels:
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.")
+ return True
+
+ # Only check the category id if we have a category whitelist and the channel has a `category_id`
+ if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories:
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.")
+ return True
+
+ # Only check the roles whitelist if we have one and ensure the author's roles attribute returns
+ # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there).
+ if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())):
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.")
+ return True
+
+ log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.")
+
+ # Some commands are secret, and should produce no feedback at all.
+ if not fail_silently:
+ raise InWhitelistCheckFailure(redirect)
+ return False
+
+
+def with_role_check(ctx: Context, *role_ids: int) -> bool:
+ """Returns True if the user has any one of the roles in role_ids."""
+ if not ctx.guild: # Return False in a DM
+ log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
+ "This command is restricted by the with_role decorator. Rejecting request.")
+ return False
+
+ for role in ctx.author.roles:
+ if role.id in role_ids:
+ log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.")
+ return True
+
+ log.trace(f"{ctx.author} does not have the required role to use "
+ f"the '{ctx.command.name}' command, so the request is rejected.")
+ return False
+
+
+def without_role_check(ctx: Context, *role_ids: int) -> bool:
+ """Returns True if the user does not have any of the roles in role_ids."""
+ if not ctx.guild: # Return False in a DM
+ log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
+ "This command is restricted by the without_role decorator. Rejecting request.")
+ return False
+
+ author_roles = [role.id for role in ctx.author.roles]
+ check = all(role not in author_roles for role in role_ids)
+ log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The result of the without_role check was {check}.")
+ return check
+
+
+def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *,
+ bypass_roles: Iterable[int]) -> Callable:
+ """
+ Applies a cooldown to a command, but allows members with certain roles to be ignored.
+
+ NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future.
+ """
+ # Make it a set so lookup is hash based.
+ bypass = set(bypass_roles)
+
+ # This handles the actual cooldown logic.
+ buckets = CooldownMapping(Cooldown(rate, per, type))
+
+ # Will be called after the command has been parse but before it has been invoked, ensures that
+ # the cooldown won't be updated if the user screws up their input to the command.
+ async def predicate(cog: Cog, ctx: Context) -> None:
+ nonlocal bypass, buckets
+
+ if any(role.id in bypass for role in ctx.author.roles):
+ return
+
+ # Cooldown logic, taken from discord.py internals.
+ current = ctx.message.created_at.replace(tzinfo=datetime.timezone.utc).timestamp()
+ bucket = buckets.get_bucket(ctx.message)
+ retry_after = bucket.update_rate_limit(current)
+ if retry_after:
+ raise CommandOnCooldown(bucket, retry_after)
+
+ def wrapper(command: Command) -> Command:
+ # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it
+ # so I just made it raise an error when the decorator is applied before the actual command object exists.
+ #
+ # If the `before_invoke` detail is ever a problem then I can quickly just swap over.
+ if not isinstance(command, Command):
+ raise TypeError('Decorator `cooldown_with_role_bypass` must be applied after the command decorator. '
+ 'This means it has to be above the command decorator in the code.')
+
+ command._before_invoke = predicate
+
+ return command
+
+ return wrapper
diff --git a/bot/utils/converters.py b/bot/utils/converters.py
new file mode 100644
index 00000000..228714c9
--- /dev/null
+++ b/bot/utils/converters.py
@@ -0,0 +1,16 @@
+import discord
+from discord.ext.commands.converter import MessageConverter
+
+
+class WrappedMessageConverter(MessageConverter):
+ """A converter that handles embed-suppressed links like <http://example.com>."""
+
+ async def convert(self, ctx: discord.ext.commands.Context, argument: str) -> discord.Message:
+ """Wrap the commands.MessageConverter to handle <> delimited message links."""
+ # It's possible to wrap a message in [<>] as well, and it's supported because its easy
+ if argument.startswith("[") and argument.endswith("]"):
+ argument = argument[1:-1]
+ if argument.startswith("<") and argument.endswith(">"):
+ argument = argument[1:-1]
+
+ return await super().convert(ctx, argument)
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index 9e6ef73d..9cdaad3f 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -11,7 +11,7 @@ from discord import Colour, Embed
from discord.ext import commands
from discord.ext.commands import CheckFailure, Command, Context
-from bot.constants import Client, ERROR_REPLIES, Month
+from bot.constants import ERROR_REPLIES, Month
from bot.utils import human_months, resolve_current_month
ONE_DAY = 24 * 60 * 60
@@ -298,23 +298,3 @@ def locked() -> t.Union[t.Callable, None]:
return await func(self, ctx, *args, **kwargs)
return inner
return wrap
-
-
-def mock_in_debug(return_value: t.Any) -> t.Callable:
- """
- Short-circuit function execution if in debug mode and return `return_value`.
-
- The original function name, and the incoming args and kwargs are DEBUG level logged
- upon each call. This is useful for expensive operations, i.e. media asset uploads
- that are prone to rate-limits but need to be tested extensively.
- """
- def decorator(func: t.Callable) -> t.Callable:
- @functools.wraps(func)
- async def wrapped(*args, **kwargs) -> t.Any:
- """Short-circuit and log if in debug mode."""
- if Client.debug:
- log.debug(f"Function {func.__name__} called with args: {args}, kwargs: {kwargs}")
- return return_value
- return await func(*args, **kwargs)
- return wrapped
- return decorator
diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py
index dc62debe..2b1c1b31 100644
--- a/bot/utils/exceptions.py
+++ b/bot/utils/exceptions.py
@@ -1,9 +1,3 @@
-class BrandingError(Exception):
- """Exception raised by the BrandingManager cog."""
-
- pass
-
-
class UserNotPlayingError(Exception):
"""Will raised when user try to use game commands when not playing."""
diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py
new file mode 100644
index 00000000..50350ea8
--- /dev/null
+++ b/bot/utils/extensions.py
@@ -0,0 +1,34 @@
+import importlib
+import inspect
+import pkgutil
+from typing import Iterator, NoReturn
+
+from bot import exts
+
+
+def unqualify(name: str) -> str:
+ """Return an unqualified name given a qualified module/package `name`."""
+ return name.rsplit(".", maxsplit=1)[-1]
+
+
+def walk_extensions() -> Iterator[str]:
+ """Yield extension names from the bot.exts subpackage."""
+
+ def on_error(name: str) -> NoReturn:
+ raise ImportError(name=name) # pragma: no cover
+
+ for module in pkgutil.walk_packages(exts.__path__, f"{exts.__name__}.", onerror=on_error):
+ if unqualify(module.name).startswith("_"):
+ # Ignore module/package names starting with an underscore.
+ continue
+
+ if module.ispkg:
+ imported = importlib.import_module(module.name)
+ if not inspect.isfunction(getattr(imported, "setup", None)):
+ # If it lacks a setup function, it's not an extension.
+ continue
+
+ yield module.name
+
+
+EXTENSIONS = frozenset(walk_extensions())
diff --git a/bot/utils/persist.py b/bot/utils/persist.py
deleted file mode 100644
index 1e178569..00000000
--- a/bot/utils/persist.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import sqlite3
-from pathlib import Path
-from shutil import copyfile
-
-from bot.exts import get_package_names
-
-DIRECTORY = Path("data") # directory that has a persistent volume mapped to it
-
-
-def make_persistent(file_path: Path) -> Path:
- """
- Copy datafile at the provided file_path to the persistent data directory.
-
- A persistent data file is needed by some features in order to not lose data
- after bot rebuilds.
-
- This function will ensure that a clean data file with default schema,
- structure or data is copied over to the persistent volume before returning
- the path to this new persistent version of the file.
-
- If the persistent file already exists, it won't be overwritten with the
- clean default file, just returning the Path instead to the existing file.
-
- Note: Avoid using the same file name as other features in the same seasons
- as otherwise only one datafile can be persistent and will be returned for
- both cases.
-
- Ensure that all open files are using explicit appropriate encoding to avoid
- encoding errors from diffent OS systems.
-
- Example Usage:
- >>> import json
- >>> template_datafile = Path("bot", "resources", "evergreen", "myfile.json")
- >>> path_to_persistent_file = make_persistent(template_datafile)
- >>> print(path_to_persistent_file)
- data/evergreen/myfile.json
- >>> with path_to_persistent_file.open("w+", encoding="utf8") as f:
- >>> data = json.load(f)
- """
- # ensure the persistent data directory exists
- DIRECTORY.mkdir(exist_ok=True)
-
- if not file_path.is_file():
- raise OSError(f"File not found at {file_path}.")
-
- # detect season in datafile path for assigning to subdirectory
- season = next((s for s in get_package_names() if s in file_path.parts), None)
-
- if season:
- # make sure subdirectory exists first
- subdirectory = Path(DIRECTORY, season)
- subdirectory.mkdir(exist_ok=True)
-
- persistent_path = Path(subdirectory, file_path.name)
-
- else:
- persistent_path = Path(DIRECTORY, file_path.name)
-
- # copy base/template datafile to persistent directory
- if not persistent_path.exists():
- copyfile(file_path, persistent_path)
-
- return persistent_path
-
-
-def sqlite(db_path: Path) -> sqlite3.Connection:
- """Copy sqlite file to the persistent data directory and return an open connection."""
- persistent_path = make_persistent(db_path)
- return sqlite3.connect(persistent_path)