aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar SebastiaanZ <[email protected]>2019-04-23 16:58:21 +0200
committerGravatar SebastiaanZ <[email protected]>2019-06-27 22:03:44 +0200
commit54314b77a0687fd3d93ef9663ecd0555cc0f2b7b (patch)
treec35f110efce64219e816e2165140cab7ce5584c5
parentMerge pull request #372 from python-discord/snekbox (diff)
Adding the WatchChannel ABC
-rw-r--r--bot/cogs/watchchannels/watchchannel.py341
1 files changed, 341 insertions, 0 deletions
diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py
new file mode 100644
index 000000000..856211139
--- /dev/null
+++ b/bot/cogs/watchchannels/watchchannel.py
@@ -0,0 +1,341 @@
+import asyncio
+import datetime
+import logging
+import re
+from abc import ABC, abstractmethod
+from collections import defaultdict, deque
+from typing import Optional
+
+import aiohttp
+import discord
+from discord import Color, Embed, Message, errors
+from discord.ext.commands import Bot, Context
+
+from bot.constants import (
+ BigBrother as BigBrotherConfig, Guild as GuildConfig
+)
+from bot.pagination import LinePaginator
+from bot.utils import messages
+from bot.utils.time import time_since
+
+log = logging.getLogger(__name__)
+
+URL_RE = re.compile(r"(https?://[^\s]+)")
+
+
+class WatchChannel(ABC):
+ """
+ Base class for WatchChannels
+
+ Abstracts the basic functionality for watchchannels in
+ a granular manner to allow for easy overwritting of
+ methods in the child class.
+ """
+
+ @abstractmethod
+ def __init__(self, bot: Bot) -> None:
+ """
+ abstractmethod for __init__ which should still be called with super().
+
+ Note: Some of the attributes below need to be overwritten in the
+ __init__ of the child after the super().__init__(*args, **kwargs)
+ call.
+ """
+
+ self.bot = bot
+
+ # These attributes need to be overwritten in the child class
+ self.destination = None # Channels.big_brother_logs
+ self.webhook_id = None # Some t.b.d. constant
+ self.api_endpoint = None # 'bot/infractions'
+ self.api_default_params = None # {'active': 'true', 'type': 'watch'}
+
+ # These attributes can be left as they are in the child class
+ self._consume_task = None
+ self.watched_users = defaultdict(dict)
+ self.message_queue = defaultdict(lambda: defaultdict(deque))
+ self.consumption_queue = None
+ self.retries = 5
+ self.retry_delay = 10
+ self.channel = None
+ self.webhook = None
+ self.message_history = [None, None, 0]
+
+ self._start = self.bot.loop.create_task(self.start_watchchannel())
+
+ @property
+ def consuming_messages(self) -> bool:
+ """Checks if a consumption task is currently running."""
+
+ if self._consume_task is None:
+ return False
+
+ if self._consume_task.done():
+ exc = self._consume_task.exception()
+ if exc:
+ log.exception(
+ f"{self.__class__.__name__} consume task has failed with:",
+ exc_info=exc
+ )
+ return False
+
+ return True
+
+ # @Cog.listener()
+ async def start_watchchannel(self) -> None:
+ """Retrieves watched users from the API."""
+
+ await self.bot.wait_until_ready()
+
+ if await self.initialize_channel() and await self.fetch_user_cache():
+ log.trace(f"Started the {self.__class__.__name__} WatchChannel")
+ else:
+ log.error(f"Failed to start the {self.__class__.__name__} WatchChannel")
+
+ async def initialize_channel(self) -> bool:
+ """
+ Checks if channel and webhook are set; if not, tries to initialize them.
+
+ Since the internal channel cache may not be available directly after `ready`,
+ this function will retry to get the channel a number of times. If both the
+ channel and webhook were initialized successfully. this function will return
+ `True`.
+ """
+
+ if self.channel is None:
+ for attempt in range(1, self.retries + 1):
+ self.channel = self.bot.get_channel(self.destination)
+
+ if self.channel is None:
+ log.error(f"Failed to get the {self.__class__.__name__} channel; cannot watch users")
+ if attempt < self.initialization_retries:
+ log.error(
+ f"Attempt {attempt}/{self.retries}; "
+ f"Retrying in {self.retry_delay} seconds...")
+ await asyncio.sleep(self.retry_delay)
+ else:
+ log.trace(f"Retrieved the TextChannel for {self.__class__.__name__}")
+ break
+ else:
+ log.error(f"Cannot find channel with id `{self.destination}`; cannot watch users")
+ return False
+
+ if self.webhook is None:
+ self.webhook = await self.bot.get_webhook_info(self.webhook_id) # This is `fetch_webhook` in current
+ if self.webhook is None:
+ log.error(f"Cannot find webhook with id `{self.webhook_id}`; cannot watch users")
+ return False
+ log.trace(f"Retrieved the webhook for {self.__class__.__name__}")
+
+ log.trace(f"WatchChannel for {self.__class__.__name__} is fully initialized")
+ return True
+
+ async def fetch_user_cache(self) -> bool:
+ """
+ Fetches watched users from the API and updates the watched user cache accordingly.
+
+ This function returns `True` if the update succeeded.
+ """
+
+ try:
+ data = await self.bot.api_client.get(
+ self.api_endpoint,
+ params=self.api_default_params
+ )
+ except aiohttp.ClientResponseError as e:
+ log.exception(
+ f"Failed to fetch {self.__class__.__name__} watched users from API",
+ exc_info=e
+ )
+ return False
+
+ self.watched_users = defaultdict(dict)
+
+ for entry in data:
+ user_id = entry.pop('user')
+ self.watched_users[user_id] = entry
+
+ return True
+
+ # @Cog.listener()
+ async def on_message(self, msg: Message):
+ """Queues up messages sent by watched users."""
+
+ if msg.author.id in self.watched_users:
+ if not self.consuming_messages:
+ self._consume_task = self.bot.loop.create_task(self.consume_messages())
+
+ log.trace(f"Received message: {msg.content} ({len(msg.attachments)} attachments)")
+ self.message_queue[msg.author.id][msg.channel.id].append(msg)
+
+ async def consume_messages(self, delay_consumption: bool = True):
+ """Consumes the message queues to log watched users' messages."""
+
+ if delay_consumption:
+ log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue")
+ await asyncio.sleep(1)
+
+ log.trace(f"{self.__class__.__name__} started consuming the message queue")
+
+ # Prevent losing a partly processed consumption queue after Task failure
+ if not self.consumption_queue:
+ self.consumption_queue = self.message_queue.copy()
+ self.message_queue.clear()
+
+ for user_channel_queues in self.consumption_queue.values():
+ for channel_queue in user_channel_queues.values():
+ while channel_queue:
+ msg = channel_queue.popleft()
+
+ log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)")
+ await self.relay_message(msg)
+
+ self.consumption_queue.clear()
+
+ if self.message_queue:
+ log.trace("Channel queue not empty: Continuing consuming queues")
+ self._consume_task = self.bot.loop.create_task(
+ self.consume_messages(delay_consumption=False)
+ )
+ else:
+ log.trace("Done consuming messages.")
+
+ async def webhook_send(
+ self, content: Optional[str] = None, username: Optional[str] = None,
+ avatar_url: Optional[str] = None, embed: Optional[Embed] = None,
+ ):
+ """Sends a message to the webhook with the specified kwargs."""
+
+ try:
+ await self.webhook.send(content=content, username=username, avatar_url=avatar_url, embed=embed)
+ except discord.HTTPException as exc:
+ log.exception(
+ f"Failed to send message to {self.__class__.__name__} webhook",
+ exc_info=exc
+ )
+
+ async def relay_message(self, msg: Message) -> None:
+ """Relays the message to the relevant WatchChannel"""
+
+ last_author, last_channel, count = self.message_history
+ limit = BigBrotherConfig.header_message_limit
+
+ if msg.author.id != last_author or msg.channel.id != last_channel or count >= limit:
+ self.message_history = [msg.author.id, msg.channel.id, 0]
+
+ await self.send_header(msg)
+
+ cleaned_content = msg.clean_content
+
+ if cleaned_content:
+ media_urls = {embed.url for embed in msg.embeds if embed.type in ("image", "video")}
+ for url in URL_RE.findall(cleaned_content):
+ if url not in media_urls:
+ cleaned_content = cleaned_content.replace(url, f"`{url}`")
+ await self.webhook_send(
+ cleaned_content,
+ username=msg.author.display_name,
+ avatar_url=msg.author.avatar_url
+ )
+
+ if msg.attachments:
+ try:
+ await messages.send_attachments(msg, self.webhook)
+ except (errors.Forbidden, errors.NotFound):
+ e = Embed(
+ description=":x: **This message contained an attachment, but it could not be retrieved**",
+ color=Color.red()
+ )
+ await self.webhook_send(
+ embed=e,
+ username=msg.author.display_name,
+ avatar_url=msg.author.avatar_url
+ )
+ except discord.HTTPException as exc:
+ log.exception(
+ f"Failed to send an attachment to {self.__class__.__name__} webhook",
+ exc_info=exc
+ )
+
+ self.message_history[2] += 1
+
+ async def send_header(self, msg):
+ """Sends an header embed to the WatchChannel"""
+
+ user_id = msg.author.id
+
+ guild = self.bot.get_guild(GuildConfig.id)
+ actor = guild.get_member(self.watched_users[user_id]['actor'])
+ actor = actor.display_name if actor else self.watched_users[user_id]['actor']
+
+ inserted_at = self.watched_users[user_id]['inserted_at']
+ date_time = datetime.datetime.strptime(
+ inserted_at,
+ "%Y-%m-%dT%H:%M:%S.%fZ"
+ ).replace(tzinfo=None)
+ time_delta = time_since(date_time, precision="minutes", max_units=1)
+
+ reason = self.watched_users[user_id]['reason']
+
+ embed = Embed(description=(
+ f"{msg.author.mention} in [#{msg.channel.name}]({msg.jump_url})\n"
+ ))
+ embed.set_footer(text=(
+ f"Added {time_delta} by {actor} | "
+ f"Reason: {reason}"
+ ))
+ await self.webhook_send(
+ embed=embed,
+ username=msg.author.display_name,
+ avatar_url=msg.author.avatar_url
+ )
+
+ async def list_watched_users(self, ctx: Context, update_cache: bool = False) -> None:
+ """
+ Gives an overview of the watched user list for this channel.
+
+ The optional kwarg `update_cache` specifies whether the cache should
+ be refreshed by polling the API.
+ """
+
+ if update_cache:
+ if not await self.fetch_user_cache():
+ e = Embed(
+ description=f":x: **Failed to update {self.__class__.__name__} user cache**",
+ color=Color.red()
+ )
+ ctx.send(embed=e)
+
+ lines = []
+ for user_id, user_data in self.watched_users.items():
+ inserted_at = user_data['inserted_at']
+ date_time = datetime.datetime.strptime(
+ inserted_at,
+ "%Y-%m-%dT%H:%M:%S.%fZ"
+ ).replace(tzinfo=None)
+ time_delta = time_since(date_time, precision="minutes", max_units=1)
+ lines.append(f"• <@{user_id}> (added {time_delta})")
+
+ await LinePaginator.paginate(
+ lines or ("There's nothing here yet.",),
+ ctx,
+ Embed(
+ title=f"{self.__class__.__name__} watched users ({'updated' if update_cache else 'cached'})",
+ color=Color.blue()
+ ),
+ empty=False
+ )
+
+ def cog_unload(self):
+ """Takes care of unloading the cog and cancelling the consumption task."""
+
+ log.trace(f"Unloading {self.__class__._name__} cog")
+ if not self._consume_task.done():
+ self._consume_task.cancel()
+ try:
+ self._consume_task.result()
+ except asyncio.CancelledError as e:
+ log.exception(
+ f"The {self.__class__._name__} consume task was cancelled. Messages may be lost.",
+ exc_info=e
+ )