简体   繁体   中英

Multiple arguments with function and scipy

Following code needs to be optimized and minimised with respect to x by using scipy optimizer. The issue is it works with single argument but when function is taken multiple values, it can't handle it.

This code works well.

from scipy.optimize import root
b = 1
def func(x):
    # result when x = 0, but result equation depends also on b value.
    result = x + b
    return  result

sol = root(func, 0.1)
print(sol.x, sol.fun)

But this is not working.....

b =[ 1, 2, 3, 4, 5]

def func(x, b):
    # result when x = 0, but result equation depends also on b value.
    result = x + b
    return  result

for B in b:
    sol = root(lambda x,B: func(x, B), 0.1)
    print(sol.x, sol.fun)

How can result obtain with iteration through b?

As @hpaulj mentioned, root accepts an args parameter that will be passed onto func . So, we can make the script more flexible:

from scipy.optimize import root

def func(x, *args):
    result = 0
    for i, a in enumerate(args): 
        result += a * x ** i
    return  result

coeff_list = [(6, 3), (-3, 2, 1), (-6, 1, 2)]


for coeffs in coeff_list:
    sol = root(func, [-4, 4][:len(coeffs)-1], args = coeffs)
    print(*coeffs, sol.x, sol.fun)

Output:

6 3 [-2.] [8.8817842e-16]
-3 2 1 [-3.  1.] [ 1.46966528e-09 -4.00870892e-10]
-6 1 2 [-2.   1.5] [-6.83897383e-14  4.97379915e-14]

Initial answer

I don't understand the need for your lambda function:

from scipy.optimize import root

def func(x):
    # result when x = 0, but result equation depends also on b value.
    result = x + b
    return  result

B =[ 1, 2, 3, 4, 5]

for b in B:
    sol = root(func, 0.1)
    print(b, sol.x, sol.fun)

Output:

1 [-1.] [0.]
2 [-2.] [0.]
3 [-3.] [0.]
4 [-4.] [0.]
5 [-5.] [0.]

I don't see in the scipy documentation any hint of how to pass parameters to func. But this approach also works for multiple parameters:

from scipy.optimize import root

def func(x):
    #depending on the parameters, this has 0, 1 or 2 solutions
    result = a * x ** 2 + b * x + c
    return  result

A = range(3)
B = [3, 2, 1]
C = [6, -3, -6]


for a, b, c in zip(A, B, C):
    sol = root(func, [-4, 4])
    print(a, b, c, sol.x, sol.fun)

Output:

0 3  6 [-2. -2.]  [ 8.8817842e-16   0.0000000e+00]
1 2 -3 [-3.  1.]  [ 1.46966528e-09 -4.00870892e-10]
2 1 -6 [-2.  1.5] [-6.83897383e-14  4.97379915e-14]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM