aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Joe Banks <[email protected]>2023-09-04 20:51:32 +0100
committerGravatar Joe Banks <[email protected]>2023-09-04 20:51:32 +0100
commit31b5c3e11139434c0de4678d35dba5aa0adeb4d0 (patch)
treefa543863087f00cddbf673a1988b0f75365dfe4c
parentSwap database creation script to asyncpg (diff)
Default postgresql driver to asyncpg
-rw-r--r--create_metricity_db.py6
-rw-r--r--docker-compose.yml2
-rw-r--r--metricity/database.py6
3 files changed, 10 insertions, 4 deletions
diff --git a/create_metricity_db.py b/create_metricity_db.py
index a792f8e..1c7acaa 100644
--- a/create_metricity_db.py
+++ b/create_metricity_db.py
@@ -11,7 +11,7 @@ def parse_db_url(db_url: str) -> SplitResult:
"""Validate and split the given database url."""
db_url_parts = urlsplit(db_url)
if not all((
- db_url_parts.hostname,
+ db_url_parts.netloc,
db_url_parts.username,
db_url_parts.password,
)):
@@ -23,7 +23,7 @@ async def create_db() -> None:
parts = parse_db_url(build_db_uri())
try:
await asyncpg.connect(
- host=parts.hostname,
+ host=parts.netloc,
user=parts.username,
database=parts.path[1:],
password=parts.password,
@@ -33,7 +33,7 @@ async def create_db() -> None:
sys_conn = await asyncpg.connect(
database="template1",
user=parts.username,
- host=parts.hostname,
+ host=parts.netloc,
password=parts.password,
)
diff --git a/docker-compose.yml b/docker-compose.yml
index 00693e5..b378a95 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -38,4 +38,4 @@ services:
env_file:
- .env
environment:
- DATABASE_URI: postgres://pysite:pysite@postgres/metricity
+ DATABASE_URI: postgres+asyncpg://pysite:pysite@postgres/metricity
diff --git a/metricity/database.py b/metricity/database.py
index 4fdf465..347ce90 100644
--- a/metricity/database.py
+++ b/metricity/database.py
@@ -2,6 +2,7 @@
import logging
from datetime import UTC, datetime
+from urllib.parse import urlsplit
from sqlalchemy.engine import Dialect
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
@@ -14,6 +15,11 @@ log = logging.getLogger(__name__)
def build_db_uri() -> str:
"""Build the database uri from the config."""
if DatabaseConfig.uri:
+ parsed = urlsplit(DatabaseConfig.uri)
+ if parsed.scheme != "postgresql+asyncpg":
+ log.debug("The given db_url did not use the asyncpg driver. Updating the db_url to use asyncpg.")
+ return parsed._replace(scheme="postgresql+asyncpg").geturl()
+
return DatabaseConfig.uri
return (