简体   繁体   中英

How does sklearn cross_val_score use kfold?

I am new to machine learning and am trying to understand cross_val_score uses Kfold to split the data to k folds.

kf = KFold(n_splits=2)
cv_results =cross_val_score(model, X_train, Y_train, cv=kf)

I know kfold splits the data but I tried printing it out

dataset = [[1,1,1],[2,2,2],[3,3,3],[4,4,4],[5,5,5],[6,6,6],[7,7,7],[8,8,8]]
kf =  KFold(n_splits=2)
print kf

>>> KFold(n_splits=2, random_state=None, shuffle=False)

It doesn't show the k folds but then how does cross_val_score get all the folds?

You need to call Kf.split(dataset) to actually split the data. Click here to see how KFold works

Just to make it clear, KFold is a class and not a function.

kf = KFold(n_splits=2) creates an object of KFold. and print kf will just print out the class object.

and when you call cross_val_score(model, X_train, Y_train, cv=kf) you are passing the object kf to cross_val_score function where kf.split(X_train) would be called to split X_train into 2 folds. Y_train would also be splitted similarly.

Try this

kf = KFold(n_splits=2)
generator = kf.split(dataset)
for train, test in generator:
    print "*" * 20
    print "Training Data:"
    for i in train:
        print dataset[i]
    print "Test Data:"
    for j in test:
        print dataset[j]

kf.split(dataset) returns a generator. Iterating through the generator would give you all the folds

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM