diff options
| author | 2020-07-17 10:08:05 -0700 | |
|---|---|---|
| committer | 2020-07-31 22:58:05 -0700 | |
| commit | da33c330a02f2ff10838d0827e8c26a045729449 (patch) | |
| tree | ceea59adb489d976d5598fa1a02192c22131cf7d | |
| parent | Decorators: more accurate return type for checks (diff) | |
Decorators: clean up imports
| -rw-r--r-- | bot/decorators.py | 50 | ||||
| -rw-r--r-- | tests/bot/test_decorators.py | 4 |
2 files changed, 26 insertions, 28 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index b9182f664..d9e5e3a83 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,15 +1,13 @@ +import asyncio import logging import random -from asyncio import Lock, create_task, sleep +import typing as t from contextlib import suppress from functools import wraps -from typing import Callable, Container, Optional, Union from weakref import WeakValueDictionary -from discord import Colour, Embed, Member -from discord.errors import NotFound -from discord.ext import commands -from discord.ext.commands import Cog, Context +from discord import Colour, Embed, Member, NotFound +from discord.ext.commands import Cog, Command, Context, check from bot.constants import Channels, ERROR_REPLIES, RedirectOutput from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check @@ -19,12 +17,12 @@ log = logging.getLogger(__name__) def in_whitelist( *, - channels: Container[int] = (), - categories: Container[int] = (), - roles: Container[int] = (), - redirect: Optional[int] = Channels.bot_commands, + channels: t.Container[int] = (), + categories: t.Container[int] = (), + roles: t.Container[int] = (), + redirect: t.Optional[int] = Channels.bot_commands, fail_silently: bool = False, -) -> commands.Command: +) -> Command: """ Check if a command was issued in a whitelisted context. @@ -42,25 +40,25 @@ def in_whitelist( """Check if command was issued in a whitelisted context.""" return in_whitelist_check(ctx, channels, categories, roles, redirect, fail_silently) - return commands.check(predicate) + return check(predicate) -def with_role(*role_ids: int) -> commands.Command: +def with_role(*role_ids: int) -> Command: """Returns True if the user has any one of the roles in role_ids.""" async def predicate(ctx: Context) -> bool: """With role checker predicate.""" return with_role_check(ctx, *role_ids) - return commands.check(predicate) + return check(predicate) -def without_role(*role_ids: int) -> commands.Command: +def without_role(*role_ids: int) -> Command: """Returns True if the user does not have any of the roles in role_ids.""" async def predicate(ctx: Context) -> bool: return without_role_check(ctx, *role_ids) - return commands.check(predicate) + return check(predicate) -def locked() -> Callable: +def locked() -> t.Callable: """ Allows the user to only run one instance of the decorated command at a time. @@ -68,12 +66,12 @@ def locked() -> Callable: This decorator must go before (below) the `command` decorator. """ - def wrap(func: Callable) -> Callable: + def wrap(func: t.Callable) -> t.Callable: func.__locks = WeakValueDictionary() @wraps(func) async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: - lock = func.__locks.setdefault(ctx.author.id, Lock()) + lock = func.__locks.setdefault(ctx.author.id, asyncio.Lock()) if lock.locked(): embed = Embed() embed.colour = Colour.red() @@ -86,13 +84,13 @@ def locked() -> Callable: await ctx.send(embed=embed) return - async with func.__locks.setdefault(ctx.author.id, Lock()): + async with func.__locks.setdefault(ctx.author.id, asyncio.Lock()): await func(self, ctx, *args, **kwargs) return inner return wrap -def redirect_output(destination_channel: int, bypass_roles: Container[int] = None) -> Callable: +def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable: """ Changes the channel in the context of the command to redirect the output to a certain channel. @@ -100,7 +98,7 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non This decorator must go before (below) the `command` decorator. """ - def wrap(func: Callable) -> Callable: + def wrap(func: t.Callable) -> t.Callable: @wraps(func) async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: if ctx.channel.id == destination_channel: @@ -119,14 +117,14 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non log.trace(f"Redirecting output of {ctx.author}'s command '{ctx.command.name}' to {redirect_channel.name}") ctx.channel = redirect_channel await ctx.channel.send(f"Here's the output of your command, {ctx.author.mention}") - create_task(func(self, ctx, *args, **kwargs)) + asyncio.create_task(func(self, ctx, *args, **kwargs)) message = await old_channel.send( f"Hey, {ctx.author.mention}, you can find the output of your command here: " f"{redirect_channel.mention}" ) if RedirectOutput.delete_invocation: - await sleep(RedirectOutput.delete_delay) + await asyncio.sleep(RedirectOutput.delete_delay) with suppress(NotFound): await message.delete() @@ -140,7 +138,7 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non return wrap -def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable: +def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable: """ Ensure the highest role of the invoking member is greater than that of the target member. @@ -152,7 +150,7 @@ def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable: This decorator must go before (below) the `command` decorator. """ - def wrap(func: Callable) -> Callable: + def wrap(func: t.Callable) -> t.Callable: @wraps(func) async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: try: diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py index 3d450caa0..22e93c1c4 100644 --- a/tests/bot/test_decorators.py +++ b/tests/bot/test_decorators.py @@ -67,7 +67,7 @@ class InWhitelistTests(unittest.TestCase): for test_case in test_cases: # patch `commands.check` with a no-op lambda that just returns the predicate passed to it # so we can test the predicate that was generated from the specified kwargs. - with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): + with unittest.mock.patch("bot.decorators.check", new=lambda predicate: predicate): predicate = in_whitelist(**test_case.kwargs) with self.subTest(test_description=test_case.description): @@ -139,7 +139,7 @@ class InWhitelistTests(unittest.TestCase): # patch `commands.check` with a no-op lambda that just returns the predicate passed to it # so we can test the predicate that was generated from the specified kwargs. - with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): + with unittest.mock.patch("bot.decorators.check", new=lambda predicate: predicate): predicate = in_whitelist(**test_case.kwargs) with self.subTest(test_description=test_case.description): |