简体   繁体   中英

How to import MLflow tracking server WSGI application via Flask or FastAPI?

MLflow provides a very cool tracking server, however, this server does not provide authentication or RBAC which is required for my needs.

I would like to add my own authentication and RBAC functionality. I think one way to accomplish this is to import the MLflow WSGI application object and add some middleware layers to perform authentication / authorization before passing requests through to the tracking server, essentially proxying requests through my custom middleware stack.

How do I go about doing this? I can see from these docs that I can use FastAPI to import another WSGI application and add custom middleware, but I'm not sure of a few things

  1. Where do I find the MLflow tracking server WSGI app (where can it be imported from)?
  2. How do I pass through the relevant arguments to the MLflow tracking server? Ie the tracking server expects params to configure the backend storage layer, host, and port. If I just import the application object, how do I pass those parameters to it?

edit - it looks like the Flask application can be found here https://github.com/mlflow/mlflow/blob/master/mlflow/server/__init__.py#L28

This was actually very simple, below is an example using FastAPI to import and mount the MLflow WSGI application.

import os
import subprocess
from fastapi import FastAPI
from fastapi.middleware.wsgi import WSGIMiddleware

from mlflow.server import app as mlflow_app

app = FastAPI()
app.mount("/", WSGIMiddleware(mlflow_app))

BACKEND_STORE_URI_ENV_VAR = "_MLFLOW_SERVER_FILE_STORE"
ARTIFACT_ROOT_ENV_VAR = "_MLFLOW_SERVER_ARTIFACT_ROOT"
ARTIFACTS_DESTINATION_ENV_VAR = "_MLFLOW_SERVER_ARTIFACT_DESTINATION"
PROMETHEUS_EXPORTER_ENV_VAR = "prometheus_multiproc_dir"
SERVE_ARTIFACTS_ENV_VAR = "_MLFLOW_SERVER_SERVE_ARTIFACTS"
ARTIFACTS_ONLY_ENV_VAR = "_MLFLOW_SERVER_ARTIFACTS_ONLY"

def parse_args():
    a = argparse.ArgumentParser()
    a.add_argument("--host", type=str, default="0.0.0.0")
    a.add_argument("--port", type=str, default="5000")
    a.add_argument("--backend-store-uri", type=str, default="sqlite:///mlflow.db")
    a.add_argument("--serve-artifacts", action="store_true", default=False)
    a.add_argument("--artifacts-destination", type=str)
    a.add_argument("--default-artifact-root", type=str)
    a.add_argument("--gunicorn-opts", type=str, default="")
    a.add_argument("--n-workers", type=str, default=1)
    return a.parse_args()

def run_command(cmd, env, cwd=None):
    cmd_env = os.environ.copy()
    if cmd_env:
        cmd_env.update(env)
    child = subprocess.Popen(
        cmd, env=cmd_env, cwd=cwd, text=True, stdin=subprocess.PIPE
    )
    child.communicate()
    exit_code = child.wait()
    if exit_code != 0:
        raise Exception("Non-zero exitcode: %s" % (exit_code))
    return exit_code

def run_server(args):
    env_map = dict()
    if args.backend_store_uri:
        env_map[BACKEND_STORE_URI_ENV_VAR] = args.backend_store_uri
    if args.serve_artifacts:
        env_map[SERVE_ARTIFACTS_ENV_VAR] = "true"
    if args.artifacts_destination:
        env_map[ARTIFACTS_DESTINATION_ENV_VAR] = args.artifacts_destination
    if args.default_artifact_root:
        env_map[ARTIFACT_ROOT_ENV_VAR] = args.default_artifact_root

    print(f"Envmap: {env_map}")

    #opts = args.gunicorn_opts.split(" ") if args.gunicorn_opts else []
    opts = args.gunicorn_opts if args.gunicorn_opts else ""

    cmd = [
        "gunicorn", "-b", f"{args.host}:{args.port}", "-w", f"{args.n_workers}", "-k", "uvicorn.workers.UvicornWorker", "server:app"
    ]
    run_command(cmd, env_map)

def main():
    args = parse_args()
    run_server(args)

if __name__ == "__main__":
    main()

Run like

python server.py --artifacts-destination s3://mlflow-mr --default-artifact-root s3://mlflow-mr --serve-artifacts

Then navigate to your browser and see the tracking server running! This allows you to insert custom FastAPI middleware in front of the tracking server

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM