简体   繁体   中英

How to convert Scikit Learn OneVsRestClassifier predict method output to dense array for google cloud ML?

I have a model that I've trained using a sklearn Pipeline and the OneVsRestClassifier that I'm trying to deploy to Cloud ML Engine, however when I use the command:

gcloud ml-engine predict --model $MODEL_NAME --version $VERSION_NAME --json-instances $INPUT_FILE

I receive the error:

{ "error": "Prediction failed: Bad output type returned.The predict function should return either a numpy ndarray or a list." }

This leads me to believe it is the fact that the OneVsRestClassifier's predict method output is a sparse matrix, when it should be a numpy array. How can I convert it's output to a dense array in my Pipeline?

The pipeline's architecture looks like this:

Pipeline([('tfidf', tfidf), ('clf', OneVsRestClassifier(XGBClassifier())])

Thanks!

I've tried using the methods here ( Google Cloud ML-engine scikit-learn prediction probability 'predict_proba()' ) to overwrite the OneVsRestClassifier's predict method with it's predict_proba method, however this results in the following error when I try and pickle the new pipeline:

PicklingError: Can't pickle <function OneVsRestClassifier.predict_proba at 0x10a8f9d08>: it's not the same object as sklearn.multiclass.OneVsRestClassifier.predict_proba

AI Platform (formerly known as Cloud Machine Learning Engine) serves your model and expects the input and the output to be json-serializable. If your model returns a sparse matrix, then you need to convert it to a dense matrix (see this for more information).

If you choose to overwrite predict_proba , then you are deploying your model with some custom code (your code that overwrites the function). You will then need to package up your custom code and pass it alongside your model when you deploy your model. For more information on how to deploy models with custom code, please visit Custom prediction routines on AI Platform.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM