diff options
| -rw-r--r-- | create_metricity_db.py | 6 | ||||
| -rw-r--r-- | docker-compose.yml | 2 | ||||
| -rw-r--r-- | metricity/database.py | 6 | 
3 files changed, 10 insertions, 4 deletions
diff --git a/create_metricity_db.py b/create_metricity_db.py index a792f8e..1c7acaa 100644 --- a/create_metricity_db.py +++ b/create_metricity_db.py @@ -11,7 +11,7 @@ def parse_db_url(db_url: str) -> SplitResult:      """Validate and split the given database url."""      db_url_parts = urlsplit(db_url)      if not all(( -        db_url_parts.hostname, +        db_url_parts.netloc,          db_url_parts.username,          db_url_parts.password,      )): @@ -23,7 +23,7 @@ async def create_db() -> None:      parts = parse_db_url(build_db_uri())      try:          await asyncpg.connect( -            host=parts.hostname, +            host=parts.netloc,              user=parts.username,              database=parts.path[1:],              password=parts.password, @@ -33,7 +33,7 @@ async def create_db() -> None:          sys_conn = await asyncpg.connect(              database="template1",              user=parts.username, -            host=parts.hostname, +            host=parts.netloc,              password=parts.password,          ) diff --git a/docker-compose.yml b/docker-compose.yml index 00693e5..b378a95 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -38,4 +38,4 @@ services:      env_file:        - .env      environment: -      DATABASE_URI: postgres://pysite:pysite@postgres/metricity +      DATABASE_URI: postgres+asyncpg://pysite:pysite@postgres/metricity diff --git a/metricity/database.py b/metricity/database.py index 4fdf465..347ce90 100644 --- a/metricity/database.py +++ b/metricity/database.py @@ -2,6 +2,7 @@  import logging  from datetime import UTC, datetime +from urllib.parse import urlsplit  from sqlalchemy.engine import Dialect  from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine @@ -14,6 +15,11 @@ log = logging.getLogger(__name__)  def build_db_uri() -> str:      """Build the database uri from the config."""      if DatabaseConfig.uri: +        parsed = urlsplit(DatabaseConfig.uri) +        if parsed.scheme != "postgresql+asyncpg": +            log.debug("The given db_url did not use the asyncpg driver. Updating the db_url to use asyncpg.") +            return parsed._replace(scheme="postgresql+asyncpg").geturl() +          return DatabaseConfig.uri      return (  |