簡體   English   中英

查找非零元素的索引並按值分組

[英]Find the indices of non-zero elements and group by values

我在python中編寫了一個代碼,它接受一個numpy矩陣作為輸入,並返回一個由相應值分組的索引列表(即output [3]返回值為3的所有索引)。 但是,我缺乏編寫矢量化代碼的知識,必須使用ndenumerate來完成。 此操作僅花費大約9秒,這太慢了。

我的第二個想法是使用numpy.nonzero如下:

for i in range(1, max_value):
   current_array = np.nonzero(input == i)
   # save in an array

這需要5.5秒,所以這是一個很好的改進,但仍然很慢。 有沒有循環或優化方式來獲得每個值的索引對的任何方法?

這是針對您的問題的O(n log n)算法。 顯而易見的循環解決方案是O(n),因此對於足夠大的數據集,這將更慢:

>>> a = np.random.randint(3, size=10)
>>> a
array([1, 2, 2, 0, 1, 0, 2, 2, 1, 1])

>>> index = np.arange(len(a))
>>> sort_idx = np.argsort(a)
>>> cnt = np.bincount(a)
>>> np.split(index[sort_idx], np.cumsum(cnt[:-1]))
[array([3, 5]), array([0, 4, 8, 9]), array([1, 2, 6, 7])]

它取決於數據的大小,但對於較大的數據集來說速度相當快:

In [1]: a = np.random.randint(1000, size=1e6)

In [2]: %%timeit
   ...: indices = np.arange(len(a))
   ...: sort_idx = np.argsort(a)
   ...: cnt = np.bincount(a)
   ...: np.split(indices[sort_idx], np.cumsum(cnt[:-1]))
   ...: 
10 loops, best of 3: 140 ms per loop

如果你願意使用一些額外的內存,你可以通過廣播進行矢量化:

import numpy as np
input = np.random.randint(1,max_value, 100)
indices = np.arange(1, max_value)

matches = input == indices[:,np.newaxis]  # broadcasts across each index

然后,每個索引i的匹配只是np.nonzero(matches[i])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM