简体   繁体   中英

Python: Apply function to every entry in numpy 3d array

I would like to apply a (more complex?) function on my 3d numpy array with the shape x,y,z = (4,4,3). Let's assume I have the following array:

array = np.arange(48)
array = array.reshape([4,4,3])

Now I would like to call the following function on each point of the array:

p(x,y,z) = a(z) + b(z)*ps(x,y)

Let's assume a and b are the following 1d arrays, respectively ps a 2d array.

a = np.random.randint(1,10, size=3)
b = np.random.randint(1,10, size=3)
ps = np.arrange(16)
ps = ps.reshape([4,4])

My intuitive approach was to loop over my array and call the function on each point. It works, but of course it's way too slow:

def calcP(a,b,ps,x,y,z):
    p = a[z]+b[z]*ps[x,y]
    return p

def stupidLoop(array, a, b, ps, x, y, z):
    dummy = array
    for z in range (0, 3):
        for x in range (0, 4):
            for y in range (0, 4):
                dummy[x,y,z]=calcP(a,b,ps,x,y,z)
    return dummy

updatedArray=stupidLoop(array,a, b, ps, x, y, z)

Is there a faster way? I know it works with vectorized functions, but I cannot figure it out with mine.

I didn't actually try it with these numbers. It's just to exemplify my problem. It comes from the Meteorology world and is a little more complex.

Vectorize the loop, and use broadcasting :

a.reshape([1,1,-1]) + b.reshape([1,1,-1]) * ps.reshape([4,4,1])

EDIT:

Thanks @NilsWerner for offering a more common way in comment:

a + b * ps[:, :, None]

You can do this using numpy.fromfunction() :

import numpy as np

a = np.random.randint(1,10, size=3)
b = np.random.randint(1,10, size=3)
ps = np.arange(16)
ps = ps.reshape([4,4])

def calcP(x,y,z,a=a,b=b,ps=ps):
    p = a[z]+b[z]*ps[x,y] + 0.0
    return p

array = np.arange(48)
array = array.reshape([4,4,3])

updatedArray = np.fromfunction(calcP, (4,4,3), a=a,b=b,ps=ps, dtype=int)
print (updatedArray)

Notice that I've modified your function calcP slightly, to take kwargs. Also, I've added 0.0 , to ensure that the output array will be of float s and not int s.

Also, notice that the second argument to fromfunction() merely specifies the shape of the grid, over which the function calcP() is to be invoked.

Output (will vary each time due to randint):

[[[  8.   5.   3.]
  [  9.   6.  12.]
  [ 10.   7.  21.]
  [ 11.   8.  30.]]

 [[ 12.   9.  39.]
  [ 13.  10.  48.]
  [ 14.  11.  57.]
  [ 15.  12.  66.]]

 [[ 16.  13.  75.]
  [ 17.  14.  84.]
  [ 18.  15.  93.]
  [ 19.  16. 102.]]

 [[ 20.  17. 111.]
  [ 21.  18. 120.]
  [ 22.  19. 129.]
  [ 23.  20. 138.]]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM