From 0c3e67161411b3b6eaa463d80b2b183af43da6a1 Mon Sep 17 00:00:00 2001 From: Joe Banks Date: Mon, 4 Sep 2023 20:39:14 +0100 Subject: Swap database creation script to asyncpg --- create_metricity_db.py | 53 +++++++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 24 deletions(-) (limited to 'create_metricity_db.py') diff --git a/create_metricity_db.py b/create_metricity_db.py index 215ba6e..a792f8e 100644 --- a/create_metricity_db.py +++ b/create_metricity_db.py @@ -1,10 +1,10 @@ """Ensures the metricity db exists before running migrations.""" -import os +import asyncio from urllib.parse import SplitResult, urlsplit -import psycopg2 -from psycopg2 import sql -from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +import asyncpg + +from metricity.database import build_db_uri def parse_db_url(db_url: str) -> SplitResult: @@ -18,25 +18,30 @@ def parse_db_url(db_url: str) -> SplitResult: raise ValueError("The given db_url is not a valid PostgreSQL database URL.") return db_url_parts +async def create_db() -> None: + """Create the Metricity database if it does not exist.""" + parts = parse_db_url(build_db_uri()) + try: + await asyncpg.connect( + host=parts.hostname, + user=parts.username, + database=parts.path[1:], + password=parts.password, + ) + except asyncpg.InvalidCatalogNameError: + print("Creating metricity database.") # noqa: T201 + sys_conn = await asyncpg.connect( + database="template1", + user=parts.username, + host=parts.hostname, + password=parts.password, + ) + + await sys_conn.execute( + f'CREATE DATABASE "{parts.path[1:] or "metricity"}" OWNER "{parts.username}"', + ) + + await sys_conn.close() if __name__ == "__main__": - database_parts = parse_db_url(os.environ["DATABASE_URI"]) - - conn = psycopg2.connect( - host=database_parts.hostname, - port=database_parts.port, - user=database_parts.username, - password=database_parts.password, - ) - - db_name = database_parts.path[1:] or "metricity" - - # Required to create a database in a .execute() call - conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - with conn.cursor() as cursor: - cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (db_name,)) - exists = cursor.fetchone() - if not exists: - print("Creating metricity database.") # noqa: T201 - cursor.execute(sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name))) - conn.close() + asyncio.get_event_loop().run_until_complete(create_db()) -- cgit v1.2.3 From 31b5c3e11139434c0de4678d35dba5aa0adeb4d0 Mon Sep 17 00:00:00 2001 From: Joe Banks Date: Mon, 4 Sep 2023 20:51:32 +0100 Subject: Default postgresql driver to asyncpg --- create_metricity_db.py | 6 +++--- docker-compose.yml | 2 +- metricity/database.py | 6 ++++++ 3 files changed, 10 insertions(+), 4 deletions(-) (limited to 'create_metricity_db.py') diff --git a/create_metricity_db.py b/create_metricity_db.py index a792f8e..1c7acaa 100644 --- a/create_metricity_db.py +++ b/create_metricity_db.py @@ -11,7 +11,7 @@ def parse_db_url(db_url: str) -> SplitResult: """Validate and split the given database url.""" db_url_parts = urlsplit(db_url) if not all(( - db_url_parts.hostname, + db_url_parts.netloc, db_url_parts.username, db_url_parts.password, )): @@ -23,7 +23,7 @@ async def create_db() -> None: parts = parse_db_url(build_db_uri()) try: await asyncpg.connect( - host=parts.hostname, + host=parts.netloc, user=parts.username, database=parts.path[1:], password=parts.password, @@ -33,7 +33,7 @@ async def create_db() -> None: sys_conn = await asyncpg.connect( database="template1", user=parts.username, - host=parts.hostname, + host=parts.netloc, password=parts.password, ) diff --git a/docker-compose.yml b/docker-compose.yml index 00693e5..b378a95 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -38,4 +38,4 @@ services: env_file: - .env environment: - DATABASE_URI: postgres://pysite:pysite@postgres/metricity + DATABASE_URI: postgres+asyncpg://pysite:pysite@postgres/metricity diff --git a/metricity/database.py b/metricity/database.py index 4fdf465..347ce90 100644 --- a/metricity/database.py +++ b/metricity/database.py @@ -2,6 +2,7 @@ import logging from datetime import UTC, datetime +from urllib.parse import urlsplit from sqlalchemy.engine import Dialect from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine @@ -14,6 +15,11 @@ log = logging.getLogger(__name__) def build_db_uri() -> str: """Build the database uri from the config.""" if DatabaseConfig.uri: + parsed = urlsplit(DatabaseConfig.uri) + if parsed.scheme != "postgresql+asyncpg": + log.debug("The given db_url did not use the asyncpg driver. Updating the db_url to use asyncpg.") + return parsed._replace(scheme="postgresql+asyncpg").geturl() + return DatabaseConfig.uri return ( -- cgit v1.2.3 From 667d3d6b332c7ba98832d6e7a961cd9bb7fefdb7 Mon Sep 17 00:00:00 2001 From: Joe Banks Date: Mon, 4 Sep 2023 20:57:56 +0100 Subject: Don't use netloc as a hostname, use hostname + port separately --- create_metricity_db.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'create_metricity_db.py') diff --git a/create_metricity_db.py b/create_metricity_db.py index 1c7acaa..2dc3a75 100644 --- a/create_metricity_db.py +++ b/create_metricity_db.py @@ -23,7 +23,8 @@ async def create_db() -> None: parts = parse_db_url(build_db_uri()) try: await asyncpg.connect( - host=parts.netloc, + host=parts.hostname, + port=parts.port, user=parts.username, database=parts.path[1:], password=parts.password, @@ -33,7 +34,8 @@ async def create_db() -> None: sys_conn = await asyncpg.connect( database="template1", user=parts.username, - host=parts.netloc, + host=parts.hostname, + port=parts.port, password=parts.password, ) -- cgit v1.2.3