简体   繁体   中英

Predicted values of each fold in K-Fold Cross Validation in sklearn

I have performed 10-fold cross validation on a dataset that I have using python sklearn,

result = cross_val_score(best_svr, X, y, cv=10, scoring='r2')
print(result.mean())

I have been able to get the mean value of the r2 score as the final result. I want to know if there is a way to print out the predicted values for each fold( in this case 10 sets of values).

我相信您正在寻找cross_val_predict函数。

A late answer, just to add to @jh314, cross_val_predict does return all the predictions, but we do not know which fold each prediction belongs to. To do that, we need to provide the folds, instead of an integer:

import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict, StratifiedKFold 

iris = sns.load_dataset('iris')
X=iris.iloc[:,:4]
y=(iris['species'] == "versicolor").astype('int')

rfc = RandomForestClassifier()
skf = StratifiedKFold(n_splits=10,random_state=111,shuffle=True)

pred = cross_val_predict(rfc, X, y, cv=skf)

And now we iterate through the Kfold object and pull out the predictions corresponding to each fold:

fold_pred = [pred[j] for i, j in skf.split(X,y)]
fold_pred

[array([0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0]),
 array([0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0]),
 array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1]),
 array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]),
 array([0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0]),
 array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0]),
 array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]),
 array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]),
 array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0]),
 array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0])]

To print the predictions for each fold,

for k in range(2,10):
    result = cross_val_score(best_svr, X, y, cv=k, scoring='r2')
    print(k, result.mean())
    y_pred = cross_val_predict(best_svr, X, y, cv=k)
    print(y_pred)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM