简体   繁体   中英

2D numpy argsort index returns 3D when used in the original matrix

I am trying to obtain the top 2 values from each row in a matrix using argsort. The indexing is working, as in argsort is returning the correct values. However, when I put the argsort result as an index, it returns a 3 dimensional result.

For example:

test_mat = np.matrix([[0 for i in range(5)] for j in range(5)])
for i in range(5):
    for j in range(5):
        test_mat[i, j] = i * j
test_mat[range(2,3)] = test_mat[range(2,3)] * -1

last_two = range(-1, -3, -1)
index = np.argsort(test_mat, axis=1)
index = index[:, last_k]

This gives:

index.shape
Out[402]: (5L, 5L)

test_mat[index].shape
Out[403]: (5L, 5L, 5L)

Python is new to me and I find indexing to be very confusing in general even after reading the various array manuals. I spend more time trying to get the right values out of objects than actually solving problems. I'd welcome any tips on where to properly learn what is going on. Thanks.

You can use linear indexing to solve your case, like so -

# Say A is your 2D input array 

# Get sort indices for the top 2 values in each row
idx = A.argsort(1)[:,::-1][:,:2]

# Get row offset numbers
row_offset = A.shape[1]*np.arange(A.shape[0])[:,None]

# Add row offsets with top2 sort indices giving us linear indices of 
# top 2 elements in each row. Index into input array with those for output.
out = np.take( A, idx + row_offset )

Here's a step-by-step sample run -

In [88]: A
Out[88]: 
array([[34, 45, 16, 20, 24],
       [37, 13, 49, 37, 21],
       [42, 36, 35, 24, 18],
       [26, 28, 21, 13, 44]])

In [89]: idx = A.argsort(1)[:,::-1][:,:2]

In [90]: idx
Out[90]: 
array([[1, 0],
       [2, 3],
       [0, 1],
       [4, 1]])

In [91]: row_offset = A.shape[1]*np.arange(A.shape[0])[:,None]

In [92]: row_offset
Out[92]: 
array([[ 0],
       [ 5],
       [10],
       [15]])

In [93]: np.take( A, idx + row_offset )
Out[93]: 
array([[45, 34],
       [49, 37],
       [42, 36],
       [44, 28]])

You can directly get the top 2 values from each row with just sorting along the second axis and some slicing , like so -

out = np.sort(A,1)[:,:-3:-1]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM