繁体   English   中英

scikit-learn模型持久性:pickle vs pmml vs ...?

[英]scikit-learn model persistence: pickle vs pmml vs …?

我构建了一个scikit-learn模型,我想在每日python cron作业中重用( 注意 :没有涉及其他平台 - 没有R,没有Java和c)。

腌制它(实际上,我腌制了我自己的对象,其中一个字段是GradientBoostingClassifier ),我在cron作业中取消它。 到目前为止一直这么好(并且已经在Scikit-Learn中将保存分类器讨论到磁盘在Scikit-Learn中进行模型持久性讨论了吗? )。

但是,我升级了sklearn ,现在我收到了这些警告:

.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator DecisionTreeRegressor from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)
.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator PriorProbabilityEstimator from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)
.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator GradientBoostingClassifier from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)

现在我该怎么做?

  • 我可以降级到0.18.1并坚持下去,直到我准备重建模型。 由于各种原因,我觉得这是不可接受的。

  • 我可以取消文件并重新腌制它。 这与0.18.2一起工作,但以0.19打破 NFG。 joblib看起来并没有更好。

  • 我希望我能以与版本无关的ASCII格式(例如,JSON或XML)保存数据。 这显然是最佳解决方案,但似乎没有办法做到这一点(另请参阅Sklearn - 没有pkl文件的模型持久性 )。

  • 我可以保存模型PMML ,但它的支持是不冷不热,充其量:我可以用sklearn2pmml 保存模型(尽管不容易),和augustus / lightpmmlpredictor 申请 (虽然没有加载)的模型。 然而,这些都不是提供给pip直接,从而使部署的噩梦。 此外, augustuslightpmmlpredictor项目似乎已经死了。 将PMML模型导入Python(Scikit-learn) - 不。

  • 上述变体:使用sklearn2pmml保存PMML,并使用openscoring进行评分。 需要与外部进程连接。 育。

建议?

不同版本的scikit-learn的模型持久性通常是不可能的。 原因很明显:你用一个定义来挑选Class1 ,并希望用另一个定义将它解开为Class2

您可以:

  • 仍然试着坚持一个版本的sklearn。
  • 忽略警告并希望对Class1有效的方法也适用于Class2
  • 编写自己的类,可以序列化您的GradientBoostingClassifier并从这个序列化的表单中恢复它,并希望它比pickle更好。

我举了一个例子,说明如何将单个DecisionTreeRegressor转换为纯粹的list-and-dict格式,完全兼容JSON,并将其恢复。

import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_classification

### Code to serialize and deserialize trees

LEAF_ATTRIBUTES = ['children_left', 'children_right', 'threshold', 'value', 'feature', 'impurity', 'weighted_n_node_samples']
TREE_ATTRIBUTES = ['n_classes_', 'n_features_', 'n_outputs_']

def serialize_tree(tree):
    """ Convert a sklearn.tree.DecisionTreeRegressor into a json-compatible format """
    encoded = {
        'nodes': {},
        'tree': {},
        'n_leaves': len(tree.tree_.threshold),
        'params': tree.get_params()
    }
    for attr in LEAF_ATTRIBUTES:
        encoded['nodes'][attr] = getattr(tree.tree_, attr).tolist()
    for attr in TREE_ATTRIBUTES:
        encoded['tree'][attr] = getattr(tree, attr)
    return encoded

def deserialize_tree(encoded):
    """ Restore a sklearn.tree.DecisionTreeRegressor from a json-compatible format """
    x = np.arange(encoded['n_leaves'])
    tree = DecisionTreeRegressor().fit(x.reshape((-1,1)), x)
    tree.set_params(**encoded['params'])
    for attr in LEAF_ATTRIBUTES:
        for i in range(encoded['n_leaves']):
            getattr(tree.tree_, attr)[i] = encoded['nodes'][attr][i]
    for attr in TREE_ATTRIBUTES:
        setattr(tree, attr, encoded['tree'][attr])
    return tree

## test the code

X, y = make_classification(n_classes=3, n_informative=10)
tree = DecisionTreeRegressor().fit(X, y)
encoded = serialize_tree(tree)
decoded = deserialize_tree(encoded)
assert (decoded.predict(X)==tree.predict(X)).all()

有了这个,你可以继续序列化和反序列化整个GradientBoostingClassifier

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble.gradient_boosting import PriorProbabilityEstimator

def serialize_gbc(clf):
    encoded = {
        'classes_': clf.classes_.tolist(),
        'max_features_': clf.max_features_, 
        'n_classes_': clf.n_classes_,
        'n_features_': clf.n_features_,
        'train_score_': clf.train_score_.tolist(),
        'params': clf.get_params(),
        'estimators_shape': list(clf.estimators_.shape),
        'estimators': [],
        'priors':clf.init_.priors.tolist()
    }
    for tree in clf.estimators_.reshape((-1,)):
        encoded['estimators'].append(serialize_tree(tree))
    return encoded

def deserialize_gbc(encoded):
    x = np.array(encoded['classes_'])
    clf = GradientBoostingClassifier(**encoded['params']).fit(x.reshape(-1, 1), x)
    trees = [deserialize_tree(tree) for tree in encoded['estimators']]
    clf.estimators_ = np.array(trees).reshape(encoded['estimators_shape'])
    clf.init_ = PriorProbabilityEstimator()
    clf.init_.priors = np.array(encoded['priors'])
    clf.classes_ = np.array(encoded['classes_'])
    clf.train_score_ = np.array(encoded['train_score_'])
    clf.max_features_ = encoded['max_features_']
    clf.n_classes_ = encoded['n_classes_']
    clf.n_features_ = encoded['n_features_']
    return clf

# test on the same problem
clf = GradientBoostingClassifier()
clf.fit(X, y);
encoded = serialize_gbc(clf)
decoded = deserialize_gbc(encoded)
assert (decoded.predict(X) == clf.predict(X)).all()

这适用于scikit-learn v0.19,但不要问我下一个版本会破坏这段代码。 我既不是先知也不是sklearn的开发者。

如果你想完全独立于sklearn的新版本,最安全的事情是编写一个遍历序列化树并进行预测的函数,而不是重新创建一个sklearn树。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM