Fastapi: [QUESTION] Elegant way to exclude middleware in unit tests?

Created on 1 Oct 2019  Β·  23Comments  Β·  Source: tiangolo/fastapi

Description

Is it possible to disable or override middleware in unit tests?

Additional context

I have written some middleware to obtain a database pool connection and release it when the request is done. I'm using asyncpg directly for this as follows:

@app.middleware("http")
async def database_middleware(request: Request, call_next):
    conn = None
    try:
        conn = await database_pool.acquire()

        try:
            await conn.fetch("SELECT 1")
        except ConnectionDoesNotExistError:
            conn = await database_pool.acquire()

        request.state.db = conn
        return await call_next(request)
    except (PostgresConnectionError, OSError) as e:
        logger.error("Unable to connect to the database: %s", e)
        return Response(
            "Unable to connect to the database.", status_code=HTTP_500_INTERNAL_SERVER_ERROR
        )
    except SyntaxOrAccessError as e:
        logger.error("Unable to execute query: %s", e)
        return Response(
            "Unable to execute the required query to obtain data from the database.",
            status_code=HTTP_500_INTERNAL_SERVER_ERROR,
        )
    finally:
        if conn:
            await database_pool.release(conn)

I then have a little get_db dependency injection which is similar to the docs:

def get_db(request: Request):
    return request.state.db

I'd like to mock my database in my unit tests so I can verify that I'm building the correct SQL given the appropriate params.

I've setup a pytest fixture to mock the database as follows:

@pytest.fixture(scope="function")
def db_mock(request):
    def fin():
        del app.dependency_overrides[get_db]

    db_mock = mock.MagicMock()
    app.dependency_overrides[get_db] = lambda: db_mock

    request.addfinalizer(fin)
    return db_mock

This works perfectly when the middleware is disabled. However, when it's enabled, naturally all endpoints fail with a 500 error as the database fails to connect.

After looking a little deeper into starlette codebase and the way middleware works, I couldn't seem to find an elegant way to disable or override middleware in my unit tests.

Any help is greatly appreciated. Also if you see anything wrong with the approach I'm taking, please do let me know.

Huge thanks!
Fotis

question

Most helpful comment

In case anyone is interested in using my implementation, here are the related unit tests:

tests/utils.py

from unittest.mock import MagicMock


class AsyncMagicMock(MagicMock):
    """Implements a MagicMock class which return async methods."""

    async def __call__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
        return super().__call__(*args, **kwargs)

tests/conftest.py

import pytest
from asyncpg.pool import PoolConnectionProxy
from fastapi import Depends, FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from myapp import database
from myapp.database import get_db


@pytest.fixture
def app():
    app = FastAPI()
    app.add_event_handler("startup", database.create_pool)
    app.add_event_handler("shutdown", database.close_pool)
    app.add_middleware(BaseHTTPMiddleware, dispatch=database.middleware)

    @app.get("/")
    async def root(db: PoolConnectionProxy = Depends(get_db)):
        return await db.fetch("SELECT * FROM data")

    return app

tests/test_database.py

from unittest import mock

from asyncpg import ConnectionDoesNotExistError, SyntaxOrAccessError
from starlette.testclient import TestClient

from .utils import AsyncMagicMock


@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_successful_query(create_pool_mock, app):
    db_mock = create_pool_mock.return_value.acquire.return_value
    db_mock.fetch.return_value = []

    with TestClient(app) as client:
        response = client.get("/")
        assert response.status_code == 200

    assert create_pool_mock.call_args == mock.call(
        host="localhost",
        port=5432,
        user="user",
        password="secret123",
        database="myapp",
        min_size=0,
        max_size=20,
        max_inactive_connection_lifetime=0,
    )
    assert create_pool_mock.return_value.acquire.call_count == 1
    assert db_mock.execute.call_args == mock.call("SELECT 1")
    assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
    assert create_pool_mock.return_value.release.call_count == 1


@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_reestablish_connection(create_pool_mock, app):
    db_mock = create_pool_mock.return_value.acquire.return_value
    db_mock.execute.side_effect = ConnectionDoesNotExistError
    db_mock.fetch.return_value = []

    with TestClient(app) as client:
        response = client.get("/")
        assert response.status_code == 200

    assert create_pool_mock.called
    assert create_pool_mock.return_value.acquire.call_count == 2
    assert db_mock.execute.call_args == mock.call("SELECT 1")
    assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
    assert create_pool_mock.return_value.release.call_count == 1


@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_failed_connection(create_pool_mock, app):
    create_pool_mock.return_value.acquire.side_effect = ConnectionRefusedError

    with TestClient(app) as client:
        response = client.get("/")
        assert response.status_code == 500
        assert response.headers["content-type"] == "application/json"
        assert response.json() == {"detail": "Unable to connect to the database."}

    assert create_pool_mock.called
    assert create_pool_mock.return_value.acquire.call_count == 1
    assert not create_pool_mock.return_value.acquire.return_value.fetch.called
    assert not create_pool_mock.return_value.release.called


@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_failed_query(create_pool_mock, app):
    db_mock = create_pool_mock.return_value.acquire.return_value
    db_mock.fetch.side_effect = SyntaxOrAccessError

    with TestClient(app) as client:
        response = client.get("/")
        assert response.status_code == 500
        assert response.headers["content-type"] == "application/json"
        assert response.json() == {
            "detail": "Unable to execute the required query to obtain data from the database."
        }

    assert create_pool_mock.called
    assert create_pool_mock.return_value.acquire.call_count == 1
    assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
    assert create_pool_mock.return_value.release.call_count == 1

All 23 comments

I'd try to use asynctest CoroutineMock instead of MagicMock
edit: ha sorry your mock is a on a def

I'd try to use asynctest CoroutineMock instead of MagicMock
edit: ha sorry your mock is a on a def

Oh interesting idea, but yeah, pytest injection doesn't use async at this stage. I'm injecting this fixture as follows:

def test_events(client, db_mock, current_user_mock):
    db_mock.fetch.return_value = Future()
    db_mock.fetch.return_value.set_result([])

    response = client.post("/v1/events/", json={"id": ["bla"]})
    assert response.status_code == 200
    assert response.json() == []
    assert db_mock.fetch.call_args == mock.call(
        "SELECT * "
        "FROM events"
        "WHERE id = ANY ($1) "
        "ORDER BY event_timestamp DESC "
        "LIMIT $2",
        ["bla"], 20
    )

Of course, I'm totally open to better ways of accomplishing this! πŸ˜„ I'm definitely a bit new to async in Python and am coming from Falcon (which is not async at all).

Actually @euri10, you're right, I can use the asynctest library (which I hadn't heard of) for slightly more elegant results:

@pytest.fixture(scope="function")
def db_mock(request):
    def fin():
        del app.dependency_overrides[get_db]

    db_mock = asynctest.Mock()
    app.dependency_overrides[get_db] = lambda: db_mock

    request.addfinalizer(fin)
    return db_mock

def test_events(client, db_mock, current_user_mock):
    db_mock.fetch = asynctest.CoroutineMock(return_value=[])

    response = client.post("/v1/events/", json={"id": ["bla"]})
    assert response.status_code == 200
    assert response.json() == []
    assert db_mock.fetch.call_args == mock.call(
        "SELECT * "
        "FROM events "
        "WHERE id = ANY ($1) "
        "ORDER BY event_timestamp DESC "
        "LIMIT $2",
        ["bla"], 20
    )

The reason this works is that the endpoint itself uses await so it's all good πŸ˜„

you can even use pytest.asyncio and have

@pytest.mark.asyncio
async def test_events(client, db_mock, current_user_mock):

So far, the only idea I can come up with is writing a create_app function which selectively adds everything:

def create_app(environment: str = os.getenv("ENVIRONMENT")):
    app = FastAPI()

    # Don't register events and middleware relating to the database while testing.
    # Event handlers don't seem to run during testing anyway but we'll do this to be sure.
    if environment != "testing":
        app.add_event_handler("startup", startup)
        app.add_event_handler("shutdown", shutdown)
        app.add_middleware(BaseHTTPMiddleware, database_middleware)

    api_router = APIRouter()
    api_router.include_router(events.router, prefix="/events", tags=["events"])
    # ...

    app.include_router(api_router, prefix="/v1")

More suggestions or ideas welcome though πŸ˜„

@fgimian that’s exactly what I do

just to understand your setup a little bit better and eventually come up with a "better" mock @fgimian
: where is your asyncpg.Pool declared (database_pool in your code if I'm correct), could you share a minimum example where your at the moment ?

If I understand correctly from your post, the goal is to be sure a db_mock is passed the same sql query the enpoint once triggered is supposed to generate, this without having to pass through the db_middleware, is that correct ?

I think it's a very interesting topic, I'm not using directly asyncpg and deal with db connection with startup and shutdown lifespan events using the databases package, and I'd be very very interested in testing the 2 different approaches with wrk

I'll definitely try out the databases package! I honestly thought it was a local import when I saw the code and clearly didn't read that part of the guide well enough.

My current approach is not perfect, but goes something like this:

# main.py
app = FastAPI()
logger = logging.getLogger(__name__)
database_pool = None


@app.on_event("startup")
async def startup():
    # TODO: Is there a better way to create the pool without using a global?
    global database_pool  # pylint: disable=global-statement
    database_pool = await asyncpg.create_pool(
        host=config.DATABASE_HOST,
        port=config.DATABASE_PORT,
        user=config.DATABASE_USERNAME,
        password=config.DATABASE_PASSWORD,
        database=config.DATABASE_DBNAME,
        min_size=0,  # don't create any connections upon start-up
        max_size=config.DATABASE_MAX_CONNECTIONS,
        max_inactive_connection_lifetime=0,  # never close connections after they're established
    )


@app.on_event("shutdown")
async def shutdown():
    await database_pool.close()


@app.middleware("http")
async def database_middleware(request: Request, call_next):
    # TODO: Improve the mechanism used to determine whether we are in a unit test or not
    if request.headers["user-agent"] == "testclient":
        return await call_next(request)

    conn = None
    try:
        conn = await database_pool.acquire()

        try:
            await conn.fetch("SELECT 1")
        except ConnectionDoesNotExistError:
            conn = await database_pool.acquire()

        request.state.db = conn
        return await call_next(request)
    except (PostgresConnectionError, OSError) as e:
        logger.error("Unable to connect to the database: %s", e)
        return Response(
            "Unable to connect to the database.", status_code=HTTP_500_INTERNAL_SERVER_ERROR
        )
    except SyntaxOrAccessError as e:
        logger.error("Unable to execute query: %s", e)
        return Response(
            "Unable to execute the required query to obtain data from the database.",
            status_code=HTTP_500_INTERNAL_SERVER_ERROR,
        )
    finally:
        if conn:
            await database_pool.release(conn)


# database.py
def get_db(request: Request):
    return request.state.db

Hope this helps. As per the TODOs, there are still some things to improve if possible.

The important thing with the database is that I don't want a connection to be made when the app starts up because that would mean the app could fail to start if the database was down at that moment. My preference is that the initial connections are established by the first few web requests and then are kept open forever. I'm typically using a pool size of around 20 per process.

Interestingly, it seems that uvicorn doesn't allow starting an app from a create_app type function. I'll test out doing this via Gunicorn today to see if that allows it.

Cheers
Fotis

@fgimian There are many ways to make sure your app waits until the database/other necessary resources are ready. Some examples:

  • Use a script like wait-for-it to wait to launch your server until the database is up (you could check using your "SELECT 1" test)
  • Wrap the "SELECT 1" call using tenacity's retry function (which will cause it to be auto-retried on failure and return once it passes) and put a call to it in your startup event.


Click to expand (adapts your startup function to use tenacity)

import logging

from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
max_tries = 60 * 5  # 5 minutes
wait_seconds = 3


@retry(
    stop=stop_after_attempt(max_tries),
    wait=wait_fixed(wait_seconds),
    before=before_log(logger, logging.INFO),
    after=after_log(logger, logging.WARN),
)
async def wait_for_db(database_pool) -> None:
    conn = await database_pool.acquire()
    await conn.fetch("SELECT 1")


@app.on_event("startup")
async def startup():
    # TODO: Is there a better way to create the pool without using a global?
    global database_pool  # pylint: disable=global-statement
    database_pool = await asyncpg.create_pool(
        host=config.DATABASE_HOST,
        port=config.DATABASE_PORT,
        user=config.DATABASE_USERNAME,
        password=config.DATABASE_PASSWORD,
        database=config.DATABASE_DBNAME,
        min_size=0,  # don't create any connections upon start-up
        max_size=config.DATABASE_MAX_CONNECTIONS,
        max_inactive_connection_lifetime=0,  # never close connections after they're established
    )
    await wait_for_db(database_pool)

This way you wouldn't need to waste a db round trip at the start of every request.

Also, +1 for databases.

@dmontagu Thank you so much, this is an awesome idea. The use of SELECT 1 upon each request is actually still useful and is an approach that SQLAlchemy uses too. See Disconnect Handling - Pessimistic for further information.

The reason it is needed is that database connections in a pool can go stale after initially established (due to network interruptions or db restarts). So you can't be sure that acquire will actually get you a working connection if it is returning a pre-established connection in the pool.

We had this issue in production too and the idea was to add a small SELECT 1 and accept a small amount of overhead to ensure that we never get a stale connection back.

@fgimian Those are some good points; thanks for sharing the sqlalchemy docs page, I wasn't aware of that and it was good reading.

In light of this consideration, I think I personally would be inclined to make use of an approach that just periodically polls the connection pool for stale connections, rather than going full "pessimistic", but I guess for most applications the overhead would probably be insignificant.

I'd like to share my solution to this problem. Firstly, I was never fond of using middleware of opening database connections. This implies that every single endpoint will open a database connection whether it is needed or not (e.g. in our case, the auth endpoint doesn't use the database as it talks to LDAP).

I looked into the databases library which does look great, but offers no advantages I really need over asyncpg. So the following solution still uses asyncpg.

The solution goes something like this:

  • Inject a dependency when the database is needed in an endpoint
  • This injected function will obtain a connection from the database pool and hand it to the endpoint while also saving the conection in request.state
  • A small middleware function will check whether or not a connection has been set in request.state after an endpoint completes processing and will close the connection if needed

So here's my database.py:

import logging

import asyncpg
from asyncpg import ConnectionDoesNotExistError, PostgresConnectionError, SyntaxOrAccessError
from fastapi import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR

from .settings import get_settings


pool = None
logger = logging.getLogger(__name__)


async def create_pool() -> None:
    global pool  # pylint: disable=global-statement
    settings = get_settings()
    pool = await asyncpg.create_pool(
        host=settings.database.host,
        port=settings.database.port,
        user=settings.database.username,
        password=settings.database.password,
        database=settings.database.db_name,
        min_size=0,  # don't create any connections upon start-up
        max_size=settings.database.max_connections,
        max_inactive_connection_lifetime=0,  # never close connections after they're established
    )


async def close_pool() -> None:
    await pool.close()


async def get_db(request: Request) -> asyncpg.connection.Connection:
    """Obtain a database connection from the pool."""
    try:
        conn = await pool.acquire()

        # Test that the connection is still active by running a trivial query
        # (https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic)
        try:
            await conn.execute("SELECT 1")
        except ConnectionDoesNotExistError:
            conn = await pool.acquire()

        request.state.db = conn
        return conn
    except (PostgresConnectionError, OSError) as e:
        logger.error("Unable to connect to the database: %s", e)
        raise HTTPException(
            status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail="Unable to connect to the database."
        )


async def middleware(request: Request, call_next):
    """Ensures that any open database connection is closed after each request."""
    try:
        return await call_next(request)
    except SyntaxOrAccessError as e:
        logger.error("Unable to execute query: %s", e)
        return JSONResponse(
            status_code=HTTP_500_INTERNAL_SERVER_ERROR,
            content={
                "detail": "Unable to execute the required query to obtain data from the database."
            },
        )
    finally:
        if hasattr(request.state, "db"):
            await pool.release(request.state.db)

I'm still not totally fond of using a global, but it seems to be the only way at present. I attempted to wrap this in a class and call an async method when injecting the dependency, but FastAPI didn't seem to know how to deal with that (it didn't execute the injection as a coroutine when it was a class method).

The great part here is that now, a connection is not established while running unit tests as I have a mock version of get_db:

@pytest.fixture
def client():
    return TestClient(app)


@pytest.fixture(scope="function")
def db_mock(request):
    def fin():
        del app.dependency_overrides[get_db]

    db = asynctest.Mock()
    app.dependency_overrides[get_db] = lambda: db

    request.addfinalizer(fin)
    return db

I will now be attempting to write unit tests for my database module. πŸ˜„

Of course, I'm totally open to critique on my solution and welcome more ideas.

Cheers
Fotis

I played a little bit with it and like it, no need for a db anymore !
The only thing I find, that can potentially become very boring down the
road is that, but maybe it is me who wrote it weirdly, you will have to
write in each and every test something like mocked_db.fetch =
CoroutineMock(return_value=[]) or execute, etc depending on the method(s)
used in the endpoint.

https://gitlab.com/euri10/rapidfastapitest/blob/master/582_mock_middleware.py#L105

So there might exist a nicer fixture that provides those. Couldn't find a
nice way yet

Le sam. 5 oct. 2019 Γ  8:56 AM, Fotis Gimian notifications@github.com a
Γ©crit :

I'd like to share my solution to this problem. Firstly, I was never fond
of using middleware of opening database connections. This implies that
every single endpoint will open a database connection whether it is needed
or not (e.g. in our case, the auth endpoint doesn't use the database as it
talks to LDAP).

I looked into the databases library which does look great, but offers no
advantages I really need over asyncpg. So the following solution still
uses asyncpg.

The solution goes something like this:

  • Inject a dependency when the database is needed in an endpoint
  • This injected function will obtain a connection from the database
    pool and hand it to the endpoint while also saving the conection in
    request.state
  • A small middleware function will check whether or not a connection
    has been set in request.state after an endpoint completes processing
    and will close the connection if needed

So here's my database.py:

import logging

import asyncpg
from asyncpg import ConnectionDoesNotExistError, PostgresConnectionError, SyntaxOrAccessError
from fastapi import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR

from . import config

pool = None

logger = logging.getLogger(__name__)

async def create_pool() -> None:

global pool

pool = await asyncpg.create_pool(

    host=config.DATABASE_HOST,

    port=config.DATABASE_PORT,

    user=config.DATABASE_USERNAME,

    password=config.DATABASE_PASSWORD,

    database=config.DATABASE_DBNAME,

    min_size=0,  # don't create any connections upon start-up

    max_size=config.DATABASE_MAX_CONNECTIONS,

    max_inactive_connection_lifetime=0,  # never close connections after they're established

)

async def close_pool() -> None:

await pool.close()

async def get_db(request: Request) -> asyncpg.connection.Connection:

"""Obtain a database connection from the pool."""

try:

    conn = await pool.acquire()



    # Test that the connection is still active by running a trivial query

    # (https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic)

    try:

        await conn.fetch("SELECT 1")

    except ConnectionDoesNotExistError:

        conn = await pool.acquire()



    request.state.db = conn

    return conn

except (PostgresConnectionError, OSError) as e:

    logger.error("Unable to connect to the database: %s", e)

    raise HTTPException(

        status_code=HTTP_500_INTERNAL_SERVER_ERROR,

        detail="Unable to connect to the database.",

    )

except SyntaxOrAccessError as e:

    logger.error("Unable to execute query: %s", e)

    raise HTTPException(

        status_code=HTTP_500_INTERNAL_SERVER_ERROR,

        detail="Unable to execute the required query to obtain data from the database.",

    )

async def middleware(request: Request, call_next):

"""Ensures that any open database connection is closed after each request."""

try:

    return await call_next(request)

finally:

    if hasattr(request.state, "db"):

        await pool.release(request.state.db)

I'm still not totally fond of using a global, but it seems to be the only
way at present. I attempted to wrap this in a class and call an async
method when injecting the dependency, but FastAPI didn't seem to know how
to deal with that (it didn't execute the injection as a coroutine when it
was a class method).

The great part here is that now, a connection is not established while
running unit tests as I have a mock version of get_db:

@pytest.fixture
def client():

return TestClient(app)

@pytest.fixture(scope="function")
def db_mock(request):

def fin():

    del app.dependency_overrides[get_db]



db = asynctest.Mock()

app.dependency_overrides[get_db] = lambda: db



request.addfinalizer(fin)

return db

I will now be attempting to write unit tests for my database module. πŸ˜„

Of course, I'm totally open to critique on my solution and welcome more
ideas.

Cheers
Fotis

β€”
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/tiangolo/fastapi/issues/582?email_source=notifications&email_token=AAINSPS3AMH6NODBFZTNNYTQNA3DJA5CNFSM4I4IN2NKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEANL6WA#issuecomment-538623832,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAINSPTTBO4GBPI67LIHMM3QNA3DJANCNFSM4I4IN2NA
.

Hey @euri10, yep that's exactly what I'm doing.

def test_bla(client, db_mock, current_user_mock):
    db_mock.fetch = asynctest.CoroutineMock(
        return_value=[
            Record([("abc", "def")]),
            ...
        ]

The only little caviet was that asyncpg's Record class can't be instantiated. So I had to create a replica myself which looks like this:

from collections import OrderedDict


class Record(OrderedDict):
    """
    Provides a very similar Record class to that returned by asyncpg.  A custom implementation
    is needed as it is currently impossible to create asyncpg Record objects from Python code.
    """

    def __getitem__(self, key_or_index):
        if isinstance(key_or_index, int):
            return list(self.values())[key_or_index]

        return super().__getitem__(key_or_index)

    def __repr__(self):
        return "<{class_name} {items}>".format(
            class_name=self.__class__.__name__,
            items=" ".join(f"{k}={v!r}" for k, v in self.items()),
        )

This has worked perfectly for all my unit tests, I now have 100% coverage on all my endpoints. The challenge is writing unit tests for the database.py file.

I figured you can declare execute, fetch, fetchrow etc as mock coroutines
in the db mock fixture, but the expected record is rd is still missing,
there might be a way
I'll play a little more with the record mock you wrote, thanks!

Le sam. 5 oct. 2019 Γ  1:33 PM, Fotis Gimian notifications@github.com a
Γ©crit :

Hey @euri10 https://github.com/euri10, yep that's exactly what I'm
doing.

def test_bla(client, db_mock, current_user_mock):
db_mock.fetch = asynctest.CoroutineMock(
return_value=[
Record([("abc", "def")]),
...
]

The only little caviet was that asyncpg's Record class can't be
instantiated. So I had to create a replica myself which looks like this:

from collections import OrderedDict

class Record(OrderedDict):
""" Provides a very similar Record class to that returned by asyncpg. A custom implementation is needed as it is currently impossible to create asyncpg Record objects from Python code. """

def __getitem__(self, key_or_index):
    if isinstance(key_or_index, int):
        return list(self.values())[key_or_index]

    return super().__getitem__(key_or_index)

def __repr__(self):
    return "<{class_name} {items}>".format(
        class_name=self.__class__.__name__,
        items=" ".join(f"{k}={v!r}" for k, v in self.items()),
    )

This has worked perfectly for all my unit tests, I now have 100% coverage
on all my endpoints. The challenge is writing unit tests for the
database.py file.

β€”
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/tiangolo/fastapi/issues/582?email_source=notifications&email_token=AAINSPX7UNMY3LSWRBNWBELQNB3P5A5CNFSM4I4IN2NKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEANQLTI#issuecomment-538641869,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAINSPXFPIEXVFJ7PP7XXVDQNB3P5ANCNFSM4I4IN2NA
.

No worries at all @euri10, really appreciate this discussion and how helpful you and @dmontagu have been here. It's so nice to see such a friendly community around FastAPI! πŸ˜„

I'll likely write a blog post or two after I finish this API implementation with various learnings to share.

Cheers
Fotis

In case anyone is interested in using my implementation, here are the related unit tests:

tests/utils.py

from unittest.mock import MagicMock


class AsyncMagicMock(MagicMock):
    """Implements a MagicMock class which return async methods."""

    async def __call__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
        return super().__call__(*args, **kwargs)

tests/conftest.py

import pytest
from asyncpg.pool import PoolConnectionProxy
from fastapi import Depends, FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from myapp import database
from myapp.database import get_db


@pytest.fixture
def app():
    app = FastAPI()
    app.add_event_handler("startup", database.create_pool)
    app.add_event_handler("shutdown", database.close_pool)
    app.add_middleware(BaseHTTPMiddleware, dispatch=database.middleware)

    @app.get("/")
    async def root(db: PoolConnectionProxy = Depends(get_db)):
        return await db.fetch("SELECT * FROM data")

    return app

tests/test_database.py

from unittest import mock

from asyncpg import ConnectionDoesNotExistError, SyntaxOrAccessError
from starlette.testclient import TestClient

from .utils import AsyncMagicMock


@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_successful_query(create_pool_mock, app):
    db_mock = create_pool_mock.return_value.acquire.return_value
    db_mock.fetch.return_value = []

    with TestClient(app) as client:
        response = client.get("/")
        assert response.status_code == 200

    assert create_pool_mock.call_args == mock.call(
        host="localhost",
        port=5432,
        user="user",
        password="secret123",
        database="myapp",
        min_size=0,
        max_size=20,
        max_inactive_connection_lifetime=0,
    )
    assert create_pool_mock.return_value.acquire.call_count == 1
    assert db_mock.execute.call_args == mock.call("SELECT 1")
    assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
    assert create_pool_mock.return_value.release.call_count == 1


@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_reestablish_connection(create_pool_mock, app):
    db_mock = create_pool_mock.return_value.acquire.return_value
    db_mock.execute.side_effect = ConnectionDoesNotExistError
    db_mock.fetch.return_value = []

    with TestClient(app) as client:
        response = client.get("/")
        assert response.status_code == 200

    assert create_pool_mock.called
    assert create_pool_mock.return_value.acquire.call_count == 2
    assert db_mock.execute.call_args == mock.call("SELECT 1")
    assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
    assert create_pool_mock.return_value.release.call_count == 1


@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_failed_connection(create_pool_mock, app):
    create_pool_mock.return_value.acquire.side_effect = ConnectionRefusedError

    with TestClient(app) as client:
        response = client.get("/")
        assert response.status_code == 500
        assert response.headers["content-type"] == "application/json"
        assert response.json() == {"detail": "Unable to connect to the database."}

    assert create_pool_mock.called
    assert create_pool_mock.return_value.acquire.call_count == 1
    assert not create_pool_mock.return_value.acquire.return_value.fetch.called
    assert not create_pool_mock.return_value.release.called


@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_failed_query(create_pool_mock, app):
    db_mock = create_pool_mock.return_value.acquire.return_value
    db_mock.fetch.side_effect = SyntaxOrAccessError

    with TestClient(app) as client:
        response = client.get("/")
        assert response.status_code == 500
        assert response.headers["content-type"] == "application/json"
        assert response.json() == {
            "detail": "Unable to execute the required query to obtain data from the database."
        }

    assert create_pool_mock.called
    assert create_pool_mock.return_value.acquire.call_count == 1
    assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
    assert create_pool_mock.return_value.release.call_count == 1

I'm happy to close this issue now if there are no further comments. Please let me know your preference. πŸ˜„

Kindest regards
Fotis

As there are no further follow-up comments, I'll close this ticket. Just wanted to thank everyone greatly for their input and help here! πŸ˜„

Cheers
Fotis

Just wanted to say a huge thanks to the FastAPI team for this new feature https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/ which completely solves this dilemma.

Absolutely beautiful! 😍

Thanks again
Fotis

Thanks for the help here everyone! :clap: :bow:

Thanks for reporting back and closing the issue @fgimian :+1: I'm glad you're liking FastAPI! :smile:

Absolutely beautiful! 😍

Hi, Can you give a final best approach for having global database access with yield?

With Python3.8 this is easily remedied like so:

from unittest.mock import patch, AsyncMock


@pytest.fixture(autouse=True)
def mocked_asyncpg():
    with patch("circular_api.utils.database.create_pool", new_callable=AsyncMock) as mocked_pool:
        mocked = mocked_pool.return_value.acquire.return_value

        yield mocked
Was this page helpful?
0 / 5 - 0 ratings