简体   繁体   English

将索引的 function 应用于 numpy 数组的所有元素

[英]Apply a function of the indices to all elements of a numpy array

I am looking for a way to apply a function to all elements of a numpy array.我正在寻找一种将 function 应用于 numpy 数组的所有元素的方法。 The function receives not only the value but also the indices of the element as arguments. function 不仅接收值,还接收元素的索引,如 arguments。 The goal is perfomance on large 2- or 3-dim.目标是在大的 2 或 3 维度上表现出色。 arrays. arrays。

(I know there are several ways to do that with a function that receives the value of an element only) (我知道有几种方法可以使用仅接收元素值的 function 来做到这一点)

The code to be replaced is要替换的代码是

def foo(x, i, j)
    return (i*x)**2 - (j*x)**3  # or some other fancy stuff

...

arr = np.zeros((nx, ny))

...

# nested loop to be replaced, e.g. via vectorization
for i in range(nx):
    for j in range(ny):
        arr[i,j] = foo(arr[i,j], i, j)               

You can do this with simple broadcasting rules, by using suitably generated indices with the proper shapes so that standard broadcasting rules match the shape of the input.您可以使用简单的广播规则来做到这一点,方法是使用适当生成的具有适当形状的索引,以便标准广播规则与输入的形状相匹配。

This can be generated either manually (eg with a combination of np.arange() and np.reshape() ) or more concisely with np.ogrid() .这可以手动生成(例如,使用np.arange()np.reshape()的组合)或更简洁地使用np.ogrid()

import numpy as np


import numpy as np


np.random.seed(0)


def foo(x):
    n, m = arr.shape
    i, j = np.ogrid[:n, :m]
    return (i * x) ** 2 - (j * x) ** 3


n, m = 2, 3
arr = np.random.random((n, m))


foo(arr)
# array([[ 0.        , -0.36581638, -1.7519857 ],
#        [ 0.29689768,  0.10344439, -1.73844954]])

This approach would require potentially large temporary memory arrays for the intermediate results.这种方法可能需要较大的临时 memory arrays 以获得中间结果。


A more efficient approach can be obtained by keeping explicit looping to be accelerated with a JIT compiler like numba :通过使用像numba这样的 JIT 编译器来加速显式循环,可以获得更有效的方法:

import numba as nb


@nb.njit
def foo_nb(arr):
    n, m = arr.shape
    out = np.empty((n, m), dtype=arr.dtype)
    for i in range(n):
        for j in range(m):
            x = arr[i, j]
            out[i, j] = (i * x) ** 2 - (j * x) ** 3
    return out


foo_nb(arr)
# array([[ 0.        , -0.36581638, -1.7519857 ],
#        [ 0.29689768,  0.10344439, -1.73844954]])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM