diff options
author | 2020-10-21 09:58:50 -0700 | |
---|---|---|
committer | 2020-10-21 09:58:50 -0700 | |
commit | 05aeab8c9760b0c683c43227152cdf3c9f275a1b (patch) | |
tree | 8df13e212a275995e959297882ff632ac23e43a0 | |
parent | Relay python-dev to mailing lists channel (diff) | |
parent | Merge master and fix LICENSE-THIRD-PARTY conflict (diff) |
Merge pull request #1113 - cache silences
Cache silences to enable rescheduling and revoking reaction permissions
-rw-r--r-- | LICENSE-THIRD-PARTY | 72 | ||||
-rw-r--r-- | bot/exts/moderation/silence.py | 202 | ||||
-rw-r--r-- | bot/exts/utils/reminders.py | 8 | ||||
-rw-r--r-- | tests/_autospec.py | 64 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 587 | ||||
-rw-r--r-- | tests/helpers.py | 21 |
6 files changed, 689 insertions, 265 deletions
diff --git a/LICENSE-THIRD-PARTY b/LICENSE-THIRD-PARTY index 3349d7c05..eacd9b952 100644 --- a/LICENSE-THIRD-PARTY +++ b/LICENSE-THIRD-PARTY @@ -1,14 +1,13 @@ -BSD 3-Clause License - +--------------------------------------------------------------------------------------------------- + BSD 3-Clause License Applies to: -- _RE_PYTHON_REPL and portions of _RE_IPYTHON_REPL in bot/cogs/codeblock/parsing.py - -- Copyright (c) 2008-Present, IPython Development Team -- Copyright (c) 2001-2007, Fernando Perez <[email protected]> -- Copyright (c) 2001, Janko Hauser <[email protected]> -- Copyright (c) 2001, Nathaniel Gray <[email protected]> - -All rights reserved. + - Copyright (c) 2008-Present, IPython Development Team + Copyright (c) 2001-2007, Fernando Perez <[email protected]> + Copyright (c) 2001, Janko Hauser <[email protected]> + Copyright (c) 2001, Nathaniel Gray <[email protected]> + All rights reserved. + - bot/exts/info/codeblock/_parsing.py: _RE_PYTHON_REPL and portions of _RE_IPYTHON_REPL +--------------------------------------------------------------------------------------------------- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -34,3 +33,56 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--------------------------------------------------------------------------------------------------- + PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +Applies to: + - Copyright © 2001-2020 Python Software Foundation. All rights reserved. + - tests/_autospec.py: _decoration_helper +--------------------------------------------------------------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation; +All Rights Reserved" are retained in Python alone or in any derivative version +prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index ac0c1c85e..e6712b3b6 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -1,8 +1,11 @@ -import asyncio +import json import logging from contextlib import suppress +from datetime import datetime, timedelta, timezone +from operator import attrgetter from typing import Optional +from async_rediscache import RedisCache from discord import TextChannel from discord.ext import commands, tasks from discord.ext.commands import Context @@ -10,10 +13,25 @@ from discord.ext.commands import Context from bot.bot import Bot from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles from bot.converters import HushDurationConverter +from bot.utils.lock import LockedResourceError, lock_arg from bot.utils.scheduling import Scheduler log = logging.getLogger(__name__) +LOCK_NAMESPACE = "silence" + +MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced." +MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely." +MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)." + +MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced." +MSG_UNSILENCE_MANUAL = ( + f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were " + f"set manually or the cache was prematurely cleared. " + f"Please edit the overwrites manually to unsilence." +) +MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel." + class SilenceNotifier(tasks.Loop): """Loop notifier for posting notices to `alert_channel` containing added channels.""" @@ -56,25 +74,32 @@ class SilenceNotifier(tasks.Loop): class Silence(commands.Cog): """Commands for stopping channel messages for `verified` role in a channel.""" + # Maps muted channel IDs to their previous overwrites for send_message and add_reactions. + # Overwrites are stored as JSON. + previous_overwrites = RedisCache() + + # Maps muted channel IDs to POSIX timestamps of when they'll be unsilenced. + # A timestamp equal to -1 means it's indefinite. + unsilence_timestamps = RedisCache() + def __init__(self, bot: Bot): self.bot = bot self.scheduler = Scheduler(self.__class__.__name__) - self.muted_channels = set() - self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars()) - self._get_instance_vars_event = asyncio.Event() + self._init_task = self.bot.loop.create_task(self._async_init()) - async def _get_instance_vars(self) -> None: - """Get instance variables after they're available to get from the guild.""" + async def _async_init(self) -> None: + """Set instance attributes once the guild is available and reschedule unsilences.""" await self.bot.wait_until_guild_available() + guild = self.bot.get_guild(Guild.id) self._verified_role = guild.get_role(Roles.verified) self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) - self._mod_log_channel = self.bot.get_channel(Channels.mod_log) - self.notifier = SilenceNotifier(self._mod_log_channel) - self._get_instance_vars_event.set() + self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log)) + await self._reschedule() @commands.command(aliases=("hush",)) + @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True) async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None: """ Silence the current channel for `duration` minutes or `forever`. @@ -82,18 +107,25 @@ class Silence(commands.Cog): Duration is capped at 15 minutes, passing forever makes the silence indefinite. Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start. """ - await self._get_instance_vars_event.wait() - log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.") - if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration): - await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.") - return - if duration is None: - await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.") + await self._init_task + + channel_info = f"#{ctx.channel} ({ctx.channel.id})" + log.debug(f"{ctx.author} is silencing channel {channel_info}.") + + if not await self._set_silence_overwrites(ctx.channel): + log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.") + await ctx.send(MSG_SILENCE_FAIL) return - await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") + await self._schedule_unsilence(ctx, duration) - self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) + if duration is None: + self.notifier.add_channel(ctx.channel) + log.info(f"Silenced {channel_info} indefinitely.") + await ctx.send(MSG_SILENCE_PERMANENT) + else: + log.info(f"Silenced {channel_info} for {duration} minute(s).") + await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration)) @commands.command(aliases=("unhush",)) async def unsilence(self, ctx: Context) -> None: @@ -102,61 +134,115 @@ class Silence(commands.Cog): If the channel was silenced indefinitely, notifications for the channel will stop. """ - await self._get_instance_vars_event.wait() + await self._init_task log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") - if not await self._unsilence(ctx.channel): - await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.") + await self._unsilence_wrapper(ctx.channel) + + @lock_arg(LOCK_NAMESPACE, "channel", raise_error=True) + async def _unsilence_wrapper(self, channel: TextChannel) -> None: + """Unsilence `channel` and send a success/failure message.""" + if not await self._unsilence(channel): + overwrite = channel.overwrites_for(self._verified_role) + if overwrite.send_messages is False or overwrite.add_reactions is False: + await channel.send(MSG_UNSILENCE_MANUAL) + else: + await channel.send(MSG_UNSILENCE_FAIL) else: - await ctx.send(f"{Emojis.check_mark} unsilenced current channel.") + await channel.send(MSG_UNSILENCE_SUCCESS) - async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: - """ - Silence `channel` for `self._verified_role`. + async def _set_silence_overwrites(self, channel: TextChannel) -> bool: + """Set silence permission overwrites for `channel` and return True if successful.""" + overwrite = channel.overwrites_for(self._verified_role) + prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) - If `persistent` is `True` add `channel` to notifier. - `duration` is only used for logging; if None is passed `persistent` should be True to not log None. - Return `True` if channel permissions were changed, `False` otherwise. - """ - current_overwrite = channel.overwrites_for(self._verified_role) - if current_overwrite.send_messages is False: - log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.") + if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()): return False - await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False)) - self.muted_channels.add(channel) - if persistent: - log.info(f"Silenced #{channel} ({channel.id}) indefinitely.") - self.notifier.add_channel(channel) - return True - - log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).") + + overwrite.update(send_messages=False, add_reactions=False) + await channel.set_permissions(self._verified_role, overwrite=overwrite) + await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites)) + return True + async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None: + """Schedule `ctx.channel` to be unsilenced if `duration` is not None.""" + if duration is None: + await self.unsilence_timestamps.set(ctx.channel.id, -1) + else: + self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) + unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration) + await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp()) + async def _unsilence(self, channel: TextChannel) -> bool: """ Unsilence `channel`. - Check if `channel` is silenced through a `PermissionOverwrite`, - if it is unsilence it and remove it from the notifier. + If `channel` has a silence task scheduled or has its previous overwrites cached, unsilence + it, cancel the task, and remove it from the notifier. Notify admins if it has a task but + not cached overwrites. + Return `True` if channel permissions were changed, `False` otherwise. """ - current_overwrite = channel.overwrites_for(self._verified_role) - if current_overwrite.send_messages is False: - await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) - log.info(f"Unsilenced channel #{channel} ({channel.id}).") - self.scheduler.cancel(channel.id) - self.notifier.remove_channel(channel) - self.muted_channels.discard(channel) - return True - log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") - return False + prev_overwrites = await self.previous_overwrites.get(channel.id) + if channel.id not in self.scheduler and prev_overwrites is None: + log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") + return False + + overwrite = channel.overwrites_for(self._verified_role) + if prev_overwrites is None: + log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.") + overwrite.update(send_messages=None, add_reactions=None) + else: + overwrite.update(**json.loads(prev_overwrites)) + + await channel.set_permissions(self._verified_role, overwrite=overwrite) + log.info(f"Unsilenced channel #{channel} ({channel.id}).") + + self.scheduler.cancel(channel.id) + self.notifier.remove_channel(channel) + await self.previous_overwrites.delete(channel.id) + await self.unsilence_timestamps.delete(channel.id) + + if prev_overwrites is None: + await self._mod_alerts_channel.send( + f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing " + f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` " + f"overwrites for {self._verified_role.mention} are at their desired values." + ) + + return True + + async def _reschedule(self) -> None: + """Reschedule unsilencing of active silences and add permanent ones to the notifier.""" + for channel_id, timestamp in await self.unsilence_timestamps.items(): + channel = self.bot.get_channel(channel_id) + if channel is None: + log.info(f"Can't reschedule silence for {channel_id}: channel not found.") + continue + + if timestamp == -1: + log.info(f"Adding permanent silence for #{channel} ({channel.id}) to the notifier.") + self.notifier.add_channel(channel) + continue + + dt = datetime.fromtimestamp(timestamp, tz=timezone.utc) + delta = (dt - datetime.now(tz=timezone.utc)).total_seconds() + if delta <= 0: + # Suppress the error since it's not being invoked by a user via the command. + with suppress(LockedResourceError): + await self._unsilence_wrapper(channel) + else: + log.info(f"Rescheduling silence for #{channel} ({channel.id}).") + self.scheduler.schedule_later(delta, channel_id, self._unsilence_wrapper(channel)) def cog_unload(self) -> None: - """Send alert with silenced channels and cancel scheduled tasks on unload.""" - self.scheduler.cancel_all() - if self.muted_channels: - channels_string = ''.join(channel.mention for channel in self.muted_channels) - message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}" - asyncio.create_task(self._mod_alerts_channel.send(message)) + """Cancel the init task and scheduled tasks.""" + # It's important to wait for _init_task (specifically for _reschedule) to be cancelled + # before cancelling scheduled tasks. Otherwise, it's possible for _reschedule to schedule + # more tasks after cancel_all has finished, despite _init_task.cancel being called first. + # This is cause cancel() on its own doesn't block until the task is cancelled. + self._init_task.cancel() + self._init_task.add_done_callback(lambda _: self.scheduler.cancel_all()) # This cannot be static (must have a __func__ attribute). async def cog_check(self, ctx: Context) -> bool: diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index bf4e24661..3113a1149 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -23,7 +23,7 @@ from bot.utils.time import humanize_delta log = logging.getLogger(__name__) -NAMESPACE = "reminder" # Used for the mutually_exclusive decorator; constant to prevent typos +LOCK_NAMESPACE = "reminder" WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 @@ -170,7 +170,7 @@ class Reminders(Cog): log.trace(f"Scheduling new task #{reminder['id']}") self.schedule_reminder(reminder) - @lock_arg(NAMESPACE, "reminder", itemgetter("id"), raise_error=True) + @lock_arg(LOCK_NAMESPACE, "reminder", itemgetter("id"), raise_error=True) async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: """Send the reminder.""" is_valid, user, channel = self.ensure_valid_reminder(reminder) @@ -378,7 +378,7 @@ class Reminders(Cog): mention_ids = [mention.id for mention in mentions] await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) - @lock_arg(NAMESPACE, "id_", raise_error=True) + @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True) async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: """Edits a reminder with the given payload, then sends a confirmation message.""" if not await self._can_modify(ctx, id_): @@ -398,7 +398,7 @@ class Reminders(Cog): await self._reschedule_reminder(reminder) @remind_group.command("delete", aliases=("remove", "cancel")) - @lock_arg(NAMESPACE, "id_", raise_error=True) + @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True) async def delete_reminder(self, ctx: Context, id_: int) -> None: """Delete one of your active reminders.""" if not await self._can_modify(ctx, id_): diff --git a/tests/_autospec.py b/tests/_autospec.py new file mode 100644 index 000000000..ee2fc1973 --- /dev/null +++ b/tests/_autospec.py @@ -0,0 +1,64 @@ +import contextlib +import functools +import unittest.mock +from typing import Callable + + [email protected](unittest.mock._patch.decoration_helper) +def _decoration_helper(self, patched, args, keywargs): + """Skips adding patchings as args if their `dont_pass` attribute is True.""" + # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added. + extra_args = [] + with contextlib.ExitStack() as exit_stack: + for patching in patched.patchings: + arg = exit_stack.enter_context(patching) + if not getattr(patching, "dont_pass", False): + # Only add the patching as an arg if dont_pass is False. + if patching.attribute_name is not None: + keywargs.update(arg) + elif patching.new is unittest.mock.DEFAULT: + extra_args.append(arg) + + args += tuple(extra_args) + yield args, keywargs + + [email protected](unittest.mock._patch.copy) +def _copy(self): + """Copy the `dont_pass` attribute along with the standard copy operation.""" + patcher_copy = _copy.original(self) + patcher_copy.dont_pass = getattr(self, "dont_pass", False) + return patcher_copy + + +# Monkey-patch the patcher class :) +_copy.original = unittest.mock._patch.copy +unittest.mock._patch.copy = _copy +unittest.mock._patch.decoration_helper = _decoration_helper + + +def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable: + """ + Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. + + If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object. + """ + # Caller's kwargs should take priority and overwrite the defaults. + kwargs = dict(spec_set=True, autospec=True) + kwargs.update(patch_kwargs) + + # Import the target if it's a string. + # This is to support both object and string targets like patch.multiple. + if type(target) is str: + target = unittest.mock._importer(target) + + def decorator(func): + for attribute in attributes: + patcher = unittest.mock.patch.object(target, attribute, **kwargs) + if not pass_mocks: + # A custom attribute to keep track of which patchings should be skipped. + patcher.dont_pass = True + func = patcher(func) + return func + return decorator diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 3c2d52ae0..104293d8e 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,23 +1,49 @@ +import asyncio import unittest +from datetime import datetime, timezone from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock +from async_rediscache import RedisSession from discord import PermissionOverwrite -from bot.constants import Channels, Emojis, Guild, Roles -from bot.exts.moderation.silence import Silence, SilenceNotifier -from tests.helpers import MockBot, MockContext, MockTextChannel +from bot.constants import Channels, Guild, Roles +from bot.exts.moderation import silence +from tests.helpers import MockBot, MockContext, MockTextChannel, autospec + +redis_session = None +redis_loop = asyncio.get_event_loop() + + +def setUpModule(): # noqa: N802 + """Create and connect to the fakeredis session.""" + global redis_session + redis_session = RedisSession(use_fakeredis=True) + redis_loop.run_until_complete(redis_session.connect()) + + +def tearDownModule(): # noqa: N802 + """Close the fakeredis session.""" + if redis_session: + redis_loop.run_until_complete(redis_session.close()) + + +# Have to subclass it because builtins can't be patched. +class PatchedDatetime(datetime): + """A datetime object with a mocked now() function.""" + + now = mock.create_autospec(datetime, "now") class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.alert_channel = MockTextChannel() - self.notifier = SilenceNotifier(self.alert_channel) + self.notifier = silence.SilenceNotifier(self.alert_channel) self.notifier.stop = self.notifier_stop_mock = Mock() self.notifier.start = self.notifier_start_mock = Mock() def test_add_channel_adds_channel(self): - """Channel in FirstHash with current loop is added to internal set.""" + """Channel is added to `_silenced_channels` with the current loop.""" channel = Mock() with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: self.notifier.add_channel(channel) @@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): self.notifier_start_mock.assert_not_called() def test_remove_channel_removes_channel(self): - """Channel in FirstHash is removed from `_silenced_channels`.""" + """Channel is removed from `_silenced_channels`.""" channel = Mock() with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: self.notifier.remove_channel(channel) @@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): with self.subTest(current_loop=current_loop): with mock.patch.object(self.notifier, "_current_loop", new=current_loop): await self.notifier._notifier() - self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") + self.alert_channel.send.assert_called_once_with( + f"<@&{Roles.moderators}> currently silenced channels: " + ) self.alert_channel.send.reset_mock() async def test_notifier_skips_alert(self): @@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): self.alert_channel.send.assert_not_called() -class SilenceTests(unittest.IsolatedAsyncioTestCase): +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceCogTests(unittest.IsolatedAsyncioTestCase): + """Tests for the general functionality of the Silence cog.""" + + @autospec(silence, "Scheduler", pass_mocks=False) def setUp(self) -> None: self.bot = MockBot() - self.cog = Silence(self.bot) - self.ctx = MockContext() - self.cog._verified_role = None - # Set event so command callbacks can continue. - self.cog._get_instance_vars_event.set() + self.cog = silence.Silence(self.bot) - async def test_instance_vars_got_guild(self): + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def test_async_init_got_guild(self): """Bot got guild after it became available.""" - await self.cog._get_instance_vars() - self.bot.wait_until_guild_available.assert_called_once() + await self.cog._async_init() + self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) - async def test_instance_vars_got_role(self): + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def test_async_init_got_role(self): """Got `Roles.verified` role from guild.""" - await self.cog._get_instance_vars() guild = self.bot.get_guild() - guild.get_role.assert_called_once_with(Roles.verified) + guild.get_role.side_effect = lambda id_: Mock(id=id_) - async def test_instance_vars_got_channels(self): + await self.cog._async_init() + self.assertEqual(self.cog._verified_role.id, Roles.verified) + + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def test_async_init_got_channels(self): """Got channels from bot.""" - await self.cog._get_instance_vars() - self.bot.get_channel.called_once_with(Channels.mod_alerts) - self.bot.get_channel.called_once_with(Channels.mod_log) + self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + + await self.cog._async_init() + self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) - @mock.patch("bot.exts.moderation.silence.SilenceNotifier") - async def test_instance_vars_got_notifier(self, notifier): + @autospec(silence, "SilenceNotifier") + async def test_async_init_got_notifier(self, notifier): """Notifier was started with channel.""" - mod_log = MockTextChannel() - self.bot.get_channel.side_effect = (None, mod_log) - await self.cog._get_instance_vars() - notifier.assert_called_once_with(mod_log) - self.bot.get_channel.side_effect = None - - async def test_silence_sent_correct_discord_message(self): - """Check if proper message was sent when called with duration in channel with previous state.""" + self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + + await self.cog._async_init() + notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) + self.assertEqual(self.cog.notifier, notifier.return_value) + + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def test_async_init_rescheduled(self): + """`_reschedule_` coroutine was awaited.""" + self.cog._reschedule = mock.create_autospec(self.cog._reschedule) + await self.cog._async_init() + self.cog._reschedule.assert_awaited_once_with() + + def test_cog_unload_cancelled_tasks(self): + """The init task was cancelled.""" + self.cog._init_task = asyncio.Future() + self.cog.cog_unload() + + # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. + self.assertTrue(self.cog._init_task.cancelled()) + + @autospec("discord.ext.commands", "has_any_role") + @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) + async def test_cog_check(self, role_check): + """Role check was called with `MODERATION_ROLES`""" + ctx = MockContext() + role_check.return_value.predicate = mock.AsyncMock() + + await self.cog.cog_check(ctx) + role_check.assert_called_once_with(*(1, 2, 3)) + role_check.return_value.predicate.assert_awaited_once_with(ctx) + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class RescheduleTests(unittest.IsolatedAsyncioTestCase): + """Tests for the rescheduling of cached unsilences.""" + + @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) + def setUp(self): + self.bot = MockBot() + self.cog = silence.Silence(self.bot) + self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) + + with mock.patch.object(self.cog, "_reschedule", autospec=True): + asyncio.run(self.cog._async_init()) # Populate instance attributes. + + async def test_skipped_missing_channel(self): + """Did nothing because the channel couldn't be retrieved.""" + self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)] + self.bot.get_channel.return_value = None + + await self.cog._reschedule() + + self.cog.notifier.add_channel.assert_not_called() + self.cog._unsilence_wrapper.assert_not_called() + self.cog.scheduler.schedule_later.assert_not_called() + + async def test_added_permanent_to_notifier(self): + """Permanently silenced channels were added to the notifier.""" + channels = [MockTextChannel(id=123), MockTextChannel(id=456)] + self.bot.get_channel.side_effect = channels + self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)] + + await self.cog._reschedule() + + self.cog.notifier.add_channel.assert_any_call(channels[0]) + self.cog.notifier.add_channel.assert_any_call(channels[1]) + + self.cog._unsilence_wrapper.assert_not_called() + self.cog.scheduler.schedule_later.assert_not_called() + + async def test_unsilenced_expired(self): + """Unsilenced expired silences.""" + channels = [MockTextChannel(id=123), MockTextChannel(id=456)] + self.bot.get_channel.side_effect = channels + self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)] + + await self.cog._reschedule() + + self.cog._unsilence_wrapper.assert_any_call(channels[0]) + self.cog._unsilence_wrapper.assert_any_call(channels[1]) + + self.cog.notifier.add_channel.assert_not_called() + self.cog.scheduler.schedule_later.assert_not_called() + + @mock.patch.object(silence, "datetime", new=PatchedDatetime) + async def test_rescheduled_active(self): + """Rescheduled active silences.""" + channels = [MockTextChannel(id=123), MockTextChannel(id=456)] + self.bot.get_channel.side_effect = channels + self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)] + silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc) + + self.cog._unsilence_wrapper = mock.MagicMock() + unsilence_return = self.cog._unsilence_wrapper.return_value + + await self.cog._reschedule() + + # Yuck. + calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)] + self.cog.scheduler.schedule_later.assert_has_calls(calls) + + unsilence_calls = [mock.call(channel) for channel in channels] + self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls) + + self.cog.notifier.add_channel.assert_not_called() + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceTests(unittest.IsolatedAsyncioTestCase): + """Tests for the silence command and its related helper methods.""" + + @autospec(silence.Silence, "_reschedule", pass_mocks=False) + @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) + def setUp(self) -> None: + self.bot = MockBot() + self.cog = silence.Silence(self.bot) + self.cog._init_task = asyncio.Future() + self.cog._init_task.set_result(None) + + # Avoid unawaited coroutine warnings. + self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() + + asyncio.run(self.cog._async_init()) # Populate instance attributes. + + self.channel = MockTextChannel() + self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) + self.channel.overwrites_for.return_value = self.overwrite + + async def test_sent_correct_message(self): + """Appropriate failure/success message was sent by the command.""" test_cases = ( - (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), - (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), - (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), + (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,), + (None, silence.MSG_SILENCE_PERMANENT, True,), + (5, silence.MSG_SILENCE_FAIL, False,), ) - for duration, result_message, _silence_patch_return in test_cases: - with self.subTest( - silence_duration=duration, - result_message=result_message, - starting_unsilenced_state=_silence_patch_return - ): - with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): - await self.cog.silence(self.cog, self.ctx, duration) - self.ctx.send.assert_called_once_with(result_message) - self.ctx.reset_mock() - - async def test_unsilence_sent_correct_discord_message(self): - """Check if proper message was sent when unsilencing channel.""" - test_cases = ( - (True, f"{Emojis.check_mark} unsilenced current channel."), - (False, f"{Emojis.cross_mark} current channel was not silenced.") + for duration, message, was_silenced in test_cases: + ctx = MockContext() + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): + with self.subTest(was_silenced=was_silenced, message=message, duration=duration): + await self.cog.silence.callback(self.cog, ctx, duration) + ctx.send.assert_called_once_with(message) + + async def test_skipped_already_silenced(self): + """Permissions were not set and `False` was returned for an already silenced channel.""" + subtests = ( + (False, PermissionOverwrite(send_messages=False, add_reactions=False)), + (True, PermissionOverwrite(send_messages=True, add_reactions=True)), + (True, PermissionOverwrite(send_messages=False, add_reactions=False)), ) - for _unsilence_patch_return, result_message in test_cases: - with self.subTest( - starting_silenced_state=_unsilence_patch_return, - result_message=result_message - ): - with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): - await self.cog.unsilence(self.cog, self.ctx) - self.ctx.send.assert_called_once_with(result_message) - self.ctx.reset_mock() - - async def test_silence_private_for_false(self): - """Permissions are not set and `False` is returned in an already silenced channel.""" - perm_overwrite = Mock(send_messages=False) - channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - - self.assertFalse(await self.cog._silence(channel, True, None)) - channel.set_permissions.assert_not_called() - async def test_silence_private_silenced_channel(self): - """Channel had `send_message` permissions revoked.""" - channel = MockTextChannel() - self.assertTrue(await self.cog._silence(channel, False, None)) - channel.set_permissions.assert_called_once() - self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) + for contains, overwrite in subtests: + with self.subTest(contains=contains, overwrite=overwrite): + self.cog.scheduler.__contains__.return_value = contains + channel = MockTextChannel() + channel.overwrites_for.return_value = overwrite + + self.assertFalse(await self.cog._set_silence_overwrites(channel)) + channel.set_permissions.assert_not_called() + + async def test_silenced_channel(self): + """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" + self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) + self.assertFalse(self.overwrite.send_messages) + self.assertFalse(self.overwrite.add_reactions) + self.channel.set_permissions.assert_awaited_once_with( + self.cog._verified_role, + overwrite=self.overwrite + ) - async def test_silence_private_preserves_permissions(self): - """Previous permissions were preserved when channel was silenced.""" - channel = MockTextChannel() - # Set up mock channel permission state. - mock_permissions = PermissionOverwrite() - mock_permissions_dict = dict(mock_permissions) - channel.overwrites_for.return_value = mock_permissions - await self.cog._silence(channel, False, None) - new_permissions = channel.set_permissions.call_args.kwargs - # Remove 'send_messages' key because it got changed in the method. - del new_permissions['send_messages'] - del mock_permissions_dict['send_messages'] - self.assertDictEqual(mock_permissions_dict, new_permissions) - - async def test_silence_private_notifier(self): - """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" - channel = MockTextChannel() - with mock.patch.object(self.cog, "notifier", create=True): - with self.subTest(persistent=True): - await self.cog._silence(channel, True, None) - self.cog.notifier.add_channel.assert_called_once() - - with mock.patch.object(self.cog, "notifier", create=True): - with self.subTest(persistent=False): - await self.cog._silence(channel, False, None) - self.cog.notifier.add_channel.assert_not_called() - - async def test_silence_private_added_muted_channel(self): - """Channel was added to `muted_channels` on silence.""" + async def test_preserved_other_overwrites(self): + """Channel's other unrelated overwrites were not changed.""" + prev_overwrite_dict = dict(self.overwrite) + await self.cog._set_silence_overwrites(self.channel) + new_overwrite_dict = dict(self.overwrite) + + # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. + del prev_overwrite_dict['send_messages'] + del prev_overwrite_dict['add_reactions'] + del new_overwrite_dict['send_messages'] + del new_overwrite_dict['add_reactions'] + + self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + + async def test_temp_not_added_to_notifier(self): + """Channel was not added to notifier if a duration was set for the silence.""" + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): + await self.cog.silence.callback(self.cog, MockContext(), 15) + self.cog.notifier.add_channel.assert_not_called() + + async def test_indefinite_added_to_notifier(self): + """Channel was added to notifier if a duration was not set for the silence.""" + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): + await self.cog.silence.callback(self.cog, MockContext(), None) + self.cog.notifier.add_channel.assert_called_once() + + async def test_silenced_not_added_to_notifier(self): + """Channel was not added to the notifier if it was already silenced.""" + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False): + await self.cog.silence.callback(self.cog, MockContext(), 15) + self.cog.notifier.add_channel.assert_not_called() + + async def test_cached_previous_overwrites(self): + """Channel's previous overwrites were cached.""" + overwrite_json = '{"send_messages": true, "add_reactions": false}' + await self.cog._set_silence_overwrites(self.channel) + self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + + @autospec(silence, "datetime") + async def test_cached_unsilence_time(self, datetime_mock): + """The UTC POSIX timestamp for the unsilence was cached.""" + now_timestamp = 100 + duration = 15 + timestamp = now_timestamp + duration * 60 + datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) + + ctx = MockContext(channel=self.channel) + await self.cog.silence.callback(self.cog, ctx, duration) + + self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) + datetime_mock.now.assert_called_once_with(tz=timezone.utc) # Ensure it's using an aware dt. + + async def test_cached_indefinite_time(self): + """A value of -1 was cached for a permanent silence.""" + ctx = MockContext(channel=self.channel) + await self.cog.silence.callback(self.cog, ctx, None) + self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) + + async def test_scheduled_task(self): + """An unsilence task was scheduled.""" + ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + + await self.cog.silence.callback(self.cog, ctx, 5) + + args = (300, ctx.channel.id, ctx.invoke.return_value) + self.cog.scheduler.schedule_later.assert_called_once_with(*args) + ctx.invoke.assert_called_once_with(self.cog.unsilence) + + async def test_permanent_not_scheduled(self): + """A task was not scheduled for a permanent silence.""" + ctx = MockContext(channel=self.channel) + await self.cog.silence.callback(self.cog, ctx, None) + self.cog.scheduler.schedule_later.assert_not_called() + + +@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) +class UnsilenceTests(unittest.IsolatedAsyncioTestCase): + """Tests for the unsilence command and its related helper methods.""" + + @autospec(silence.Silence, "_reschedule", pass_mocks=False) + @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) + def setUp(self) -> None: + self.bot = MockBot(get_channel=lambda _: MockTextChannel()) + self.cog = silence.Silence(self.bot) + self.cog._init_task = asyncio.Future() + self.cog._init_task.set_result(None) + + overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) + self.cog.previous_overwrites = overwrites_cache + + asyncio.run(self.cog._async_init()) # Populate instance attributes. + + self.cog.scheduler.__contains__.return_value = True + overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' + self.channel = MockTextChannel() + self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) + self.channel.overwrites_for.return_value = self.overwrite + + async def test_sent_correct_message(self): + """Appropriate failure/success message was sent by the command.""" + unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) + test_cases = ( + (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite), + (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), + (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), + (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)), + (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)), + ) + for was_unsilenced, message, overwrite in test_cases: + ctx = MockContext() + with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): + with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): + ctx.channel.overwrites_for.return_value = overwrite + await self.cog.unsilence.callback(self.cog, ctx) + ctx.channel.send.assert_called_once_with(message) + + async def test_skipped_already_unsilenced(self): + """Permissions were not set and `False` was returned for an already unsilenced channel.""" + self.cog.scheduler.__contains__.return_value = False + self.cog.previous_overwrites.get.return_value = None channel = MockTextChannel() - with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._silence(channel, False, None) - muted_channels.add.assert_called_once_with(channel) - async def test_unsilence_private_for_false(self): - """Permissions are not set and `False` is returned in an unsilenced channel.""" - channel = Mock() self.assertFalse(await self.cog._unsilence(channel)) channel.set_permissions.assert_not_called() - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_unsilenced_channel(self, _): - """Channel had `send_message` permissions restored""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - self.assertTrue(await self.cog._unsilence(channel)) - channel.set_permissions.assert_called_once() - self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_removed_notifier(self, notifier): - """Channel was removed from `notifier` on unsilence.""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - await self.cog._unsilence(channel) - notifier.remove_channel.assert_called_once_with(channel) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_removed_muted_channel(self, _): - """Channel was removed from `muted_channels` on unsilence.""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._unsilence(channel) - muted_channels.discard.assert_called_once_with(channel) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_preserves_permissions(self, _): - """Previous permissions were preserved when channel was unsilenced.""" - channel = MockTextChannel() - # Set up mock channel permission state. - mock_permissions = PermissionOverwrite(send_messages=False) - mock_permissions_dict = dict(mock_permissions) - channel.overwrites_for.return_value = mock_permissions - await self.cog._unsilence(channel) - new_permissions = channel.set_permissions.call_args.kwargs - # Remove 'send_messages' key because it got changed in the method. - del new_permissions['send_messages'] - del mock_permissions_dict['send_messages'] - self.assertDictEqual(mock_permissions_dict, new_permissions) - - @mock.patch("bot.exts.moderation.silence.asyncio") - @mock.patch.object(Silence, "_mod_alerts_channel", create=True) - def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): - """Task for sending an alert was created with present `muted_channels`.""" - with mock.patch.object(self.cog, "muted_channels"): - self.cog.cog_unload() - alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") - asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - - @mock.patch("bot.exts.moderation.silence.asyncio") - def test_cog_unload_skips_task_start(self, asyncio_mock): - """No task created with no channels.""" - self.cog.cog_unload() - asyncio_mock.create_task.assert_not_called() + async def test_restored_overwrites(self): + """Channel's `send_message` and `add_reactions` overwrites were restored.""" + await self.cog._unsilence(self.channel) + self.channel.set_permissions.assert_awaited_once_with( + self.cog._verified_role, + overwrite=self.overwrite, + ) - @mock.patch("discord.ext.commands.has_any_role") - @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) - async def test_cog_check(self, role_check): - """Role check is called with `MODERATION_ROLES`""" - role_check.return_value.predicate = mock.AsyncMock() - await self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(*(1, 2, 3)) - role_check.return_value.predicate.assert_awaited_once_with(self.ctx) + # Recall that these values are determined by the fixture. + self.assertTrue(self.overwrite.send_messages) + self.assertFalse(self.overwrite.add_reactions) + + async def test_cache_miss_used_default_overwrites(self): + """Both overwrites were set to None due previous values not being found in the cache.""" + self.cog.previous_overwrites.get.return_value = None + + await self.cog._unsilence(self.channel) + self.channel.set_permissions.assert_awaited_once_with( + self.cog._verified_role, + overwrite=self.overwrite, + ) + + self.assertIsNone(self.overwrite.send_messages) + self.assertIsNone(self.overwrite.add_reactions) + + async def test_cache_miss_sent_mod_alert(self): + """A message was sent to the mod alerts channel.""" + self.cog.previous_overwrites.get.return_value = None + + await self.cog._unsilence(self.channel) + self.cog._mod_alerts_channel.send.assert_awaited_once() + + async def test_removed_notifier(self): + """Channel was removed from `notifier`.""" + await self.cog._unsilence(self.channel) + self.cog.notifier.remove_channel.assert_called_once_with(self.channel) + + async def test_deleted_cached_overwrite(self): + """Channel was deleted from the overwrites cache.""" + await self.cog._unsilence(self.channel) + self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) + + async def test_deleted_cached_time(self): + """Channel was deleted from the timestamp cache.""" + await self.cog._unsilence(self.channel) + self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) + + async def test_cancelled_task(self): + """The scheduled unsilence task should be cancelled.""" + await self.cog._unsilence(self.channel) + self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) + + async def test_preserved_other_overwrites(self): + """Channel's other unrelated overwrites were not changed, including cache misses.""" + for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None): + with self.subTest(overwrite_json=overwrite_json): + self.cog.previous_overwrites.get.return_value = overwrite_json + + prev_overwrite_dict = dict(self.overwrite) + await self.cog._unsilence(self.channel) + new_overwrite_dict = dict(self.overwrite) + + # Remove these keys because they were modified by the unsilence. + del prev_overwrite_dict['send_messages'] + del prev_overwrite_dict['add_reactions'] + del new_overwrite_dict['send_messages'] + del new_overwrite_dict['add_reactions'] + + self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) diff --git a/tests/helpers.py b/tests/helpers.py index e47fdf28f..870f66197 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,7 +5,7 @@ import itertools import logging import unittest.mock from asyncio import AbstractEventLoop -from typing import Callable, Iterable, Optional +from typing import Iterable, Optional import discord from aiohttp import ClientSession @@ -14,6 +14,7 @@ from discord.ext.commands import Context from bot.api import APIClient from bot.async_stats import AsyncStatsClient from bot.bot import Bot +from tests._autospec import autospec # noqa: F401 other modules import it via this module for logger in logging.Logger.manager.loggerDict.values(): @@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -def autospec(target, *attributes: str, **kwargs) -> Callable: - """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" - # Caller's kwargs should take priority and overwrite the defaults. - kwargs = {'spec_set': True, 'autospec': True, **kwargs} - - # Import the target if it's a string. - # This is to support both object and string targets like patch.multiple. - if type(target) is str: - target = unittest.mock._importer(target) - - def decorator(func): - for attribute in attributes: - patcher = unittest.mock.patch.object(target, attribute, **kwargs) - func = patcher(func) - return func - return decorator - - class HashableMixin(discord.mixins.EqualityComparable): """ Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. |