简体   繁体   中英

Getting an error with random forest model using sklearn

I ran the following code to fit a random forest model. I used a Kaggle data set:

Data link: https://www.kaggle.com/arnavr10880/winedataset-eda-ml/data?select=WineQT.csv

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold,cross_val_score,GridSearchCV
from sklearn import linear_model
from sklearn.ensemble import  RandomForestRegressor
import numpy as np


data= pd.read_csv("C:/Users/Downloads/Model Test Data.csv")

y=data.loc[: ,["y"]]
x=data.iloc[:,1:]

x_train, x_test,y_train, y_test = train_test_split(x,y)


rf=RandomForestRegressor()


params = {
    'n_estimators'      : [300,500],
    'max_depth'         : np.array([8,9,12]),
    'random_state'      : [0],
    
}

scoring = ["neg_mean_absolute_error","neg_mean_squared_error"]

for score in scoring:
    print("score %s" % scoring)
    clf= GridSearchCV(rf,param_grid=params,scoring="%s" %score,verbose=False)
    clf.fit(x_train,y_train)
    print("Best parameters:")
    print(clf.best_params_)
    means=clf.cv_results_["mean_test_score"]
    stds=clf.cv_results_["std_test_score"]

    for mean,sd,params in zip(means,stds, clf.cv_results_["params"]):
        print("%0.3f (+/-%0.3f) for %r" %(mean,2*sd,params) )

However, I got the following error:

 Parameter grid for parameter (max_depth) needs to be a list or numpy array,
 but got (<class 'int'>). Single values need to be wrapped in a list with one element.

Could anyone help me to fix this?

When you run your example, you see that the first score in the for loop prints just fine. After that, examining the params variable shows {'max_depth': 12, 'n_estimators': 500, 'random_state': 0} so you've accidentally overwritten the params space with a specific parameter combination.

Looking again at your code, it's in the print at the end of the loop:

    for mean,sd,***params*** in zip(means,stds, clf.cv_results_["params"]):
        print("%0.3f (+/-%0.3f) for %r" %(mean,2*sd,params) )

so just use a different variable here.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM