aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-07-17 10:08:05 -0700
committerGravatar MarkKoz <[email protected]>2020-07-31 22:58:05 -0700
commitda33c330a02f2ff10838d0827e8c26a045729449 (patch)
treeceea59adb489d976d5598fa1a02192c22131cf7d
parentDecorators: more accurate return type for checks (diff)
Decorators: clean up imports
-rw-r--r--bot/decorators.py50
-rw-r--r--tests/bot/test_decorators.py4
2 files changed, 26 insertions, 28 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index b9182f664..d9e5e3a83 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,15 +1,13 @@
+import asyncio
import logging
import random
-from asyncio import Lock, create_task, sleep
+import typing as t
from contextlib import suppress
from functools import wraps
-from typing import Callable, Container, Optional, Union
from weakref import WeakValueDictionary
-from discord import Colour, Embed, Member
-from discord.errors import NotFound
-from discord.ext import commands
-from discord.ext.commands import Cog, Context
+from discord import Colour, Embed, Member, NotFound
+from discord.ext.commands import Cog, Command, Context, check
from bot.constants import Channels, ERROR_REPLIES, RedirectOutput
from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check
@@ -19,12 +17,12 @@ log = logging.getLogger(__name__)
def in_whitelist(
*,
- channels: Container[int] = (),
- categories: Container[int] = (),
- roles: Container[int] = (),
- redirect: Optional[int] = Channels.bot_commands,
+ channels: t.Container[int] = (),
+ categories: t.Container[int] = (),
+ roles: t.Container[int] = (),
+ redirect: t.Optional[int] = Channels.bot_commands,
fail_silently: bool = False,
-) -> commands.Command:
+) -> Command:
"""
Check if a command was issued in a whitelisted context.
@@ -42,25 +40,25 @@ def in_whitelist(
"""Check if command was issued in a whitelisted context."""
return in_whitelist_check(ctx, channels, categories, roles, redirect, fail_silently)
- return commands.check(predicate)
+ return check(predicate)
-def with_role(*role_ids: int) -> commands.Command:
+def with_role(*role_ids: int) -> Command:
"""Returns True if the user has any one of the roles in role_ids."""
async def predicate(ctx: Context) -> bool:
"""With role checker predicate."""
return with_role_check(ctx, *role_ids)
- return commands.check(predicate)
+ return check(predicate)
-def without_role(*role_ids: int) -> commands.Command:
+def without_role(*role_ids: int) -> Command:
"""Returns True if the user does not have any of the roles in role_ids."""
async def predicate(ctx: Context) -> bool:
return without_role_check(ctx, *role_ids)
- return commands.check(predicate)
+ return check(predicate)
-def locked() -> Callable:
+def locked() -> t.Callable:
"""
Allows the user to only run one instance of the decorated command at a time.
@@ -68,12 +66,12 @@ def locked() -> Callable:
This decorator must go before (below) the `command` decorator.
"""
- def wrap(func: Callable) -> Callable:
+ def wrap(func: t.Callable) -> t.Callable:
func.__locks = WeakValueDictionary()
@wraps(func)
async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
- lock = func.__locks.setdefault(ctx.author.id, Lock())
+ lock = func.__locks.setdefault(ctx.author.id, asyncio.Lock())
if lock.locked():
embed = Embed()
embed.colour = Colour.red()
@@ -86,13 +84,13 @@ def locked() -> Callable:
await ctx.send(embed=embed)
return
- async with func.__locks.setdefault(ctx.author.id, Lock()):
+ async with func.__locks.setdefault(ctx.author.id, asyncio.Lock()):
await func(self, ctx, *args, **kwargs)
return inner
return wrap
-def redirect_output(destination_channel: int, bypass_roles: Container[int] = None) -> Callable:
+def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable:
"""
Changes the channel in the context of the command to redirect the output to a certain channel.
@@ -100,7 +98,7 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
This decorator must go before (below) the `command` decorator.
"""
- def wrap(func: Callable) -> Callable:
+ def wrap(func: t.Callable) -> t.Callable:
@wraps(func)
async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
if ctx.channel.id == destination_channel:
@@ -119,14 +117,14 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
log.trace(f"Redirecting output of {ctx.author}'s command '{ctx.command.name}' to {redirect_channel.name}")
ctx.channel = redirect_channel
await ctx.channel.send(f"Here's the output of your command, {ctx.author.mention}")
- create_task(func(self, ctx, *args, **kwargs))
+ asyncio.create_task(func(self, ctx, *args, **kwargs))
message = await old_channel.send(
f"Hey, {ctx.author.mention}, you can find the output of your command here: "
f"{redirect_channel.mention}"
)
if RedirectOutput.delete_invocation:
- await sleep(RedirectOutput.delete_delay)
+ await asyncio.sleep(RedirectOutput.delete_delay)
with suppress(NotFound):
await message.delete()
@@ -140,7 +138,7 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
return wrap
-def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable:
+def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable:
"""
Ensure the highest role of the invoking member is greater than that of the target member.
@@ -152,7 +150,7 @@ def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable:
This decorator must go before (below) the `command` decorator.
"""
- def wrap(func: Callable) -> Callable:
+ def wrap(func: t.Callable) -> t.Callable:
@wraps(func)
async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
try:
diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py
index 3d450caa0..22e93c1c4 100644
--- a/tests/bot/test_decorators.py
+++ b/tests/bot/test_decorators.py
@@ -67,7 +67,7 @@ class InWhitelistTests(unittest.TestCase):
for test_case in test_cases:
# patch `commands.check` with a no-op lambda that just returns the predicate passed to it
# so we can test the predicate that was generated from the specified kwargs.
- with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate):
+ with unittest.mock.patch("bot.decorators.check", new=lambda predicate: predicate):
predicate = in_whitelist(**test_case.kwargs)
with self.subTest(test_description=test_case.description):
@@ -139,7 +139,7 @@ class InWhitelistTests(unittest.TestCase):
# patch `commands.check` with a no-op lambda that just returns the predicate passed to it
# so we can test the predicate that was generated from the specified kwargs.
- with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate):
+ with unittest.mock.patch("bot.decorators.check", new=lambda predicate: predicate):
predicate = in_whitelist(**test_case.kwargs)
with self.subTest(test_description=test_case.description):