简体   繁体   English

Python scikit-learn:导出训练有素的分类器

[英]Python scikit-learn: exporting trained classifier

I am using a DBN (deep belief network) from nolearn based on scikit-learn. 我正在使用基于scikit-learn的nolearn的DBN(深度信念网络)。

I have already built a Network which can classify my data very well, now I am interested in exporting the model for deployment, but I don't know how (I am training the DBN every time I want to predict something). 我已经构建了一个可以很好地对我的数据进行分类的网络,现在我有兴趣导出模型进行部署,但我不知道(我每次想要预测某些东西时都在训练DBN)。 In matlab I would just export the weight matrix and import it in another machine. matlab我只需导出权重矩阵并将其导入另一台机器。

Does someone know how to export the model/the weight matrix to be imported without needing to train the whole model again? 有人知道如何导出要导入的模型/权重矩阵而无需再次训练整个模型吗?

You can use: 您可以使用:

>>> from sklearn.externals import joblib
>>> joblib.dump(clf, 'my_model.pkl', compress=9)

And then later, on the prediction server: 然后,在预测服务器上:

>>> from sklearn.externals import joblib
>>> model_clone = joblib.load('my_model.pkl')

This is basically a Python pickle with an optimized handling for large numpy arrays. 这基本上是一个Python pickle,具有针对大型numpy数组的优化处理。 It has the same limitations as the regular pickle wrt code change: if the class structure of the pickle object changes you might no longer be able to unpickle the object with new versions of nolearn or scikit-learn. 它与常规pickle wrt代码更改具有相同的限制:如果pickle对象的类结构发生更改,则可能无法再使用nolearn或scikit-learn的新版本来解除对象的攻击。

If you want long-term robust way of storing your model parameters you might need to write your own IO layer (eg using binary format serialization tools such as protocol buffers or avro or an inefficient yet portable text / json / xml representation such as PMML ). 如果您想要长期稳健地存储模型参数,则可能需要编写自己的IO层(例如,使用二进制格式序列化工具,如协议缓冲区或avro或低效但可移植的文本/ json / xml表示,如PMML ) 。

Pickling/unpickling has the disadvantage that it only works with matching python versions (major and possibly also minor versions) and sklearn, joblib library versions. pickling / unpickling的缺点是它只适用于匹配的python版本(主要版本,也可能是次要版本)和sklearn,joblib库版本。

There are alternative descriptive output formats for machine learning models, such as developed by the Data Mining Group , such as the predictive models markup language (PMML) and the portable format for analytics (PFA). 机器学习模型还有其他描述性输出格式,例如由数据挖掘组开发的格式,例如预测模型标记语言(PMML)和可移植分析格式(PFA)。 Of the two, PMML is much better supported . 在这两者中,PMML得到了更好的支持

So you have the option of saving a model from scikit-learn into PMML (for example using sklearn2pmml ), and then deploy and run it in java, spark, or hive using jpmml (of course you have more choices). 所以,你必须保存从模型的选择scikit学习到PMML(例如使用sklearn2pmml ),然后部署,并在Java中,火花,或蜂房使用运行jpmml (当然你有更多的选择)。

The section 3.4. 3.4 Model persistence in scikit-learn documentation covers pretty much everything. scikit-learn文档中的模型持久性几乎涵盖了所有内容。

In addition to sklearn.externals.joblib ogrisel pointed to, it shows how to use the regular pickle package: 除了sklearn.externals.joblib ogrisel指出,它还显示了如何使用常规的pickle包:

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
  kernel='rbf', max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0])
array([0])
>>> y[0]
0

and gives a few warnings such as models saved in one version of scikit-learn might not load in another version. 并提供一些警告,例如在一个版本的scikit-learn中保存的模型可能无法加载到另一个版本中。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM