简体   繁体   中英

How to use a simple validation set in hyperparameter optimization of keras model with GridSearchCV?

I am trying to perform hyperparameter optimization for a large dataset. And I want to avoid using cross validation cv to speed up the optimization. That's why I want to use a validation set with a validation split = 0.2 from the training dataset.

   grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1, cv=3)
   grid_result = grid.fit(X_train, y_train)

How should I modify the GridSearchCV() arguments above to use a validation dataset with a validation_split=0.2 and ignore cross-validation to perform hyperparameter optimization?

with PredefinedSplit you can use the same validation set for hyperparam opt. -1 identify your train data while 0 identify your valid data

from sklearn.model_selection import PredefinedSplit, GridSearchCV

X_train = np.random.uniform(0,1, (10000,30))
y_train = np.random.uniform(0,1, 10000)
val_spilt = np.random.choice([-1,0], len(y_train), p=[0.8, 0.2])

grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1, 
                    cv=PredefinedSplit(val_spilt))
grid_result = grid.fit(X_train, y_train)

make a manual check here:

ps = PredefinedSplit(val_spilt)
for train_index, val_index in ps.split():
    print("TRAIN:", len(train_index), "VAL:", len(val_index))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM