diff options
author | 2023-07-24 10:50:18 +0100 | |
---|---|---|
committer | 2023-07-24 10:50:18 +0100 | |
commit | 28c903b18cdfaf8165d45278791150bdd85a4c2c (patch) | |
tree | 5335ee679bd3864626f1aa08135120bdaa515940 | |
parent | Move linting to ruff (diff) |
Update code style for new ruff rules
-rw-r--r-- | create_metricity_db.py | 16 | ||||
-rw-r--r-- | metricity/__init__.py | 1 | ||||
-rw-r--r-- | metricity/__main__.py | 2 | ||||
-rw-r--r-- | metricity/bot.py | 6 | ||||
-rw-r--r-- | metricity/config.py | 60 | ||||
-rw-r--r-- | metricity/database.py | 8 | ||||
-rw-r--r-- | metricity/exts/error_handler.py | 9 | ||||
-rw-r--r-- | metricity/exts/event_listeners/_utils.py | 4 | ||||
-rw-r--r-- | metricity/exts/event_listeners/guild_listeners.py | 26 | ||||
-rw-r--r-- | metricity/exts/event_listeners/member_listeners.py | 25 | ||||
-rw-r--r-- | metricity/exts/status.py | 4 | ||||
-rw-r--r-- | metricity/models.py | 18 |
12 files changed, 84 insertions, 95 deletions
diff --git a/create_metricity_db.py b/create_metricity_db.py index edec2db..215ba6e 100644 --- a/create_metricity_db.py +++ b/create_metricity_db.py @@ -8,16 +8,14 @@ from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT def parse_db_url(db_url: str) -> SplitResult: - """Validate and split the given databse url.""" + """Validate and split the given database url.""" db_url_parts = urlsplit(db_url) if not all(( db_url_parts.hostname, db_url_parts.username, - db_url_parts.password + db_url_parts.password, )): - raise ValueError( - "The given db_url is not a valid PostgreSQL database URL." - ) + raise ValueError("The given db_url is not a valid PostgreSQL database URL.") return db_url_parts @@ -28,7 +26,7 @@ if __name__ == "__main__": host=database_parts.hostname, port=database_parts.port, user=database_parts.username, - password=database_parts.password + password=database_parts.password, ) db_name = database_parts.path[1:] or "metricity" @@ -39,8 +37,6 @@ if __name__ == "__main__": cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (db_name,)) exists = cursor.fetchone() if not exists: - print("Creating metricity database.") - cursor.execute( - sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name)) - ) + print("Creating metricity database.") # noqa: T201 + cursor.execute(sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name))) conn.close() diff --git a/metricity/__init__.py b/metricity/__init__.py index 12e0b4a..df7d211 100644 --- a/metricity/__init__.py +++ b/metricity/__init__.py @@ -5,7 +5,6 @@ import logging import os from typing import TYPE_CHECKING - import coloredlogs from pydis_core.utils import apply_monkey_patches diff --git a/metricity/__main__.py b/metricity/__main__.py index 71fe8a7..caf8ef6 100644 --- a/metricity/__main__.py +++ b/metricity/__main__.py @@ -25,7 +25,7 @@ async def main() -> None: presences=False, messages=True, reactions=False, - typing=False + typing=False, ) async with aiohttp.ClientSession() as session: diff --git a/metricity/bot.py b/metricity/bot.py index 17aec9c..0c0fa38 100644 --- a/metricity/bot.py +++ b/metricity/bot.py @@ -23,10 +23,10 @@ class Bot(BotBase): async def setup_hook(self) -> None: """Connect to db and load cogs.""" await super().setup_hook() - log.info(f"Metricity is online, logged in as {self.user}") + log.info("Metricity is online, logged in as %s", self.user) await connect() await self.load_extensions(exts) - async def on_error(self, event: str, *args, **kwargs) -> None: + async def on_error(self, event: str, *_args, **_kwargs) -> None: """Log errors raised in event listeners rather than printing them to stderr.""" - log.exception(f"Unhandled exception in {event}.") + log.exception("Unhandled exception in %s", event) diff --git a/metricity/config.py b/metricity/config.py index a77b252..197287e 100644 --- a/metricity/config.py +++ b/metricity/config.py @@ -2,7 +2,7 @@ import logging from os import environ from pathlib import Path -from typing import Any, Optional +from typing import Any import toml from deepmerge import Merger @@ -28,7 +28,7 @@ def get_section(section: str) -> dict[str, Any]: if not Path("config-default.toml").exists(): raise MetricityConfigurationError("config-default.toml is missing") - with open("config-default.toml", "r") as default_config_file: + with Path.open("config-default.toml") as default_config_file: default_config = toml.load(default_config_file) # Load user configuration @@ -36,25 +36,23 @@ def get_section(section: str) -> dict[str, Any]: user_config_location = Path(environ.get("CONFIG_LOCATION", "./config.toml")) if user_config_location.exists(): - with open(user_config_location, "r") as user_config_file: + with Path.open(user_config_location) as user_config_file: user_config = toml.load(user_config_file) # Merge the configuration merger = Merger( [ - (dict, "merge") + (dict, "merge"), ], ["override"], - ["override"] + ["override"], ) conf = merger.merge(default_config, user_config) # Check whether we are missing the requested section if not conf.get(section): - raise MetricityConfigurationError( - f"Config is missing section '{section}'" - ) + raise MetricityConfigurationError(f"Config is missing section '{section}'") return conf[section] @@ -66,34 +64,30 @@ class ConfigSection(type): cls: type, name: str, bases: tuple[type], - dictionary: dict[str, Any] + dictionary: dict[str, Any], ) -> type: """Use the section attr in the subclass to fill in the values from the TOML.""" config = get_section(dictionary["section"]) - log.info(f"Loading configuration section {dictionary['section']}") + log.info("Loading configuration section %s", dictionary["section"]) for key, value in config.items(): - if isinstance(value, dict): - if env_var := value.get("env"): - if env_value := environ.get(env_var): - config[key] = env_value - else: - if not value.get("optional"): - raise MetricityConfigurationError( - f"Required config option '{key}' in" - f" '{dictionary['section']}' is missing, either set" - f" the environment variable {env_var} or override " - "it in your config.toml file" - ) - else: - config[key] = None + if isinstance(value, dict) and (env_var := value.get("env")): + if env_value := environ.get(env_var): + config[key] = env_value + elif not value.get("optional"): + raise MetricityConfigurationError( + f"Required config option '{key}' in" + f" '{dictionary['section']}' is missing, either set" + f" the environment variable {env_var} or override " + "it in your config.toml file", + ) + else: + config[key] = None dictionary.update(config) - config_section = super().__new__(cls, name, bases, dictionary) - - return config_section + return super().__new__(cls, name, bases, dictionary) class PythonConfig(metaclass=ConfigSection): @@ -125,10 +119,10 @@ class DatabaseConfig(metaclass=ConfigSection): section = "database" - uri: Optional[str] + uri: str | None - host: Optional[str] - port: Optional[int] - database: Optional[str] - username: Optional[str] - password: Optional[str] + host: str | None + port: int | None + database: str | None + username: str | None + password: str | None diff --git a/metricity/database.py b/metricity/database.py index a4d953e..717d3f2 100644 --- a/metricity/database.py +++ b/metricity/database.py @@ -45,17 +45,15 @@ class TZDateTime(TypeDecorator): impl = DateTime cache_ok = True - def process_bind_param(self, value: datetime, dialect: Dialect) -> datetime: + def process_bind_param(self, value: datetime, _dialect: Dialect) -> datetime: """Convert the value to aware before saving to db.""" if value is not None: if not value.tzinfo: raise TypeError("tzinfo is required") - value = value.astimezone(timezone.utc).replace( - tzinfo=None - ) + value = value.astimezone(timezone.utc).replace(tzinfo=None) return value - def process_result_value(self, value: datetime, dialect: Dialect) -> datetime: + def process_result_value(self, value: datetime, _dialect: Dialect) -> datetime: """Convert the value to aware before passing back to user-land.""" if value is not None: value = value.replace(tzinfo=timezone.utc) diff --git a/metricity/exts/error_handler.py b/metricity/exts/error_handler.py index b3c8fb7..1330eda 100644 --- a/metricity/exts/error_handler.py +++ b/metricity/exts/error_handler.py @@ -24,7 +24,7 @@ class ErrorHandler(commands.Cog): return discord.Embed( title=title, colour=discord.Colour.red(), - description=body + description=body, ) @commands.Cog.listener() @@ -32,8 +32,11 @@ class ErrorHandler(commands.Cog): """Provide generic command error handling.""" if isinstance(e, SUPPRESSED_ERRORS): log.debug( - f"Command {ctx.invoked_with} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" + "Command %s invoked by %s with error %s: %s", + ctx.invoked_with, + ctx.message.author, + e.__class__.__name__, + e, ) diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py index f0bbe39..6b2aacf 100644 --- a/metricity/exts/event_listeners/_utils.py +++ b/metricity/exts/event_listeners/_utils.py @@ -16,7 +16,7 @@ async def insert_thread(thread: discord.Thread) -> None: ) -async def sync_message(message: discord.Message, from_thread: bool) -> None: +async def sync_message(message: discord.Message, *, from_thread: bool) -> None: """Sync the given message with the database.""" if await models.Message.get(str(message.id)): return @@ -25,7 +25,7 @@ async def sync_message(message: discord.Message, from_thread: bool) -> None: "id": str(message.id), "channel_id": str(message.channel.id), "author_id": str(message.author.id), - "created_at": message.created_at + "created_at": message.created_at, } if from_thread: diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py index 4f14021..c7c074f 100644 --- a/metricity/exts/event_listeners/guild_listeners.py +++ b/metricity/exts/event_listeners/guild_listeners.py @@ -33,9 +33,8 @@ class GuildListeners(commands.Cog): log.info("Beginning user synchronisation process") await models.User.update.values(in_guild=False).gino.status() - users = [] - for user in guild.members: - users.append({ + users = [ + { "id": str(user.id), "name": user.name, "avatar_hash": getattr(user.avatar, "key", None), @@ -46,15 +45,17 @@ class GuildListeners(commands.Cog): "bot": user.bot, "in_guild": True, "public_flags": dict(user.public_flags), - "pending": user.pending - }) + "pending": user.pending, + } + for user in guild.members + ] - log.info(f"Performing bulk upsert of {len(users)} rows") + log.info("Performing bulk upsert of %d rows", len(users)) user_chunks = discord.utils.as_chunks(users, 500) for chunk in user_chunks: - log.info(f"Upserting chunk of {len(chunk)}") + log.info("Upserting chunk of %d", len(chunk)) await models.User.bulk_upsert(chunk) log.info("User upsert complete") @@ -85,15 +86,14 @@ class GuildListeners(commands.Cog): log.info("Category synchronisation process complete, synchronising channels") for channel in guild.channels: - if channel.category: - if channel.category.id in BotConfig.ignore_categories: - continue + if channel.category and channel.category.id in BotConfig.ignore_categories: + continue if not isinstance(channel, discord.CategoryChannel): category_id = str(channel.category.id) if channel.category else None # Cast to bool so is_staff is False if channel.category is None is_staff = channel.id in BotConfig.staff_channels or bool( - channel.category and channel.category.id in BotConfig.staff_categories + channel.category and channel.category.id in BotConfig.staff_categories, ) if db_chan := await models.Channel.get(str(channel.id)): await db_chan.update( @@ -145,7 +145,7 @@ class GuildListeners(commands.Cog): async def on_guild_channel_update( self, _before: discord.abc.GuildChannel, - channel: discord.abc.GuildChannel + channel: discord.abc.GuildChannel, ) -> None: """Sync the channels when one is updated.""" if channel.guild.id != BotConfig.guild_id: @@ -172,7 +172,7 @@ class GuildListeners(commands.Cog): @commands.Cog.listener() async def on_guild_available(self, guild: discord.Guild) -> None: """Synchronize the user table with the Discord users.""" - log.info(f"Received guild available for {guild.id}") + log.info("Received guild available for %d", guild.id) if guild.id != BotConfig.guild_id: log.info("Guild was not the configured guild, discarding event") diff --git a/metricity/exts/event_listeners/member_listeners.py b/metricity/exts/event_listeners/member_listeners.py index f3074ce..ddf5954 100644 --- a/metricity/exts/event_listeners/member_listeners.py +++ b/metricity/exts/event_listeners/member_listeners.py @@ -1,5 +1,7 @@ """An ext to listen for member events and syncs them to the database.""" +import contextlib + import discord from asyncpg.exceptions import UniqueViolationError from discord.ext import commands @@ -25,7 +27,7 @@ class MemberListeners(commands.Cog): if db_user := await User.get(str(member.id)): await db_user.update( - in_guild=False + in_guild=False, ).apply() @commands.Cog.listener() @@ -47,10 +49,10 @@ class MemberListeners(commands.Cog): is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], public_flags=dict(member.public_flags), pending=member.pending, - in_guild=True + in_guild=True, ).apply() else: - try: + with contextlib.suppress(UniqueViolationError): await User.create( id=str(member.id), name=member.name, @@ -61,13 +63,11 @@ class MemberListeners(commands.Cog): is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], public_flags=dict(member.public_flags), pending=member.pending, - in_guild=True + in_guild=True, ) - except UniqueViolationError: - pass @commands.Cog.listener() - async def on_member_update(self, before: discord.Member, member: discord.Member) -> None: + async def on_member_update(self, _before: discord.Member, member: discord.Member) -> None: """When a member updates their profile, update the DB record.""" await self.bot.sync_process_complete.wait() @@ -78,7 +78,7 @@ class MemberListeners(commands.Cog): if not member.joined_at: return - roles = set([role.id for role in member.roles]) + roles = {role.id for role in member.roles} if db_user := await User.get(str(member.id)): if ( @@ -99,10 +99,10 @@ class MemberListeners(commands.Cog): is_staff=BotConfig.staff_role_id in roles, public_flags=dict(member.public_flags), in_guild=True, - pending=member.pending + pending=member.pending, ).apply() else: - try: + with contextlib.suppress(UniqueViolationError): await User.create( id=str(member.id), name=member.name, @@ -113,10 +113,9 @@ class MemberListeners(commands.Cog): is_staff=BotConfig.staff_role_id in roles, public_flags=dict(member.public_flags), in_guild=True, - pending=member.pending + pending=member.pending, ) - except UniqueViolationError: - pass + async def setup(bot: Bot) -> None: diff --git a/metricity/exts/status.py b/metricity/exts/status.py index c69e0c1..607c1e3 100644 --- a/metricity/exts/status.py +++ b/metricity/exts/status.py @@ -9,7 +9,7 @@ from metricity.config import BotConfig DESCRIPTIONS = ( "Command processing time", "Last event received", - "Discord API latency" + "Discord API latency", ) ROUND_LATENCY = 3 INTRO_MESSAGE = "Hello, I'm {name}. I insert all your data into a GDPR-compliant database." @@ -49,7 +49,7 @@ class Status(commands.Cog): description=INTRO_MESSAGE.format(name=ctx.guild.me.display_name), ) - for desc, latency in zip(DESCRIPTIONS, (bot_ping, last_event, discord_ping)): + for desc, latency in zip(DESCRIPTIONS, (bot_ping, last_event, discord_ping), strict=True): embed.add_field(name=desc, value=latency, inline=False) await ctx.send(embed=embed) diff --git a/metricity/models.py b/metricity/models.py index 4f136de..700464e 100644 --- a/metricity/models.py +++ b/metricity/models.py @@ -1,7 +1,7 @@ """Database models used by Metricity for statistic collection.""" from datetime import datetime, timezone -from typing import Any, Dict, List +from typing import Any from sqlalchemy.dialects.postgresql import insert @@ -27,7 +27,7 @@ class Channel(db.Model): category_id = db.Column( db.String, db.ForeignKey("categories.id", ondelete="CASCADE"), - nullable=True + nullable=True, ) is_staff = db.Column(db.Boolean, nullable=False) @@ -41,7 +41,7 @@ class Thread(db.Model): parent_channel_id = db.Column( db.String, db.ForeignKey("channels.id", ondelete="CASCADE"), - nullable=False + nullable=False, ) created_at = db.Column(TZDateTime(), default=datetime.now(timezone.utc)) name = db.Column(db.String, nullable=False) @@ -69,7 +69,7 @@ class User(db.Model): pending = db.Column(db.Boolean, default=False) @classmethod - def bulk_upsert(cls: type, users: List[Dict[str, Any]]) -> Any: + def bulk_upsert(cls: type, users: list[dict[str, Any]]) -> Any: # noqa: ANN401 """Perform a bulk insert/update of the database to sync the user table.""" qs = insert(cls.__table__).values(users) @@ -82,12 +82,12 @@ class User(db.Model): "bot", "in_guild", "public_flags", - "pending" + "pending", ] return qs.on_conflict_do_update( index_elements=[cls.id], - set_={k: getattr(qs.excluded, k) for k in update_cols} + set_={k: getattr(qs.excluded, k) for k in update_cols}, ).returning(cls.__table__).gino.all() @@ -100,17 +100,17 @@ class Message(db.Model): channel_id = db.Column( db.String, db.ForeignKey("channels.id", ondelete="CASCADE"), - index=True + index=True, ) thread_id = db.Column( db.String, db.ForeignKey("threads.id", ondelete="CASCADE"), - index=True + index=True, ) author_id = db.Column( db.String, db.ForeignKey("users.id", ondelete="CASCADE"), - index=True + index=True, ) created_at = db.Column(TZDateTime(), default=datetime.now(timezone.utc)) is_deleted = db.Column(db.Boolean, default=False) |