繁体   English   中英

cross_validation.StratifiedKFold 弃用后使用 StratifiedKFold 进行分层交叉验证

[英]Stratified Cross Validation with StratifiedKFold after cross_validation.StratifiedKFold deprecation

我正在关注 3 年前的一些示例脚本,并遇到了使用不推荐使用的函数 (cross_validation.StratifiedKFold) 的函数定义。

这是 3 年前的原始代码片段:

def stratified_cv(X, y, clf_class, shuffle=True, n_folds=10, **kwargs):
    stratified_k_fold = cross_validation.StratifiedKFold(y, n_folds=n_folds, shuffle=shuffle)
    y_pred = y.copy()
    # ii -> train
    # jj -> test indices
    for ii, jj in stratified_k_fold: 
        X_train, X_test = X[ii], X[jj]
        y_train = y[ii]
        clf = clf_class(**kwargs)
        clf.fit(X_train,y_train)
        y_pred[jj] = clf.predict(X_test)
    return y_pred

我已经尝试通过遵循 sklearn.model_selection.StratifiedKFold ( https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html ) 上的一些文档来更新它,这就是我到目前为止所拥有的:

## Attempt to modernize with StratifiedKFold from sklearn.model_selection
def stratified_cv(X, y, clf_class, shuffle=True, n_splits=10, **kwargs):
    stratified_k_fold = StratifiedKFold(n_splits=n_splits)
    y_pred = y.copy()
    # ii -> train
    # jj -> test indices
    for ii, jj in stratified_k_fold: 
        X_train, X_test = X[ii], X[jj]
        y_train = y[ii]
        clf = clf_class(**kwargs)
        clf.fit(X_train,y_train)
        y_pred[jj] = clf.predict(X_test)
    return y_pred

然后我尝试运行以下块并遇到后续错误:

print('Gradient Boosting Classifier:  {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, ensemble.GradientBoostingClassifier))))
print('Support vector machine(SVM):   {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, svm.SVC))))
print('Random Forest Classifier:      {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, ensemble.RandomForestClassifier))))
print('K Nearest Neighbor Classifier: {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, neighbors.KNeighborsClassifier))))
print('Logistic Regression:           {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, linear_model.LogisticRegression))))

错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-122-a61be22f8ca9> in <module>
----> 1 print('Gradient Boosting Classifier:  {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, ensemble.GradientBoostingClassifier))))
      2 print('Support vector machine(SVM):   {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, svm.SVC))))
      3 print('Random Forest Classifier:      {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, ensemble.RandomForestClassifier))))
      4 print('K Nearest Neighbor Classifier: {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, neighbors.KNeighborsClassifier))))
      5 print('Logistic Regression:           {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, linear_model.LogisticRegression))))

<ipython-input-121-e373d74b2cca> in stratified_cv(X, y, clf_class, shuffle, n_splits, **kwargs)
      5     # ii -> train
      6     # jj -> test indices
----> 7     for ii, jj in stratified_k_fold:
      8         X_train, X_test = X[ii], X[jj]
      9         y_train = y[ii]

TypeError: 'StratifiedKFold' object is not iterable

您需要使用 StratifiedKFold 来拆分您的数据,而无需过多更改您的代码,下面应该可以工作:

from sklearn.model_selection import StratifiedKFold
from sklearn import datasets
from sklearn import metrics
from sklearn import svm

iris = datasets.load_iris()
X = iris.data
y = iris.target

def stratified_cv(X, y, clf_class, shuffle=True, n_splits=10, **kwargs):
    stratified_k_fold = StratifiedKFold(n_splits=n_splits)
    y_pred = y.copy()

    for ii,jj in stratified_k_fold.split(X, y):
            
        y_train = y[ii]
        X_train, X_test = X[ii], X[jj]
        clf = clf_class(**kwargs)
        clf.fit(X_train,y_train)
        y_pred[jj] = clf.predict(X_test)
    return y_pred
    
print('Gradient Boosting Classifier:  {:.2f}'.format(metrics.accuracy_score(y, stratified_cv(X, y, svm.SVC))))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM