简体   繁体   中英

How can I find X_train indexes in the main dataset?

We can split the dataset to X_train, y_train by Sklearn function in Python.

X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, test_size=0.3)

My question is: how can we find the X_train or y_train indexes in our data set?

suppose we found the prediction by

prediction = model.predict(X_test)

Also, how can we find the indexes for prediction?

I am asking because I would like to see each row's values when I get inaccurate results.

In other words, data is the main dataset and subset is data's subset

data = array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
subest = array([ 2, 4, 5, 6])

How can I find the subset's index in data?

As documented in sklearn.model_selection.train_test_split , it is a quick application of sklearn.model_selection.ShuffleSplit :

from sklearn.model_selection import ShuffleSplit, train_test_split

x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=1, test_size=1)
x_train
array([[2, 3],
       [8, 9],
       [0, 1],
       [6, 7]])

This is yield by the split sets of indices from ShuffleSplit :

train_ind, test_ind = next(ShuffleSplit(random_state=1).split(X, y))
X[train_ind]
array([[2, 3],
       [8, 9],
       [0, 1],
       [6, 7]])

So you can use train_ind and/or test_ind made by ShuffleSplit and it will be just same as using train_test_split

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM