[英]How to use Arrow type in FastAPI response schema?
我想在FastAPI
響應中使用Arrow
類型,因為我已經在SQLAlchemy
模型中使用它(感謝sqlalchemy_utils
)。
我准備了一個帶有最小 FastAPI 應用程序的小型獨立示例。 我希望這個應用程序從數據庫中返回product1
數據。
不幸的是,下面的代碼給出了異常:
Exception has occurred: FastAPIError
Invalid args for response field! Hint: check that <class 'arrow.arrow.Arrow'> is a valid pydantic field type
import sqlalchemy
import uvicorn
from arrow import Arrow
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import ArrowType
app = FastAPI()
engine = sqlalchemy.create_engine('sqlite:///db.db')
Base = declarative_base()
class Product(Base):
__tablename__ = "product"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Text, nullable=True)
created_at = Column(ArrowType(timezone=True), nullable=False, server_default=func.now())
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
product1 = Product(name="ice cream")
product2 = Product(name="donut")
product3 = Product(name="apple pie")
session.add_all([product1, product2, product3])
session.commit()
class ProductResponse(BaseModel):
id: int
name: str
created_at: Arrow
class Config:
orm_mode = True
arbitrary_types_allowed = True
@app.get('/', response_model=ProductResponse)
async def return_product():
product = session.query(Product).filter(Product.id == 1).first()
return product
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)
要求.txt:
sqlalchemy==1.4.23
sqlalchemy_utils==0.37.8
arrow==1.1.1
fastapi==0.68.1
uvicorn==0.15.0
那些 FastAPI 問題中已經討論了這個錯誤:
一種可能的解決方法是添加此代碼( source ):
from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True
將它放在@app.get('/'...
之上就足夠了,但它甚至可以放在app = FastAPI()
之前
此解決方案的問題是 GET 端點的輸出將是:
// 20210826001330
// http://localhost:8000/
{
"id": 1,
"name": "ice cream",
"created_at": {
"_datetime": "2021-08-25T21:38:01+00:00"
}
}
而不是想要的:
// 20210826001330
// http://localhost:8000/
{
"id": 1,
"name": "ice cream",
"created_at": "2021-08-25T21:38:01+00:00"
}
添加一個帶有@validator
裝飾器的自定義函數,該函數返回對象的所需_datetime
:
class ProductResponse(BaseModel):
id: int
name: str
created_at: Arrow
class Config:
orm_mode = True
arbitrary_types_allowed = True
@validator("created_at")
def format_datetime(cls, value):
return value._datetime
在本地測試,似乎工作:
$ curl -s localhost:8000 | jq
{
"id": 1,
"name": "ice cream",
"created_at": "2021-12-02T08:25:10+00:00"
}
解決方案是monkeypatch pydantic 的ENCODERS_BY_TYPE
以便它知道如何轉換 Arrow 對象以便它可以被 json 格式接受:
from arrow import Arrow
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {Arrow: str}
設置BaseConfig.arbitrary_types_allowed = True
也是必要的。
結果:
// 20220514022717
// http://localhost:8000/
{
"id": 1,
"name": "ice cream",
"created_at": "2022-05-14T00:20:11+00:00"
}
完整代碼:
import sqlalchemy
import uvicorn
from arrow import Arrow
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import ArrowType
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {Arrow: str}
from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True
app = FastAPI()
engine = sqlalchemy.create_engine('sqlite:///db.db')
Base = declarative_base()
class Product(Base):
__tablename__ = "product"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Text, nullable=True)
created_at = Column(ArrowType(timezone=True), nullable=False, server_default=func.now())
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
product1 = Product(name="ice cream")
product2 = Product(name="donut")
product3 = Product(name="apple pie")
session.add_all([product1, product2, product3])
session.commit()
class ProductResponse(BaseModel):
id: int
name: str
created_at: Arrow
class Config:
orm_mode = True
arbitrary_types_allowed = True
@app.get('/', response_model=ProductResponse)
async def return_product():
product = session.query(Product).filter(Product.id == 1).first()
return product
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)
這是一個代碼示例,您不需要class Config
並且可以通過使用驗證器創建自己的子類來處理任何類型:
from psycopg2.extras import DateTimeTZRange as DateTimeTZRangeBase
from sqlalchemy.dialects.postgresql import TSTZRANGE
from sqlmodel import (
Column,
Field,
Identity,
SQLModel,
)
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {DateTimeTZRangeBase: str}
class DateTimeTZRange(DateTimeTZRangeBase):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if isinstance(v, str):
lower = v.split(", ")[0][1:].strip().strip()
upper = v.split(", ")[1][:-1].strip().strip()
bounds = v[:1] + v[-1:]
return DateTimeTZRange(lower, upper, bounds)
elif isinstance(v, DateTimeTZRangeBase):
return v
raise TypeError("Type must be string or DateTimeTZRange")
@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string", example="[2022,01,01, 2022,02,02)")
class EventBase(SQLModel):
__tablename__ = "event"
timestamp_range: DateTimeTZRange = Field(
sa_column=Column(
TSTZRANGE(),
nullable=False,
),
)
class Event(EventBase, table=True):
id: int | None = Field(
default=None,
sa_column_args=(Identity(always=True),),
primary_key=True,
nullable=False,
)
鏈接到 Github 問題: https ://github.com/tiangolo/sqlmodel/issues/235#issuecomment-1162063590
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.