[英]How to use scipy least_squares to get the estimation of unknow variables
我是使用 scipy.optimize 的新手。 我有以下函数调用 func。 我将 x 和 y 值作为列表给出,需要得到 a、b 和 c 的估计值。 我可以使用curve_fit 来估计a、b 和c。 但是,我想探索使用 minimum_squares 的可能性。 当我运行以下代码时,出现以下错误。 如果有人能指出我正确的方向,那就太好了。
import numpy as np
from scipy.optimize import curve_fit
from scipy.optimize import least_squares
np.random.seed(0)
x = np.random.randint(0, 100, 100) # sample dataset for independent variables
y = np.random.randint(0, 100, 100) # sample dataset for dependent variables
def func(x,a,b,c):
return a*x**2 + b*x + c
def result(list_x, list_y):
popt = curve_fit(func, list_x, list_y)
sol = least_squares(result,x, args=(y,),method='lm',jac='2-point',max_nfev=2000)
TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be
safely coerced to any supported types according to the casting rule ''safe''
要使用least_squares
您需要一个残差函数而不是curve_fit
。 此外, least_squares
需要猜测您要拟合的参数(即 a、b、c)。 在你的情况下,如果你想使用least_squares
你可以写一些类似的东西(我只是用随机值来猜测)
import numpy as np
from scipy.optimize import least_squares
np.random.seed(0)
x = np.random.randint(0, 100, 100) # sample dataset for independent variables
y = np.random.randint(0, 100, 100) # sample dataset for dependent variables
def func(x,a,b,c):
return a*x**2 + b*x + c
def residual(p, x, y):
return y - func(x, *p)
guess = np.random.rand(3)
sol = least_squares(residual, guess, args=(x, y,),method='lm',jac='2-point',max_nfev=2000)
以下代码使用 least_squares() 例程进行优化。 与您的代码相比,最重要的变化是确保 func() 返回残差向量。 我还将解决方案与线性代数结果进行了比较以确保正确性。
import numpy as np
from scipy.optimize import curve_fit
from scipy.optimize import least_squares
np.random.seed(0)
x = np.random.randint(0, 100, 100) # sample dataset for independent variables
y = np.random.randint(0, 100, 100) # sample dataset for dependent variables
def func(theta, x, y):
# Return residual = fit-observed
return (theta[0]*x**2 + theta[1]*x + theta[2]) - y
# Initial parameter guess
theta0 = np.array([0.5, -0.1, 0.3])
# Compute solution providing initial guess theta0, x input, and y input
sol = least_squares(func, theta0, args=(x,y))
print(sol.x)
#------------------- OPTIONAL -------------------#
# Compare to linear algebra solution
temp = x.reshape((100,1))
X = np.hstack( (temp**2, temp, np.ones((100,1))) )
OLS = np.linalg.lstsq(X, y.reshape((100,1)), rcond=None)
print(OLS[0])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.