简体   繁体   中英

How to use GridSearchCV , cross_val_score and a model

I need to find best hyperparams for ANN and then run prediction on the best model. I use KerasRegressor . I find conflicting examples and advices. Please help me understand the right sequence and which params to use when.

  1. I split my data into Train and Test datasets
  2. I look for the best hyperparams using GridSearchCV on Train dataset GridSearchCV.fit(X_Train, Y_Train)
  3. I take GridSearchCV.best_estimator_ and use it in cross_val_score on Test dataset, ie cross_val_score(model.best_estimator_, X_Test, Y_Test , scoring='r2')
    • I'm not sure if I need to do this step? In theory, it should show similar r2 scores as GridSearchCV did for this best_estimator_ shouldn't it ?
  4. I use model.best_estimator_.predict( X_Test, Y_Test) on Test data to predict the results. Ie I pass best_estimator_ from GridSearchCV to run actual prediction .
    • Is this correct ? *Do I need to fit again model.best_estimator_ on Train data before doing a prediction? Or does it keep all the weights found during GridSearchCV ? Do I need to save weights to be able to reuse it later ?

Usually when you use GridSearchCV on your training set, you will have an object which contains the best trained model with the best parameters.

gs = GridSearchCV.fit(X_train, y_train)

This is also evident from running gs.best_params_ which will print out the best parameters of the model after cross validation. Now, you can make predictions on your test set directly by running gs.predict(X_test, y_test) which will use the best selected model to predict on your test set.

For question 3, you don't need to use cross_val_score again, as this is an helper function that allows you to perform cross validation on your dataset and returns the score of each fold of the data split.

For question 4, I believe this answer is quite explanatory: https://stats.stackexchange.com/a/26535

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM