繁体   English   中英

查找非零元素的索引并按值分组

[英]Find the indices of non-zero elements and group by values

我在python中编写了一个代码,它接受一个numpy矩阵作为输入,并返回一个由相应值分组的索引列表(即output [3]返回值为3的所有索引)。 但是,我缺乏编写矢量化代码的知识,必须使用ndenumerate来完成。 此操作仅花费大约9秒,这太慢了。

我的第二个想法是使用numpy.nonzero如下:

for i in range(1, max_value):
   current_array = np.nonzero(input == i)
   # save in an array

这需要5.5秒,所以这是一个很好的改进,但仍然很慢。 有没有循环或优化方式来获得每个值的索引对的任何方法?

这是针对您的问题的O(n log n)算法。 显而易见的循环解决方案是O(n),因此对于足够大的数据集,这将更慢:

>>> a = np.random.randint(3, size=10)
>>> a
array([1, 2, 2, 0, 1, 0, 2, 2, 1, 1])

>>> index = np.arange(len(a))
>>> sort_idx = np.argsort(a)
>>> cnt = np.bincount(a)
>>> np.split(index[sort_idx], np.cumsum(cnt[:-1]))
[array([3, 5]), array([0, 4, 8, 9]), array([1, 2, 6, 7])]

它取决于数据的大小,但对于较大的数据集来说速度相当快:

In [1]: a = np.random.randint(1000, size=1e6)

In [2]: %%timeit
   ...: indices = np.arange(len(a))
   ...: sort_idx = np.argsort(a)
   ...: cnt = np.bincount(a)
   ...: np.split(indices[sort_idx], np.cumsum(cnt[:-1]))
   ...: 
10 loops, best of 3: 140 ms per loop

如果你愿意使用一些额外的内存,你可以通过广播进行矢量化:

import numpy as np
input = np.random.randint(1,max_value, 100)
indices = np.arange(1, max_value)

matches = input == indices[:,np.newaxis]  # broadcasts across each index

然后,每个索引i的匹配只是np.nonzero(matches[i])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM