简体   繁体   中英

How can I get the trained model from xgboost CV?

I am running the following code:

params = {"objective":"reg:squarederror",'colsample_bytree': 0.3,'learning_rate': 0.15,
                'max_depth': 5, 'alpha': 15}

data_dmatrix = xgb.DMatrix(data=X_train,label=y_train)
cv_results = xgb.cv(dtrain=data_dmatrix, params=params, nfold=3,
                    num_boost_round=50, early_stopping_rounds=10, 
                    metrics="rmse", as_pandas=True, seed=0)

The result looks great and I would like to test the best model from the cross validation with my data I held back. But how can I get the model?

Unlike, say, scikit-learn GridSearchCV , which returns a model (optionally refitted with the whole data if called with refit=True ), xgb.cv does not return any model, only the evaluation history; from the docs :

Returns evaluation history

In this sense, it is similar to scikit-learn's cross_validate , which also does not return any model - only the metrics.

So, provided that you are happy with the CV results and you want to proceed to fit the model with all the data, you must do it separately:

bst = xgb.train(dtrain=data_dmatrix, params=params, num_boost_round=50)

XGBoost API provides the callbacks mechanism . Callbacks allow you to call custom function before and after every epoch, before and after training.

Since you need get final models after cv, we can define such callback:

class SaveBestModel(xgb.callback.TrainingCallback):
    def __init__(self, cvboosters):
        self._cvboosters = cvboosters
    
    def after_training(self, model):
        self._cvboosters[:] = [cvpack.bst for cvpack in model.cvfolds]
        return model

In case of xgb.cv the argument model in method after_training is an instance of xgb.training._PackedBooster . Now we should pass callback to xgb.cv .

cvboosters = []

cv_results = xgb.cv(dtrain=data_dmatrix, params=params, nfold=3,
                    num_boost_round=50, early_stopping_rounds=10, 
                    metrics="rmse", as_pandas=True, seed=0,
                    callbacks=[SaveBestModel(cvboosters), ])

Your models will be saved in cvboosters .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM