aboutsummaryrefslogtreecommitdiffstats
path: root/create_metricity_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'create_metricity_db.py')
-rw-r--r--create_metricity_db.py57
1 files changed, 32 insertions, 25 deletions
diff --git a/create_metricity_db.py b/create_metricity_db.py
index 215ba6e..2dc3a75 100644
--- a/create_metricity_db.py
+++ b/create_metricity_db.py
@@ -1,42 +1,49 @@
"""Ensures the metricity db exists before running migrations."""
-import os
+import asyncio
from urllib.parse import SplitResult, urlsplit
-import psycopg2
-from psycopg2 import sql
-from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
+import asyncpg
+
+from metricity.database import build_db_uri
def parse_db_url(db_url: str) -> SplitResult:
"""Validate and split the given database url."""
db_url_parts = urlsplit(db_url)
if not all((
- db_url_parts.hostname,
+ db_url_parts.netloc,
db_url_parts.username,
db_url_parts.password,
)):
raise ValueError("The given db_url is not a valid PostgreSQL database URL.")
return db_url_parts
+async def create_db() -> None:
+ """Create the Metricity database if it does not exist."""
+ parts = parse_db_url(build_db_uri())
+ try:
+ await asyncpg.connect(
+ host=parts.hostname,
+ port=parts.port,
+ user=parts.username,
+ database=parts.path[1:],
+ password=parts.password,
+ )
+ except asyncpg.InvalidCatalogNameError:
+ print("Creating metricity database.") # noqa: T201
+ sys_conn = await asyncpg.connect(
+ database="template1",
+ user=parts.username,
+ host=parts.hostname,
+ port=parts.port,
+ password=parts.password,
+ )
+
+ await sys_conn.execute(
+ f'CREATE DATABASE "{parts.path[1:] or "metricity"}" OWNER "{parts.username}"',
+ )
+
+ await sys_conn.close()
if __name__ == "__main__":
- database_parts = parse_db_url(os.environ["DATABASE_URI"])
-
- conn = psycopg2.connect(
- host=database_parts.hostname,
- port=database_parts.port,
- user=database_parts.username,
- password=database_parts.password,
- )
-
- db_name = database_parts.path[1:] or "metricity"
-
- # Required to create a database in a .execute() call
- conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
- with conn.cursor() as cursor:
- cursor.execute("SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (db_name,))
- exists = cursor.fetchone()
- if not exists:
- print("Creating metricity database.") # noqa: T201
- cursor.execute(sql.SQL("CREATE DATABASE {dbname}").format(dbname=sql.Identifier(db_name)))
- conn.close()
+ asyncio.get_event_loop().run_until_complete(create_db())