aboutsummaryrefslogtreecommitdiffstats
path: root/thallium-backend/tests/conftest.py
blob: dbc9ec4a850bae8fb3b3f6e373289b73af913a9c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from collections.abc import AsyncGenerator, Callable

import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

from src.app import fastapi_app
from src.orm import Base
from src.settings import CONFIG, Connections, _get_db_session


@pytest.fixture(scope="session")
async def test_database_engine() -> AsyncEngine:
    """Yield back a Database engine object."""
    test_db_url = CONFIG.database_url.get_secret_value() + "_test"
    test_db_engine = create_async_engine(test_db_url, isolation_level="REPEATABLE READ", echo=False)

    # Use the engine from the main app to create the test DB
    main_engine = Connections.DB_ENGINE.execution_options(isolation_level="AUTOCOMMIT", echo=False)
    async with main_engine.connect() as conn:
        await conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_engine.url.database}"))
        await conn.execute(text(f"CREATE DATABASE {test_db_engine.url.database}"))

    return test_db_engine


@pytest.fixture()
async def db_session(test_database_engine: AsyncEngine) -> AsyncGenerator[AsyncSession]:
    """Yield an Asynchronous database session."""
    async with test_database_engine.begin() as conn:
        await conn.run_sync(Base.metadata.drop_all)
        await conn.run_sync(Base.metadata.create_all)
        async with AsyncSession(bind=conn, expire_on_commit=False) as session:
            yield session
            await session.close()


@pytest.fixture()
def override_db_session(db_session: AsyncSession) -> AsyncSession:
    """Yield the modified Database session that uses the correspondent Database."""

    async def _override_db_session() -> AsyncGenerator[AsyncSession]:
        yield db_session

    return _override_db_session


@pytest.fixture()
def app(override_db_session: Callable) -> FastAPI:
    """Override the default FastAPI app to use the overridden DB session."""
    fastapi_app.dependency_overrides[_get_db_session] = override_db_session
    return fastapi_app


@pytest.fixture()
async def http_client(app: FastAPI) -> AsyncGenerator[AsyncClient]:
    """Yield a client for testing the app."""
    async with AsyncClient(app=app, base_url="http://testserver", follow_redirects=True) as ac:
        yield ac