简体   繁体   中英

Applying a function to an array using Numpy when the function contains a condition

I am having a difficulty with applying a function to an array when the function contains a condition. I have an inefficient workaround and am looking for an efficient (fast) approach. In a simple example:

pts = np.linspace(0,1,11)
def fun(x, y):
    if x > y:
        return 0
    else:
        return 1

Now, if I run:

result = fun(pts, pts)

then I get the error

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

raised at the if x > y line. My inefficient workaround, which gives the correct result but is too slow is:

result = np.full([len(pts)]*2, np.nan)
for i in range(len(pts)):
    for j in range(len(pts)):
        result[i,j] = fun(pts[i], pts[j])

What is the best way to obtain this in a nicer (and more importantly, faster) way?

I am having a difficulty with applying a function to an array when the function contains a condition. I have an inefficient workaround and am looking for an efficient (fast) approach. In a simple example:

pts = np.linspace(0,1,11)
def fun(x, y):
    if x > y:
        return 0
    else:
        return 1

Now, if I run:

result = fun(pts, pts)

then I get the error

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

raised at the if x > y line. My inefficient workaround, which gives the correct result but is too slow is:

result = np.full([len(pts)]*2, np.nan)
for i in range(len(pts)):
    for j in range(len(pts)):
        result[i,j] = fun(pts[i], pts[j])

What is the best way to obtain this in a nicer (and more importantly, faster) way?

EDIT : using

def fun(x, y):
    if x > y:
        return 0
    else:
        return 1
x = np.array(range(10))
y = np.array(range(10))
xv,yv = np.meshgrid(x,y)
result = fun(xv, yv)  

still raises the same ValueError .

The error is quite explicit - suppose you have

x = np.array([1,2])
y = np.array([2,1])

such that

(x>y) == np.array([0,1])

what should be the result of your if np.array([0,1]) statement? is it true or false? numpy is telling you this is ambiguous. Using

(x>y).all()

or

(x>y).any()

is explicit, and thus numpy is offering you solutions - either any cell pair fulfills the condition, or all of them - both an unambiguous truth value. You have to define for yourself exactly what you meant by vector x is larger than vector y .

The numpy solution to operate on all pairs of x and y such that x[i]>y[j] is to use mesh grid to generate all pairs:

>>> import numpy as np
>>> x=np.array(range(10))
>>> y=np.array(range(10))
>>> xv,yv=np.meshgrid(x,y)
>>> xv[xv>yv]
array([1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8,
       9, 4, 5, 6, 7, 8, 9, 5, 6, 7, 8, 9, 6, 7, 8, 9, 7, 8, 9, 8, 9, 9])
>>> yv[xv>yv]
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
       2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8])

either send xv and yv to fun , or create the mesh in the function, depending on what makes more sense. This generates all pairs xi,yj such that xi>yj . If you want the actual indices just return xv>yv , where each cell ij corresponds x[i] and y[j] . In your case:

def fun(x, y):
    xv,yv=np.meshgrid(x,y)
    return xv>yv

will return a matrix where fun(x,y)[i][j] is True if x[i]>y[j] , or False otherwise. Alternatively

return  np.where(xv>yv)

will return a tuple of two arrays of pairs of the indices, such that

for i,j in fun(x,y):

will guarantee x[i]>y[j] as well.

In [253]: x = np.random.randint(0,10,5)
In [254]: y = np.random.randint(0,10,5)
In [255]: x
Out[255]: array([3, 2, 2, 2, 5])
In [256]: y
Out[256]: array([2, 6, 7, 6, 5])
In [257]: x>y
Out[257]: array([ True, False, False, False, False])
In [258]: np.where(x>y,0,1)
Out[258]: array([0, 1, 1, 1, 1])

For a cartesian comparison to these two 1d arrays, reshape one so it can use broadcasting :

In [259]: x[:,None]>y
Out[259]: 
array([[ True, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [ True, False, False, False, False]])
In [260]: np.where(x[:,None]>y,0,1)
Out[260]: 
array([[0, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [0, 1, 1, 1, 1]])

Your function, with the if only works for scalar inputs. If given arrays, the a>b produces a boolean array, which cannot be used in an if statement. Your iteration works because it passes scalar values. For some complex functions that's the best you can do ( np.vectorize can make the iteration simpler, but not faster).

My answer is to look at the array comparison, and derive the answer from that. In this case, the 3 argument where does a nice job of mapping the boolean array onto the desired 1/0. There are other ways of doing this mapping as well.

Your double loop requires an added layer of coding, the broadcasted None .

For a more complex example or if the arrays you are dealing with are a bit larger, or if you can write to a already preallocated array you could consider Numba .

Example

import numba as nb
import numpy as np

@nb.njit()
def fun(x, y):
  if x > y:
    return 0
  else:
    return 1

@nb.njit(parallel=False)
#@nb.njit(parallel=True)
def loop(x,y):
  result=np.empty((x.shape[0],y.shape[0]),dtype=np.int32)
  for i in nb.prange(x.shape[0]):
    for j in range(y.shape[0]):
      result[i,j] = fun(x[i], y[j])
  return result

@nb.njit(parallel=False)
def loop_preallocated(x,y,result):
  for i in nb.prange(x.shape[0]):
    for j in range(y.shape[0]):
      result[i,j] = fun(x[i], y[j])
  return result

Timings

x = np.array(range(1000))
y = np.array(range(1000))

#Compilation overhead of the first call is neglected

res=np.where(x[:,None]>y,0,1) -> 2.46ms
loop(single_threaded)         -> 1.23ms
loop(parallel)                -> 1.0ms
loop(single_threaded)*        -> 0.27ms
loop(parallel)*               -> 0.058ms

*Maybe influenced by cache. Test on your own examples.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM