简体   繁体   中英

FastAPI database dependency setup for connection pooling

Consider the following fastapi setup:

application.add_event_handler(
    "startup",
    create_start_app_handler(application, settings),
)

def create_start_app_handler(
    app: FastAPI,
    settings: AppSettings,
) -> Callable: 
    async def start_app() -> None:
        await connect_to_db(app, settings)
    return start_app

async def connect_to_db(app: FastAPI, settings: AppSettings) -> None:
    db_url = settings.DATABASE_URL
    engine = create_engine(db_url, pool_size=settings.POOL_SIZE, max_overflow=settings.MAX_OVERFLOW)

    SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
    db = SessionLocal()

    def close_db():
        db.close()
        engine.dispose()

    app.state.db = db
    app.state.close_db = close_db

close_db is used to close the database connection on app shutdown I have the following dependencies defined:

def _get_db(request: Request) -> Generator:
    yield request.app.state.db

def get_repository(
    repo_type: Type[BaseRepository],
) -> Callable[[Session], BaseRepository]:
    def _get_repo(
        sess: Session = Depends(_get_db),
    ) -> BaseRepository:
        return repo_type(sess)

    return _get_repo

Would this still allow me to take advantage of connection pooling?

Also, this feels a little hacky and I could use some feedback if there's anything in particular that I should not be doing.

To be blunt; it seems overly complicated for something that is pretty well documented in the docs .

In your case, you create only 1 instance of SessionLocal() and will share that across all your requests (because you store it in the app.state ). In other words: no this will not be using connection pooling, it will use only 1 connection.

A better approach is to yield an instance per request, either via middleware or via a dependency. That way, the connection is actually closed when the incoming request has been fully handled. For example, like this:

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

@app.get("/")
def root(db: SessionLocal = Depends(get_db)):
    return "hello world"

I am not sure how you ended up where you ended up, but I would recommend to refactor a bunch.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM