简体   繁体   中英

Does scikit-learn train_test_split preserve relationships?

I am trying to understand this code. I do not understand how if you do:

x_validation, x_test, y_validation, y_test = 
  train_test_split(x_validation_and_test, y_validation_and_test...

you can later do:

(len(x_validation[y_validation == 0])

surely the train_test_split means x_validation and y_validation aren't related. What am I missing?

EDIT: There are some good answers already but I just want to clarify. Are x_validation and y_validation guaranteed to be in the correct order, and the same as each other. Obviously you could add a row to either and mess things up, but is there an underlying index that means order is preserved? I come from a non-python background and sometimes you could not guarantee order of things like SQL columns.

You absolutely do want the x_validation to be related to the y_validation , ie correspond to the same rows as you had in your original dataset. eg if Validation takes rows 1,3,7 from the input x, you would want rows 1, 3, 7 in both the x_validation and y_validation .

The idea of the train_test_split function to divide your dataset up into a two sets of features (the x s) and the corresponding labels (the y s). So you want and require

len(x_validation) == len(y_validation)

and

len(x_test) == len(y_test)

Looking at other parts of you question that might be causing confusion:

y_validation == 0

will generate a boolean mask of True and False values that you can use to select only those rows from any data frame with the same length, so in this case it will also work with x_validataion .

As an aside,

len(x_validation[y_validation == 0])

Seems a slightly confusing way of counting the number of examples that are of class 0 . I would have gone for

(y_validation == 0).sum()

myself and then you can write the % negative calculation as

100*(y_validation == 0).sum()/len(y_validation)

which Is a bit neater to me.

train_test_split takes n random indexes from the length of your data, and returns the values at these indexes in xtrain , ytrain , or whatever you pass. Look at this simple demonstration:

import numpy as np
from sklearn.model_selection import train_test_split

data = np.random.randint(0, 10, 15) 
target = data**2
indices = np.arange(15)

xtrain, xtest, ytrain, ytest, indicestrain, indicestest = \
    train_test_split(data, target, indices)

print(data)
print(target)
print(indices)
[7 7 4 1 1 3 1 6 8 2 2 1 9 9 9] # random val between 0 and 10
[49 49 16  1  1  9  1 36 64  4  4  1 81 81 81] # their squared values
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]  # the index in the original array (data)

Now, take some random indices from these, and let's see if they keep being related:

xtrain, xtest, ytrain, ytest, indicestrain, indicestest = \
    train_test_split(data, target, indices)

print(xtrain)
print(ytrain)
print(indicestrain)
[9 9 2 1 8 2 4 7 1 6 7]
[81 81  4  1 64  4 16 49  1 36 49]
[12 14 10 11  8  9  2  1  3  7  0]

As you can see, the second row is the squared of the first, meaning that the order was preserved.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM