[英]Increasing performance of highly repeated numpy array index operations
在我的程序代码中,我有numpy
值数组和numpy
索引数组。 两者都是在程序初始化期间预先分配和预定义的。
程序的每一部分都有一个用于执行计算的数组values
,以及三个索引数组idx_from_exch
、 idx_values
和idx_to_exch
。 有一个全局值数组来交换几个部分的值: exch_arr
。
大多数情况下,索引数组有 2 到 5 个索引,很少(很可能永远不会)需要更多索引。 dtype=np.int32
, shape
和值在整个程序运行期间是恒定的。 因此我在初始化后设置ndarray.flags.writeable=False
,但这是可选的。 索引数组idx_values
和idx_to_exch
的索引值按数字顺序排序, idx_source
可以排序,但无法定义。 对应于一个值数组/部分的所有索引数组具有相同的shape
。
values
数组和exch_arr
通常有 50 到 1000 个元素。 shape
和dtype=np.float64
在整个程序运行过程中保持不变,数组的值在每次迭代中都会发生变化。
以下是示例数组:
import numpy as np
import numba as nb
values = np.random.rand(100) * 100 # just some random numbers
exch_arr = np.random.rand(60) * 3 # just some random numbers
idx_values = np.array((0, 4, 55, -1), dtype=np.int32) # sorted but varying steps
idx_to_exch = np.array((7, 8, 9, 10), dtype=np.int32) # sorted and constant steps!
idx_from_exch = np.array((19, 4, 7, 43), dtype=np.int32) # not sorted and varying steps
示例索引操作如下所示:
values[idx_values] = exch_arr[idx_from_exch] # get values from exchange array
values *= 1.1 # some inplace array operations, this is just a dummy for more complex things
exch_arr[idx_to_exch] = values[idx_values] # pass some values back to exchange array
由于这些操作在几百万次迭代中每次迭代应用一次,因此速度至关重要。 在我之前的问题中,我一直在研究提高索引速度的许多不同方法,但是考虑到我的应用程序(尤其是通过使用常量索引数组进行索引并将它们传递给另一个索引数组来获取值),我忘记了足够具体。
到目前为止,最好的方法似乎是花哨的索引。 我目前也在试验numba
guvectorize
,但似乎不值得付出努力,因为我的数组很小。 memoryviews
会很好,但由于索引数组不一定具有一致的步骤,我知道没有办法使用memoryviews
。
那么有没有更快的方法来进行重复索引? 为每个索引操作预定义内存地址数组的某种方法,因为dtype
和shape
总是恒定的? ndarray.__array_interface__
给了我一个内存地址,但我无法将它用于索引。 我想过这样的事情:
stride_exch = exch_arr.strides[0]
mem_address = exch_arr.__array_interface__['data'][0]
idx_to_exch = idx_to_exch * stride_exch + mem_address
那可行吗?
我也一直在寻找到使用strides
直接与as_strided
,但据我所知,只有一致的步伐是允许的,我的问题就需要不一致strides
。
任何帮助表示赞赏! 提前致谢!
编辑:
我刚刚在我的示例计算中纠正了一个巨大的错误!
操作values = values * 1.1
更改数组的内存地址。 我在程序代码中的所有操作都不会改变数组的内存地址,因为很多其他操作都依赖于使用内存视图。 因此,我用正确的就地操作替换了虚拟操作: values *= 1.1
使用 numpy 布尔数组绕过昂贵的花哨索引的一种解决方案是使用 numba 并跳过 numpy 布尔数组中的 False 值。
示例实现:
@numba.guvectorize(['float64[:], float64[:,:], float64[:]'], '(n),(m,n)->(m)', nopython=True, target="cpu")
def test_func(arr1, arr2, inds, res):
for i in range(arr1.shape[0]):
if not inds[i]:
continue
for j in range(arr2.shape[0]):
res[j, i] = arr1[i] + arr2[j, i]
当然,使用 numpy 数据类型(较小的字节大小会运行得更快)并且目标是"cpu"
或"parallel"
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.