diff options
| author | 2019-09-10 21:56:25 -0400 | |
|---|---|---|
| committer | 2019-09-10 21:56:25 -0400 | |
| commit | 4f2ca226fe61b62e4e560805f3adbc2abd3d5c16 (patch) | |
| tree | 9d2fa3c6653919dc6e2f0f450bd0488d77dcd9ce | |
| parent | Docstring linting chunk 6 (diff) | |
Docstring linting chunk 7
Whew
| -rw-r--r-- | bot/interpreter.py | 14 | ||||
| -rw-r--r-- | bot/rules/attachments.py | 2 | ||||
| -rw-r--r-- | bot/rules/burst.py | 2 | ||||
| -rw-r--r-- | bot/rules/burst_shared.py | 2 | ||||
| -rw-r--r-- | bot/rules/chars.py | 2 | ||||
| -rw-r--r-- | bot/rules/discord_emojis.py | 2 | ||||
| -rw-r--r-- | bot/rules/duplicates.py | 2 | ||||
| -rw-r--r-- | bot/rules/links.py | 2 | ||||
| -rw-r--r-- | bot/rules/mentions.py | 2 | ||||
| -rw-r--r-- | bot/rules/newlines.py | 2 | ||||
| -rw-r--r-- | bot/rules/role_mentions.py | 2 | ||||
| -rw-r--r-- | bot/utils/__init__.py | 37 | ||||
| -rw-r--r-- | bot/utils/checks.py | 18 | ||||
| -rw-r--r-- | bot/utils/messages.py | 63 | ||||
| -rw-r--r-- | bot/utils/moderation.py | 8 | ||||
| -rw-r--r-- | bot/utils/scheduling.py | 43 | ||||
| -rw-r--r-- | bot/utils/time.py | 48 | 
17 files changed, 96 insertions, 155 deletions
| diff --git a/bot/interpreter.py b/bot/interpreter.py index 06343db39..6ea49e026 100644 --- a/bot/interpreter.py +++ b/bot/interpreter.py @@ -1,5 +1,8 @@  from code import InteractiveInterpreter  from io import StringIO +from typing import Any + +from discord.ext.commands import Bot, Context  CODE_TEMPLATE = """  async def _func(): @@ -8,13 +11,20 @@ async def _func():  class Interpreter(InteractiveInterpreter): +    """ +    Subclass InteractiveInterpreter to specify custom run functionality. + +    Helper class for internal eval +    """ +      write_callable = None -    def __init__(self, bot): +    def __init__(self, bot: Bot):          _locals = {"bot": bot}          super().__init__(_locals) -    async def run(self, code, ctx, io, *args, **kwargs): +    async def run(self, code: str, ctx: Context, io: StringIO, *args, **kwargs) -> Any: +        """Execute the provided source code as the bot & return the output."""          self.locals["_rvalue"] = []          self.locals["ctx"] = ctx          self.locals["print"] = lambda x: io.write(f"{x}\n") diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index 47b927101..e71b96183 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -10,7 +10,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply attachment spam detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/rules/burst.py b/bot/rules/burst.py index 80c79be60..8859f8d51 100644 --- a/bot/rules/burst.py +++ b/bot/rules/burst.py @@ -10,7 +10,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply burst message spam detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/rules/burst_shared.py b/bot/rules/burst_shared.py index 2cb7b5200..b8c73ecb4 100644 --- a/bot/rules/burst_shared.py +++ b/bot/rules/burst_shared.py @@ -10,7 +10,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply burst repeated message spam filter."""      total_recent = len(recent_messages)      if total_recent > config['max']: diff --git a/bot/rules/chars.py b/bot/rules/chars.py index d05e3cd83..ae8ac93ef 100644 --- a/bot/rules/chars.py +++ b/bot/rules/chars.py @@ -10,7 +10,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply excessive character count detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/rules/discord_emojis.py b/bot/rules/discord_emojis.py index e4f957ddb..87d129f37 100644 --- a/bot/rules/discord_emojis.py +++ b/bot/rules/discord_emojis.py @@ -14,7 +14,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply emoji spam detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/rules/duplicates.py b/bot/rules/duplicates.py index 763fc9983..8648fd955 100644 --- a/bot/rules/duplicates.py +++ b/bot/rules/duplicates.py @@ -10,7 +10,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply duplicate message spam detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/rules/links.py b/bot/rules/links.py index fa4043fcb..924f092b1 100644 --- a/bot/rules/links.py +++ b/bot/rules/links.py @@ -14,7 +14,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply link spam detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 45c47b6ba..3372fd1e1 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -10,7 +10,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply user mention spam detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/rules/newlines.py b/bot/rules/newlines.py index fdad6ffd3..d04f8c9ed 100644 --- a/bot/rules/newlines.py +++ b/bot/rules/newlines.py @@ -11,7 +11,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply newline spam detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/rules/role_mentions.py b/bot/rules/role_mentions.py index 2177a73b5..a8b819d0d 100644 --- a/bot/rules/role_mentions.py +++ b/bot/rules/role_mentions.py @@ -10,7 +10,7 @@ async def apply(      recent_messages: List[Message],      config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - +    """Apply role mention spam detection filter."""      relevant_messages = tuple(          msg          for msg in recent_messages diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 4c99d50e8..141559657 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,3 +1,5 @@ +from typing import Any, Generator, Hashable, Iterable +  class CaseInsensitiveDict(dict):      """ @@ -7,50 +9,59 @@ class CaseInsensitiveDict(dict):      """      @classmethod -    def _k(cls, key): +    def _k(cls, key: Hashable) -> Any: +        """Return lowered key if a string-like is passed, otherwise pass key straight through."""          return key.lower() if isinstance(key, str) else key      def __init__(self, *args, **kwargs):          super(CaseInsensitiveDict, self).__init__(*args, **kwargs)          self._convert_keys() -    def __getitem__(self, key): +    def __getitem__(self, key: Hashable) -> Any: +        """Case insensitive __setitem__."""          return super(CaseInsensitiveDict, self).__getitem__(self.__class__._k(key)) -    def __setitem__(self, key, value): +    def __setitem__(self, key: Hashable, value: Any): +        """Case insensitive __setitem__."""          super(CaseInsensitiveDict, self).__setitem__(self.__class__._k(key), value) -    def __delitem__(self, key): +    def __delitem__(self, key: Hashable) -> Any: +        """Case insensitive __delitem__."""          return super(CaseInsensitiveDict, self).__delitem__(self.__class__._k(key)) -    def __contains__(self, key): +    def __contains__(self, key: Hashable) -> bool: +        """Case insensitive __contains__."""          return super(CaseInsensitiveDict, self).__contains__(self.__class__._k(key)) -    def pop(self, key, *args, **kwargs): +    def pop(self, key: Hashable, *args, **kwargs) -> Any: +        """Case insensitive pop."""          return super(CaseInsensitiveDict, self).pop(self.__class__._k(key), *args, **kwargs) -    def get(self, key, *args, **kwargs): +    def get(self, key: Hashable, *args, **kwargs) -> Any: +        """Case insensitive get."""          return super(CaseInsensitiveDict, self).get(self.__class__._k(key), *args, **kwargs) -    def setdefault(self, key, *args, **kwargs): +    def setdefault(self, key: Hashable, *args, **kwargs) -> Any: +        """Case insensitive setdefault."""          return super(CaseInsensitiveDict, self).setdefault(self.__class__._k(key), *args, **kwargs) -    def update(self, E=None, **F): +    def update(self, E: Any = None, **F) -> None: +        """Case insensitive update."""          super(CaseInsensitiveDict, self).update(self.__class__(E))          super(CaseInsensitiveDict, self).update(self.__class__(**F)) -    def _convert_keys(self): +    def _convert_keys(self) -> None: +        """Helper method to lowercase all existing string-like keys."""          for k in list(self.keys()):              v = super(CaseInsensitiveDict, self).pop(k)              self.__setitem__(k, v) -def chunks(iterable, size): +def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]:      """ -    Generator that allows you to iterate over any indexable collection in `size`-length chunks +    Generator that allows you to iterate over any indexable collection in `size`-length chunks.      Found: https://stackoverflow.com/a/312464/4022104      """ -      for i in range(0, len(iterable), size):          yield iterable[i:i + size] diff --git a/bot/utils/checks.py b/bot/utils/checks.py index 37dc657f7..1f4c1031b 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -6,11 +6,7 @@ log = logging.getLogger(__name__)  def with_role_check(ctx: Context, *role_ids: int) -> bool: -    """ -    Returns True if the user has any one -    of the roles in role_ids. -    """ - +    """Returns True if the user has any one of the roles in role_ids."""      if not ctx.guild:  # Return False in a DM          log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "                    "This command is restricted by the with_role decorator. Rejecting request.") @@ -27,11 +23,7 @@ def with_role_check(ctx: Context, *role_ids: int) -> bool:  def without_role_check(ctx: Context, *role_ids: int) -> bool: -    """ -    Returns True if the user does not have any -    of the roles in role_ids. -    """ - +    """Returns True if the user does not have any of the roles in role_ids."""      if not ctx.guild:  # Return False in a DM          log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "                    "This command is restricted by the without_role decorator. Rejecting request.") @@ -45,11 +37,7 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool:  def in_channel_check(ctx: Context, channel_id: int) -> bool: -    """ -    Checks if the command was executed -    inside of the specified channel. -    """ - +    """Checks if the command was executed inside of the specified channel."""      check = ctx.channel.id == channel_id      log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "                f"The result of the in_channel check was {check}.") diff --git a/bot/utils/messages.py b/bot/utils/messages.py index 94a8b36ed..5058d42fc 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -1,9 +1,9 @@  import asyncio  import contextlib  from io import BytesIO -from typing import Sequence, Union +from typing import Optional, Sequence, Union -from discord import Embed, File, Message, TextChannel, Webhook +from discord import Client, Embed, File, Member, Message, Reaction, TextChannel, Webhook  from discord.abc import Snowflake  from discord.errors import HTTPException @@ -17,42 +17,18 @@ async def wait_for_deletion(      user_ids: Sequence[Snowflake],      deletion_emojis: Sequence[str] = (Emojis.cross_mark,),      timeout: float = 60 * 5, -    attach_emojis=True, -    client=None -): -    """ -    Waits for up to `timeout` seconds for a reaction by -    any of the specified `user_ids` to delete the message. - -    Args: -        message (Message): -            The message that should be monitored for reactions -            and possibly deleted. Must be a message sent on a -            guild since access to the bot instance is required. - -        user_ids (Sequence[Snowflake]): -            A sequence of users that are allowed to delete -            this message. - -    Kwargs: -        deletion_emojis (Sequence[str]): -            A sequence of emojis that are considered deletion -            emojis. - -        timeout (float): -            A positive float denoting the maximum amount of -            time to wait for a deletion reaction. - -        attach_emojis (bool): -            Whether to attach the given `deletion_emojis` -            to the message in the given `context` - -        client (Optional[discord.Client]): -            The client instance handling the original command. -            If not given, will take the client from the guild -            of the message. +    attach_emojis: bool = True, +    client: Optional[Client] = None +) -> None:      """ +    Waits for up to `timeout` seconds for a reaction by any of the specified `user_ids` to delete the message. + +    An `attach_emojis` bool may be specified to determine whether to attach the given +    `deletion_emojis` to the message in the given `context` +    A `client` instance may be optionally specified, otherwise client will be taken from the +    guild of the message. +    """      if message.guild is None and client is None:          raise ValueError("Message must be sent on a guild") @@ -62,7 +38,8 @@ async def wait_for_deletion(          for emoji in deletion_emojis:              await message.add_reaction(emoji) -    def check(reaction, user): +    def check(reaction: Reaction, user: Member) -> bool: +        """Check that the deletion emoji is reacted by the approprite user."""          return (              reaction.message.id == message.id              and reaction.emoji in deletion_emojis @@ -70,25 +47,17 @@ async def wait_for_deletion(          )      with contextlib.suppress(asyncio.TimeoutError): -        await bot.wait_for( -            'reaction_add', -            check=check, -            timeout=timeout -        ) +        await bot.wait_for('reaction_add', check=check, timeout=timeout)          await message.delete() -async def send_attachments(message: Message, destination: Union[TextChannel, Webhook]): +async def send_attachments(message: Message, destination: Union[TextChannel, Webhook]) -> None:      """      Re-uploads each attachment in a message to the given channel or webhook.      Each attachment is sent as a separate message to more easily comply with the 8 MiB request size limit.      If attachments are too large, they are instead grouped into a single embed which links to them. - -    :param message: the message whose attachments to re-upload -    :param destination: the channel in which to re-upload the attachments      """ -      large = []      for attachment in message.attachments:          try: diff --git a/bot/utils/moderation.py b/bot/utils/moderation.py index c1eb98dd6..9ea2db07c 100644 --- a/bot/utils/moderation.py +++ b/bot/utils/moderation.py @@ -21,8 +21,8 @@ async def post_infraction(      expires_at: datetime = None,      hidden: bool = False,      active: bool = True, -): - +) -> Union[dict, None]: +    """Post infraction to the API."""      payload = {          "actor": ctx.message.author.id,          "hidden": hidden, @@ -35,9 +35,7 @@ async def post_infraction(          payload['expires_at'] = expires_at.isoformat()      try: -        response = await ctx.bot.api_client.post( -            'bot/infractions', json=payload -        ) +        response = await ctx.bot.api_client.post('bot/infractions', json=payload)      except ClientError:          log.exception("There was an error adding an infraction.")          await ctx.send(":x: There was an error adding the infraction.") diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index ded6401b0..9975b04e0 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -2,12 +2,13 @@ import asyncio  import contextlib  import logging  from abc import ABC, abstractmethod -from typing import Dict +from typing import Coroutine, Dict, Union  log = logging.getLogger(__name__)  class Scheduler(ABC): +    """Task scheduler."""      def __init__(self): @@ -15,24 +16,23 @@ class Scheduler(ABC):          self.scheduled_tasks: Dict[str, asyncio.Task] = {}      @abstractmethod -    async def _scheduled_task(self, task_object: dict): +    async def _scheduled_task(self, task_object: dict) -> None:          """ -        A coroutine which handles the scheduling. This is added to the scheduled tasks, -        and should wait the task duration, execute the desired code, and clean up the task. +        A coroutine which handles the scheduling. + +        This is added to the scheduled tasks, and should wait the task duration, execute the desired +        code, then clean up the task. +          For example, in Reminders this will wait for the reminder duration, send the reminder,          then make a site API request to delete the reminder from the database. - -        :param task_object:          """ -    def schedule_task(self, loop: asyncio.AbstractEventLoop, task_id: str, task_data: dict): +    def schedule_task(self, loop: asyncio.AbstractEventLoop, task_id: str, task_data: dict) -> None:          """          Schedules a task. -        :param loop: the asyncio event loop -        :param task_id: the ID of the task. -        :param task_data: the data of the task, passed to `Scheduler._scheduled_expiration`. -        """ +        `task_data` is passed to `Scheduler._scheduled_expiration` +        """          if task_id in self.scheduled_tasks:              return @@ -40,12 +40,8 @@ class Scheduler(ABC):          self.scheduled_tasks[task_id] = task -    def cancel_task(self, task_id: str): -        """ -        Un-schedules a task. -        :param task_id: the ID of the infraction in question -        """ - +    def cancel_task(self, task_id: str) -> None: +        """Un-schedules a task."""          task = self.scheduled_tasks.get(task_id)          if task is None: @@ -57,14 +53,8 @@ class Scheduler(ABC):          del self.scheduled_tasks[task_id] -def create_task(loop: asyncio.AbstractEventLoop, coro_or_future): -    """ -    Creates an asyncio.Task object from a coroutine or future object. - -    :param loop: the asyncio event loop. -    :param coro_or_future: the coroutine or future object to be scheduled. -    """ - +def create_task(loop: asyncio.AbstractEventLoop, coro_or_future: Union[Coroutine, asyncio.Future]) -> asyncio.Task: +    """Creates an asyncio.Task object from a coroutine or future object."""      task: asyncio.Task = asyncio.ensure_future(coro_or_future, loop=loop)      # Silently ignore exceptions in a callback (handles the CancelledError nonsense) @@ -72,6 +62,7 @@ def create_task(loop: asyncio.AbstractEventLoop, coro_or_future):      return task -def _silent_exception(future): +def _silent_exception(future: asyncio.Future) -> None: +    """Suppress future exception."""      with contextlib.suppress(Exception):          future.exception() diff --git a/bot/utils/time.py b/bot/utils/time.py index a330c9cd8..fe1c4e3ee 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -6,10 +6,9 @@ from dateutil.relativedelta import relativedelta  RFC1123_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" -def _stringify_time_unit(value: int, unit: str): +def _stringify_time_unit(value: int, unit: str) -> str:      """ -    Returns a string to represent a value and time unit, -    ensuring that it uses the right plural form of the unit. +    Returns a string to represent a value and time unit, ensuring that it uses the right plural form of the unit.      >>> _stringify_time_unit(1, "seconds")      "1 second" @@ -18,7 +17,6 @@ def _stringify_time_unit(value: int, unit: str):      >>> _stringify_time_unit(0, "minutes")      "less than a minute"      """ -      if value == 1:          return f"{value} {unit[:-1]}"      elif value == 0: @@ -27,18 +25,8 @@ def _stringify_time_unit(value: int, unit: str):          return f"{value} {unit}" -def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: int = 6): -    """ -    Returns a human-readable version of the relativedelta. - -    :param delta:      A dateutil.relativedelta.relativedelta object -    :param precision:  The smallest unit that should be included. -    :param max_units:  The maximum number of time-units to return. - -    :return:           A string like `4 days, 12 hours and 1 second`, -                       `1 minute`, or `less than a minute`. -    """ - +def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: int = 6) -> str: +    """Returns a human-readable version of the relativedelta."""      units = (          ("years", delta.years),          ("months", delta.months), @@ -73,19 +61,8 @@ def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units:      return humanized -def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max_units: int = 6): -    """ -    Takes a datetime and returns a human-readable string that -    describes how long ago that datetime was. - -    :param past_datetime:  A datetime.datetime object -    :param precision:      The smallest unit that should be included. -    :param max_units:      The maximum number of time-units to return. - -    :return:               A string like `4 days, 12 hours and 1 second ago`, -                           `1 minute ago`, or `less than a minute ago`. -    """ - +def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max_units: int = 6) -> str: +    """Takes a datetime and returns a human-readable string that describes how long ago that datetime was."""      now = datetime.datetime.utcnow()      delta = abs(relativedelta(now, past_datetime)) @@ -94,20 +71,17 @@ def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max      return f"{humanized} ago" -def parse_rfc1123(time_str): +def parse_rfc1123(time_str: str) -> datetime.datetime: +    """Parse RFC1123 time string into datetime."""      return datetime.datetime.strptime(time_str, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc)  # Hey, this could actually be used in the off_topic_names and reddit cogs :) -async def wait_until(time: datetime.datetime): -    """ -    Wait until a given time. - -    :param time: A datetime.datetime object to wait until. -    """ - +async def wait_until(time: datetime.datetime) -> None: +    """Wait until a given time."""      delay = time - datetime.datetime.utcnow()      delay_seconds = delay.total_seconds() +    # Incorporate a small delay so we don't rapid-fire the event due to time precision errors      if delay_seconds > 1.0:          await asyncio.sleep(delay_seconds) | 
