aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Joe Banks <[email protected]>2023-09-04 20:39:14 +0100
committerGravatar Joe Banks <[email protected]>2023-09-04 20:39:14 +0100
commit0c3e67161411b3b6eaa463d80b2b183af43da6a1 (patch)
tree02793cc299010ff2214572a35b93a47b031d951d
parentConsistent version specifying for new dependencies (diff)
Swap database creation script to asyncpg
-rw-r--r--create_metricity_db.py53
1 files changed, 29 insertions, 24 deletions
diff --git a/create_metricity_db.py b/create_metricity_db.py
index 215ba6e..a792f8e 100644
--- a/create_metricity_db.py
+++ b/create_metricity_db.py
@@ -1,10 +1,10 @@
"""Ensures the metricity db exists before running migrations."""
-import os
+import asyncio
from urllib.parse import SplitResult, urlsplit
-import psycopg2
-from psycopg2 import sql
-from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
+import asyncpg
+
+from metricity.database import build_db_uri
def parse_db_url(db_url: str) -> SplitResult:
@@ -18,25 +18,30 @@ def parse_db_url(db_url: str) -> SplitResult:
raise ValueError("The given db_url is not a valid PostgreSQL database URL.")
return db_url_parts
+async def create_db() -> None:
+ """Create the Metricity database if it does not exist."""
+ parts = parse_db_url(build_db_uri())
+ try:
+ await asyncpg.connect(
+ host=parts.hostname,
+ user=parts.username,
+ database=parts.path[1:],
+ password=parts.password,
+ )
+ except asyncpg.InvalidCatalogNameError:
+ print("Creating metricity database.") # noqa: T201
+ sys_conn = await asyncpg.connect(
+ database="template1",
+ user=parts.username,
+ host=parts.hostname,
+ password=parts.password,
+ )
+
+ await sys_conn.execute(
+ f'CREATE DATABASE "{parts.path[1:] or "metricity"}" OWNER "{parts.username}"',
+ )
+
+ await sys_conn.close()
if __name__ == "__main__":
- database_parts = parse_db_url(os.environ["DATABASE_URI"])
-
- conn = psycopg2.connect(
- host=database_parts.hostname,
- port=database_parts.port,
- user=database_parts.username,
- password=database_parts.password,
- )
-
- db_name = database_parts.path[1:] or "metricity"
-
- # Required to create a database in a .execute() call
- conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
- with conn.cursor() as cursor:
- cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (db_name,))
- exists = cursor.fetchone()
- if not exists:
- print("Creating metricity database.") # noqa: T201
- cursor.execute(sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name)))
- conn.close()
+ asyncio.get_event_loop().run_until_complete(create_db())