简体   繁体   English

更快的替代numpy.where?

[英]faster alternative to numpy.where?

I have a 3d array filled with integers from 0 to N. I need a list of the indices corresponding to where the array is equal 1, 2, 3, ... N. I can do it with np.where as follows: 我有一个3d数组,填充从0到N的整数。我需要一个索引列表,对应于数组所在的位置1,2,3,... N.我可以用np.where来完成,如下所示:

N = 300
shape = (1000,1000,10)
data = np.random.randint(0,N+1,shape)
indx = [np.where(data == i_id) for i_id in range(1,data.max()+1)]

but this is quite slow. 但这很慢。 According to this question fast python numpy where functionality? 根据这个问题快速python numpy在哪里功能? it should be possible to speed up the index search quite a lot, but I haven't been able to transfer the methods proposed there to my problem of getting the actual indices. 应该可以加快索引搜索的速度,但是我无法将那里提出的方法转移到我获取实际索引的问题上。 What would be the best way to speed up the above code? 什么是加速上述代码的最佳方法?

As an add-on: I want to store the indices later, for which it makes sense to use np.ravel_multi_index to reduce the size from saving 3 indices to only 1, ie using: 作为一个附加组件:我想稍后存储索引,为此有意义的是使用np.ravel_multi_index来减小从保存3个索引到仅1的大小,即使用:

indx = [np.ravel_multi_index(np.where(data == i_id), data.shape) for i_id in range(1, data.max()+1)]

which is closer to eg Matlab's find function. 这更接近于Matlab的find函数。 Can this be directly incorporated in a solution that doesn't use np.where? 这可以直接包含在不使用np.where的解决方案中吗?

I think that a standard vectorized approach to this problem would end up being very memory intensive – for int64 data, it would require O(8 * N * data.size) bytes, or ~22 gigs of memory for the example you gave above. 我认为这个问题的标准向量化方法最终会占用大量内存 - 对于int64数据,它需要O(8 * N * data.size)字节,或者上面给出的示例~22 gig的内存。 I'm assuming that is not an option. 我假设这不是一个选择。

You might make some progress by using a sparse matrix to store the locations of the unique values. 您可以通过使用稀疏矩阵来存储唯一值的位置来取得一些进展。 For example: 例如:

import numpy as np
from scipy.sparse import csr_matrix

def compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)),
                      shape=(data.max() + 1, data.size))

def get_indices_sparse(data):
    M = compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]

This takes advantage of fast code within the sparse matrix constructor to organize the data in a useful way, constructing a sparse matrix where row i contains just the indices where the flattened data equals i . 这利用稀疏矩阵构造函数内的快速代码以有用的方式组织数据,构造稀疏矩阵,其中行i仅包含展平数据等于i的索引。

To test it out, I'll also define a function that does your straightforward method: 为了测试它,我还将定义一个执行您直接方法的函数:

def get_indices_simple(data):
    return [np.where(data == i) for i in range(0, data.max() + 1)]

The two functions give the same results for the same input: 这两个函数为相同的输入提供相同的结果:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_sparse(data_small)))
# True

And the sparse method is an order of magnitude faster than the simple method for your dataset: 稀疏方法比数据集的简单方法快一个数量级:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.1 s, sys: 638 ms, total: 14.7 s
# Wall time: 14.8 s

%time ind = get_indices_sparse(data)
# CPU times: user 881 ms, sys: 301 ms, total: 1.18 s
# Wall time: 1.18 s

%time M = compute_M(data)
# CPU times: user 216 ms, sys: 148 ms, total: 365 ms
# Wall time: 363 ms

The other benefit of the sparse method is that the matrix M ends up being a very compact and efficient way to store all the relevant information for later use, as mentioned in the add-on part of your question. 稀疏方法的另一个好处是矩阵M最终是一种非常紧凑和有效的方式来存储所有相关信息供以后使用,如问题的附加部分所述。 Hope that's useful! 希望这很有用!


Edit: I realized there was a bug in the initial version: it failed if any values in the range didn't appear in the data: that's now fixed above. 编辑:我意识到初始版本中存在一个错误:如果数据中没有出现任何值,则失败:现在已经修复了。

I was mulling on this and realized that there's a more intuitive (but slightly slower) approach to solving this using Pandas groupby() . 我正在考虑这一点并意识到使用Pandas groupby()解决这个问题的方法更为直观(但速度稍慢groupby() Consider this: 考虑一下:

import numpy as np
import pandas as pd

def get_indices_pandas(data):
    d = data.ravel()
    f = lambda x: np.unravel_index(x.index, data.shape)
    return pd.Series(d).groupby(d).apply(f)

This returns the same result as get_indices_simple from my previous answer: 这将从我之前的答案返回与get_indices_simple相同的结果:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_pandas(data_small)))
# True

And this Pandas approach is just slightly slower than the less intuitive matrix approach: 而且这种Pandas方法比不那么直观的矩阵方法略慢:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.2 s, sys: 665 ms, total: 14.8 s
# Wall time: 14.9 s

%time ind = get_indices_sparse(data)
# CPU times: user 842 ms, sys: 277 ms, total: 1.12 s
# Wall time: 1.12 s

%time ind = get_indices_pandas(data)
# CPU times: user 1.16 s, sys: 326 ms, total: 1.49 s
# Wall time: 1.49 s

Here's one vectorized approach - 这是一种矢量化方法 -

# Mask of matches for data elements against all IDs from 1 to data.max()
mask = data == np.arange(1,data.max()+1)[:,None,None,None]

# Indices of matches across all IDs and their linear indices
idx = np.argwhere(mask.reshape(N,-1))

# Get cut indices where IDs shift
_,cut_idx = np.unique(idx[:,0],return_index=True)

# Cut at shifts to give us the final indx output
out = np.hsplit(idx[:,1],cut_idx[1:])

Basically, most answers to the other question have the message "use indirect sorting". 基本上,对其他问题的大多数答案都有“使用间接排序”的消息。

We can get the linear indices (so similar to find in MATLAB) corresponding to i = [0..N] with a call to numpy.argsort over the flattened array: 我们可以通过在flattened数组上调用numpy.argsortnumpy.argsorti = [0..N]相对应的线性索引(在MATLAB中类似于find ):

flat = data.ravel()
lin_idx = np.argsort(flat, kind='mergesort')

But then we get a single big array; 但后来我们得到了一个大阵列; which indices belong to which i ? 该指数属于哪个i We just split the indices array based on the counts of each i : 我们只根据每个i的计数拆分索引数组:

ans = np.split(lin_idx, np.cumsum(np.bincount(flat)[:-1]))

If you still need the 3D indices somewhere, you can use numpy.unravel_index . 如果你仍然需要某处的3D索引,你可以使用numpy.unravel_index

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM