简体   繁体   English

如何从包含最多3个值的2d numpy数组中获取列的索引

[英]How to get index of columns from 2d numpy array containing max 3 values

I have a array: 我有一个数组:

a = np.array([[22,11,44,33,66],
              [22,11,2,1,66],
              [1,11,44,22,4],
              [22,11,88,99,66]])

and as a output I want a array containing index of max 3 values as 2d array. 作为输出,我想要一个包含最多3个值的索引的数组作为2d数组。 For example for above array output would be: 例如,上面的数组输出将是:

array([[4,2,3],
       [4,0,1],
       [2,3,1],
       [3,2,4]])

To get the top k elements of an array, partition it. 要获取数组的前k元素,请对其进行分区 Since partitioning normally gives you the k lowest elements, use the reverse indices: 由于分区通常为您提供k最低的元素,因此请使用反向索引:

k = 3
top = np.argpartition(a, -k, axis=1)[:, -k:]

If you need to have the indices sorted in descending order, use np.argsort with the result: 如果需要按降序对索引进行排序,请对结果使用np.argsort

rows = np.arange(a.shape[0])[:, None]
s = np.argsort(a[rows, top], axis=1)[:, ::-1]
top = top[rows, s]

rows is necessary to make sure that all the indices are selected properly when you do fancy indexing with top and s . 使用tops进行华丽索引时,必须使用rows以确保正确选择了所有索引。 The indices for each row have to be reversed ( [:, ::-1] ) to get ascending order. 每行的索引必须颠倒( [:, ::-1] )以得到升序。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM