繁体   English   中英

如何在 FastApi 中将 BackgroundTasks 的结果返回为 websocket 答案?

[英]How to return the result of BackgroundTasks as websocket answer in FastApi?

我有下一个代码:

from fastapi import FastAPI, WebSocket, BackgroundTasks
import uvicorn
import time

app = FastAPI()


def run_model():
    ...
    ## code of the model
    answer = [1, 2, 3]
    ...
    results = {"message": "the model has been excuted succesfully!!", "results": answer}
    return results


@app.post("/execute-model")
async def ping(background_tasks: BackgroundTasks):
    background_tasks.add_task(run_model)
    return {"message": "the model is executing"}


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    while True:
        ## Here I wnat the results of run_model
        await websocket.send_text(1)

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8001)

我需要对 /execute-model 进行后期处理。 此端点将执行 run_model function 作为后台任务。 当 run_model() 完成时,我需要将答案返回到前面,我想使用 websockets,但我不知道该怎么做。 请帮忙。

我有类似的东西。 这是我的做法(并不是说这是最好的,甚至不是一个好的解决方案,但到目前为止它一直有效):

路由端点:

# client makes a post request, gets saved model immeditely, while a background task is started to process the image
@app.post("/analyse", response_model=schemas.ImageAnalysis , tags=["Image Analysis"])
async def create_image_analysis( 
    img: schemas.ImageAnalysisCreate, 
    background_tasks: BackgroundTasks, 
    db: Session = Depends(get_db),
):
    saved = crud.create_analysis(db=db, img=img)
    background_tasks.add_task(analyse_image,db=db, img=img)

    #model includes a ws_token (some random string) that the client can connect to right away
    return saved

websocket 端点:

@app.websocket("/ws/{ws_token}")
async def websocket_endpoint(websocket: WebSocket, ws_token: str):
    #add the websocket to the connections dict (by ws_token)
    await socket_connections.connect(websocket,ws_token=ws_token)
    try:
        while True:
            print(socket_connections)
            await websocket.receive_text() #not really necessary
            
    except WebSocketDisconnect:
        socket_connections.disconnect(websocket,ws_token=ws_token)

analyse_image function:

#notice - the function is not async, as it does not work with background tasks otherwise!!
def analyse_image (db: Session, img: ImageAnalysis):

    print('analyse_image started')
    for index, round in enumerate(img.rounds):
        
        # some heavy workload etc

        # send update to user
        socket_connections.send_message({
                "status":EstimationStatus.RUNNING,
                "current_step":index+1,
                "total_steps":len(img.rounds)
            }, ws_token=img.ws_token)

    print("analysis finished")

连接管理器:

import asyncio
from typing import Dict, List
from fastapi import  WebSocket

#notice: active_connections is changed to a dict (key= ws_token), so we know which user listens to which model
class ConnectionManager:
    
    def __init__(self):
        self.active_connections: Dict[str, List[WebSocket]] = {}

    async def connect(self, websocket: WebSocket, ws_token: str):
        await websocket.accept()
        if ws_token in self.active_connections:
             self.active_connections.get(ws_token).append(websocket)
        else:
            self.active_connections.update({ws_token: [websocket]})


    def disconnect(self, websocket: WebSocket, ws_token: str):
        self.active_connections.get(ws_token).remove(websocket)
        if(len(self.active_connections.get(ws_token))==0):
            self.active_connections.pop(ws_token)

    # notice: changed from async to sync as background tasks messes up with async functions
    def send_message(self, data: dict,ws_token: str):
        sockets = self.active_connections.get(ws_token)
        if sockets:
            #notice: socket send is originally async. We have to change it to syncronous code - 
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

            for socket in sockets:
                socket.send_text
                loop.run_until_complete(socket.send_json(data))


socket_connections = ConnectionManager()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM