简体   繁体   中英

GridSearch example from SCIKIT learn user guide tried giving error

Was trying to run the same code as per the SCIKIT user guide of Grid search but giving error.Quite surprised.

from sklearn.model_selection import GridSearchCV
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import cross_val_score
from sklearn.datasets import load_iris
X,y=make_moons()
calibrated_forest=CalibratedClassifierCV(base_estimator=RandomForestClassifier(n_estimators=10))
paramgrid={'base_estimator_max_depth':[2,4,6,8]}
search=GridSearchCV(calibrated_forest,paramgrid,cv=5)
search.fit(X,y)

Error message as below:

ValueError: Invalid parameter base_estimator_max_depth for estimator CalibratedClassifierCV(base_estimator=RandomForestClassifier(n_estimators=10)). Check the list of available parameters with `estimator.get_params().keys()`.

I tried with Iris data set which also gave the same error as above.

Then i used the make_moon dataset X,y and run the Random classifier as below.

clf = RandomForestClassifier(n_estimators=10, max_depth=2)
cross_val_score(clf, X, y, cv=5)

Got the output as below.

array([0.8 , 0.8 , 0.9 , 0.95, 0.95])

Looking strange and not sure what is happening and where iam wrong. Request help please.

Note the double score __ between base_estimator and a param:

from sklearn.model_selection import GridSearchCV
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import cross_val_score
from sklearn.datasets import load_iris
X,y=make_moons()
calibrated_forest=CalibratedClassifierCV(base_estimator=RandomForestClassifier(n_estimators=10))
paramgrid={'base_estimator__max_depth':[2,4,6,8]}
search=GridSearchCV(calibrated_forest,paramgrid,cv=5)
search.fit(X,y)
GridSearchCV(cv=5,
             estimator=CalibratedClassifierCV(base_estimator=RandomForestClassifier(n_estimators=10)),
             param_grid={'base_estimator__max_depth': [2, 4, 6, 8]})

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM