diff options
Diffstat (limited to 'create_metricity_db.py')
-rw-r--r-- | create_metricity_db.py | 53 |
1 files changed, 29 insertions, 24 deletions
diff --git a/create_metricity_db.py b/create_metricity_db.py index 215ba6e..a792f8e 100644 --- a/create_metricity_db.py +++ b/create_metricity_db.py @@ -1,10 +1,10 @@ """Ensures the metricity db exists before running migrations.""" -import os +import asyncio from urllib.parse import SplitResult, urlsplit -import psycopg2 -from psycopg2 import sql -from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +import asyncpg + +from metricity.database import build_db_uri def parse_db_url(db_url: str) -> SplitResult: @@ -18,25 +18,30 @@ def parse_db_url(db_url: str) -> SplitResult: raise ValueError("The given db_url is not a valid PostgreSQL database URL.") return db_url_parts +async def create_db() -> None: + """Create the Metricity database if it does not exist.""" + parts = parse_db_url(build_db_uri()) + try: + await asyncpg.connect( + host=parts.hostname, + user=parts.username, + database=parts.path[1:], + password=parts.password, + ) + except asyncpg.InvalidCatalogNameError: + print("Creating metricity database.") # noqa: T201 + sys_conn = await asyncpg.connect( + database="template1", + user=parts.username, + host=parts.hostname, + password=parts.password, + ) + + await sys_conn.execute( + f'CREATE DATABASE "{parts.path[1:] or "metricity"}" OWNER "{parts.username}"', + ) + + await sys_conn.close() if __name__ == "__main__": - database_parts = parse_db_url(os.environ["DATABASE_URI"]) - - conn = psycopg2.connect( - host=database_parts.hostname, - port=database_parts.port, - user=database_parts.username, - password=database_parts.password, - ) - - db_name = database_parts.path[1:] or "metricity" - - # Required to create a database in a .execute() call - conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - with conn.cursor() as cursor: - cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (db_name,)) - exists = cursor.fetchone() - if not exists: - print("Creating metricity database.") # noqa: T201 - cursor.execute(sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name))) - conn.close() + asyncio.get_event_loop().run_until_complete(create_db()) |