简体   繁体   中英

Nested cross-validation: How does cross_validate handle GridSearchCV as its input estimator?

The following code combines cross_validate with GridSearchCV to perform a nested cross-validation for an SVC on the iris dataset.

(Modified example of the following documentation page: https://scikit-learn.org/stable/auto_examples/model_selection/plot_nested_cross_validation_iris.html#sphx-glr-auto-examples-model-selection-plot-nested-cross-validation-iris-py .)


from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, cross_validate, KFold
import numpy as np
np.set_printoptions(precision=2)

# Load the dataset
iris = load_iris()
X_iris = iris.data
y_iris = iris.target

# Set up possible values of parameters to optimize over
p_grid = {"C": [1, 10],
          "gamma": [.01, .1]}

# We will use a Support Vector Classifier with "rbf" kernel
svm = SVC(kernel="rbf")

# Choose techniques for the inner and outer loop of nested cross-validation
inner_cv = KFold(n_splits=5, shuffle=True, random_state=1)
outer_cv = KFold(n_splits=4, shuffle=True, random_state=1)

# Perform nested cross-validation
clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv, iid=False)
clf.fit(X_iris, y_iris)
best_estimator = clf.best_estimator_

cv_dic = cross_validate(clf, X_iris, y_iris, cv=outer_cv, scoring=['accuracy'], return_estimator=False, return_train_score=True)
mean_val_score = cv_dic['test_accuracy'].mean()

print('nested_train_scores: ', cv_dic['train_accuracy'])
print('nested_val_scores:   ', cv_dic['test_accuracy'])
print('mean score:            {0:.2f}'.format(mean_val_score))

cross_validate splits the data set in each fold into a training and a test set. In each fold, the input estimator is then trained based on the training set associated with the fold. The inputted estimator here is clf , a parameterized GridSearchCV estimator, ie an estimator that cross-validates itself again.

I have three questions about the whole thing:

  1. If clf is used as the estimator for cross_validate , does it (in the course of the GridSearchCV cross validation) split the above mentioned training set into a subtraining set and a validation set in order to determine the best hyper parameter combination?
  2. Out of all models tested via GridSearchCV , does cross_validate validate only the model stored in the best_estimator_ attribute?
  3. Does cross_validate train a model at all (if so, why?) or is the model stored in best_estimator_ validated directly via the test set?

To make it clearer how the questions are meant, here is an illustration of how I imagine the double cross validation at the moment.

在此处输入图片说明

If clf is used as the estimator for cross_validate , does it split the above mentioned training set into a subtraining set and a validation set in order to determine the best hyper parameter combination?

Yes as you can see here at Line 230 the training set is again split into a subtraining and validation set (Specifically at line 240 ).

Update Yes, when you will pass the GridSearchCV classifier into cross-validate it will again split the training set into a test and train set. Here is a link describing this in more detail. Your diagram and assumption is correct.

Out of all models tested via GridSearchCV, does cross_validate train & validate only the model stored in the variable best_estimator?

Yes, as you can see from the answers here and here , the GridSearchCV returns the best_estimator in your case(since refit parameter is True by default in your case.) However, this best estimator will has to be trained again

Does cross_validate train a model at all (if so, why?) or is the model stored in best_estimator_ validated directly via the test set?

As per your third and final question, Yes, it trains an estimator and returns it if return_estimator is set to True . See this line . Which makes sense, since how else is it supposed to return the scores without training an estimator in the first place ?

Update The reason the model is trained again is because the default use case for cross-validate does not assume that you give in the best classfier with the optimum parameters. In this case specifically, you are sending in a classifier from the GridSearchCV but if you send any untrained classifier it is supposed to be trained. What I mean to say here is that, yes, in your case it shouldn't train it again since you are already doing cross-validation using GridSearchCV and using the best estimator. However, there is no way for cross-validate to know this, hence, it assumes that you are sending in an un-optimized or rather untrained estimator, thus it has to train it again and return the scores for the same.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM