MLflow provides a very cool tracking server, however, this server does not provide authentication or RBAC which is required for my needs.
I would like to add my own authentication and RBAC functionality. I think one way to accomplish this is to import the MLflow WSGI application object and add some middleware layers to perform authentication / authorization before passing requests through to the tracking server, essentially proxying requests through my custom middleware stack.
How do I go about doing this? I can see from these docs that I can use FastAPI to import another WSGI application and add custom middleware, but I'm not sure of a few things
edit - it looks like the Flask application can be found here https://github.com/mlflow/mlflow/blob/master/mlflow/server/__init__.py#L28
This was actually very simple, below is an example using FastAPI to import and mount the MLflow WSGI application.
import os
import subprocess
from fastapi import FastAPI
from fastapi.middleware.wsgi import WSGIMiddleware
from mlflow.server import app as mlflow_app
app = FastAPI()
app.mount("/", WSGIMiddleware(mlflow_app))
BACKEND_STORE_URI_ENV_VAR = "_MLFLOW_SERVER_FILE_STORE"
ARTIFACT_ROOT_ENV_VAR = "_MLFLOW_SERVER_ARTIFACT_ROOT"
ARTIFACTS_DESTINATION_ENV_VAR = "_MLFLOW_SERVER_ARTIFACT_DESTINATION"
PROMETHEUS_EXPORTER_ENV_VAR = "prometheus_multiproc_dir"
SERVE_ARTIFACTS_ENV_VAR = "_MLFLOW_SERVER_SERVE_ARTIFACTS"
ARTIFACTS_ONLY_ENV_VAR = "_MLFLOW_SERVER_ARTIFACTS_ONLY"
def parse_args():
a = argparse.ArgumentParser()
a.add_argument("--host", type=str, default="0.0.0.0")
a.add_argument("--port", type=str, default="5000")
a.add_argument("--backend-store-uri", type=str, default="sqlite:///mlflow.db")
a.add_argument("--serve-artifacts", action="store_true", default=False)
a.add_argument("--artifacts-destination", type=str)
a.add_argument("--default-artifact-root", type=str)
a.add_argument("--gunicorn-opts", type=str, default="")
a.add_argument("--n-workers", type=str, default=1)
return a.parse_args()
def run_command(cmd, env, cwd=None):
cmd_env = os.environ.copy()
if cmd_env:
cmd_env.update(env)
child = subprocess.Popen(
cmd, env=cmd_env, cwd=cwd, text=True, stdin=subprocess.PIPE
)
child.communicate()
exit_code = child.wait()
if exit_code != 0:
raise Exception("Non-zero exitcode: %s" % (exit_code))
return exit_code
def run_server(args):
env_map = dict()
if args.backend_store_uri:
env_map[BACKEND_STORE_URI_ENV_VAR] = args.backend_store_uri
if args.serve_artifacts:
env_map[SERVE_ARTIFACTS_ENV_VAR] = "true"
if args.artifacts_destination:
env_map[ARTIFACTS_DESTINATION_ENV_VAR] = args.artifacts_destination
if args.default_artifact_root:
env_map[ARTIFACT_ROOT_ENV_VAR] = args.default_artifact_root
print(f"Envmap: {env_map}")
#opts = args.gunicorn_opts.split(" ") if args.gunicorn_opts else []
opts = args.gunicorn_opts if args.gunicorn_opts else ""
cmd = [
"gunicorn", "-b", f"{args.host}:{args.port}", "-w", f"{args.n_workers}", "-k", "uvicorn.workers.UvicornWorker", "server:app"
]
run_command(cmd, env_map)
def main():
args = parse_args()
run_server(args)
if __name__ == "__main__":
main()
Run like
python server.py --artifacts-destination s3://mlflow-mr --default-artifact-root s3://mlflow-mr --serve-artifacts
Then navigate to your browser and see the tracking server running! This allows you to insert custom FastAPI middleware in front of the tracking server
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.