繁体   English   中英

查找 numpy 数组的 k 个最小值的索引

[英]Find the index of the k smallest values of a numpy array

为了找到最小值的索引,我可以使用argmin

import numpy as np
A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
print A.argmin()     # 4 because A[4] = 0.1

但是我怎样才能找到k 最小值的索引呢?

我正在寻找类似的东西:

 print A.argmin(numberofvalues=3) # [4, 0, 7] because A[4] <= A[0] <= A[7] <= all other A[i]

注意:在我的用例中,A 有大约 10 000 到 100 000 个值,我只对 k=10 最小值的索引感兴趣。 k 永远不会 > 10。

使用np.argpartition 它不会对整个数组进行排序。 它只保证kth元素处于排序位置,并且所有较小的元素都将移动到它之前。 因此,前k元素将是 k 个最小的元素。

import numpy as np

A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
k = 3

idx = np.argpartition(A, k)
print(idx)
# [4 0 7 3 1 2 6 5]

这将返回 k 最小值。 请注意,这些可能不是按排序顺序排列的。

print(A[idx[:k]])
# [ 0.1  1.   1.5]

要获得 k 最大值,请使用

idx = np.argpartition(A, -k)
# [4 0 7 3 1 2 6 5]

A[idx[-k:]]
# [  9.  17.  17.]

警告:不要(重新)使用idx = np.argpartition(A, k); A[idx[-k:]] idx = np.argpartition(A, k); A[idx[-k:]]获得 k 最大。 那不会总是奏效。 例如,这些不是x的 3 个最大值:

x = np.array([100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0])
idx = np.argpartition(x, 3)
x[idx[-3:]]
array([ 70,  80, 100])

这是与np.argsort的比较,它也有效,但只是对整个数组进行排序以获得结果。

In [2]: x = np.random.randn(100000)

In [3]: %timeit idx0 = np.argsort(x)[:100]
100 loops, best of 3: 8.26 ms per loop

In [4]: %timeit idx1 = np.argpartition(x, 100)[:100]
1000 loops, best of 3: 721 µs per loop

In [5]: np.alltrue(np.sort(np.argsort(x)[:100]) == np.sort(np.argpartition(x, 100)[:100]))
Out[5]: True

您可以将numpy.argsort与切片numpy.argsort使用

>>> import numpy as np
>>> A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
>>> np.argsort(A)[:3]
array([4, 0, 7], dtype=int32)

对于n 维数组,此函数运行良好。 indecies 以可调用的形式返回。 如果要返回索引列表,则需要在创建列表之前转置数组。

要检索最大的k ,只需传入-k

def get_indices_of_k_smallest(arr, k):
    idx = np.argpartition(arr.ravel(), k)
    return tuple(np.array(np.unravel_index(idx, arr.shape))[:, range(min(k, 0), max(k, 0))])
    # if you want it in a list of indices . . . 
    # return np.array(np.unravel_index(idx, arr.shape))[:, range(k)].transpose().tolist()

例子:

r = np.random.RandomState(1234)
arr = r.randint(1, 1000, 2 * 4 * 6).reshape(2, 4, 6)

indices = get_indices_of_k_smallest(arr, 4)
indices
# (array([1, 0, 0, 1], dtype=int64),
#  array([3, 2, 0, 1], dtype=int64),
#  array([3, 0, 3, 3], dtype=int64))

arr[indices]
# array([ 4, 31, 54, 77])

%%timeit
get_indices_of_k_smallest(arr, 4)
# 17.1 µs ± 651 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

numpy.partition(your_array, k)是另一种选择。 不需要切片,因为它给出了排序到第kth元素的值。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM