简体   繁体   中英

how to find indices of k smallest numbers in a multidimentional array?

I want to form an array containing indices of k smallest values in an array:

import heapq
import numpy as np

a= np.array([[1, 3, 5, 2, 3],
       [7, 6, 5, 2, 4],
       [2, 0, 5, 6, 4]])

[t[0] for t in heapq.nsmallest(2,enumerate(a[1]),lambda(t):t[1])]
===[3, 4]

But this fails:

[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]

Traceback (most recent call last):
  File "<pyshell#19>", line 1, in <module>
    [t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
TypeError: 'numpy.bool_' object is not iterable

Your problem is in a.all() :

[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]

This checks the truthness of all the elements of your array, that is False (as you have a 0).

If the arrays are not too big compared with k, you can get the values using .argsort . Here I will select the positions of the two biggest for each row:

print a.argsort()[:,:2]

array([[0, 3],
       [3, 4],
       [1, 0]])

If you want the positions of the global minima, flatten the array firts:

a.flatten().argsort()[:2]

If the arrays are very large, you can get a better performance using np.argpartition , that will perform only a partial sort.

You can use numpy.ndenumerate with a heap, or a partial sort as suggested by David:

a = np.array([[1, 3, 5, 2, 3],
       [7, 6, 5, 2, 4],
       [2, 0, 5, 6, 4]])
heap = [(v, k) for k,v in numpy.ndenumerate(npa)]
heapq.heapify(heap)
heapq.nsmallest(10, heap) # for k = 10

And you get:

[(0, (2, 1)),
 (1, (0, 0)),
 (2, (0, 3)),
 (2, (1, 3)),
 (2, (2, 0)),
 (3, (0, 1)),
 (3, (0, 4)),
 (4, (1, 4)),
 (4, (2, 4)),
 (5, (0, 2))]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM