aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Kieran Siek <[email protected]>2020-06-16 20:21:29 +0800
committerGravatar GitHub <[email protected]>2020-06-16 20:21:29 +0800
commit13ab8f1be24239cd9406547e16a90471e9b8652b (patch)
tree14af82dcd11e0cf2a434c71cd48a5395fbcd51ff
parentRevise inaccurate docstring in RedisCache (diff)
parentLog exception info for failed attachment uploads (diff)
Merge branch 'master' into help_channel_rediscache
-rw-r--r--.github/workflows/codeql-analysis.yml32
-rw-r--r--bot/cogs/filtering.py61
-rw-r--r--bot/cogs/help_channels.py2
-rw-r--r--bot/cogs/moderation/modlog.py4
-rw-r--r--bot/cogs/moderation/silence.py37
-rw-r--r--bot/cogs/site.py2
-rw-r--r--bot/cogs/stats.py3
-rw-r--r--bot/cogs/token_remover.py119
-rw-r--r--bot/cogs/utils.py4
-rw-r--r--bot/constants.py1
-rw-r--r--bot/converters.py5
-rw-r--r--bot/utils/__init__.py5
-rw-r--r--bot/utils/messages.py2
-rw-r--r--bot/utils/redis_cache.py11
-rw-r--r--config-default.yml3
-rw-r--r--docker-compose.yml9
-rw-r--r--tests/bot/cogs/moderation/test_silence.py18
-rw-r--r--tests/bot/cogs/test_token_remover.py367
-rw-r--r--tests/bot/test_converters.py113
-rw-r--r--tests/bot/utils/test_redis_cache.py10
-rw-r--r--tests/helpers.py20
21 files changed, 585 insertions, 243 deletions
diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml
new file mode 100644
index 000000000..8760b35ec
--- /dev/null
+++ b/.github/workflows/codeql-analysis.yml
@@ -0,0 +1,32 @@
+name: "Code scanning - action"
+
+on:
+ push:
+ pull_request:
+ schedule:
+ - cron: '0 12 * * *'
+
+jobs:
+ CodeQL-Build:
+
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v2
+ with:
+ fetch-depth: 2
+
+ - run: git checkout HEAD^2
+ if: ${{ github.event_name == 'pull_request' }}
+
+ - name: Initialize CodeQL
+ uses: github/codeql-action/init@v1
+ with:
+ languages: python
+
+ - name: Autobuild
+ uses: github/codeql-action/autobuild@v1
+
+ - name: Perform CodeQL Analysis
+ uses: github/codeql-action/analyze@v1
diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py
index 1d9fddb12..4ebc831e1 100644
--- a/bot/cogs/filtering.py
+++ b/bot/cogs/filtering.py
@@ -1,6 +1,8 @@
+import asyncio
import logging
import re
-from typing import Optional, Union
+from datetime import datetime, timedelta
+from typing import List, Optional, Union
import discord.errors
from dateutil.relativedelta import relativedelta
@@ -14,6 +16,7 @@ from bot.constants import (
Channels, Colours,
Filter, Icons, URLs
)
+from bot.utils.redis_cache import RedisCache
log = logging.getLogger(__name__)
@@ -40,6 +43,8 @@ TOKEN_WATCHLIST_PATTERNS = [
]
WATCHLIST_PATTERNS = WORD_WATCHLIST_PATTERNS + TOKEN_WATCHLIST_PATTERNS
+DAYS_BETWEEN_ALERTS = 3
+
def expand_spoilers(text: str) -> str:
"""Return a string containing all interpretations of a spoilered message."""
@@ -52,8 +57,12 @@ def expand_spoilers(text: str) -> str:
class Filtering(Cog):
"""Filtering out invites, blacklisting domains, and warning us of certain regular expressions."""
+ # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent
+ name_alerts = RedisCache()
+
def __init__(self, bot: Bot):
self.bot = bot
+ self.name_lock = asyncio.Lock()
staff_mistake_str = "If you believe this was a mistake, please let staff know!"
self.filters = {
@@ -112,6 +121,7 @@ class Filtering(Cog):
async def on_message(self, msg: Message) -> None:
"""Invoke message filter for new messages."""
await self._filter_message(msg)
+ await self.check_bad_words_in_name(msg.author)
@Cog.listener()
async def on_message_edit(self, before: Message, after: Message) -> None:
@@ -126,6 +136,55 @@ class Filtering(Cog):
delta = relativedelta(after.edited_at, before.edited_at).microseconds
await self._filter_message(after, delta)
+ @staticmethod
+ def get_name_matches(name: str) -> List[re.Match]:
+ """Check bad words from passed string (name). Return list of matches."""
+ matches = []
+ for pattern in WATCHLIST_PATTERNS:
+ if match := pattern.search(name):
+ matches.append(match)
+ return matches
+
+ async def check_send_alert(self, member: Member) -> bool:
+ """When there is less than 3 days after last alert, return `False`, otherwise `True`."""
+ if last_alert := await self.name_alerts.get(member.id):
+ last_alert = datetime.utcfromtimestamp(last_alert)
+ if datetime.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert:
+ log.trace(f"Last alert was too recent for {member}'s nickname.")
+ return False
+
+ return True
+
+ async def check_bad_words_in_name(self, member: Member) -> None:
+ """Send a mod alert every 3 days if a username still matches a watchlist pattern."""
+ # Use lock to avoid race conditions
+ async with self.name_lock:
+ # Check whether the users display name contains any words in our blacklist
+ matches = self.get_name_matches(member.display_name)
+
+ if not matches or not await self.check_send_alert(member):
+ return
+
+ log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).")
+
+ log_string = (
+ f"**User:** {member.mention} (`{member.id}`)\n"
+ f"**Display Name:** {member.display_name}\n"
+ f"**Bad Matches:** {', '.join(match.group() for match in matches)}"
+ )
+
+ await self.mod_log.send_log_message(
+ icon_url=Icons.token_removed,
+ colour=Colours.soft_red,
+ title="Username filtering alert",
+ text=log_string,
+ channel_id=Channels.mod_alerts,
+ thumbnail=member.avatar_url
+ )
+
+ # Update time when alert sent
+ await self.name_alerts.set(member.id, datetime.utcnow().timestamp())
+
async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None:
"""Filter the input message to see if it violates any of our rules, and then respond accordingly."""
# Should we filter this message?
diff --git a/bot/cogs/help_channels.py b/bot/cogs/help_channels.py
index 4c464a7d2..187adfe51 100644
--- a/bot/cogs/help_channels.py
+++ b/bot/cogs/help_channels.py
@@ -22,7 +22,7 @@ log = logging.getLogger(__name__)
ASKING_GUIDE_URL = "https://pythondiscord.com/pages/asking-good-questions/"
MAX_CHANNELS_PER_CATEGORY = 50
-EXCLUDED_CHANNELS = (constants.Channels.how_to_get_help,)
+EXCLUDED_CHANNELS = (constants.Channels.how_to_get_help, constants.Channels.cooldown)
HELP_CHANNEL_TOPIC = """
This is a Python help channel. You can claim your own help channel in the Python Help: Available category.
diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py
index 9d28030d9..41472c64c 100644
--- a/bot/cogs/moderation/modlog.py
+++ b/bot/cogs/moderation/modlog.py
@@ -555,6 +555,10 @@ class ModLog(Cog, name="ModLog"):
channel = message.channel
author = message.author
+ # Ignore DMs.
+ if not message.guild:
+ return
+
if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist:
return
diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py
index 25febfa51..c8ab6443b 100644
--- a/bot/cogs/moderation/silence.py
+++ b/bot/cogs/moderation/silence.py
@@ -1,7 +1,7 @@
import asyncio
import logging
from contextlib import suppress
-from typing import Optional
+from typing import NamedTuple, Optional
from discord import TextChannel
from discord.ext import commands, tasks
@@ -11,10 +11,18 @@ from bot.bot import Bot
from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles
from bot.converters import HushDurationConverter
from bot.utils.checks import with_role_check
+from bot.utils.scheduling import Scheduler
log = logging.getLogger(__name__)
+class TaskData(NamedTuple):
+ """Data for a scheduled task."""
+
+ delay: int
+ ctx: Context
+
+
class SilenceNotifier(tasks.Loop):
"""Loop notifier for posting notices to `alert_channel` containing added channels."""
@@ -53,15 +61,25 @@ class SilenceNotifier(tasks.Loop):
await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}")
-class Silence(commands.Cog):
+class Silence(Scheduler, commands.Cog):
"""Commands for stopping channel messages for `verified` role in a channel."""
def __init__(self, bot: Bot):
+ super().__init__()
self.bot = bot
self.muted_channels = set()
self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars())
self._get_instance_vars_event = asyncio.Event()
+ async def _scheduled_task(self, task: TaskData) -> None:
+ """Calls `self.unsilence` on expired silenced channel to unsilence it."""
+ await asyncio.sleep(task.delay)
+ log.info("Unsilencing channel after set delay.")
+
+ # Because `self.unsilence` explicitly cancels this scheduled task, it is shielded
+ # to avoid prematurely cancelling itself
+ await asyncio.shield(task.ctx.invoke(self.unsilence))
+
async def _get_instance_vars(self) -> None:
"""Get instance variables after they're available to get from the guild."""
await self.bot.wait_until_guild_available()
@@ -90,9 +108,13 @@ class Silence(commands.Cog):
return
await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).")
- await asyncio.sleep(duration*60)
- log.info("Unsilencing channel after set delay.")
- await ctx.invoke(self.unsilence)
+
+ task_data = TaskData(
+ delay=duration*60,
+ ctx=ctx
+ )
+
+ self.schedule_task(ctx.channel.id, task_data)
@commands.command(aliases=("unhush",))
async def unsilence(self, ctx: Context) -> None:
@@ -103,7 +125,9 @@ class Silence(commands.Cog):
"""
await self._get_instance_vars_event.wait()
log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.")
- if await self._unsilence(ctx.channel):
+ if not await self._unsilence(ctx.channel):
+ await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.")
+ else:
await ctx.send(f"{Emojis.check_mark} unsilenced current channel.")
async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool:
@@ -140,6 +164,7 @@ class Silence(commands.Cog):
if current_overwrite.send_messages is False:
await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None))
log.info(f"Unsilenced channel #{channel} ({channel.id}).")
+ self.cancel_task(channel.id)
self.notifier.remove_channel(channel)
self.muted_channels.discard(channel)
return True
diff --git a/bot/cogs/site.py b/bot/cogs/site.py
index e61cd5003..ac29daa1d 100644
--- a/bot/cogs/site.py
+++ b/bot/cogs/site.py
@@ -33,7 +33,7 @@ class Site(Cog):
embed.colour = Colour.blurple()
embed.description = (
f"[Our official website]({url}) is an open-source community project "
- "created with Python and Flask. It contains information about the server "
+ "created with Python and Django. It contains information about the server "
"itself, lets you sign up for upcoming events, has its own wiki, contains "
"a list of valuable learning resources, and much more."
)
diff --git a/bot/cogs/stats.py b/bot/cogs/stats.py
index 4ebb6423c..d42f55466 100644
--- a/bot/cogs/stats.py
+++ b/bot/cogs/stats.py
@@ -36,7 +36,8 @@ class Stats(Cog):
if message.guild.id != Guild.id:
return
- if message.channel.category.id == Categories.modmail:
+ cat = getattr(message.channel, "category", None)
+ if cat is not None and cat.id == Categories.modmail:
if message.channel.id != Channels.incidents:
# Do not report modmail channels to stats, there are too many
# of them for interesting statistics to be drawn out of this.
diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py
index 6721f0e02..d55e079e9 100644
--- a/bot/cogs/token_remover.py
+++ b/bot/cogs/token_remover.py
@@ -2,20 +2,22 @@ import base64
import binascii
import logging
import re
-import struct
import typing as t
-from datetime import datetime
from discord import Colour, Message
from discord.ext.commands import Cog
-from discord.utils import snowflake_time
+from bot import utils
from bot.bot import Bot
from bot.cogs.moderation import ModLog
from bot.constants import Channels, Colours, Event, Icons
log = logging.getLogger(__name__)
+LOG_MESSAGE = (
+ "Censored a seemingly valid token sent by {author} (`{author_id}`) in {channel}, "
+ "token was `{user_id}.{timestamp}.{hmac}`"
+)
DELETION_MESSAGE_TEMPLATE = (
"Hey {mention}! I noticed you posted a seemingly valid Discord API "
"token in your message and have removed your message. "
@@ -25,15 +27,22 @@ DELETION_MESSAGE_TEMPLATE = (
"Feel free to re-post it with the token removed. "
"If you believe this was a mistake, please let us know!"
)
-DISCORD_EPOCH_TIMESTAMP = datetime(2017, 1, 1)
+DISCORD_EPOCH = 1_420_070_400
TOKEN_EPOCH = 1_293_840_000
-TOKEN_RE = re.compile(
- r"[^\s\.()\"']+" # Matches token part 1: The user ID string, encoded as base64
- r"\." # Matches a literal dot between the token parts
- r"[^\s\.()\"']+" # Matches token part 2: The creation timestamp, as an integer
- r"\." # Matches a literal dot between the token parts
- r"[^\s\.()\"']+" # Matches token part 3: The HMAC, unused by us, but check that it isn't empty
-)
+
+# Three parts delimited by dots: user ID, creation timestamp, HMAC.
+# The HMAC isn't parsed further, but it's in the regex to ensure it at least exists in the string.
+# Each part only matches base64 URL-safe characters.
+# Padding has never been observed, but the padding character '=' is matched just in case.
+TOKEN_RE = re.compile(r"([\w\-=]+)\.([\w\-=]+)\.([\w\-=]+)", re.ASCII)
+
+
+class Token(t.NamedTuple):
+ """A Discord Bot token."""
+
+ user_id: str
+ timestamp: str
+ hmac: str
class TokenRemover(Cog):
@@ -65,64 +74,58 @@ class TokenRemover(Cog):
See: https://discordapp.com/developers/docs/reference#snowflakes
"""
- found_token = self.find_token_in_message(after)
- if found_token:
- await self.take_action(after, found_token)
+ await self.on_message(after)
- async def take_action(self, msg: Message, found_token: str) -> None:
- """Remove the `msg` containing a token an send a mod_log message."""
- user_id, creation_timestamp, hmac = found_token.split('.')
+ async def take_action(self, msg: Message, found_token: Token) -> None:
+ """Remove the `msg` containing the `found_token` and send a mod log message."""
self.mod_log.ignore(Event.message_delete, msg.id)
await msg.delete()
await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention))
- message = (
- "Censored a seemingly valid token sent by "
- f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was "
- f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`"
- )
- log.debug(message)
+ log_message = self.format_log_message(msg, found_token)
+ log.debug(log_message)
# Send pretty mod log embed to mod-alerts
await self.mod_log.send_log_message(
icon_url=Icons.token_removed,
colour=Colour(Colours.soft_red),
title="Token removed!",
- text=message,
+ text=log_message,
thumbnail=msg.author.avatar_url_as(static_format="png"),
channel_id=Channels.mod_alerts,
)
self.bot.stats.incr("tokens.removed_tokens")
+ @staticmethod
+ def format_log_message(msg: Message, token: Token) -> str:
+ """Return the log message to send for `token` being censored in `msg`."""
+ return LOG_MESSAGE.format(
+ author=msg.author,
+ author_id=msg.author.id,
+ channel=msg.channel.mention,
+ user_id=token.user_id,
+ timestamp=token.timestamp,
+ hmac='x' * len(token.hmac),
+ )
+
@classmethod
- def find_token_in_message(cls, msg: Message) -> t.Optional[str]:
+ def find_token_in_message(cls, msg: Message) -> t.Optional[Token]:
"""Return a seemingly valid token found in `msg` or `None` if no token is found."""
if msg.author.bot:
return
- # Use findall rather than search to guard against method calls prematurely returning the
+ # Use finditer rather than search to guard against method calls prematurely returning the
# token check (e.g. `message.channel.send` also matches our token pattern)
- maybe_matches = TOKEN_RE.findall(msg.content)
- for substr in maybe_matches:
- if cls.is_maybe_token(substr):
+ for match in TOKEN_RE.finditer(msg.content):
+ token = Token(*match.groups())
+ if cls.is_valid_user_id(token.user_id) and cls.is_valid_timestamp(token.timestamp):
# Short-circuit on first match
- return substr
+ return token
# No matching substring
return
- @classmethod
- def is_maybe_token(cls, test_str: str) -> bool:
- """Check the provided string to see if it is a seemingly valid token."""
- try:
- user_id, creation_timestamp, hmac = test_str.split('.')
- except ValueError:
- return False
-
- if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp):
- return True
-
@staticmethod
def is_valid_user_id(b64_content: str) -> bool:
"""
@@ -130,29 +133,41 @@ class TokenRemover(Cog):
See: https://discordapp.com/developers/docs/reference#snowflakes
"""
- b64_content += '=' * (-len(b64_content) % 4)
+ b64_content = utils.pad_base64(b64_content)
try:
- content: bytes = base64.b64decode(b64_content)
- return content.decode('utf-8').isnumeric()
- except (binascii.Error, UnicodeDecodeError):
+ decoded_bytes = base64.urlsafe_b64decode(b64_content)
+ string = decoded_bytes.decode('utf-8')
+
+ # isdigit on its own would match a lot of other Unicode characters, hence the isascii.
+ return string.isascii() and string.isdigit()
+ except (binascii.Error, ValueError):
return False
@staticmethod
def is_valid_timestamp(b64_content: str) -> bool:
"""
- Check potential token to see if it contains a valid timestamp.
+ Return True if `b64_content` decodes to a valid timestamp.
- See: https://discordapp.com/developers/docs/reference#snowflakes
+ If the timestamp is greater than the Discord epoch, it's probably valid.
+ See: https://i.imgur.com/7WdehGn.png
"""
- b64_content += '=' * (-len(b64_content) % 4)
+ b64_content = utils.pad_base64(b64_content)
try:
- content = base64.urlsafe_b64decode(b64_content)
- snowflake = struct.unpack('i', content)[0]
- except (binascii.Error, struct.error):
+ decoded_bytes = base64.urlsafe_b64decode(b64_content)
+ timestamp = int.from_bytes(decoded_bytes, byteorder="big")
+ except (binascii.Error, ValueError) as e:
+ log.debug(f"Failed to decode token timestamp '{b64_content}': {e}")
+ return False
+
+ # Seems like newer tokens don't need the epoch added, but add anyway since an upper bound
+ # is not checked.
+ if timestamp + TOKEN_EPOCH >= DISCORD_EPOCH:
+ return True
+ else:
+ log.debug(f"Invalid token timestamp '{b64_content}': smaller than Discord epoch")
return False
- return snowflake_time(snowflake + TOKEN_EPOCH) < DISCORD_EPOCH_TIMESTAMP
def setup(bot: Bot) -> None:
diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py
index 73b4a1c0a..697bf60ce 100644
--- a/bot/cogs/utils.py
+++ b/bot/cogs/utils.py
@@ -6,7 +6,7 @@ from email.parser import HeaderParser
from io import StringIO
from typing import Tuple, Union
-from discord import Colour, Embed
+from discord import Colour, Embed, utils
from discord.ext.commands import BadArgument, Cog, Context, command
from bot.bot import Bot
@@ -145,7 +145,7 @@ class Utils(Cog):
u_code = f"\\U{digit:>08}"
url = f"https://www.compart.com/en/unicode/U+{digit:>04}"
name = f"[{unicodedata.name(char, '')}]({url})"
- info = f"`{u_code.ljust(10)}`: {name} - {char}"
+ info = f"`{u_code.ljust(10)}`: {name} - {utils.escape_markdown(char)}"
return info, u_code
charlist, rawlist = zip(*(get_info(c) for c in characters))
diff --git a/bot/constants.py b/bot/constants.py
index b31a9c99e..470221369 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -389,6 +389,7 @@ class Channels(metaclass=YAMLGetter):
attachment_log: int
big_brother_logs: int
bot_commands: int
+ cooldown: int
defcon: int
dev_contrib: int
dev_core: int
diff --git a/bot/converters.py b/bot/converters.py
index 72c46fdf0..4deb59f87 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -217,7 +217,10 @@ class Duration(Converter):
delta = relativedelta(**duration_dict)
now = datetime.utcnow()
- return now + delta
+ try:
+ return now + delta
+ except ValueError:
+ raise BadArgument(f"`{duration}` results in a datetime outside the supported range.")
class ISODateTime(Converter):
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index c5a12d5e3..5a6e1811b 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -11,3 +11,8 @@ class CogABCMeta(CogMeta, ABCMeta):
"""Metaclass for ABCs meant to be implemented as Cogs."""
pass
+
+
+def pad_base64(data: str) -> str:
+ """Return base64 `data` with padding characters to ensure its length is a multiple of 4."""
+ return data + "=" * (-len(data) % 4)
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index de8e186f3..23519a514 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -97,7 +97,7 @@ async def send_attachments(
if link_large and e.status == 413:
large.append(attachment)
else:
- log.warning(f"{failure_msg} with status {e.status}.")
+ log.warning(f"{failure_msg} with status {e.status}.", exc_info=e)
if link_large and large:
desc = "\n".join(f"[{attachment.filename}]({attachment.url})" for attachment in large)
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py
index f342bbb62..58cfe1df5 100644
--- a/bot/utils/redis_cache.py
+++ b/bot/utils/redis_cache.py
@@ -101,16 +101,7 @@ class RedisCache:
def _set_namespace(self, namespace: str) -> None:
"""Try to set the namespace, but do not permit collisions."""
- # We need a unique namespace, to prevent collisions. This loop
- # will try appending underscores to the end of the namespace until
- # it finds one that is unique.
- #
- # For example, if `john` and `john_` are both taken, the namespace will
- # be `john__` at the end of this loop.
- while namespace in self._namespaces:
- namespace += "_"
-
- log.trace(f"RedisCache setting namespace to {self._namespace}")
+ log.trace(f"RedisCache setting namespace to {namespace}")
self._namespaces.append(namespace)
self._namespace = namespace
diff --git a/config-default.yml b/config-default.yml
index 2c85f5ef3..aff5fb2e1 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -142,6 +142,7 @@ guild:
# Python Help: Available
how_to_get_help: 704250143020417084
+ cooldown: 720603994149486673
# Logs
attachment_log: &ATTACH_LOG 649243850006855680
@@ -297,6 +298,8 @@ filter:
- 613425648685547541 # Discord Developers
- 185590609631903755 # Blender Hub
- 420324994703163402 # /r/FlutterDev
+ - 488751051629920277 # Python Atlanta
+ - 143867839282020352 # C#
domain_blacklist:
- pornhub.com
diff --git a/docker-compose.yml b/docker-compose.yml
index 9884e35f0..cff7d33d6 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -17,6 +17,14 @@ services:
ports:
- "127.0.0.1:6379:6379"
+ snekbox:
+ image: pythondiscord/snekbox:latest
+ init: true
+ ipc: none
+ ports:
+ - "127.0.0.1:8060:8060"
+ privileged: true
+
web:
image: pythondiscord/site:latest
command: ["run", "--debug"]
@@ -47,6 +55,7 @@ services:
depends_on:
- web
- redis
+ - snekbox
environment:
BOT_TOKEN: ${BOT_TOKEN}
BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531
diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py
index 3fd149f04..ab3d0742a 100644
--- a/tests/bot/cogs/moderation/test_silence.py
+++ b/tests/bot/cogs/moderation/test_silence.py
@@ -127,10 +127,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
self.ctx.reset_mock()
async def test_unsilence_sent_correct_discord_message(self):
- """Proper reply after a successful unsilence."""
- with mock.patch.object(self.cog, "_unsilence", return_value=True):
- await self.cog.unsilence.callback(self.cog, self.ctx)
- self.ctx.send.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.")
+ """Check if proper message was sent when unsilencing channel."""
+ test_cases = (
+ (True, f"{Emojis.check_mark} unsilenced current channel."),
+ (False, f"{Emojis.cross_mark} current channel was not silenced.")
+ )
+ for _unsilence_patch_return, result_message in test_cases:
+ with self.subTest(
+ starting_silenced_state=_unsilence_patch_return,
+ result_message=result_message
+ ):
+ with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return):
+ await self.cog.unsilence.callback(self.cog, self.ctx)
+ self.ctx.send.assert_called_once_with(result_message)
+ self.ctx.reset_mock()
async def test_silence_private_for_false(self):
"""Permissions are not set and `False` is returned in an already silenced channel."""
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
index 33d1ec170..a10124d2d 100644
--- a/tests/bot/cogs/test_token_remover.py
+++ b/tests/bot/cogs/test_token_remover.py
@@ -1,56 +1,89 @@
-import asyncio
-import logging
import unittest
-from unittest.mock import AsyncMock, MagicMock
+from re import Match
+from unittest import mock
+from unittest.mock import MagicMock
from discord import Colour
-from bot.cogs.token_remover import (
- DELETION_MESSAGE_TEMPLATE,
- TokenRemover,
- setup as setup_cog,
-)
-from bot.constants import Channels, Colours, Event, Icons
-from tests.helpers import MockBot, MockMessage
+from bot import constants
+from bot.cogs import token_remover
+from bot.cogs.moderation import ModLog
+from bot.cogs.token_remover import Token, TokenRemover
+from tests.helpers import MockBot, MockMessage, autospec
-class TokenRemoverTests(unittest.TestCase):
+class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
"""Tests the `TokenRemover` cog."""
def setUp(self):
"""Adds the cog, a bot, and a message to the instance for usage in tests."""
self.bot = MockBot()
- self.bot.get_cog.return_value = MagicMock()
- self.bot.get_cog.return_value.send_log_message = AsyncMock()
self.cog = TokenRemover(bot=self.bot)
- self.msg = MockMessage(id=555, content='')
- self.msg.author.__str__ = MagicMock()
- self.msg.author.__str__.return_value = 'lemon'
- self.msg.author.bot = False
- self.msg.author.avatar_url_as.return_value = 'picture-lemon.png'
- self.msg.author.id = 42
- self.msg.author.mention = '@lemon'
+ self.msg = MockMessage(id=555, content="hello world")
self.msg.channel.mention = "#lemonade-stand"
+ self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name)
+ self.msg.author.avatar_url_as.return_value = "picture-lemon.png"
- def test_is_valid_user_id_is_true_for_numeric_content(self):
- """A string decoding to numeric characters is a valid user ID."""
- # MTIz = base64(123)
- self.assertTrue(TokenRemover.is_valid_user_id('MTIz'))
+ def test_is_valid_user_id_valid(self):
+ """Should consider user IDs valid if they decode entirely to ASCII digits."""
+ ids = (
+ "NDcyMjY1OTQzMDYyNDEzMzMy",
+ "NDc1MDczNjI5Mzk5NTQ3OTA0",
+ "NDY3MjIzMjMwNjUwNzc3NjQx",
+ )
+
+ for user_id in ids:
+ with self.subTest(user_id=user_id):
+ result = TokenRemover.is_valid_user_id(user_id)
+ self.assertTrue(result)
- def test_is_valid_user_id_is_false_for_alphabetic_content(self):
- """A string decoding to alphabetic characters is not a valid user ID."""
- # YWJj = base64(abc)
- self.assertFalse(TokenRemover.is_valid_user_id('YWJj'))
+ def test_is_valid_user_id_invalid(self):
+ """Should consider non-digit and non-ASCII IDs invalid."""
+ ids = (
+ ("SGVsbG8gd29ybGQ", "non-digit ASCII"),
+ ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"),
+ ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"),
+ ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"),
+ ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"),
+ ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
+ ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
+ )
- def test_is_valid_timestamp_is_true_for_valid_timestamps(self):
- """A string decoding to a valid timestamp should be recognized as such."""
- self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A'))
+ for user_id, msg in ids:
+ with self.subTest(msg=msg):
+ result = TokenRemover.is_valid_user_id(user_id)
+ self.assertFalse(result)
- def test_is_valid_timestamp_is_false_for_invalid_values(self):
- """A string not decoding to a valid timestamp should not be recognized as such."""
- # MTIz = base64(123)
- self.assertFalse(TokenRemover.is_valid_timestamp('MTIz'))
+ def test_is_valid_timestamp_valid(self):
+ """Should consider timestamps valid if they're greater than the Discord epoch."""
+ timestamps = (
+ "XsyRkw",
+ "Xrim9Q",
+ "XsyR-w",
+ "XsySD_",
+ "Dn9r_A",
+ )
+
+ for timestamp in timestamps:
+ with self.subTest(timestamp=timestamp):
+ result = TokenRemover.is_valid_timestamp(timestamp)
+ self.assertTrue(result)
+
+ def test_is_valid_timestamp_invalid(self):
+ """Should consider timestamps invalid if they're before Discord epoch or can't be parsed."""
+ timestamps = (
+ ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"),
+ ("ew", "123"),
+ ("AoIKgA", "42076800"),
+ ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
+ ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
+ )
+
+ for timestamp, msg in timestamps:
+ with self.subTest(msg=msg):
+ result = TokenRemover.is_valid_timestamp(timestamp)
+ self.assertFalse(result)
def test_mod_log_property(self):
"""The `mod_log` property should ask the bot to return the `ModLog` cog."""
@@ -58,74 +91,206 @@ class TokenRemoverTests(unittest.TestCase):
self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value)
self.bot.get_cog.assert_called_once_with('ModLog')
- def test_ignores_bot_messages(self):
- """When the message event handler is called with a bot message, nothing is done."""
+ async def test_on_message_edit_uses_on_message(self):
+ """The edit listener should delegate handling of the message to the normal listener."""
+ self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True)
+
+ await self.cog.on_message_edit(MockMessage(), self.msg)
+ self.cog.on_message.assert_awaited_once_with(self.msg)
+
+ @autospec(TokenRemover, "find_token_in_message", "take_action")
+ async def test_on_message_takes_action(self, find_token_in_message, take_action):
+ """Should take action if a valid token is found when a message is sent."""
+ cog = TokenRemover(self.bot)
+ found_token = "foobar"
+ find_token_in_message.return_value = found_token
+
+ await cog.on_message(self.msg)
+
+ find_token_in_message.assert_called_once_with(self.msg)
+ take_action.assert_awaited_once_with(cog, self.msg, found_token)
+
+ @autospec(TokenRemover, "find_token_in_message", "take_action")
+ async def test_on_message_skips_missing_token(self, find_token_in_message, take_action):
+ """Shouldn't take action if a valid token isn't found when a message is sent."""
+ cog = TokenRemover(self.bot)
+ find_token_in_message.return_value = False
+
+ await cog.on_message(self.msg)
+
+ find_token_in_message.assert_called_once_with(self.msg)
+ take_action.assert_not_awaited()
+
+ @autospec("bot.cogs.token_remover", "TOKEN_RE")
+ def test_find_token_ignores_bot_messages(self, token_re):
+ """The token finder should ignore messages authored by bots."""
self.msg.author.bot = True
- coroutine = self.cog.on_message(self.msg)
- self.assertIsNone(asyncio.run(coroutine))
-
- def test_ignores_messages_without_tokens(self):
- """Messages without anything looking like a token are ignored."""
- for content in ('', 'lemon wins'):
- with self.subTest(content=content):
- self.msg.content = content
- coroutine = self.cog.on_message(self.msg)
- self.assertIsNone(asyncio.run(coroutine))
-
- def test_ignores_messages_with_invalid_tokens(self):
- """Messages with values that are invalid tokens are ignored."""
- for content in ('foo.bar.baz', 'x.y.'):
- with self.subTest(content=content):
- self.msg.content = content
- coroutine = self.cog.on_message(self.msg)
- self.assertIsNone(asyncio.run(coroutine))
-
- def test_censors_valid_tokens(self):
- """Valid tokens are censored."""
- cases = (
- # (content, censored_token)
- ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'),
+
+ return_value = TokenRemover.find_token_in_message(self.msg)
+
+ self.assertIsNone(return_value)
+ token_re.finditer.assert_not_called()
+
+ @autospec("bot.cogs.token_remover", "TOKEN_RE")
+ def test_find_token_no_matches(self, token_re):
+ """None should be returned if the regex matches no tokens in a message."""
+ token_re.finditer.return_value = ()
+
+ return_value = TokenRemover.find_token_in_message(self.msg)
+
+ self.assertIsNone(return_value)
+ token_re.finditer.assert_called_once_with(self.msg.content)
+
+ @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp")
+ @autospec("bot.cogs.token_remover", "Token")
+ @autospec("bot.cogs.token_remover", "TOKEN_RE")
+ def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp):
+ """The first match with a valid user ID and timestamp should be returned as a `Token`."""
+ matches = [
+ mock.create_autospec(Match, spec_set=True, instance=True),
+ mock.create_autospec(Match, spec_set=True, instance=True),
+ ]
+ tokens = [
+ mock.create_autospec(Token, spec_set=True, instance=True),
+ mock.create_autospec(Token, spec_set=True, instance=True),
+ ]
+
+ token_re.finditer.return_value = matches
+ token_cls.side_effect = tokens
+ is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid.
+ is_valid_timestamp.return_value = True
+
+ return_value = TokenRemover.find_token_in_message(self.msg)
+
+ self.assertEqual(tokens[1], return_value)
+ token_re.finditer.assert_called_once_with(self.msg.content)
+
+ @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp")
+ @autospec("bot.cogs.token_remover", "Token")
+ @autospec("bot.cogs.token_remover", "TOKEN_RE")
+ def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp):
+ """None should be returned if no matches have valid user IDs or timestamps."""
+ token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)]
+ token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True)
+ is_valid_id.return_value = False
+ is_valid_timestamp.return_value = False
+
+ return_value = TokenRemover.find_token_in_message(self.msg)
+
+ self.assertIsNone(return_value)
+ token_re.finditer.assert_called_once_with(self.msg.content)
+
+ def test_regex_invalid_tokens(self):
+ """Messages without anything looking like a token are not matched."""
+ tokens = (
+ "",
+ "lemon wins",
+ "..",
+ "x.y",
+ "x.y.",
+ ".y.z",
+ ".y.",
+ "..z",
+ "x..z",
+ " . . ",
+ "\n.\n.\n",
+ "hellö.world.bye",
+ "base64.nötbåse64.morebase64",
+ "19jd3J.dfkm3d.€víł§tüff",
+ )
+
+ for token in tokens:
+ with self.subTest(token=token):
+ results = token_remover.TOKEN_RE.findall(token)
+ self.assertEqual(len(results), 0)
+
+ def test_regex_valid_tokens(self):
+ """Messages that look like tokens should be matched."""
+ # Don't worry, these tokens have been invalidated.
+ tokens = (
+ "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8",
+ "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8",
+ "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds",
+ "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4",
)
- for content, censored_token in cases:
- with self.subTest(content=content, censored_token=censored_token):
- self.msg.content = content
- coroutine = self.cog.on_message(self.msg)
- with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm:
- self.assertIsNone(asyncio.run(coroutine)) # no return value
-
- [line] = cm.output
- log_message = (
- "Censored a seemingly valid token sent by "
- "lemon (`42`) in #lemonade-stand, "
- f"token was `{censored_token}`"
- )
- self.assertIn(log_message, line)
-
- self.msg.delete.assert_called_once_with()
- self.msg.channel.send.assert_called_once_with(
- DELETION_MESSAGE_TEMPLATE.format(mention='@lemon')
- )
- self.bot.get_cog.assert_called_with('ModLog')
- self.msg.author.avatar_url_as.assert_called_once_with(static_format='png')
-
- mod_log = self.bot.get_cog.return_value
- mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id)
- mod_log.send_log_message.assert_called_once_with(
- icon_url=Icons.token_removed,
- colour=Colour(Colours.soft_red),
- title="Token removed!",
- text=log_message,
- thumbnail='picture-lemon.png',
- channel_id=Channels.mod_alerts
- )
-
-
-class TokenRemoverSetupTests(unittest.TestCase):
- """Tests setup of the `TokenRemover` cog."""
-
- def test_setup(self):
- """Setup of the extension should call add_cog."""
+ for token in tokens:
+ with self.subTest(token=token):
+ results = token_remover.TOKEN_RE.fullmatch(token)
+ self.assertIsNotNone(results, f"{token} was not matched by the regex")
+
+ def test_regex_matches_multiple_valid(self):
+ """Should support multiple matches in the middle of a string."""
+ token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8"
+ token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc"
+ message = f"garbage {token_1} hello {token_2} world"
+
+ results = token_remover.TOKEN_RE.finditer(message)
+ results = [match[0] for match in results]
+ self.assertCountEqual((token_1, token_2), results)
+
+ @autospec("bot.cogs.token_remover", "LOG_MESSAGE")
+ def test_format_log_message(self, log_message):
+ """Should correctly format the log message with info from the message and token."""
+ token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ log_message.format.return_value = "Howdy"
+
+ return_value = TokenRemover.format_log_message(self.msg, token)
+
+ self.assertEqual(return_value, log_message.format.return_value)
+ log_message.format.assert_called_once_with(
+ author=self.msg.author,
+ author_id=self.msg.author.id,
+ channel=self.msg.channel.mention,
+ user_id=token.user_id,
+ timestamp=token.timestamp,
+ hmac="x" * len(token.hmac),
+ )
+
+ @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock)
+ @autospec("bot.cogs.token_remover", "log")
+ @autospec(TokenRemover, "format_log_message")
+ async def test_take_action(self, format_log_message, logger, mod_log_property):
+ """Should delete the message and send a mod log."""
+ cog = TokenRemover(self.bot)
+ mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True)
+ token = mock.create_autospec(Token, spec_set=True, instance=True)
+ log_msg = "testing123"
+
+ mod_log_property.return_value = mod_log
+ format_log_message.return_value = log_msg
+
+ await cog.take_action(self.msg, token)
+
+ self.msg.delete.assert_called_once_with()
+ self.msg.channel.send.assert_called_once_with(
+ token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention)
+ )
+
+ format_log_message.assert_called_once_with(self.msg, token)
+ logger.debug.assert_called_with(log_msg)
+ self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens")
+
+ mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id)
+ mod_log.send_log_message.assert_called_once_with(
+ icon_url=constants.Icons.token_removed,
+ colour=Colour(constants.Colours.soft_red),
+ title="Token removed!",
+ text=log_msg,
+ thumbnail=self.msg.author.avatar_url_as.return_value,
+ channel_id=constants.Channels.mod_alerts
+ )
+
+
+class TokenRemoverExtensionTests(unittest.TestCase):
+ """Tests for the token_remover extension."""
+
+ @autospec("bot.cogs.token_remover", "TokenRemover")
+ def test_extension_setup(self, cog):
+ """The TokenRemover cog should be added."""
bot = MockBot()
- setup_cog(bot)
+ token_remover.setup(bot)
+
+ cog.assert_called_once_with(bot)
bot.add_cog.assert_called_once()
+ self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover))
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index ca8cb6825..c42111f3f 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -1,5 +1,5 @@
-import asyncio
import datetime
+import re
import unittest
from unittest.mock import MagicMock, patch
@@ -16,7 +16,7 @@ from bot.converters import (
)
-class ConverterTests(unittest.TestCase):
+class ConverterTests(unittest.IsolatedAsyncioTestCase):
"""Tests our custom argument converters."""
@classmethod
@@ -26,7 +26,7 @@ class ConverterTests(unittest.TestCase):
cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00')
- def test_tag_content_converter_for_valid(self):
+ async def test_tag_content_converter_for_valid(self):
"""TagContentConverter should return correct values for valid input."""
test_values = (
('hello', 'hello'),
@@ -35,10 +35,10 @@ class ConverterTests(unittest.TestCase):
for content, expected_conversion in test_values:
with self.subTest(content=content, expected_conversion=expected_conversion):
- conversion = asyncio.run(TagContentConverter.convert(self.context, content))
+ conversion = await TagContentConverter.convert(self.context, content)
self.assertEqual(conversion, expected_conversion)
- def test_tag_content_converter_for_invalid(self):
+ async def test_tag_content_converter_for_invalid(self):
"""TagContentConverter should raise the proper exception for invalid input."""
test_values = (
('', "Tag contents should not be empty, or filled with whitespace."),
@@ -47,10 +47,10 @@ class ConverterTests(unittest.TestCase):
for value, exception_message in test_values:
with self.subTest(tag_content=value, exception_message=exception_message):
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(TagContentConverter.convert(self.context, value))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await TagContentConverter.convert(self.context, value)
- def test_tag_name_converter_for_valid(self):
+ async def test_tag_name_converter_for_valid(self):
"""TagNameConverter should return the correct values for valid tag names."""
test_values = (
('tracebacks', 'tracebacks'),
@@ -60,10 +60,10 @@ class ConverterTests(unittest.TestCase):
for name, expected_conversion in test_values:
with self.subTest(name=name, expected_conversion=expected_conversion):
- conversion = asyncio.run(TagNameConverter.convert(self.context, name))
+ conversion = await TagNameConverter.convert(self.context, name)
self.assertEqual(conversion, expected_conversion)
- def test_tag_name_converter_for_invalid(self):
+ async def test_tag_name_converter_for_invalid(self):
"""TagNameConverter should raise the correct exception for invalid tag names."""
test_values = (
('👋', "Don't be ridiculous, you can't use that character!"),
@@ -75,29 +75,29 @@ class ConverterTests(unittest.TestCase):
for invalid_name, exception_message in test_values:
with self.subTest(invalid_name=invalid_name, exception_message=exception_message):
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(TagNameConverter.convert(self.context, invalid_name))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await TagNameConverter.convert(self.context, invalid_name)
- def test_valid_python_identifier_for_valid(self):
+ async def test_valid_python_identifier_for_valid(self):
"""ValidPythonIdentifier returns valid identifiers unchanged."""
test_values = ('foo', 'lemon')
for name in test_values:
with self.subTest(identifier=name):
- conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name))
+ conversion = await ValidPythonIdentifier.convert(self.context, name)
self.assertEqual(name, conversion)
- def test_valid_python_identifier_for_invalid(self):
+ async def test_valid_python_identifier_for_invalid(self):
"""ValidPythonIdentifier raises the proper exception for invalid identifiers."""
test_values = ('nested.stuff', '#####')
for name in test_values:
with self.subTest(identifier=name):
exception_message = f'`{name}` is not a valid Python identifier'
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(ValidPythonIdentifier.convert(self.context, name))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await ValidPythonIdentifier.convert(self.context, name)
- def test_duration_converter_for_valid(self):
+ async def test_duration_converter_for_valid(self):
"""Duration returns the correct `datetime` for valid duration strings."""
test_values = (
# Simple duration strings
@@ -159,35 +159,35 @@ class ConverterTests(unittest.TestCase):
mock_datetime.utcnow.return_value = self.fixed_utc_now
with self.subTest(duration=duration, duration_dict=duration_dict):
- converted_datetime = asyncio.run(converter.convert(self.context, duration))
+ converted_datetime = await converter.convert(self.context, duration)
self.assertEqual(converted_datetime, expected_datetime)
- def test_duration_converter_for_invalid(self):
+ async def test_duration_converter_for_invalid(self):
"""Duration raises the right exception for invalid duration strings."""
test_values = (
# Units in wrong order
- ('1d1w'),
- ('1s1y'),
+ '1d1w',
+ '1s1y',
# Duplicated units
- ('1 year 2 years'),
- ('1 M 10 minutes'),
+ '1 year 2 years',
+ '1 M 10 minutes',
# Unknown substrings
- ('1MVes'),
- ('1y3breads'),
+ '1MVes',
+ '1y3breads',
# Missing amount
- ('ym'),
+ 'ym',
# Incorrect whitespace
- (" 1y"),
- ("1S "),
- ("1y 1m"),
+ " 1y",
+ "1S ",
+ "1y 1m",
# Garbage
- ('Guido van Rossum'),
- ('lemon lemon lemon lemon lemon lemon lemon'),
+ 'Guido van Rossum',
+ 'lemon lemon lemon lemon lemon lemon lemon',
)
converter = Duration()
@@ -195,10 +195,21 @@ class ConverterTests(unittest.TestCase):
for invalid_duration in test_values:
with self.subTest(invalid_duration=invalid_duration):
exception_message = f'`{invalid_duration}` is not a valid duration string.'
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(converter.convert(self.context, invalid_duration))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, invalid_duration)
- def test_isodatetime_converter_for_valid(self):
+ @patch("bot.converters.datetime")
+ async def test_duration_converter_out_of_range(self, mock_datetime):
+ """Duration converter should raise BadArgument if datetime raises a ValueError."""
+ mock_datetime.__add__.side_effect = ValueError
+ mock_datetime.utcnow.return_value = mock_datetime
+
+ duration = f"{datetime.MAXYEAR}y"
+ exception_message = f"`{duration}` results in a datetime outside the supported range."
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await Duration().convert(self.context, duration)
+
+ async def test_isodatetime_converter_for_valid(self):
"""ISODateTime converter returns correct datetime for valid datetime string."""
test_values = (
# `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ`
@@ -243,37 +254,37 @@ class ConverterTests(unittest.TestCase):
for datetime_string, expected_dt in test_values:
with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt):
- converted_dt = asyncio.run(converter.convert(self.context, datetime_string))
+ converted_dt = await converter.convert(self.context, datetime_string)
self.assertIsNone(converted_dt.tzinfo)
self.assertEqual(converted_dt, expected_dt)
- def test_isodatetime_converter_for_invalid(self):
+ async def test_isodatetime_converter_for_invalid(self):
"""ISODateTime converter raises the correct exception for invalid datetime strings."""
test_values = (
# Make sure it doesn't interfere with the Duration converter
- ('1Y'),
- ('1d'),
- ('1H'),
+ '1Y',
+ '1d',
+ '1H',
# Check if it fails when only providing the optional time part
- ('10:10:10'),
- ('10:00'),
+ '10:10:10',
+ '10:00',
# Invalid date format
- ('19-01-01'),
+ '19-01-01',
# Other non-valid strings
- ('fisk the tag master'),
+ 'fisk the tag master',
)
converter = ISODateTime()
for datetime_string in test_values:
with self.subTest(datetime_string=datetime_string):
exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string"
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(converter.convert(self.context, datetime_string))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, datetime_string)
- def test_hush_duration_converter_for_valid(self):
+ async def test_hush_duration_converter_for_valid(self):
"""HushDurationConverter returns correct value for minutes duration or `"forever"` strings."""
test_values = (
("0", 0),
@@ -286,10 +297,10 @@ class ConverterTests(unittest.TestCase):
converter = HushDurationConverter()
for minutes_string, expected_minutes in test_values:
with self.subTest(minutes_string=minutes_string, expected_minutes=expected_minutes):
- converted = asyncio.run(converter.convert(self.context, minutes_string))
+ converted = await converter.convert(self.context, minutes_string)
self.assertEqual(expected_minutes, converted)
- def test_hush_duration_converter_for_invalid(self):
+ async def test_hush_duration_converter_for_invalid(self):
"""HushDurationConverter raises correct exception for invalid minutes duration strings."""
test_values = (
("16", "Duration must be at most 15 minutes."),
@@ -299,5 +310,5 @@ class ConverterTests(unittest.TestCase):
converter = HushDurationConverter()
for invalid_minutes_string, exception_message in test_values:
with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message):
- with self.assertRaisesRegex(BadArgument, exception_message):
- asyncio.run(converter.convert(self.context, invalid_minutes_string))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, invalid_minutes_string)
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
index 62c411681..a2f0fe55d 100644
--- a/tests/bot/utils/test_redis_cache.py
+++ b/tests/bot/utils/test_redis_cache.py
@@ -44,16 +44,6 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(RuntimeError):
await bad_cache.set("test", "me_up_deadman")
- def test_namespace_collision(self):
- """Test that we prevent colliding namespaces."""
- bob_cache_1 = RedisCache()
- bob_cache_1._set_namespace("BobRoss")
- self.assertEqual(bob_cache_1._namespace, "BobRoss")
-
- bob_cache_2 = RedisCache()
- bob_cache_2._set_namespace("BobRoss")
- self.assertEqual(bob_cache_2._namespace, "BobRoss_")
-
async def test_set_get_item(self):
"""Test that users can set and get items from the RedisDict."""
test_cases = (
diff --git a/tests/helpers.py b/tests/helpers.py
index faa839370..facc4e1af 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -5,7 +5,7 @@ import itertools
import logging
import unittest.mock
from asyncio import AbstractEventLoop
-from typing import Iterable, Optional
+from typing import Callable, Iterable, Optional
import discord
from aiohttp import ClientSession
@@ -26,6 +26,24 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
+def autospec(target, *attributes: str, **kwargs) -> Callable:
+ """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True."""
+ # Caller's kwargs should take priority and overwrite the defaults.
+ kwargs = {'spec_set': True, 'autospec': True, **kwargs}
+
+ # Import the target if it's a string.
+ # This is to support both object and string targets like patch.multiple.
+ if type(target) is str:
+ target = unittest.mock._importer(target)
+
+ def decorator(func):
+ for attribute in attributes:
+ patcher = unittest.mock.patch.object(target, attribute, **kwargs)
+ func = patcher(func)
+ return func
+ return decorator
+
+
class HashableMixin(discord.mixins.EqualityComparable):
"""
Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.