aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--create_metricity_db.py16
-rw-r--r--metricity/__init__.py1
-rw-r--r--metricity/__main__.py2
-rw-r--r--metricity/bot.py6
-rw-r--r--metricity/config.py60
-rw-r--r--metricity/database.py8
-rw-r--r--metricity/exts/error_handler.py9
-rw-r--r--metricity/exts/event_listeners/_utils.py4
-rw-r--r--metricity/exts/event_listeners/guild_listeners.py26
-rw-r--r--metricity/exts/event_listeners/member_listeners.py25
-rw-r--r--metricity/exts/status.py4
-rw-r--r--metricity/models.py18
12 files changed, 84 insertions, 95 deletions
diff --git a/create_metricity_db.py b/create_metricity_db.py
index edec2db..215ba6e 100644
--- a/create_metricity_db.py
+++ b/create_metricity_db.py
@@ -8,16 +8,14 @@ from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
def parse_db_url(db_url: str) -> SplitResult:
- """Validate and split the given databse url."""
+ """Validate and split the given database url."""
db_url_parts = urlsplit(db_url)
if not all((
db_url_parts.hostname,
db_url_parts.username,
- db_url_parts.password
+ db_url_parts.password,
)):
- raise ValueError(
- "The given db_url is not a valid PostgreSQL database URL."
- )
+ raise ValueError("The given db_url is not a valid PostgreSQL database URL.")
return db_url_parts
@@ -28,7 +26,7 @@ if __name__ == "__main__":
host=database_parts.hostname,
port=database_parts.port,
user=database_parts.username,
- password=database_parts.password
+ password=database_parts.password,
)
db_name = database_parts.path[1:] or "metricity"
@@ -39,8 +37,6 @@ if __name__ == "__main__":
cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (db_name,))
exists = cursor.fetchone()
if not exists:
- print("Creating metricity database.")
- cursor.execute(
- sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name))
- )
+ print("Creating metricity database.") # noqa: T201
+ cursor.execute(sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name)))
conn.close()
diff --git a/metricity/__init__.py b/metricity/__init__.py
index 12e0b4a..df7d211 100644
--- a/metricity/__init__.py
+++ b/metricity/__init__.py
@@ -5,7 +5,6 @@ import logging
import os
from typing import TYPE_CHECKING
-
import coloredlogs
from pydis_core.utils import apply_monkey_patches
diff --git a/metricity/__main__.py b/metricity/__main__.py
index 71fe8a7..caf8ef6 100644
--- a/metricity/__main__.py
+++ b/metricity/__main__.py
@@ -25,7 +25,7 @@ async def main() -> None:
presences=False,
messages=True,
reactions=False,
- typing=False
+ typing=False,
)
async with aiohttp.ClientSession() as session:
diff --git a/metricity/bot.py b/metricity/bot.py
index 17aec9c..0c0fa38 100644
--- a/metricity/bot.py
+++ b/metricity/bot.py
@@ -23,10 +23,10 @@ class Bot(BotBase):
async def setup_hook(self) -> None:
"""Connect to db and load cogs."""
await super().setup_hook()
- log.info(f"Metricity is online, logged in as {self.user}")
+ log.info("Metricity is online, logged in as %s", self.user)
await connect()
await self.load_extensions(exts)
- async def on_error(self, event: str, *args, **kwargs) -> None:
+ async def on_error(self, event: str, *_args, **_kwargs) -> None:
"""Log errors raised in event listeners rather than printing them to stderr."""
- log.exception(f"Unhandled exception in {event}.")
+ log.exception("Unhandled exception in %s", event)
diff --git a/metricity/config.py b/metricity/config.py
index a77b252..197287e 100644
--- a/metricity/config.py
+++ b/metricity/config.py
@@ -2,7 +2,7 @@
import logging
from os import environ
from pathlib import Path
-from typing import Any, Optional
+from typing import Any
import toml
from deepmerge import Merger
@@ -28,7 +28,7 @@ def get_section(section: str) -> dict[str, Any]:
if not Path("config-default.toml").exists():
raise MetricityConfigurationError("config-default.toml is missing")
- with open("config-default.toml", "r") as default_config_file:
+ with Path.open("config-default.toml") as default_config_file:
default_config = toml.load(default_config_file)
# Load user configuration
@@ -36,25 +36,23 @@ def get_section(section: str) -> dict[str, Any]:
user_config_location = Path(environ.get("CONFIG_LOCATION", "./config.toml"))
if user_config_location.exists():
- with open(user_config_location, "r") as user_config_file:
+ with Path.open(user_config_location) as user_config_file:
user_config = toml.load(user_config_file)
# Merge the configuration
merger = Merger(
[
- (dict, "merge")
+ (dict, "merge"),
],
["override"],
- ["override"]
+ ["override"],
)
conf = merger.merge(default_config, user_config)
# Check whether we are missing the requested section
if not conf.get(section):
- raise MetricityConfigurationError(
- f"Config is missing section '{section}'"
- )
+ raise MetricityConfigurationError(f"Config is missing section '{section}'")
return conf[section]
@@ -66,34 +64,30 @@ class ConfigSection(type):
cls: type,
name: str,
bases: tuple[type],
- dictionary: dict[str, Any]
+ dictionary: dict[str, Any],
) -> type:
"""Use the section attr in the subclass to fill in the values from the TOML."""
config = get_section(dictionary["section"])
- log.info(f"Loading configuration section {dictionary['section']}")
+ log.info("Loading configuration section %s", dictionary["section"])
for key, value in config.items():
- if isinstance(value, dict):
- if env_var := value.get("env"):
- if env_value := environ.get(env_var):
- config[key] = env_value
- else:
- if not value.get("optional"):
- raise MetricityConfigurationError(
- f"Required config option '{key}' in"
- f" '{dictionary['section']}' is missing, either set"
- f" the environment variable {env_var} or override "
- "it in your config.toml file"
- )
- else:
- config[key] = None
+ if isinstance(value, dict) and (env_var := value.get("env")):
+ if env_value := environ.get(env_var):
+ config[key] = env_value
+ elif not value.get("optional"):
+ raise MetricityConfigurationError(
+ f"Required config option '{key}' in"
+ f" '{dictionary['section']}' is missing, either set"
+ f" the environment variable {env_var} or override "
+ "it in your config.toml file",
+ )
+ else:
+ config[key] = None
dictionary.update(config)
- config_section = super().__new__(cls, name, bases, dictionary)
-
- return config_section
+ return super().__new__(cls, name, bases, dictionary)
class PythonConfig(metaclass=ConfigSection):
@@ -125,10 +119,10 @@ class DatabaseConfig(metaclass=ConfigSection):
section = "database"
- uri: Optional[str]
+ uri: str | None
- host: Optional[str]
- port: Optional[int]
- database: Optional[str]
- username: Optional[str]
- password: Optional[str]
+ host: str | None
+ port: int | None
+ database: str | None
+ username: str | None
+ password: str | None
diff --git a/metricity/database.py b/metricity/database.py
index a4d953e..717d3f2 100644
--- a/metricity/database.py
+++ b/metricity/database.py
@@ -45,17 +45,15 @@ class TZDateTime(TypeDecorator):
impl = DateTime
cache_ok = True
- def process_bind_param(self, value: datetime, dialect: Dialect) -> datetime:
+ def process_bind_param(self, value: datetime, _dialect: Dialect) -> datetime:
"""Convert the value to aware before saving to db."""
if value is not None:
if not value.tzinfo:
raise TypeError("tzinfo is required")
- value = value.astimezone(timezone.utc).replace(
- tzinfo=None
- )
+ value = value.astimezone(timezone.utc).replace(tzinfo=None)
return value
- def process_result_value(self, value: datetime, dialect: Dialect) -> datetime:
+ def process_result_value(self, value: datetime, _dialect: Dialect) -> datetime:
"""Convert the value to aware before passing back to user-land."""
if value is not None:
value = value.replace(tzinfo=timezone.utc)
diff --git a/metricity/exts/error_handler.py b/metricity/exts/error_handler.py
index b3c8fb7..1330eda 100644
--- a/metricity/exts/error_handler.py
+++ b/metricity/exts/error_handler.py
@@ -24,7 +24,7 @@ class ErrorHandler(commands.Cog):
return discord.Embed(
title=title,
colour=discord.Colour.red(),
- description=body
+ description=body,
)
@commands.Cog.listener()
@@ -32,8 +32,11 @@ class ErrorHandler(commands.Cog):
"""Provide generic command error handling."""
if isinstance(e, SUPPRESSED_ERRORS):
log.debug(
- f"Command {ctx.invoked_with} invoked by {ctx.message.author} with error "
- f"{e.__class__.__name__}: {e}"
+ "Command %s invoked by %s with error %s: %s",
+ ctx.invoked_with,
+ ctx.message.author,
+ e.__class__.__name__,
+ e,
)
diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py
index f0bbe39..6b2aacf 100644
--- a/metricity/exts/event_listeners/_utils.py
+++ b/metricity/exts/event_listeners/_utils.py
@@ -16,7 +16,7 @@ async def insert_thread(thread: discord.Thread) -> None:
)
-async def sync_message(message: discord.Message, from_thread: bool) -> None:
+async def sync_message(message: discord.Message, *, from_thread: bool) -> None:
"""Sync the given message with the database."""
if await models.Message.get(str(message.id)):
return
@@ -25,7 +25,7 @@ async def sync_message(message: discord.Message, from_thread: bool) -> None:
"id": str(message.id),
"channel_id": str(message.channel.id),
"author_id": str(message.author.id),
- "created_at": message.created_at
+ "created_at": message.created_at,
}
if from_thread:
diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py
index 4f14021..c7c074f 100644
--- a/metricity/exts/event_listeners/guild_listeners.py
+++ b/metricity/exts/event_listeners/guild_listeners.py
@@ -33,9 +33,8 @@ class GuildListeners(commands.Cog):
log.info("Beginning user synchronisation process")
await models.User.update.values(in_guild=False).gino.status()
- users = []
- for user in guild.members:
- users.append({
+ users = [
+ {
"id": str(user.id),
"name": user.name,
"avatar_hash": getattr(user.avatar, "key", None),
@@ -46,15 +45,17 @@ class GuildListeners(commands.Cog):
"bot": user.bot,
"in_guild": True,
"public_flags": dict(user.public_flags),
- "pending": user.pending
- })
+ "pending": user.pending,
+ }
+ for user in guild.members
+ ]
- log.info(f"Performing bulk upsert of {len(users)} rows")
+ log.info("Performing bulk upsert of %d rows", len(users))
user_chunks = discord.utils.as_chunks(users, 500)
for chunk in user_chunks:
- log.info(f"Upserting chunk of {len(chunk)}")
+ log.info("Upserting chunk of %d", len(chunk))
await models.User.bulk_upsert(chunk)
log.info("User upsert complete")
@@ -85,15 +86,14 @@ class GuildListeners(commands.Cog):
log.info("Category synchronisation process complete, synchronising channels")
for channel in guild.channels:
- if channel.category:
- if channel.category.id in BotConfig.ignore_categories:
- continue
+ if channel.category and channel.category.id in BotConfig.ignore_categories:
+ continue
if not isinstance(channel, discord.CategoryChannel):
category_id = str(channel.category.id) if channel.category else None
# Cast to bool so is_staff is False if channel.category is None
is_staff = channel.id in BotConfig.staff_channels or bool(
- channel.category and channel.category.id in BotConfig.staff_categories
+ channel.category and channel.category.id in BotConfig.staff_categories,
)
if db_chan := await models.Channel.get(str(channel.id)):
await db_chan.update(
@@ -145,7 +145,7 @@ class GuildListeners(commands.Cog):
async def on_guild_channel_update(
self,
_before: discord.abc.GuildChannel,
- channel: discord.abc.GuildChannel
+ channel: discord.abc.GuildChannel,
) -> None:
"""Sync the channels when one is updated."""
if channel.guild.id != BotConfig.guild_id:
@@ -172,7 +172,7 @@ class GuildListeners(commands.Cog):
@commands.Cog.listener()
async def on_guild_available(self, guild: discord.Guild) -> None:
"""Synchronize the user table with the Discord users."""
- log.info(f"Received guild available for {guild.id}")
+ log.info("Received guild available for %d", guild.id)
if guild.id != BotConfig.guild_id:
log.info("Guild was not the configured guild, discarding event")
diff --git a/metricity/exts/event_listeners/member_listeners.py b/metricity/exts/event_listeners/member_listeners.py
index f3074ce..ddf5954 100644
--- a/metricity/exts/event_listeners/member_listeners.py
+++ b/metricity/exts/event_listeners/member_listeners.py
@@ -1,5 +1,7 @@
"""An ext to listen for member events and syncs them to the database."""
+import contextlib
+
import discord
from asyncpg.exceptions import UniqueViolationError
from discord.ext import commands
@@ -25,7 +27,7 @@ class MemberListeners(commands.Cog):
if db_user := await User.get(str(member.id)):
await db_user.update(
- in_guild=False
+ in_guild=False,
).apply()
@commands.Cog.listener()
@@ -47,10 +49,10 @@ class MemberListeners(commands.Cog):
is_staff=BotConfig.staff_role_id in [role.id for role in member.roles],
public_flags=dict(member.public_flags),
pending=member.pending,
- in_guild=True
+ in_guild=True,
).apply()
else:
- try:
+ with contextlib.suppress(UniqueViolationError):
await User.create(
id=str(member.id),
name=member.name,
@@ -61,13 +63,11 @@ class MemberListeners(commands.Cog):
is_staff=BotConfig.staff_role_id in [role.id for role in member.roles],
public_flags=dict(member.public_flags),
pending=member.pending,
- in_guild=True
+ in_guild=True,
)
- except UniqueViolationError:
- pass
@commands.Cog.listener()
- async def on_member_update(self, before: discord.Member, member: discord.Member) -> None:
+ async def on_member_update(self, _before: discord.Member, member: discord.Member) -> None:
"""When a member updates their profile, update the DB record."""
await self.bot.sync_process_complete.wait()
@@ -78,7 +78,7 @@ class MemberListeners(commands.Cog):
if not member.joined_at:
return
- roles = set([role.id for role in member.roles])
+ roles = {role.id for role in member.roles}
if db_user := await User.get(str(member.id)):
if (
@@ -99,10 +99,10 @@ class MemberListeners(commands.Cog):
is_staff=BotConfig.staff_role_id in roles,
public_flags=dict(member.public_flags),
in_guild=True,
- pending=member.pending
+ pending=member.pending,
).apply()
else:
- try:
+ with contextlib.suppress(UniqueViolationError):
await User.create(
id=str(member.id),
name=member.name,
@@ -113,10 +113,9 @@ class MemberListeners(commands.Cog):
is_staff=BotConfig.staff_role_id in roles,
public_flags=dict(member.public_flags),
in_guild=True,
- pending=member.pending
+ pending=member.pending,
)
- except UniqueViolationError:
- pass
+
async def setup(bot: Bot) -> None:
diff --git a/metricity/exts/status.py b/metricity/exts/status.py
index c69e0c1..607c1e3 100644
--- a/metricity/exts/status.py
+++ b/metricity/exts/status.py
@@ -9,7 +9,7 @@ from metricity.config import BotConfig
DESCRIPTIONS = (
"Command processing time",
"Last event received",
- "Discord API latency"
+ "Discord API latency",
)
ROUND_LATENCY = 3
INTRO_MESSAGE = "Hello, I'm {name}. I insert all your data into a GDPR-compliant database."
@@ -49,7 +49,7 @@ class Status(commands.Cog):
description=INTRO_MESSAGE.format(name=ctx.guild.me.display_name),
)
- for desc, latency in zip(DESCRIPTIONS, (bot_ping, last_event, discord_ping)):
+ for desc, latency in zip(DESCRIPTIONS, (bot_ping, last_event, discord_ping), strict=True):
embed.add_field(name=desc, value=latency, inline=False)
await ctx.send(embed=embed)
diff --git a/metricity/models.py b/metricity/models.py
index 4f136de..700464e 100644
--- a/metricity/models.py
+++ b/metricity/models.py
@@ -1,7 +1,7 @@
"""Database models used by Metricity for statistic collection."""
from datetime import datetime, timezone
-from typing import Any, Dict, List
+from typing import Any
from sqlalchemy.dialects.postgresql import insert
@@ -27,7 +27,7 @@ class Channel(db.Model):
category_id = db.Column(
db.String,
db.ForeignKey("categories.id", ondelete="CASCADE"),
- nullable=True
+ nullable=True,
)
is_staff = db.Column(db.Boolean, nullable=False)
@@ -41,7 +41,7 @@ class Thread(db.Model):
parent_channel_id = db.Column(
db.String,
db.ForeignKey("channels.id", ondelete="CASCADE"),
- nullable=False
+ nullable=False,
)
created_at = db.Column(TZDateTime(), default=datetime.now(timezone.utc))
name = db.Column(db.String, nullable=False)
@@ -69,7 +69,7 @@ class User(db.Model):
pending = db.Column(db.Boolean, default=False)
@classmethod
- def bulk_upsert(cls: type, users: List[Dict[str, Any]]) -> Any:
+ def bulk_upsert(cls: type, users: list[dict[str, Any]]) -> Any: # noqa: ANN401
"""Perform a bulk insert/update of the database to sync the user table."""
qs = insert(cls.__table__).values(users)
@@ -82,12 +82,12 @@ class User(db.Model):
"bot",
"in_guild",
"public_flags",
- "pending"
+ "pending",
]
return qs.on_conflict_do_update(
index_elements=[cls.id],
- set_={k: getattr(qs.excluded, k) for k in update_cols}
+ set_={k: getattr(qs.excluded, k) for k in update_cols},
).returning(cls.__table__).gino.all()
@@ -100,17 +100,17 @@ class Message(db.Model):
channel_id = db.Column(
db.String,
db.ForeignKey("channels.id", ondelete="CASCADE"),
- index=True
+ index=True,
)
thread_id = db.Column(
db.String,
db.ForeignKey("threads.id", ondelete="CASCADE"),
- index=True
+ index=True,
)
author_id = db.Column(
db.String,
db.ForeignKey("users.id", ondelete="CASCADE"),
- index=True
+ index=True,
)
created_at = db.Column(TZDateTime(), default=datetime.now(timezone.utc))
is_deleted = db.Column(db.Boolean, default=False)