繁体   English   中英

基于参考 n-dim 数组对 n-dim 数组进行操作的最有效方法

[英]Most efficient way to operate on a n-dim array based on a reference n-dim array

我有两个相同形状的 numpy arrays : dat_araref_ara

我想在op_funcaxis = -1上执行dat_ara操作,但是我只想对每个数组中的选定值切片进行操作,当参考数组thres超过阈值 thres 时指定ref_ara

为了说明,在 arrays 只是 2-dim 的简单情况下,我有:

thres = 4

op_func = np.average

ref_ara = array([[1, 2, 1, 4, 3, 5, 1, 5, 2, 5],
                 [1, 2, 2, 1, 1, 1, 2, 7, 5, 8],
                 [2, 3, 2, 5, 1, 6, 5, 2, 7, 3]]) 

dat_ara = array([[1, 0, 0, 1, 1, 1, 1, 0, 1, 1],
                 [1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
                 [1, 0, 1, 1, 1, 1, 0, 1, 1, 1]]) 

我们看到 thres 在thresaxis=0中的第一个、第二个和第三个数组的第 5 个、第 7 个和第 3 个索引中被ref_ara 因此我想要的结果是

out_ara = array([op_func(array([1, 0, 0, 1, 1, 1]), 
                 op_func(array([1, 1, 1, 1, 1, 1, 1, 0]),
                 op_func(array([1, 0, 1, 1])])

这个问题很困难,因为它需要引用ref_ara 如果不是这样,我可以简单地使用numpy.apply_along_axis

我尝试扩展两个 arrays 的尺寸以将它们关联起来进行计算,即:

assos_ara = np.append(np.expand_dims(dat_ara, axis=-1), np.expand_dims(ref_ara, axis=-1), axis=-1)

But again, numpy.apply_along_axis requires the input function to only operate on 1-dim arrays, and thus I still cannot utilise the function.

我知道的唯一其他方法是明智地遍历 arrays 索引,但是,arrays 具有两个 arrays 的不断变化的尺寸,这不是计算效率高的问题,而是一个棘手的问题。

我很想使用矢量化函数来帮助这个过程。 go 最有效的方法是什么?

这是屏蔽 arrays 的一个很好的用例,因为它们允许您对部分数据执行正常的 numpy 操作。

假设每一行至少包含一个大于阈值的值。 您可以将断点的索引计算为

breaks = np.argmax(ref_ara > thres, axis=-1)   # 5, 7, 3

然后,您可以使用我之前链接的问题答案创建一个掩码。 掩码通常是处理 numpy 中不规则形状数据的最佳方式。

mask = np.arange(ref_ara.shape[-1]) <= breaks.reshape(*breaks.shape, 1)

在这里,我们不需要对arange做任何花哨的事情,因为它位于最后一个维度。 如果不是这种情况,您可能希望在范围为 go 的中断形状中插入一个 1,并用 1 填充范围形状的尾部。

现在掩码数组和 ufunc 解决方案略有不同。 掩码数组版本更通用,所以它首先出现:

data = np.ma.array(data_ara, mask=~mask)

掩码的 arrays 从正常的 boolean 索引的方式向后解释掩码,因此我们反转掩码。 或者,您可以使用>而不是<=来计算掩码。 计算现在很简单:

out_ara = np.ma.average(data, axis=-1).data

一个不太通用的替代方法是将您的操作分解为 ufunc,并使用它们提供的掩码。 这对于np.average来说很容易,它只是np.sumnp.divide ,但对于更复杂的操作可能更难。

从 numpy 1.17.0 开始, np.sum有一个where关键字:

out_ara = np.sum(dat_ara, where=mask, axis=-1) / breaks

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM