简体   繁体   English

通过 cross_val 保存 sklearn 分类器拟合

[英]Save sklearn classifier fit by cross_val

I've got a classifier I'm fitting using a cross_val and getting good results.我有一个分类器,我可以使用 cross_val 进行拟合并获得良好的结果。 Essentially all I'm doing is:基本上我所做的就是:

clf = RandomForestClassifier(class_weight="balanced")
scores = cross_val_score(clf, data, target, cv=8)
predict_RF = cross_val_predict(clf, data, target, cv=8)

from sklearn.externals import joblib
joblib.dump(clf, 'churnModel.pkl')

Essentially what I want to do is take the model that's getting fit by cross_val and export to joblib.基本上我想做的是采用 cross_val 拟合的模型并导出到 joblib。 However when I try to pull it in in a separate project I get:但是,当我尝试将它放入一个单独的项目中时,我得到:

sklearn.exceptions.NotFittedError: Estimator not fitted, call `fit` before exploiting the model.

So I'm guessing cross_val is not actually saving the fit to my clf?所以我猜 cross_val 实际上并没有将拟合保存到我的 clf 中? How do I persist the model fit that cross_val is generating?如何保持 cross_val 生成的模型拟合?

juanpa.arrivillaga is right. juanpa.arrivillaga 是对的。 I am afraid you would have to do it manually, but scikit-learn makes it quite easy.恐怕您必须手动完成,但是 scikit-learn 使它变得非常容易。 The cross_val_score create trained models that are not returned to you. cross_val_score 创建不会返回给您的经过训练的模型。 Below you would have the trained models in a list (ie clf_models)下面你将有一个列表中的训练模型(即 clf_models)

from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from copy import deepcopy

kf = StratifiedKFold(n_splits=8)
clf = RandomForestClassifier(class_weight="balanced")
clf_models = []

# keep in mind your X and y should be indexed same here
kf.get_n_splits(X_data)
for train_index, test_index in kf.split(X_data, y_data):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X_data[train_index], X_data[test_index]
    y_train, y_test = y_data[train_index], y_data[test_index]
    tmp_clf = deepcopy(clf)
    tmp_clf.fit(X_train, y_train)

    print("Got a score of {}".format(tmp_clf.score(X_test, y_test)))
    clf_models.append(tmp_clf)

-edit via juanpa.arrivillaga's advice StratifiedKFold is a better choice. - 通过 juanpa.arrivillaga 的建议编辑 StratifiedKFold 是更好的选择。 Here I selected just for demonstration purposes.这里我选择只是为了演示目的。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM