简体   繁体   中英

How to change functions according to valued in numpy array

I am trying to make a surface plot of a function that looks like this:

def model(param,x_1,x_2,x_3,x_4):
    est=param[0]+param[1]*(x_1+x_2*x_3+x_2**2*x_4)
    return est

The point is that according to the value of x_2, x_3=1 for x_2>=0 and x_4=1 for x_2<0 (else:0).

When I tried to make a surface plot, I was confused how to make the mesh-grid as there are 2 more variables in addition to x_1 and x_2.

To compute the z axis, I tried to modify function as:

def function (param,x_1,x_2):
    if x_2>0:
      est=param[0]+param[1]*(x_1+x_2)
    else:
      est=param[0]+param[1]*(x_1+x_2**2)
    return est)

However, it says the truth value is ambiguous. I understood it as python sees it whether all values in x_2 >=0 or not.

I also tried to use np.sign(), but it doesn't act in a way I want this case.

Would there be any ways to change the function according to the value of elenemt in the array and/or solve this without manually computing z axis using for loop?

If you want to check all values greater than 0, use all :

def function (param,x_1,x_2):
    if all(x_2>0):
      est=param[0]+param[1]*(x_1+x_2)
    else:
      est=param[0]+param[1]*(x_1+x_2**2)
    return est

but if you want to apply the test on each value, use np.where :

def function (param,x_1,x_2):
    return np.where(x_2 > 0,
                    param[0]+param[1]*(x_1+x_2),
                    param[0]+param[1]*(x_1+x_2**2))

I think you need numpy.where :

def function (param,x_1,x_2):
    return np.where(x_2>0, 
                    param[0]+param[1]*(x_1+x_2), 
                    param[0]+param[1]*(x_1+x_2**2))

How it working:

param = [10,8]
x_1 = np.array([1,2,3])
x_2 = np.array([0,4,10])

If True s in mask are used values from param[0]+param[1]*(x_1+x_2) else from param[0]+param[1]*(x_1+x_2**2) :

print (x_2>0)
[False  True  True]

print (param[0]+param[1]*(x_1+x_2))
[ 18  58 114]

print (param[0]+param[1]*(x_1+x_2**2))
[ 18 154 834]

print (function(param,x_1,x_2))
[ 18  58 114]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM