简体   繁体   中英

GridSearch over RegressorChain using Scikit-Learn?

I am currently working on a multi-output regression problem, where I am trying to predict multiple output values at once. I am aware there are standard regressors that natively support this task.

However, I would like to use a RegressorChain and tune the hyperparameter of the Regressor in the RegressorChain using GridSearchCV . I wrote the following code for this:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
from sklearn.model_selection import GridSearchCV

# setup the pipeline
pipeline = Pipeline(steps = [('scale', StandardScaler(with_mean=True, with_std=True)),
                             ('estimator', RegressorChain(SVR())])

# setup the parameter grid
param_grid = {'estimator__estimator__C': [0.1,1,10,100]}           

# setup the grid search
grid = GridSearchCV(pipeline, 
                    param_grid, 
                    scoring='neg_mean_squared_error')
# fit model   
grid.fit(X, y)

It tried:

param_grid = {'estimator__C': [0.1,1,10,100]}  

and:

param_grid = {'estimator__estimator__C': [0.1,1,10,100]}

But I got both times the following ValueError :

ValueError: Invalid parameter C for estimator RegressorChain(base_estimator=SVR(C=1.0, cache_size=200, coef0=0.0, degree=3, epsilon=0.1, gamma='auto_deprecated', kernel='rbf', max_iter=-1, shrinking=True, tol=0.001, verbose=False), cv=None, order=None, random_state=None). Check the list of available parameters with estimator.get_params().keys() .

Does anyone have an idea, how to setup this pipeline correctly? Thank you!

As the error message suggests, print the results of RegressorChain(SVR()).get_params() and you will get:

{
    'base_estimator__C': 1.0, 
    'base_estimator__cache_size': 200, 
    'base_estimator__coef0': 0.0, 
    'base_estimator__degree': 3,
    ...
}

Given the pipeline you defined, this means you should use

param_grid = {'estimator__base_estimator__C': [0.1, 1, 10, 100]} 

to set possible values for C of your SVR object during the iterations of the grid search.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM