diff options
| -rw-r--r-- | Pipfile.lock | 80 | ||||
| -rw-r--r-- | bot/__main__.py | 21 | ||||
| -rw-r--r-- | bot/bot.py | 49 | ||||
| -rw-r--r-- | bot/constants.py | 25 | ||||
| -rw-r--r-- | bot/exts/filters/antispam.py | 3 | ||||
| -rw-r--r-- | bot/exts/filters/filtering.py | 40 | ||||
| -rw-r--r-- | bot/exts/fun/duck_pond.py | 101 | ||||
| -rw-r--r-- | bot/exts/help_channels.py | 2 | ||||
| -rw-r--r-- | bot/exts/moderation/dm_relay.py | 2 | ||||
| -rw-r--r-- | bot/exts/moderation/verification.py | 2 | ||||
| -rw-r--r-- | bot/rules/__init__.py | 1 | ||||
| -rw-r--r-- | bot/rules/everyone_ping.py | 41 | ||||
| -rw-r--r-- | bot/utils/__init__.py | 3 | ||||
| -rw-r--r-- | bot/utils/redis_cache.py | 414 | ||||
| -rw-r--r-- | config-default.yml | 84 | ||||
| -rw-r--r-- | tests/bot/exts/fun/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/fun/test_duck_pond.py | 548 | ||||
| -rw-r--r-- | tests/bot/utils/test_redis_cache.py | 265 | ||||
| -rw-r--r-- | tests/helpers.py | 6 |
19 files changed, 224 insertions, 1463 deletions
diff --git a/Pipfile.lock b/Pipfile.lock index acda79f11..f75852081 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -135,44 +135,36 @@ }, "cffi": { "hashes": [ - "sha256:005f2bfe11b6745d726dbb07ace4d53f057de66e336ff92d61b8c7e9c8f4777d", - "sha256:09e96138280241bd355cd585148dec04dbbedb4f46128f340d696eaafc82dd7b", - "sha256:0b1ad452cc824665ddc682400b62c9e4f5b64736a2ba99110712fdee5f2505c4", - "sha256:0ef488305fdce2580c8b2708f22d7785ae222d9825d3094ab073e22e93dfe51f", - "sha256:15f351bed09897fbda218e4db5a3d5c06328862f6198d4fb385f3e14e19decb3", - "sha256:22399ff4870fb4c7ef19fff6eeb20a8bbf15571913c181c78cb361024d574579", - "sha256:23e5d2040367322824605bc29ae8ee9175200b92cb5483ac7d466927a9b3d537", - "sha256:2791f68edc5749024b4722500e86303a10d342527e1e3bcac47f35fbd25b764e", - "sha256:2f9674623ca39c9ebe38afa3da402e9326c245f0f5ceff0623dccdac15023e05", - "sha256:3363e77a6176afb8823b6e06db78c46dbc4c7813b00a41300a4873b6ba63b171", - "sha256:33c6cdc071ba5cd6d96769c8969a0531be2d08c2628a0143a10a7dcffa9719ca", - "sha256:3b8eaf915ddc0709779889c472e553f0d3e8b7bdf62dab764c8921b09bf94522", - "sha256:3cb3e1b9ec43256c4e0f8d2837267a70b0e1ca8c4f456685508ae6106b1f504c", - "sha256:3eeeb0405fd145e714f7633a5173318bd88d8bbfc3dd0a5751f8c4f70ae629bc", - "sha256:44f60519595eaca110f248e5017363d751b12782a6f2bd6a7041cba275215f5d", - "sha256:4d7c26bfc1ea9f92084a1d75e11999e97b62d63128bcc90c3624d07813c52808", - "sha256:529c4ed2e10437c205f38f3691a68be66c39197d01062618c55f74294a4a4828", - "sha256:6642f15ad963b5092d65aed022d033c77763515fdc07095208f15d3563003869", - "sha256:85ba797e1de5b48aa5a8427b6ba62cf69607c18c5d4eb747604b7302f1ec382d", - "sha256:8f0f1e499e4000c4c347a124fa6a27d37608ced4fe9f7d45070563b7c4c370c9", - "sha256:a624fae282e81ad2e4871bdb767e2c914d0539708c0f078b5b355258293c98b0", - "sha256:b0358e6fefc74a16f745afa366acc89f979040e0cbc4eec55ab26ad1f6a9bfbc", - "sha256:bbd2f4dfee1079f76943767fce837ade3087b578aeb9f69aec7857d5bf25db15", - "sha256:bf39a9e19ce7298f1bd6a9758fa99707e9e5b1ebe5e90f2c3913a47bc548747c", - "sha256:c11579638288e53fc94ad60022ff1b67865363e730ee41ad5e6f0a17188b327a", - "sha256:c150eaa3dadbb2b5339675b88d4573c1be3cb6f2c33a6c83387e10cc0bf05bd3", - "sha256:c53af463f4a40de78c58b8b2710ade243c81cbca641e34debf3396a9640d6ec1", - "sha256:cb763ceceae04803adcc4e2d80d611ef201c73da32d8f2722e9d0ab0c7f10768", - "sha256:cc75f58cdaf043fe6a7a6c04b3b5a0e694c6a9e24050967747251fb80d7bce0d", - "sha256:d80998ed59176e8cba74028762fbd9b9153b9afc71ea118e63bbf5d4d0f9552b", - "sha256:de31b5164d44ef4943db155b3e8e17929707cac1e5bd2f363e67a56e3af4af6e", - "sha256:e66399cf0fc07de4dce4f588fc25bfe84a6d1285cc544e67987d22663393926d", - "sha256:f0620511387790860b249b9241c2f13c3a80e21a73e0b861a2df24e9d6f56730", - "sha256:f4eae045e6ab2bb54ca279733fe4eb85f1effda392666308250714e01907f394", - "sha256:f92cdecb618e5fa4658aeb97d5eb3d2f47aa94ac6477c6daf0f306c5a3b9e6b1", - "sha256:f92f789e4f9241cd262ad7a555ca2c648a98178a953af117ef7fad46aa1d5591" - ], - "version": "==1.14.3" + "sha256:0da50dcbccd7cb7e6c741ab7912b2eff48e85af217d72b57f80ebc616257125e", + "sha256:12a453e03124069b6896107ee133ae3ab04c624bb10683e1ed1c1663df17c13c", + "sha256:15419020b0e812b40d96ec9d369b2bc8109cc3295eac6e013d3261343580cc7e", + "sha256:15a5f59a4808f82d8ec7364cbace851df591c2d43bc76bcbe5c4543a7ddd1bf1", + "sha256:23e44937d7695c27c66a54d793dd4b45889a81b35c0751ba91040fe825ec59c4", + "sha256:29c4688ace466a365b85a51dcc5e3c853c1d283f293dfcc12f7a77e498f160d2", + "sha256:57214fa5430399dffd54f4be37b56fe22cedb2b98862550d43cc085fb698dc2c", + "sha256:577791f948d34d569acb2d1add5831731c59d5a0c50a6d9f629ae1cefd9ca4a0", + "sha256:6539314d84c4d36f28d73adc1b45e9f4ee2a89cdc7e5d2b0a6dbacba31906798", + "sha256:65867d63f0fd1b500fa343d7798fa64e9e681b594e0a07dc934c13e76ee28fb1", + "sha256:672b539db20fef6b03d6f7a14b5825d57c98e4026401fce838849f8de73fe4d4", + "sha256:6843db0343e12e3f52cc58430ad559d850a53684f5b352540ca3f1bc56df0731", + "sha256:7057613efefd36cacabbdbcef010e0a9c20a88fc07eb3e616019ea1692fa5df4", + "sha256:76ada88d62eb24de7051c5157a1a78fd853cca9b91c0713c2e973e4196271d0c", + "sha256:837398c2ec00228679513802e3744d1e8e3cb1204aa6ad408b6aff081e99a487", + "sha256:8662aabfeab00cea149a3d1c2999b0731e70c6b5bac596d95d13f643e76d3d4e", + "sha256:95e9094162fa712f18b4f60896e34b621df99147c2cee216cfa8f022294e8e9f", + "sha256:99cc66b33c418cd579c0f03b77b94263c305c389cb0c6972dac420f24b3bf123", + "sha256:9b219511d8b64d3fa14261963933be34028ea0e57455baf6781fe399c2c3206c", + "sha256:ae8f34d50af2c2154035984b8b5fc5d9ed63f32fe615646ab435b05b132ca91b", + "sha256:b9aa9d8818c2e917fa2c105ad538e222a5bce59777133840b93134022a7ce650", + "sha256:bf44a9a0141a082e89c90e8d785b212a872db793a0080c20f6ae6e2a0ebf82ad", + "sha256:c0b48b98d79cf795b0916c57bebbc6d16bb43b9fc9b8c9f57f4cf05881904c75", + "sha256:da9d3c506f43e220336433dffe643fbfa40096d408cb9b7f2477892f369d5f82", + "sha256:e4082d832e36e7f9b2278bc774886ca8207346b99f278e54c9de4834f17232f7", + "sha256:e4b9b7af398c32e408c00eb4e0d33ced2f9121fd9fb978e6c1b57edd014a7d15", + "sha256:e613514a82539fc48291d01933951a13ae93b6b444a88782480be32245ed4afa", + "sha256:f5033952def24172e60493b68717792e3aebb387a8d186c43c020d9363ee7281" + ], + "version": "==1.14.2" }, "chardet": { "hashes": [ @@ -583,11 +575,11 @@ }, "sentry-sdk": { "hashes": [ - "sha256:96a0e494b243a81065ec7ab73457d16719fb955ed9e469c8e4577ba737bc836e", - "sha256:a698993f3abbe06e88e8a3c8b61c8a79c12f62e503f1a23eda30c3921f0525a9" + "sha256:1a086486ff9da15791f294f6e9915eb3747d161ef64dee2d038a4d0b4a369b24", + "sha256:45486deb031cea6bbb25a540d7adb4dd48cd8a1cc31e6a5ce9fb4f792a572e9a" ], "index": "pypi", - "version": "==0.17.7" + "version": "==0.17.6" }, "six": { "hashes": [ @@ -865,11 +857,11 @@ }, "identify": { "hashes": [ - "sha256:d7da7de6825568daa4449858ce328ecc0e1ada2554d972a6f4f90e736aaf499a", - "sha256:e4db4796b3b0cf4f9cb921da51430abffff2d4ba7d7c521184ed5252bd90d461" + "sha256:c770074ae1f19e08aadbda1c886bc6d0cb55ffdc503a8c0fe8699af2fc9664ae", + "sha256:d02d004568c5a01261839a05e91705e3e9f5c57a3551648f9b3fb2b9c62c0f62" ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==1.5.4" + "version": "==1.5.3" }, "mccabe": { "hashes": [ diff --git a/bot/__main__.py b/bot/__main__.py index 8770ac31b..a07bc21d6 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,7 +1,9 @@ +import asyncio import logging import discord import sentry_sdk +from async_rediscache import RedisSession from discord.ext.commands import when_mentioned_or from sentry_sdk.integrations.aiohttp import AioHttpIntegration from sentry_sdk.integrations.logging import LoggingIntegration @@ -26,9 +28,28 @@ sentry_sdk.init( ] ) +# Create the redis session instance. +redis_session = RedisSession( + address=(constants.Redis.host, constants.Redis.port), + password=constants.Redis.password, + minsize=1, + maxsize=20, + use_fakeredis=constants.Redis.use_fakeredis, + global_namespace="bot", +) + +# Connect redis session to ensure it's connected before we try to access Redis +# from somewhere within the bot. We create the event loop in the same way +# discord.py normally does and pass it to the bot's __init__. +loop = asyncio.get_event_loop() +loop.run_until_complete(redis_session.connect()) + + # Instantiate the bot. allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] bot = Bot( + redis_session=redis_session, + loop=loop, command_prefix=when_mentioned_or(constants.Bot.prefix), activity=discord.Game(name="Commands: !help"), case_insensitive=True, diff --git a/bot/bot.py b/bot/bot.py index d25074fd9..b2e5237fe 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -6,9 +6,8 @@ from collections import defaultdict from typing import Dict, Optional import aiohttp -import aioredis import discord -import fakeredis.aioredis +from async_rediscache import RedisSession from discord.ext import commands from sentry_sdk import push_scope @@ -21,7 +20,7 @@ log = logging.getLogger('bot') class Bot(commands.Bot): """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, redis_session: RedisSession, **kwargs): if "connector" in kwargs: warnings.warn( "If login() is called (or the bot is started), the connector will be overwritten " @@ -31,9 +30,7 @@ class Bot(commands.Bot): super().__init__(*args, **kwargs) self.http_session: Optional[aiohttp.ClientSession] = None - self.redis_session: Optional[aioredis.Redis] = None - self.redis_ready = asyncio.Event() - self.redis_closed = False + self.redis_session = redis_session self.api_client = api.APIClient(loop=self.loop) self.filter_list_cache = defaultdict(dict) @@ -58,30 +55,6 @@ class Bot(commands.Bot): for item in full_cache: self.insert_item_into_filter_list_cache(item) - async def _create_redis_session(self) -> None: - """ - Create the Redis connection pool, and then open the redis event gate. - - If constants.Redis.use_fakeredis is True, we'll set up a fake redis pool instead - of attempting to communicate with a real Redis server. This is useful because it - means contributors don't necessarily need to get Redis running locally just - to run the bot. - - The fakeredis cache won't have persistence across restarts, but that - usually won't matter for local bot testing. - """ - if constants.Redis.use_fakeredis: - log.info("Using fakeredis instead of communicating with a real Redis server.") - self.redis_session = await fakeredis.aioredis.create_redis_pool() - else: - self.redis_session = await aioredis.create_redis_pool( - address=(constants.Redis.host, constants.Redis.port), - password=constants.Redis.password, - ) - - self.redis_closed = False - self.redis_ready.set() - def _recreate(self) -> None: """Re-create the connector, aiohttp session, the APIClient and the Redis session.""" # Use asyncio for DNS resolution instead of threads so threads aren't spammed. @@ -94,13 +67,10 @@ class Bot(commands.Bot): "The previous connector was not closed; it will remain open and be overwritten" ) - if self.redis_session and not self.redis_session.closed: - log.warning( - "The previous redis pool was not closed; it will remain open and be overwritten" - ) - - # Create the redis session - self.loop.create_task(self._create_redis_session()) + if self.redis_session.closed: + # If the RedisSession was somehow closed, we try to reconnect it + # here. Normally, this shouldn't happen. + self.loop.create_task(self.redis_session.connect()) # Use AF_INET as its socket family to prevent HTTPS related problems both locally # and in production. @@ -180,10 +150,7 @@ class Bot(commands.Bot): self.stats._transport.close() if self.redis_session: - self.redis_closed = True - self.redis_session.close() - self.redis_ready.clear() - await self.redis_session.wait_closed() + await self.redis_session.close() def insert_item_into_filter_list_cache(self, item: Dict[str, str]) -> None: """Add an item to the bots filter_list_cache.""" diff --git a/bot/constants.py b/bot/constants.py index 17f14fec0..0cb076d5c 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -217,6 +217,7 @@ class Filter(metaclass=YAMLGetter): filter_zalgo: bool filter_invites: bool filter_domains: bool + filter_everyone_ping: bool watch_regex: bool watch_rich_embeds: bool @@ -224,6 +225,7 @@ class Filter(metaclass=YAMLGetter): notify_user_zalgo: bool notify_user_invites: bool notify_user_domains: bool + notify_user_everyone_ping: bool ping_everyone: bool offensive_msg_delete_days: int @@ -252,7 +254,7 @@ class DuckPond(metaclass=YAMLGetter): section = "duck_pond" threshold: int - custom_emojis: List[int] + channel_blacklist: List[int] class Emojis(metaclass=YAMLGetter): @@ -292,20 +294,6 @@ class Emojis(metaclass=YAMLGetter): cross_mark: str check_mark: str - ducky_yellow: int - ducky_blurple: int - ducky_regal: int - ducky_camo: int - ducky_ninja: int - ducky_devil: int - ducky_tube: int - ducky_hunt: int - ducky_wizard: int - ducky_party: int - ducky_angel: int - ducky_maul: int - ducky_santa: int - upvotes: str comments: str user: str @@ -395,12 +383,14 @@ class Channels(metaclass=YAMLGetter): section = "guild" subsection = "channels" + admin_announcements: int admin_spam: int admins: int announcements: int attachment_log: int big_brother_logs: int bot_commands: int + change_log: int cooldown: int defcon: int dev_contrib: int @@ -412,9 +402,11 @@ class Channels(metaclass=YAMLGetter): how_to_get_help: int incidents: int incidents_archive: int + mailing_lists: int message_log: int meta: int mod_alerts: int + mod_announcements: int mod_log: int mod_spam: int mods: int @@ -423,7 +415,10 @@ class Channels(metaclass=YAMLGetter): off_topic_2: int organisation: int python_discussion: int + python_events: int + python_news: int reddit: int + staff_announcements: int talent_pool: int user_event_announcements: int user_log: int diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index 2e7e32d9a..f2a2689e1 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -36,9 +36,6 @@ RULE_FUNCTION_MAPPING = { 'mentions': rules.apply_mentions, 'newlines': rules.apply_newlines, 'role_mentions': rules.apply_role_mentions, - # the everyone filter is temporarily disabled until - # it has been improved. - # 'everyone_ping': rules.apply_everyone_ping, } diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index 2751ed7f6..b7eb41244 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -6,6 +6,7 @@ from typing import List, Mapping, Optional, Tuple, Union import dateutil import discord.errors +from async_rediscache import RedisCache from dateutil.relativedelta import relativedelta from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel from discord.ext.commands import Cog @@ -14,17 +15,22 @@ from discord.utils import escape_markdown from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import ( - Channels, Colours, - Filter, Icons, URLs + Channels, Colours, Filter, + Guild, Icons, URLs ) from bot.exts.moderation.modlog import ModLog -from bot.utils.redis_cache import RedisCache from bot.utils.regex import INVITE_RE from bot.utils.scheduling import Scheduler log = logging.getLogger(__name__) # Regular expressions +CODE_BLOCK_RE = re.compile( + r"(?P<delim>``?)[^`]+?(?P=delim)(?!`+)" # Inline codeblock + r"|```(.+?)```", # Multiline codeblock + re.DOTALL | re.MULTILINE +) +EVERYONE_PING_RE = re.compile(rf"@everyone|<@&{Guild.id}>|@here") SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) ZALGO_RE = re.compile(r"[\u0300-\u036F\u0489]") @@ -82,6 +88,19 @@ class Filtering(Cog): ), "schedule_deletion": False }, + "filter_everyone_ping": { + "enabled": Filter.filter_everyone_ping, + "function": self._has_everyone_ping, + "type": "filter", + "content_only": True, + "user_notification": Filter.notify_user_everyone_ping, + "notification_msg": ( + "Please don't try to ping `@everyone` or `@here`. " + f"Your message has been removed. {staff_mistake_str}" + ), + "schedule_deletion": False, + "ping_everyone": False + }, "watch_regex": { "enabled": Filter.watch_regex, "function": self._has_watch_regex_match, @@ -332,6 +351,9 @@ class Filtering(Cog): log.debug(message) + # Allow specific filters to override ping_everyone + ping_everyone = Filter.ping_everyone and _filter.get("ping_everyone", True) + # Send pretty mod log embed to mod-alerts await self.mod_log.send_log_message( icon_url=Icons.filtering, @@ -340,7 +362,7 @@ class Filtering(Cog): text=message, thumbnail=msg.author.avatar_url_as(static_format="png"), channel_id=Channels.mod_alerts, - ping_everyone=Filter.ping_everyone if not is_private else False, + ping_everyone=ping_everyone if not is_private else False, additional_embeds=additional_embeds, additional_embeds_msg=additional_embeds_msg ) @@ -528,6 +550,16 @@ class Filtering(Cog): return False return False + @staticmethod + async def _has_everyone_ping(text: str) -> bool: + """Determines if `msg` contains an @everyone or @here ping outside of a codeblock.""" + # First pass to avoid running re.sub on every message + if not EVERYONE_PING_RE.search(text): + return False + + content_without_codeblocks = CODE_BLOCK_RE.sub("", text) + return bool(EVERYONE_PING_RE.search(content_without_codeblocks)) + async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: """ Notify filtered_member about a moderation action with the reason str. diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py index 7021069fa..6c2d22b9c 100644 --- a/bot/exts/fun/duck_pond.py +++ b/bot/exts/fun/duck_pond.py @@ -1,12 +1,14 @@ +import asyncio import logging from typing import Union import discord from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors -from discord.ext.commands import Cog +from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot +from bot.utils.checks import has_any_role from bot.utils.messages import send_attachments from bot.utils.webhooks import send_webhook @@ -21,6 +23,7 @@ class DuckPond(Cog): self.webhook_id = constants.Webhooks.duck_pond self.webhook = None self.bot.loop.create_task(self.fetch_webhook()) + self.relay_lock = None async def fetch_webhook(self) -> None: """Fetches the webhook object, so we can post to it.""" @@ -49,32 +52,32 @@ class DuckPond(Cog): return True return False + @staticmethod + def _is_duck_emoji(emoji: Union[str, discord.PartialEmoji, discord.Emoji]) -> bool: + """Check if the emoji is a valid duck emoji.""" + if isinstance(emoji, str): + return emoji == "🦆" + else: + return hasattr(emoji, "name") and emoji.name.startswith("ducky_") + async def count_ducks(self, message: Message) -> int: """ Count the number of ducks in the reactions of a specific message. Only counts ducks added by staff members. """ - duck_count = 0 - duck_reactors = [] + duck_reactors = set() + # iterate over all reactions for reaction in message.reactions: - async for user in reaction.users(): - - # Is the user a staff member and not already counted as reactor? - if not self.is_staff(user) or user.id in duck_reactors: - continue - - # Is the emoji a duck? - if hasattr(reaction.emoji, "id"): - if reaction.emoji.id in constants.DuckPond.custom_emojis: - duck_count += 1 - duck_reactors.append(user.id) - elif isinstance(reaction.emoji, str): - if reaction.emoji == "🦆": - duck_count += 1 - duck_reactors.append(user.id) - return duck_count + # check if the current reaction is a duck + if not self._is_duck_emoji(reaction.emoji): + continue + + # update the set of reactors with all staff reactors + duck_reactors |= {user.id async for user in reaction.users() if self.is_staff(user)} + + return len(duck_reactors) async def relay_message(self, message: Message) -> None: """Relays the message's content and attachments to the duck pond channel.""" @@ -103,18 +106,35 @@ class DuckPond(Cog): except discord.HTTPException: log.exception("Failed to send an attachment to the webhook") - await message.add_reaction("✅") + async def locked_relay(self, message: discord.Message) -> bool: + """Relay a message after obtaining the relay lock.""" + if self.relay_lock is None: + # Lazily load the lock to ensure it's created within the + # appropriate event loop. + self.relay_lock = asyncio.Lock() - @staticmethod - def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: + async with self.relay_lock: + # check if the message has a checkmark after acquiring the lock + if await self.has_green_checkmark(message): + return False + + # relay the message + await self.relay_message(message) + + # add a green checkmark to indicate that the message was relayed + await message.add_reaction("✅") + return True + + def _payload_has_duckpond_emoji(self, emoji: discord.PartialEmoji) -> bool: """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" - if payload.emoji.is_custom_emoji(): - if payload.emoji.id in constants.DuckPond.custom_emojis: - return True - elif payload.emoji.name == "🦆": - return True + if emoji.is_unicode_emoji(): + # For unicode PartialEmojis, the `name` attribute is just the string + # representation of the emoji. This is what the helper method + # expects, as unicode emojis show up as just a `str` instance when + # inspecting the reactions attached to a message. + emoji = emoji.name - return False + return self._is_duck_emoji(emoji) @Cog.listener() async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: @@ -125,20 +145,24 @@ class DuckPond(Cog): amount of ducks specified in the config under duck_pond/threshold, it will send the message off to the duck pond. """ + # Was this reaction issued in a blacklisted channel? + if payload.channel_id in constants.DuckPond.channel_blacklist: + return + # Is the emoji in the reaction a duck? - if not self._payload_has_duckpond_emoji(payload): + if not self._payload_has_duckpond_emoji(payload.emoji): return channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) message = await channel.fetch_message(payload.message_id) member = discord.utils.get(message.guild.members, id=payload.user_id) - # Is the member a human and a staff member? - if not self.is_staff(member) or member.bot: + # Was the message sent by a human staff member? + if not self.is_staff(message.author) or message.author.bot: return - # Does the message already have a green checkmark? - if await self.has_green_checkmark(message): + # Is the reactor a human staff member? + if not self.is_staff(member) or member.bot: return # Time to count our ducks! @@ -146,7 +170,7 @@ class DuckPond(Cog): # If we've got more than the required amount of ducks, send the message to the duck_pond. if duck_count >= constants.DuckPond.threshold: - await self.relay_message(message) + await self.locked_relay(message) @Cog.listener() async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: @@ -160,6 +184,15 @@ class DuckPond(Cog): if duck_count >= constants.DuckPond.threshold: await message.add_reaction("✅") + @command(name="duckify", aliases=("duckpond", "pondify")) + @has_any_role(constants.Roles.admins) + async def duckify(self, ctx: Context, message: discord.Message) -> None: + """Relay a message to the duckpond, no ducks required!""" + if await self.locked_relay(message): + await ctx.message.add_reaction("🦆") + else: + await ctx.message.add_reaction("❌") + def setup(bot: Bot) -> None: """Load the DuckPond cog.""" diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py index 17142071f..9e33a6aba 100644 --- a/bot/exts/help_channels.py +++ b/bot/exts/help_channels.py @@ -9,11 +9,11 @@ from pathlib import Path import discord import discord.abc +from async_rediscache import RedisCache from discord.ext import commands from bot import constants from bot.bot import Bot -from bot.utils import RedisCache from bot.utils.scheduling import Scheduler log = logging.getLogger(__name__) diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py index 7a3fe49bb..14263e004 100644 --- a/bot/exts/moderation/dm_relay.py +++ b/bot/exts/moderation/dm_relay.py @@ -2,6 +2,7 @@ import logging from typing import Optional import discord +from async_rediscache import RedisCache from discord import Color from discord.ext import commands from discord.ext.commands import Cog @@ -9,7 +10,6 @@ from discord.ext.commands import Cog from bot import constants from bot.bot import Bot from bot.converters import UserMentionOrID -from bot.utils import RedisCache from bot.utils.checks import in_whitelist_check from bot.utils.messages import send_attachments from bot.utils.webhooks import send_webhook diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py index 8ec68ac1e..6dad82d1e 100644 --- a/bot/exts/moderation/verification.py +++ b/bot/exts/moderation/verification.py @@ -5,6 +5,7 @@ from contextlib import suppress from datetime import datetime, timedelta import discord +from async_rediscache import RedisCache from discord.ext import tasks from discord.ext.commands import Cog, Context, command, group, has_any_role from discord.utils import snowflake_time @@ -14,7 +15,6 @@ from bot.bot import Bot from bot.decorators import has_no_roles, in_whitelist from bot.exts.moderation.modlog import ModLog from bot.utils.checks import InWhitelistCheckFailure, has_no_roles_check -from bot.utils.redis_cache import RedisCache log = logging.getLogger(__name__) diff --git a/bot/rules/__init__.py b/bot/rules/__init__.py index 8a69cadee..a01ceae73 100644 --- a/bot/rules/__init__.py +++ b/bot/rules/__init__.py @@ -10,4 +10,3 @@ from .links import apply as apply_links from .mentions import apply as apply_mentions from .newlines import apply as apply_newlines from .role_mentions import apply as apply_role_mentions -from .everyone_ping import apply as apply_everyone_ping diff --git a/bot/rules/everyone_ping.py b/bot/rules/everyone_ping.py deleted file mode 100644 index 89d9fe570..000000000 --- a/bot/rules/everyone_ping.py +++ /dev/null @@ -1,41 +0,0 @@ -import random -import re -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Embed, Member, Message - -from bot.constants import Colours, Guild, NEGATIVE_REPLIES - -# Generate regex for checking for pings: -guild_id = Guild.id -EVERYONE_RE_INLINE_CODE = re.compile(rf"^(?!`).*@everyone.*(?!`)$|^(?!`).*<@&{guild_id}>.*(?!`)$") -EVERYONE_RE_MULTILINE_CODE = re.compile(rf"^(?!```).*@everyone.*(?!```)$|^(?!```).*<@&{guild_id}>.*(?!```)$") - - -async def apply( - last_message: Message, - recent_messages: List[Message], - config: Dict[str, int], -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - """Detects if a user has sent an '@everyone' ping.""" - relevant_messages = tuple(msg for msg in recent_messages if msg.author == last_message.author) - - everyone_messages_count = 0 - for msg in relevant_messages: - num_everyone_pings_inline = len(re.findall(EVERYONE_RE_INLINE_CODE, msg.content)) - num_everyone_pings_multiline = len(re.findall(EVERYONE_RE_MULTILINE_CODE, msg.content)) - if num_everyone_pings_inline and num_everyone_pings_multiline: - everyone_messages_count += 1 - - if everyone_messages_count > config["max"]: - # Send the channel an embed giving the user more info: - embed_text = f"Please don't try to ping {last_message.guild.member_count:,} people." - embed = Embed(title=random.choice(NEGATIVE_REPLIES), description=embed_text, colour=Colours.soft_red) - await last_message.channel.send(embed=embed) - - return ( - "pinged the everyone role", - (last_message.author,), - relevant_messages, - ) - return None diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 3e93fcb06..60170a88f 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,5 +1,4 @@ from bot.utils.helpers import CogABCMeta, find_nth_occurrence, pad_base64 -from bot.utils.redis_cache import RedisCache from bot.utils.services import send_to_paste_service -__all__ = ['RedisCache', 'CogABCMeta', 'find_nth_occurrence', 'pad_base64', 'send_to_paste_service'] +__all__ = ['CogABCMeta', 'find_nth_occurrence', 'pad_base64', 'send_to_paste_service'] diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py deleted file mode 100644 index 52b689b49..000000000 --- a/bot/utils/redis_cache.py +++ /dev/null @@ -1,414 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging -from functools import partialmethod -from typing import Any, Dict, ItemsView, Optional, Tuple, Union - -from bot.bot import Bot - -log = logging.getLogger(__name__) - -# Type aliases -RedisKeyType = Union[str, int] -RedisValueType = Union[str, int, float, bool] -RedisKeyOrValue = Union[RedisKeyType, RedisValueType] - -# Prefix tuples -_PrefixTuple = Tuple[Tuple[str, Any], ...] -_VALUE_PREFIXES = ( - ("f|", float), - ("i|", int), - ("s|", str), - ("b|", bool), -) -_KEY_PREFIXES = ( - ("i|", int), - ("s|", str), -) - - -class NoBotInstanceError(RuntimeError): - """Raised when RedisCache is created without an available bot instance on the owner class.""" - - -class NoNamespaceError(RuntimeError): - """Raised when RedisCache has no namespace, for example if it is not assigned to a class attribute.""" - - -class NoParentInstanceError(RuntimeError): - """Raised when the parent instance is available, for example if called by accessing the parent class directly.""" - - -class RedisCache: - """ - A simplified interface for a Redis connection. - - We implement several convenient methods that are fairly similar to have a dict - behaves, and should be familiar to Python users. The biggest difference is that - all the public methods in this class are coroutines, and must be awaited. - - Because of limitations in Redis, this cache will only accept strings and integers for keys, - and strings, integers, floats and booleans for values. - - Please note that this class MUST be created as a class attribute, and that that class - must also contain an attribute with an instance of our Bot. See `__get__` and `__set_name__` - for more information about how this works. - - Simple example for how to use this: - - class SomeCog(Cog): - # To initialize a valid RedisCache, just add it as a class attribute here. - # Do not add it to the __init__ method or anywhere else, it MUST be a class - # attribute. Do not pass any parameters. - cache = RedisCache() - - async def my_method(self): - - # Now we're ready to use the RedisCache. - # One thing to note here is that this will not work unless - # we access self.cache through an _instance_ of this class. - # - # For example, attempting to use SomeCog.cache will _not_ work, - # you _must_ instantiate the class first and use that instance. - # - # Now we can store some stuff in the cache just by doing this. - # This data will persist through restarts! - await self.cache.set("key", "value") - - # To get the data, simply do this. - value = await self.cache.get("key") - - # Other methods work more or less like a dictionary. - # Checking if something is in the cache - await self.cache.contains("key") - - # iterating the cache - async for key, value in self.cache.items(): - print(value) - - # We can even iterate in a comprehension! - consumed = [value async for key, value in self.cache.items()] - """ - - _namespaces = [] - - def __init__(self) -> None: - """Initialize the RedisCache.""" - self._namespace = None - self.bot = None - self._increment_lock = None - - def _set_namespace(self, namespace: str) -> None: - """Try to set the namespace, but do not permit collisions.""" - log.trace(f"RedisCache setting namespace to {namespace}") - self._namespaces.append(namespace) - self._namespace = namespace - - @staticmethod - def _to_typestring(key_or_value: RedisKeyOrValue, prefixes: _PrefixTuple) -> str: - """Turn a valid Redis type into a typestring.""" - for prefix, _type in prefixes: - # Convert bools into integers before storing them. - if type(key_or_value) is bool: - bool_int = int(key_or_value) - return f"{prefix}{bool_int}" - - # isinstance is a bad idea here, because isintance(False, int) == True. - if type(key_or_value) is _type: - return f"{prefix}{key_or_value}" - - raise TypeError(f"RedisCache._to_typestring only supports the following: {prefixes}.") - - @staticmethod - def _from_typestring(key_or_value: Union[bytes, str], prefixes: _PrefixTuple) -> RedisKeyOrValue: - """Deserialize a typestring into a valid Redis type.""" - # Stuff that comes out of Redis will be bytestrings, so let's decode those. - if isinstance(key_or_value, bytes): - key_or_value = key_or_value.decode('utf-8') - - # Now we convert our unicode string back into the type it originally was. - for prefix, _type in prefixes: - if key_or_value.startswith(prefix): - - # For booleans, we need special handling because bool("False") is True. - if prefix == "b|": - value = key_or_value[len(prefix):] - return bool(int(value)) - - # Otherwise we can just convert normally. - return _type(key_or_value[len(prefix):]) - raise TypeError(f"RedisCache._from_typestring only supports the following: {prefixes}.") - - # Add some nice partials to call our generic typestring converters. - # These are basically methods that will fill in some of the parameters for you, so that - # any call to _key_to_typestring will be like calling _to_typestring with the two parameters - # at `prefixes` and `types_string` pre-filled. - # - # See https://docs.python.org/3/library/functools.html#functools.partialmethod - _key_to_typestring = partialmethod(_to_typestring, prefixes=_KEY_PREFIXES) - _value_to_typestring = partialmethod(_to_typestring, prefixes=_VALUE_PREFIXES) - _key_from_typestring = partialmethod(_from_typestring, prefixes=_KEY_PREFIXES) - _value_from_typestring = partialmethod(_from_typestring, prefixes=_VALUE_PREFIXES) - - def _dict_from_typestring(self, dictionary: Dict) -> Dict: - """Turns all contents of a dict into valid Redis types.""" - return {self._key_from_typestring(key): self._value_from_typestring(value) for key, value in dictionary.items()} - - def _dict_to_typestring(self, dictionary: Dict) -> Dict: - """Turns all contents of a dict into typestrings.""" - return {self._key_to_typestring(key): self._value_to_typestring(value) for key, value in dictionary.items()} - - async def _validate_cache(self) -> None: - """Validate that the RedisCache is ready to be used.""" - if self._namespace is None: - error_message = ( - "Critical error: RedisCache has no namespace. " - "This object must be initialized as a class attribute." - ) - log.error(error_message) - raise NoNamespaceError(error_message) - - if self.bot is None: - error_message = ( - "Critical error: RedisCache has no `Bot` instance. " - "This happens when the class RedisCache was created in doesn't " - "have a Bot instance. Please make sure that you're instantiating " - "the RedisCache inside a class that has a Bot instance attribute." - ) - log.error(error_message) - raise NoBotInstanceError(error_message) - - if not self.bot.redis_closed: - await self.bot.redis_ready.wait() - - def __set_name__(self, owner: Any, attribute_name: str) -> None: - """ - Set the namespace to Class.attribute_name. - - Called automatically when this class is constructed inside a class as an attribute. - - This class MUST be created as a class attribute in a class, otherwise it will raise - exceptions whenever a method is used. This is because it uses this method to create - a namespace like `MyCog.my_class_attribute` which is used as a hash name when we store - stuff in Redis, to prevent collisions. - """ - self._set_namespace(f"{owner.__name__}.{attribute_name}") - - def __get__(self, instance: RedisCache, owner: Any) -> RedisCache: - """ - This is called if the RedisCache is a class attribute, and is accessed. - - The class this object is instantiated in must contain an attribute with an - instance of Bot. This is because Bot contains our redis_session, which is - the mechanism by which we will communicate with the Redis server. - - Any attempt to use RedisCache in a class that does not have a Bot instance - will fail. It is mostly intended to be used inside of a Cog, although theoretically - it should work in any class that has a Bot instance. - """ - if self.bot: - return self - - if self._namespace is None: - error_message = "RedisCache must be a class attribute." - log.error(error_message) - raise NoNamespaceError(error_message) - - if instance is None: - error_message = ( - "You must access the RedisCache instance through the cog instance " - "before accessing it using the cog's class object." - ) - log.error(error_message) - raise NoParentInstanceError(error_message) - - for attribute in vars(instance).values(): - if isinstance(attribute, Bot): - self.bot = attribute - return self - else: - error_message = ( - "Critical error: RedisCache has no `Bot` instance. " - "This happens when the class RedisCache was created in doesn't " - "have a Bot instance. Please make sure that you're instantiating " - "the RedisCache inside a class that has a Bot instance attribute." - ) - log.error(error_message) - raise NoBotInstanceError(error_message) - - def __repr__(self) -> str: - """Return a beautiful representation of this object instance.""" - return f"RedisCache(namespace={self._namespace!r})" - - async def set(self, key: RedisKeyType, value: RedisValueType) -> None: - """Store an item in the Redis cache.""" - await self._validate_cache() - - # Convert to a typestring and then set it - key = self._key_to_typestring(key) - value = self._value_to_typestring(value) - - log.trace(f"Setting {key} to {value}.") - await self.bot.redis_session.hset(self._namespace, key, value) - - async def get(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> Optional[RedisValueType]: - """Get an item from the Redis cache.""" - await self._validate_cache() - key = self._key_to_typestring(key) - - log.trace(f"Attempting to retrieve {key}.") - value = await self.bot.redis_session.hget(self._namespace, key) - - if value is None: - log.trace(f"Value not found, returning default value {default}") - return default - else: - value = self._value_from_typestring(value) - log.trace(f"Value found, returning value {value}") - return value - - async def delete(self, key: RedisKeyType) -> None: - """ - Delete an item from the Redis cache. - - If we try to delete a key that does not exist, it will simply be ignored. - - See https://redis.io/commands/hdel for more info on how this works. - """ - await self._validate_cache() - key = self._key_to_typestring(key) - - log.trace(f"Attempting to delete {key}.") - return await self.bot.redis_session.hdel(self._namespace, key) - - async def contains(self, key: RedisKeyType) -> bool: - """ - Check if a key exists in the Redis cache. - - Return True if the key exists, otherwise False. - """ - await self._validate_cache() - key = self._key_to_typestring(key) - exists = await self.bot.redis_session.hexists(self._namespace, key) - - log.trace(f"Testing if {key} exists in the RedisCache - Result is {exists}") - return exists - - async def items(self) -> ItemsView: - """ - Fetch all the key/value pairs in the cache. - - Returns a normal ItemsView, like you would get from dict.items(). - - Keep in mind that these items are just a _copy_ of the data in the - RedisCache - any changes you make to them will not be reflected - into the RedisCache itself. If you want to change these, you need - to make a .set call. - - Example: - items = await my_cache.items() - for key, value in items: - # Iterate like a normal dictionary - """ - await self._validate_cache() - items = self._dict_from_typestring( - await self.bot.redis_session.hgetall(self._namespace) - ).items() - - log.trace(f"Retrieving all key/value pairs from cache, total of {len(items)} items.") - return items - - async def length(self) -> int: - """Return the number of items in the Redis cache.""" - await self._validate_cache() - number_of_items = await self.bot.redis_session.hlen(self._namespace) - log.trace(f"Returning length. Result is {number_of_items}.") - return number_of_items - - async def to_dict(self) -> Dict: - """Convert to dict and return.""" - return {key: value for key, value in await self.items()} - - async def clear(self) -> None: - """Deletes the entire hash from the Redis cache.""" - await self._validate_cache() - log.trace("Clearing the cache of all key/value pairs.") - await self.bot.redis_session.delete(self._namespace) - - async def pop(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> RedisValueType: - """Get the item, remove it from the cache, and provide a default if not found.""" - log.trace(f"Attempting to pop {key}.") - value = await self.get(key, default) - - log.trace( - f"Attempting to delete item with key '{key}' from the cache. " - "If this key doesn't exist, nothing will happen." - ) - await self.delete(key) - - return value - - async def update(self, items: Dict[RedisKeyType, RedisValueType]) -> None: - """ - Update the Redis cache with multiple values. - - This works exactly like dict.update from a normal dictionary. You pass - a dictionary with one or more key/value pairs into this method. If the keys - do not exist in the RedisCache, they are created. If they do exist, the values - are updated with the new ones from `items`. - - Please note that keys and the values in the `items` dictionary - must consist of valid RedisKeyTypes and RedisValueTypes. - """ - await self._validate_cache() - log.trace(f"Updating the cache with the following items:\n{items}") - await self.bot.redis_session.hmset_dict(self._namespace, self._dict_to_typestring(items)) - - async def increment(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None: - """ - Increment the value by `amount`. - - This works for both floats and ints, but will raise a TypeError - if you try to do it for any other type of value. - - This also supports negative amounts, although it would provide better - readability to use .decrement() for that. - """ - log.trace(f"Attempting to increment/decrement the value with the key {key} by {amount}.") - - # We initialize the lock here, because we need to ensure we get it - # running on the same loop as the calling coroutine. - # - # If we initialized the lock in the __init__, the loop that the coroutine this method - # would be called from might not exist yet, and so the lock would be on a different - # loop, which would raise RuntimeErrors. - if self._increment_lock is None: - self._increment_lock = asyncio.Lock() - - # Since this has several API calls, we need a lock to prevent race conditions - async with self._increment_lock: - value = await self.get(key) - - # Can't increment a non-existing value - if value is None: - error_message = "The provided key does not exist!" - log.error(error_message) - raise KeyError(error_message) - - # If it does exist, and it's an int or a float, increment and set it. - if isinstance(value, int) or isinstance(value, float): - value += amount - await self.set(key, value) - else: - error_message = "You may only increment or decrement values that are integers or floats." - log.error(error_message) - raise TypeError(error_message) - - async def decrement(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None: - """ - Decrement the value by `amount`. - - Basically just does the opposite of .increment. - """ - await self.increment(key, -amount) diff --git a/config-default.yml b/config-default.yml index 58651f548..c809a7340 100644 --- a/config-default.yml +++ b/config-default.yml @@ -62,20 +62,6 @@ style: cross_mark: "\u274C" check_mark: "\u2705" - ducky_yellow: &DUCKY_YELLOW 574951975574175744 - ducky_blurple: &DUCKY_BLURPLE 574951975310065675 - ducky_regal: &DUCKY_REGAL 637883439185395712 - ducky_camo: &DUCKY_CAMO 637914731566596096 - ducky_ninja: &DUCKY_NINJA 637923502535606293 - ducky_devil: &DUCKY_DEVIL 637925314982576139 - ducky_tube: &DUCKY_TUBE 637881368008851456 - ducky_hunt: &DUCKY_HUNT 639355090909528084 - ducky_wizard: &DUCKY_WIZARD 639355996954689536 - ducky_party: &DUCKY_PARTY 639468753440210977 - ducky_angel: &DUCKY_ANGEL 640121935610511361 - ducky_maul: &DUCKY_MAUL 640137724958867467 - ducky_santa: &DUCKY_SANTA 655360331002019870 - # emotes used for #reddit upvotes: "<:reddit_upvotes:755845219890757644>" comments: "<:reddit_comments:755845255001014384>" @@ -144,9 +130,14 @@ guild: modmail: 714494672835444826 channels: - announcements: 354619224620138496 - user_event_announcements: &USER_EVENT_A 592000283102674944 - python_news: &PYNEWS_CHANNEL 704372456592506880 + # Public announcement and news channels + change_log: &CHANGE_LOG 748238795236704388 + announcements: &ANNOUNCEMENTS 354619224620138496 + python_news: &PYNEWS_CHANNEL 704372456592506880 + python_events: &PYEVENTS_CHANNEL 729674110270963822 + mailing_lists: &MAILING_LISTS 704372456592506880 + reddit: &REDDIT_CHANNEL 458224812528238616 + user_event_announcements: &USER_EVENT_A 592000283102674944 # Development dev_contrib: &DEV_CONTRIB 635950537262759947 @@ -177,7 +168,6 @@ guild: # Special bot_commands: &BOT_CMD 267659945086812160 esoteric: 470884583684964352 - reddit: 458224812528238616 verification: 352442727016693763 # Staff @@ -192,6 +182,12 @@ guild: mod_spam: &MOD_SPAM 620607373828030464 organisation: &ORGANISATION 551789653284356126 staff_lounge: &STAFF_LOUNGE 464905259261755392 + duck_pond: &DUCK_POND 637820308341915648 + + # Staff announcement channels + staff_announcements: &STAFF_ANNOUNCEMENTS 464033278631084042 + mod_announcements: &MOD_ANNOUNCEMENTS 372115205867700225 + admin_announcements: &ADMIN_ANNOUNCEMENTS 749736155569848370 # Voice admins_voice: &ADMINS_VOICE 500734494840717332 @@ -275,17 +271,19 @@ guild: filter: # What do we filter? - filter_zalgo: false - filter_invites: true - filter_domains: true - watch_regex: true - watch_rich_embeds: true + filter_zalgo: false + filter_invites: true + filter_domains: true + filter_everyone_ping: true + watch_regex: true + watch_rich_embeds: true # Notify user on filter? # Notifications are not expected for "watchlist" type filters - notify_user_zalgo: false - notify_user_invites: true - notify_user_domains: false + notify_user_zalgo: false + notify_user_invites: true + notify_user_domains: false + notify_user_everyone_ping: true # Filter configuration ping_everyone: true @@ -391,12 +389,6 @@ anti_spam: interval: 10 max: 3 - # The everyone ping filter is temporarily disabled - # until we've fixed a couple of bugs. - # everyone_ping: - # interval: 10 - # max: 0 - reddit: subreddits: @@ -467,21 +459,19 @@ sync: max_diff: 10 duck_pond: - threshold: 5 - custom_emojis: - - *DUCKY_YELLOW - - *DUCKY_BLURPLE - - *DUCKY_CAMO - - *DUCKY_DEVIL - - *DUCKY_NINJA - - *DUCKY_REGAL - - *DUCKY_TUBE - - *DUCKY_HUNT - - *DUCKY_WIZARD - - *DUCKY_PARTY - - *DUCKY_ANGEL - - *DUCKY_MAUL - - *DUCKY_SANTA + threshold: 4 + channel_blacklist: + - *ANNOUNCEMENTS + - *PYNEWS_CHANNEL + - *PYEVENTS_CHANNEL + - *MAILING_LISTS + - *REDDIT_CHANNEL + - *USER_EVENT_A + - *DUCK_POND + - *CHANGE_LOG + - *STAFF_ANNOUNCEMENTS + - *MOD_ANNOUNCEMENTS + - *ADMIN_ANNOUNCEMENTS python_news: mail_lists: diff --git a/tests/bot/exts/fun/__init__.py b/tests/bot/exts/fun/__init__.py deleted file mode 100644 index e69de29bb..000000000 --- a/tests/bot/exts/fun/__init__.py +++ /dev/null diff --git a/tests/bot/exts/fun/test_duck_pond.py b/tests/bot/exts/fun/test_duck_pond.py deleted file mode 100644 index 704b08066..000000000 --- a/tests/bot/exts/fun/test_duck_pond.py +++ /dev/null @@ -1,548 +0,0 @@ -import asyncio -import logging -import typing -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -import discord - -from bot import constants -from bot.exts.fun import duck_pond -from tests import base -from tests import helpers - -MODULE_PATH = "bot.exts.fun.duck_pond" - - -class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): - """Tests for DuckPond functionality.""" - - @classmethod - def setUpClass(cls): - """Sets up the objects that only have to be initialized once.""" - cls.nonstaff_member = helpers.MockMember(name="Non-staffer") - - cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) - cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) - - cls.checkmark_emoji = "\N{White Heavy Check Mark}" - cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" - cls.unicode_duck_emoji = "\N{Duck}" - cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) - cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) - - def setUp(self): - """Sets up the objects that need to be refreshed before each test.""" - self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) - self.cog = duck_pond.DuckPond(bot=self.bot) - - def test_duck_pond_correctly_initializes(self): - """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" - bot = helpers.MockBot() - cog = MagicMock() - - duck_pond.DuckPond.__init__(cog, bot) - - self.assertEqual(cog.bot, bot) - self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) - bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) - - def test_fetch_webhook_succeeds_without_connectivity_issues(self): - """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" - self.bot.fetch_webhook.return_value = "dummy webhook" - self.cog.webhook_id = 1 - - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - self.assertEqual(self.cog.webhook, "dummy webhook") - - def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): - """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" - self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") - self.cog.webhook_id = 1 - - log = logging.getLogger(MODULE_PATH) - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def test_is_staff_returns_correct_values_based_on_instance_passed(self): - """The `is_staff` method should return correct values based on the instance passed.""" - test_cases = ( - (helpers.MockUser(name="User instance"), False), - (helpers.MockMember(name="Member instance without staff role"), False), - (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) - ) - - for user, expected_return in test_cases: - actual_return = self.cog.is_staff(user) - with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): - """The `has_green_checkmark` method should only return `True` if one is present.""" - test_cases = ( - ( - "No reactions", helpers.MockMessage(), False - ), - ( - "No green check mark reactions", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) - ]), - False - ), - ( - "Green check mark reaction, but not from our bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) - ]), - False - ), - ( - "Green check mark reaction, with one from the bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) - ]), - True - ) - ) - - for description, message, expected_return in test_cases: - actual_return = await self.cog.has_green_checkmark(message) - with self.subTest( - test_case=description, - expected_return=expected_return, - actual_return=actual_return - ): - self.assertEqual(expected_return, actual_return) - - def _get_reaction( - self, - emoji: typing.Union[str, helpers.MockEmoji], - staff: int = 0, - nonstaff: int = 0 - ) -> helpers.MockReaction: - staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] - nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] - return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - - async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): - """The `count_ducks` method should return the number of unique staffers who gave a duck.""" - test_cases = ( - # Simple test cases - # A message without reactions should return 0 - ( - "No reactions", - helpers.MockMessage(), - 0 - ), - # A message with a non-duck reaction from a non-staffer should return 0 - ( - "Non-duck reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), - 0 - ), - # A message with a non-duck reaction from a staffer should return 0 - ( - "Non-duck reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), - 0 - ), - # A message with a non-duck reaction from a non-staffer and staffer should return 0 - ( - "Non-duck reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a non-staffer should return 0 - ( - "Unicode Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a staffer should return 1 - ( - "Unicode Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), - 1 - ), - # A message with a unicode duck reaction from a non-staffer and staffer should return 1 - ( - "Unicode Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer should return 0 - ( - "Duckpond Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), - 0 - ), - # A message with a duckpond duck reaction from a staffer should return 1 - ( - "Duckpond Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 - ( - "Duckpond Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), - 1 - ), - - # Complex test cases - # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), - 3 - ), - # A staffer with multiple duck reactions only counts once - ( - "Two different duck reactions from the same staffer", - helpers.MockMessage( - reactions=[ - helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), - ] - ), - 1 - ), - # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) - ( - "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), - 0 - ), - # We correctly sum when multiple reactions are provided. - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage( - reactions=[ - self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), - self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), - ] - ), - 3 + 4 - ), - ) - - for description, message, expected_count in test_cases: - actual_count = await self.cog.count_ducks(message) - with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): - self.assertEqual(expected_count, actual_count) - - async def test_relay_message_correctly_relays_content_and_attachments(self): - """The `relay_message` method should correctly relay message content and attachments.""" - send_webhook_path = f"{MODULE_PATH}.send_webhook" - send_attachments_path = f"{MODULE_PATH}.send_attachments" - author = MagicMock( - display_name="x", - avatar_url="https://" - ) - - self.cog.webhook = helpers.MockAsyncWebhook() - - test_values = ( - (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), - (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), - (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), - (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), - ) - - for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: - with self.subTest(clean_content=message.clean_content, attachments=message.attachments): - await self.cog.relay_message(message) - - self.assertEqual(expect_webhook_call, send_webhook.called) - self.assertEqual(expect_attachment_call, send_attachments.called) - - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger(MODULE_PATH) - - for side_effect in side_effects: # pragma: no cover - send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertNotLogs(logger=log, level=logging.ERROR): - await self.cog.relay_message(message) - - self.assertEqual(send_webhook.call_count, 2) - - @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger(MODULE_PATH) - - side_effect = discord.HTTPException(MagicMock(), "") - send_attachments.side_effect = side_effect - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - await self.cog.relay_message(message) - - send_webhook.assert_called_once_with( - webhook=self.cog.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): - """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" - payload = MagicMock(name=label) - payload.emoji.is_custom_emoji.return_value = is_custom_emoji - payload.emoji.id = id_ - payload.emoji.name = emoji_name - return payload - - async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): - """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" - test_values = ( - # Custom Emojis - ( - self._mock_payload( - label="Custom Duckpond Emoji", - is_custom_emoji=True, - id_=constants.DuckPond.custom_emojis[0], - emoji_name="" - ), - True - ), - ( - self._mock_payload( - label="Custom Non-Duckpond Emoji", - is_custom_emoji=True, - id_=123, - emoji_name="" - ), - False - ), - # Unicode Emojis - ( - self._mock_payload( - label="Unicode Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.unicode_duck_emoji - ), - True - ), - ( - self._mock_payload( - label="Unicode Non-Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.thumbs_up_emoji - ), - False - ), - ) - - for payload, expected_return in test_values: - actual_return = self.cog._payload_has_duckpond_emoji(payload) - with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - @patch(f"{MODULE_PATH}.discord.utils.get") - @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) - def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): - """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) - - # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check - utils_get.assert_not_called() - - def _raw_reaction_mocks(self, channel_id, message_id, user_id): - """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" - channel = helpers.MockTextChannel(id=channel_id) - self.bot.get_all_channels.return_value = (channel,) - - message = helpers.MockMessage(id=message_id) - - channel.fetch_message.return_value = message - - member = helpers.MockMember(id=user_id, roles=[self.staff_role]) - message.guild.members = (member,) - - payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) - - return channel, message, member, payload - - async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): - """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" - channel_id = 1234 - message_id = 2345 - user_id = 3456 - - channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - test_cases = ( - ("non-staff member", helpers.MockMember(id=user_id)), - ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), - ) - - payload.emoji = self.duck_pond_emoji - - for description, member in test_cases: - message.guild.members = (member, ) - with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: - checkmark.side_effect = AssertionError( - "Expected method to return before calling `self.has_green_checkmark`." - ) - self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) - - # Check that we did make it past the payload checks - channel.fetch_message.assert_called_once() - channel.fetch_message.reset_mock() - - @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) - def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): - """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" - channel_id = 31415926535 - message_id = 27182818284 - user_id = 16180339887 - - channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) - payload.emoji.is_custom_emoji.return_value = False - - message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] - - is_staff.return_value = True - count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) - - # Assert that we've made it past `self.is_staff` - is_staff.assert_called_once() - - async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): - """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - - channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) - - payload.emoji = self.duck_pond_emoji - - for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_relay=should_relay): - await self.cog.on_raw_reaction_add(payload) - - # Confirm that we've made it past counting - count_ducks.assert_called_once() - - # Did we relay a message? - has_relayed = relay_message.called - self.assertEqual(has_relayed, should_relay) - - if should_relay: - relay_message.assert_called_once_with(message) - - async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): - """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" - checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) - - message = helpers.MockMessage(id=1234) - - channel = helpers.MockTextChannel(id=98765) - channel.fetch_message.return_value = message - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) - - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): - await self.cog.on_raw_reaction_remove(payload) - - # Check if we fetched the message - channel.fetch_message.assert_called_once_with(message.id) - - # Check if we actually counted the number of ducks - count_ducks.assert_called_once_with(message) - - has_re_added_checkmark = message.add_reaction.called - self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) - - if should_re_add_checkmark: - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - message.add_reaction.reset_mock() - - # reset mocks - channel.fetch_message.reset_mock() - message.reset_mock() - - def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): - """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" - channel = helpers.MockTextChannel(id=98765) - - channel.fetch_message.side_effect = AssertionError( - "Expected method to return before calling `channel.fetch_message`" - ) - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) - - channel.fetch_message.assert_not_called() - - -class DuckPondSetupTests(unittest.TestCase): - """Tests setup of the `DuckPond` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = helpers.MockBot() - duck_pond.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py deleted file mode 100644 index a2f0fe55d..000000000 --- a/tests/bot/utils/test_redis_cache.py +++ /dev/null @@ -1,265 +0,0 @@ -import asyncio -import unittest - -import fakeredis.aioredis - -from bot.utils import RedisCache -from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError -from tests import helpers - - -class RedisCacheTests(unittest.IsolatedAsyncioTestCase): - """Tests the RedisCache class from utils.redis_dict.py.""" - - async def asyncSetUp(self): # noqa: N802 - """Sets up the objects that only have to be initialized once.""" - self.bot = helpers.MockBot() - self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - - # Okay, so this is necessary so that we can create a clean new - # class for every test method, and we want that because it will - # ensure we get a fresh loop, which is necessary for test_increment_lock - # to be able to pass. - class DummyCog: - """A dummy cog, for dummies.""" - - redis = RedisCache() - - def __init__(self, bot: helpers.MockBot): - self.bot = bot - - self.cog = DummyCog(self.bot) - - await self.cog.redis.clear() - - def test_class_attribute_namespace(self): - """Test that RedisDict creates a namespace automatically for class attributes.""" - self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") - - async def test_class_attribute_required(self): - """Test that errors are raised when not assigned as a class attribute.""" - bad_cache = RedisCache() - self.assertIs(bad_cache._namespace, None) - - with self.assertRaises(RuntimeError): - await bad_cache.set("test", "me_up_deadman") - - async def test_set_get_item(self): - """Test that users can set and get items from the RedisDict.""" - test_cases = ( - ('favorite_fruit', 'melon'), - ('favorite_number', 86), - ('favorite_fraction', 86.54), - ('favorite_boolean', False), - ('other_boolean', True), - ) - - # Test that we can get and set different types. - for test in test_cases: - await self.cog.redis.set(*test) - self.assertEqual(await self.cog.redis.get(test[0]), test[1]) - - # Test that .get allows a default value - self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") - - async def test_set_item_type(self): - """Test that .set rejects keys and values that are not permitted.""" - fruits = ["lemon", "melon", "apple"] - - with self.assertRaises(TypeError): - await self.cog.redis.set(fruits, "nice") - - with self.assertRaises(TypeError): - await self.cog.redis.set(4.23, "nice") - - async def test_delete_item(self): - """Test that .delete allows us to delete stuff from the RedisCache.""" - # Add an item and verify that it gets added - await self.cog.redis.set("internet", "firetruck") - self.assertEqual(await self.cog.redis.get("internet"), "firetruck") - - # Delete that item and verify that it gets deleted - await self.cog.redis.delete("internet") - self.assertIs(await self.cog.redis.get("internet"), None) - - async def test_contains(self): - """Test that we can check membership with .contains.""" - await self.cog.redis.set('favorite_country', "Burkina Faso") - - self.assertIs(await self.cog.redis.contains('favorite_country'), True) - self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) - - async def test_items(self): - """Test that the RedisDict can be iterated.""" - # Set up our test cases in the Redis cache - test_cases = [ - ('favorite_turtle', 'Donatello'), - ('second_favorite_turtle', 'Leonardo'), - ('third_favorite_turtle', 'Raphael'), - ] - for key, value in test_cases: - await self.cog.redis.set(key, value) - - # Consume the AsyncIterator into a regular list, easier to compare that way. - redis_items = [item for item in await self.cog.redis.items()] - - # These sequences are probably in the same order now, but probably - # isn't good enough for tests. Let's not rely on .hgetall always - # returning things in sequence, and just sort both lists to be safe. - redis_items = sorted(redis_items) - test_cases = sorted(test_cases) - - # If these are equal now, everything works fine. - self.assertSequenceEqual(test_cases, redis_items) - - async def test_length(self): - """Test that we can get the correct .length from the RedisDict.""" - await self.cog.redis.set('one', 1) - await self.cog.redis.set('two', 2) - await self.cog.redis.set('three', 3) - self.assertEqual(await self.cog.redis.length(), 3) - - await self.cog.redis.set('four', 4) - self.assertEqual(await self.cog.redis.length(), 4) - - async def test_to_dict(self): - """Test that the .to_dict method returns a workable dictionary copy.""" - copy = await self.cog.redis.to_dict() - local_copy = {key: value for key, value in await self.cog.redis.items()} - self.assertIs(type(copy), dict) - self.assertDictEqual(copy, local_copy) - - async def test_clear(self): - """Test that the .clear method removes the entire hash.""" - await self.cog.redis.set('teddy', 'with me') - await self.cog.redis.set('in my dreams', 'you have a weird hat') - self.assertEqual(await self.cog.redis.length(), 2) - - await self.cog.redis.clear() - self.assertEqual(await self.cog.redis.length(), 0) - - async def test_pop(self): - """Test that we can .pop an item from the RedisDict.""" - await self.cog.redis.set('john', 'was afraid') - - self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') - self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') - self.assertEqual(await self.cog.redis.length(), 0) - - async def test_update(self): - """Test that we can .update the RedisDict with multiple items.""" - await self.cog.redis.set("reckfried", "lona") - await self.cog.redis.set("bel air", "prince") - await self.cog.redis.update({ - "reckfried": "jona", - "mega": "hungry, though", - }) - - result = { - "reckfried": "jona", - "bel air": "prince", - "mega": "hungry, though", - } - self.assertDictEqual(await self.cog.redis.to_dict(), result) - - def test_typestring_conversion(self): - """Test the typestring-related helper functions.""" - conversion_tests = ( - (12, "i|12"), - (12.4, "f|12.4"), - ("cowabunga", "s|cowabunga"), - ) - - # Test conversion to typestring - for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) - - # Test conversion from typestrings - for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) - - # Test that exceptions are raised on invalid input - with self.assertRaises(TypeError): - self.cog.redis._value_to_typestring(["internet"]) - self.cog.redis._value_from_typestring("o|firedog") - - async def test_increment_decrement(self): - """Test .increment and .decrement methods.""" - await self.cog.redis.set("entropic", 5) - await self.cog.redis.set("disentropic", 12.5) - - # Test default increment - await self.cog.redis.increment("entropic") - self.assertEqual(await self.cog.redis.get("entropic"), 6) - - # Test default decrement - await self.cog.redis.decrement("entropic") - self.assertEqual(await self.cog.redis.get("entropic"), 5) - - # Test float increment with float - await self.cog.redis.increment("disentropic", 2.0) - self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) - - # Test float increment with int - await self.cog.redis.increment("disentropic", 2) - self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) - - # Test negative increments, because why not. - await self.cog.redis.increment("entropic", -5) - self.assertEqual(await self.cog.redis.get("entropic"), 0) - - # Negative decrements? Sure. - await self.cog.redis.decrement("entropic", -5) - self.assertEqual(await self.cog.redis.get("entropic"), 5) - - # What about if we use a negative float to decrement an int? - # This should convert the type into a float. - await self.cog.redis.decrement("entropic", -2.5) - self.assertEqual(await self.cog.redis.get("entropic"), 7.5) - - # Let's test that they raise the right errors - with self.assertRaises(KeyError): - await self.cog.redis.increment("doesn't_exist!") - - await self.cog.redis.set("stringthing", "stringthing") - with self.assertRaises(TypeError): - await self.cog.redis.increment("stringthing") - - async def test_increment_lock(self): - """Test that we can't produce a race condition in .increment.""" - await self.cog.redis.set("test_key", 0) - tasks = [] - - # Increment this a lot in different tasks - for _ in range(100): - task = asyncio.create_task( - self.cog.redis.increment("test_key", 1) - ) - tasks.append(task) - await asyncio.gather(*tasks) - - # Confirm that the value has been incremented the exact right number of times. - value = await self.cog.redis.get("test_key") - self.assertEqual(value, 100) - - async def test_exceptions_raised(self): - """Testing that the various RuntimeErrors are reachable.""" - class MyCog: - cache = RedisCache() - - def __init__(self): - self.other_cache = RedisCache() - - cog = MyCog() - - # Raises "No Bot instance" - with self.assertRaises(NoBotInstanceError): - await cog.cache.get("john") - - # Raises "RedisCache has no namespace" - with self.assertRaises(NoNamespaceError): - await cog.other_cache.get("was") - - # Raises "You must access the RedisCache instance through the cog instance" - with self.assertRaises(NoParentInstanceError): - await MyCog.cache.get("afraid") diff --git a/tests/helpers.py b/tests/helpers.py index facc4e1af..e47fdf28f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -308,7 +308,11 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ - spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) + spec_set = Bot( + command_prefix=unittest.mock.MagicMock(), + loop=_get_mock_loop(), + redis_session=unittest.mock.MagicMock(), + ) additional_spec_asyncs = ("wait_for", "redis_ready") def __init__(self, **kwargs) -> None: |