简体   繁体   English

如何为 FastAPI 应用程序编写 SQLAlchemy 测试装置

[英]How do I write SQLAlchemy test fixtures for FastAPI applications

I am writing a FastAPI application that uses a SQLAlchemy database.我正在编写一个使用 SQLAlchemy 数据库的FastAPI应用程序。 I have copied the example from the FastAPI documentation , simplifying the database schema for concisions' sake.我已经从 FastAPI 文档中复制了示例,为了简洁起见,简化了数据库模式。 The complete source is at the bottom of this post.完整的来源在这篇文章的底部。

This works.这有效。 I can run it with uvicorn sql_app.main:app and interact with the database via the Swagger docs.我可以使用uvicorn sql_app.main:app运行它并通过 Swagger 文档与数据库交互。 When it runs it creates a test.db in the working directory.当它运行时,它会在工作目录中创建一个test.db

Now I want to add a unit test.现在我想添加一个单元测试。 Something like this.像这样的东西。

from fastapi import status
from fastapi.testclient import TestClient
from pytest import fixture

from main import app


@fixture
def client() -> TestClient:
    return TestClient(app)


def test_fast_sql(client: TestClient):
    response = client.get("/users/")
    assert response.status_code == status.HTTP_200_OK
    assert response.json() == []

Using the source code below, this takes the test.db in the working directory as the database.使用下面的源代码,这将工作目录中的test.db作为数据库。 Instead I want to create a new database for every unit test that is deleted at the end of the test.相反,我想为在测试结束时删除的每个单元测试创​​建一个新数据库。

I could put the global database.engine and database.SessionLocal inside an object that is created at runtime, like so:我可以将全局database.enginedatabase.SessionLocal放在运行时创建的对象中,如下所示:

    class UserDatabase:
        def __init__(self, directory: Path):
            directory.mkdir(exist_ok=True, parents=True)
            sqlalchemy_database_url = f"sqlite:///{directory}/store.db"
            self.engine = create_engine(
                sqlalchemy_database_url, connect_args={"check_same_thread": False}
            )
            self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
            models.Base.metadata.create_all(bind=self.engine)

but I don't know how to make that work with main.get_db , since the Depends(get_db) logic ultimately assumes database.engine and database.SessionLocal are available globally.但我不知道如何使用main.get_db使其工作,因为Depends(get_db)逻辑最终假定database.enginedatabase.SessionLocal是全局可用的。

I'm used to working with Flask, whose unit testing facilities handle all this for you.我习惯于使用 Flask,它的单元测试工具会为您处理所有这些。 I don't know how to write it myself.我自己不知道怎么写。 Can someone show me the minimal changes I'd have to make in order to generate a new database for each unit test in this framework?有人可以告诉我为了在这个框架中为每个单元测试生成一个新数据库而必须做的最小更改吗?


The complete source of the simplified FastAPI/SQLAlchemy app is as follows.简化的 FastAPI/SQLAlchemy 应用程序的完整源代码如下。

database.py数据库.py

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()

models.py模型.py

from sqlalchemy import Column, Integer, String

from database import Base


class User(Base):
    __tablename__ = "users"

    id = Column(Integer, primary_key=True, index=True)
    name = Column(String)
    age = Column(Integer)

schemas.py模式.py

from pydantic import BaseModel


class UserBase(BaseModel):
    name: str
    age: int


class UserCreate(UserBase):
    pass


class User(UserBase):
    id: int

    class Config:
        orm_mode = True

crud.py粗略的.py

from sqlalchemy.orm import Session

import schemas
import models


def get_user(db: Session, user_id: int):
    return db.query(models.User).filter(models.User.id == user_id).first()


def get_users(db: Session, skip: int = 0, limit: int = 100):
    return db.query(models.User).offset(skip).limit(limit).all()


def create_user(db: Session, user: schemas.UserCreate):
    db_user = models.User(name=user.name, age=user.age)
    db.add(db_user)
    db.commit()
    db.refresh(db_user)
    return db_user

main.py主文件

from typing import List

from fastapi import Depends, FastAPI, HTTPException
from sqlalchemy.orm import Session

import schemas
import models
import crud
from database import SessionLocal, engine

models.Base.metadata.create_all(bind=engine)

app = FastAPI()


# Dependency
def get_db():
    try:
        db = SessionLocal()
        yield db
    finally:
        db.close()


@app.post("/users/", response_model=schemas.User)
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
    return crud.create_user(db=db, user=user)


@app.get("/users/", response_model=List[schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
    users = crud.get_users(db, skip=skip, limit=limit)
    return users


@app.get("/users/{user_id}", response_model=schemas.User)
def read_user(user_id: int, db: Session = Depends(get_db)):
    db_user = crud.get_user(db, user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user

You need to override your get_db dependency in your tests, see these docs .您需要在测试中覆盖get_db依赖项,请参阅这些文档

Something like this for your fixture:对于您的固定装置是这样的:

@fixture
def db_fixture() -> Session:
    raise NotImplementError()  # Make this return your temporary session

@fixture
def client(db_fixture) -> TestClient:

    def _get_db_override():
        return db_fixture

    app.dependency_overrides[get_db] = _get_db_override
    return TestClient(app)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM