diff options
| author | 2023-09-04 21:26:17 +0100 | |
|---|---|---|
| committer | 2023-09-04 21:26:17 +0100 | |
| commit | 48459ff3139f0fdf4c272fbbbf40148000fe28bd (patch) | |
| tree | 04848901819ede9c851bd07fea46bd709d91f7af /create_metricity_db.py | |
| parent | Merge pull request #73 from python-discord/jb3/deps-and-toolchain-updates (diff) | |
| parent | Subject Alembic files to the wrath of the linter (diff) | |
Merge pull request #74 from python-discord/jb3/sqlalchemy-2
SQLAlchemy 2
Diffstat (limited to 'create_metricity_db.py')
| -rw-r--r-- | create_metricity_db.py | 57 | 
1 files changed, 32 insertions, 25 deletions
| diff --git a/create_metricity_db.py b/create_metricity_db.py index 215ba6e..2dc3a75 100644 --- a/create_metricity_db.py +++ b/create_metricity_db.py @@ -1,42 +1,49 @@  """Ensures the metricity db exists before running migrations.""" -import os +import asyncio  from urllib.parse import SplitResult, urlsplit -import psycopg2 -from psycopg2 import sql -from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +import asyncpg + +from metricity.database import build_db_uri  def parse_db_url(db_url: str) -> SplitResult:      """Validate and split the given database url."""      db_url_parts = urlsplit(db_url)      if not all(( -        db_url_parts.hostname, +        db_url_parts.netloc,          db_url_parts.username,          db_url_parts.password,      )):          raise ValueError("The given db_url is not a valid PostgreSQL database URL.")      return db_url_parts +async def create_db() -> None: +    """Create the Metricity database if it does not exist.""" +    parts = parse_db_url(build_db_uri()) +    try: +        await asyncpg.connect( +            host=parts.hostname, +            port=parts.port, +            user=parts.username, +            database=parts.path[1:], +            password=parts.password, +        ) +    except asyncpg.InvalidCatalogNameError: +        print("Creating metricity database.") # noqa: T201 +        sys_conn = await asyncpg.connect( +            database="template1", +            user=parts.username, +            host=parts.hostname, +            port=parts.port, +            password=parts.password, +        ) + +        await sys_conn.execute( +            f'CREATE DATABASE "{parts.path[1:] or "metricity"}" OWNER "{parts.username}"', +        ) + +        await sys_conn.close()  if __name__ == "__main__": -    database_parts = parse_db_url(os.environ["DATABASE_URI"]) - -    conn = psycopg2.connect( -        host=database_parts.hostname, -        port=database_parts.port, -        user=database_parts.username, -        password=database_parts.password, -    ) - -    db_name = database_parts.path[1:] or "metricity" - -    # Required to create a database in a .execute() call -    conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) -    with conn.cursor() as cursor: -        cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (db_name,)) -        exists = cursor.fetchone() -        if not exists: -            print("Creating metricity database.")  # noqa: T201 -            cursor.execute(sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name))) -    conn.close() +    asyncio.get_event_loop().run_until_complete(create_db()) | 
