简体   繁体   中英

How to select indices of rows in numpy array?

I have the following numpy array y_train :

y_train =

2
2
1
0
1
1 
2
0
0

I need to randomly select n (n=2) indices of rows as follows:

n=2 
n indices of rows where y=0 
n indices of rows where y=1 
n indices of rows where y=2

I use the following code:

n=2
idx = [y_train[np.random.choice(np.where(y_train==np.unique(y_train)[I])[0],n)].index.tolist() \
 for i in np.unique(y_train).astype(int)]

Error in my real array y_train :

KeyError: '[70798 63260 64755 ...  7012 65605 45218] not in index'

If your expected output is a list of randomly selected indices for each unique value in y_train :

idx = [np.random.choice(np.where(y_train == i)[0], size=2, \
       replace=False) for i in np.unique(y_train)]

OUTPUT:

[array([7, 8]), array([5, 4]), array([1, 0])]

If you want to flatten the arrays into a single array:

idx = np.array(idx).flatten()

OUTPUT:

array([7, 8, 5, 4, 1, 0])

One alternative solution to get the desired indices is using nonzero and simply looping over range(n+1)

y_train = np.array([2,2,1,0,1,1,2,0,0])

indices = [np.random.choice((y_train==i).nonzero()[0], 2, replace=False) for i in range(n+1)]
print (indices)
# [array([7, 3]), array([5, 4]), array([0, 1])]

print (np.array(indices).ravel())
# [7 3 5 4 0 1]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM