简体   繁体   English

使用索引数组中的索引元组索引多维数组 - NumPy / Python

[英]Indexing multi-dimensional array with tuple of indices from an indexing array - NumPy / Python

I have a 3-D numpy array a with dimensions (6,m,n) .我有一个尺寸为(6,m,n)的 3-D numpy 数组a I also have a 6-D boolean numpy array b with dimensions (20,20,20,20,20,20) that effectively works as a mask.我还有一个尺寸为(20,20,20,20,20,20)的 6 (20,20,20,20,20,20) numpy 数组b ,可以有效地用作掩码。

I would like to use the 6 values at each location (m,n) in the first array to retrieve the corresponding value in the second array.我想使用第一个数组中每个位置(m,n)的 6 个值来检索第二个数组中的相应值。 Effectively, I will compress the 3D int array into a 2D boolean array.实际上,我会将 3D int 数组压缩为 2D 布尔数组。 I thought the solution would be using np.where , but I don't think it can deal with using values as indices.我认为解决方案将使用np.where ,但我认为它不能处理使用值作为索引。

The naive implementation for this will be something like:对此的幼稚实现将类似于:

for i in range(m):
    for j in range(n):
         new_arr[i,j]=b[tuple(a[:,i,j])]

Is there any way to implement this without using a loop?有没有办法在不使用循环的情况下实现这一点?

Approach #1方法#1

Reshape a to 2D keeping the first axis length as the same.a重塑为2D保持第一个轴的长度相同。 Convert each thus 2D-flattened-block to a tuple and then index into b .将每个这样的 2D-flattened-block 转换为一个元组,然后索引到b This tuple-conversion leads to a packing of each elements along the first axis as an indexer to select an element each off b .这种元组转换导致沿第一个轴打包每个元素作为索引器来选择每个元素b Finally a reshaping is needed to get a 2D output.最后需要重新整形以获得2D输出。 Hence, the implementation would look something like this -因此,实现看起来像这样 -

b[tuple(a.reshape(6,-1))].reshape(m,n)

Or, skip all that reshaping mess and simply do -或者,跳过所有那些重塑混乱的事情,简单地做——

b[tuple(a)]

This does the same indexer creation and solves the problem.这会创建相同的索引器并解决问题。

Approach #2方法#2

Alternatively, we can also compute the flattened indices and then index into flattened b with those and extract relevant boolean values off it -或者,我们也可以计算扁平索引,然后用这些索引到扁平b并从中提取相关的布尔值 -

b.ravel()[np.ravel_multi_index(a,b.shape)]

Timings on a large dataset -大型数据集上的计时 -

In [89]: np.random.seed(0)
    ...: m,n = 500,500
    ...: b = np.random.rand(20,20,20,20,20,20)>0.5
    ...: a = np.random.randint(0,20,(6,m,n))

In [90]: %timeit b[tuple(a)]
14.6 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [91]: %timeit b.ravel()[np.ravel_multi_index(a,b.shape)]
7.35 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM