简体   繁体   中英

Python scikit-learn: exporting trained classifier

I am using a DBN (deep belief network) from nolearn based on scikit-learn.

I have already built a Network which can classify my data very well, now I am interested in exporting the model for deployment, but I don't know how (I am training the DBN every time I want to predict something). In matlab I would just export the weight matrix and import it in another machine.

Does someone know how to export the model/the weight matrix to be imported without needing to train the whole model again?

You can use:

>>> from sklearn.externals import joblib
>>> joblib.dump(clf, 'my_model.pkl', compress=9)

And then later, on the prediction server:

>>> from sklearn.externals import joblib
>>> model_clone = joblib.load('my_model.pkl')

This is basically a Python pickle with an optimized handling for large numpy arrays. It has the same limitations as the regular pickle wrt code change: if the class structure of the pickle object changes you might no longer be able to unpickle the object with new versions of nolearn or scikit-learn.

If you want long-term robust way of storing your model parameters you might need to write your own IO layer (eg using binary format serialization tools such as protocol buffers or avro or an inefficient yet portable text / json / xml representation such as PMML ).

Pickling/unpickling has the disadvantage that it only works with matching python versions (major and possibly also minor versions) and sklearn, joblib library versions.

There are alternative descriptive output formats for machine learning models, such as developed by the Data Mining Group , such as the predictive models markup language (PMML) and the portable format for analytics (PFA). Of the two, PMML is much better supported .

So you have the option of saving a model from scikit-learn into PMML (for example using sklearn2pmml ), and then deploy and run it in java, spark, or hive using jpmml (of course you have more choices).

The section 3.4. Model persistence in scikit-learn documentation covers pretty much everything.

In addition to sklearn.externals.joblib ogrisel pointed to, it shows how to use the regular pickle package:

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
  kernel='rbf', max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0])
array([0])
>>> y[0]
0

and gives a few warnings such as models saved in one version of scikit-learn might not load in another version.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM