简体   繁体   中英

Understanding Cross Validation for Machine learning

Is the following correct about cross validation?:

The training data is divided into different groups, all but one of the training data sets is used for training the model. Once the model is trained the 'left out' training data is used to perform hyperparameter tuning. Once the most optimal hyperparameters have been chosen the test data is applied to the model to give a result which is then compared to other models that have undergone a similar process but with different combinations of training data sets. The model with the best results on the test data is then chosen.

在此处输入图像描述

I don't think it is correct. You wrote:

Once the model is trained the 'left out' training data is used to perform hyperparameter tuning

You tune the model by picking (manually or using a method like grid search or random search) a set of model's hyperparameters (parameters which values are set by you, before you will even fit the model to data). Then for a selected set of hyperparameters' values you calculate the validation set error using Cross-Validation.

So it should be like this:

The training data is divided into different groups, all but one of the training data sets is used for training the model. Once the model is trained the 'left out' training data is used to...

... calculate the error. At the end of the cross validation, you will have k errors calculated on k left out sets. What you do next is calculating a mean of these k errors which gives you a single value - validation set error.

If you have n sets of hyperparameters, you simply repeat the procedure n times, which gives you n validation set errors. You then pick this set, that gave you the smallest validation error.

At the end, you will typically calculate the test set error to see what is the model's performance on unseen data, which simulates putting a model into production and to see whether there is a difference between test set error and validation set error. If there is a significant difference, it means over-fitting.

Just to add something on cross-validation itself, the reason why we use k-CV or LOOCV is that it is great test set error estimate, which means that when I manipulate with hyperparameters and the value of validation set error dropped down, I know that I really improved model instead of being lucky and simply better fitting the model to train set.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM