![](/img/trans.png)
[英]Non linear Least Squares: Reproducing Matlabs lsqnonlin with Scipy.optimize.least_squares using Levenberg-Marquardt
[英]Program to fit a hyperbola to linear data using least squares (Levenberg-Marquardt algorithm) not working as expected
我有一个1D数组数据,我试图使用三个参数建模为双曲线。 我正在尝试使用scipy.optimize库中的leastsq函数实现Levenberg Marquardt算法。 但是,我的程序陷入了一个迭代,其中一个数字除以零,我不明白为什么。
一些背景:1D阵列数据基本上是不同盒子尺寸的空隙值。 我已经从一些声音文件生成了空隙数据,其中的上下文可以在这里找到。
在算法中,最小二乘函数需要三个输入:
(a)对三个参数的初步猜测
(b)最小二乘问题的x坐标 - 在我的问题中基本上是从1到100的1D整数数组
(c)最小二乘问题的y坐标 - 这是存储空隙度值的一维数组。 因此,空隙度值是x的函数,其中x从1到100变化。
双曲线使用三个参数a,b和c建模
该代码给出以下错误:
“OverflowError:无法将浮点无穷大转换为整数”
编码:
#import
from scipy import *
from scipy.optimize import leastsq
import matplotlib.pylab as plt
import numpy as np
import codecs, json
from math import *
# Define your function to calculate the residuals.
#The fitting function holds your parameter values.
def residuals(p, y, x):
err = y-pval(x,p)
return err
def pval(x, p):
z = x
for i in range(100):
print(x)
print(x[i]**p[1])
z[i] = p[0]/(x[i]**p[1])+p[2]
return z
#read in your data
obj_text = codecs.open('textfiles\CC1.json', 'r', encoding='utf-8').read()
b_new = json.loads(obj_text)
data = np.array(b_new)
x = np.arange(1,101)
y = data[1:101]
#guess at initial parameters
A1_0=1.0
A2_0=1.0
A3_0=0.5
#leastsq package calls the Levenberg-Marquardt algorithm
pname = (['A1','A2','A3'])
p0 = array([A1_0 , A2_0, A3_0])
plsq = leastsq(residuals, p0, args=(y, x), maxfev=2000)
# Now, plot your data
plt.plot(x,y,'xo',x,pval(x,plsq[0]),'x')
title('Least-squares fit to data')
xlabel('x')
ylabel('y')
legend(['Data', 'Fit'],loc=4)
# Your best-fit paramters are kept within plsq[0].
print(plsq[0])
根据该错误,x的值在迭代中的某个点处变为0,并且第一个参数a最终被除以零,这给出了错误。
要进行故障排除,我在执行代码时打印了值x [i] ^ b和数组x, 您可以在此处查看值 。 我看到数组x正在被修改,这不应该发生。 x应该保持从1到100的1D自然数组,并且不会在迭代中被修改。 我无法确定修改数组x的代码究竟在哪里。
我希望数组x保持不变,并且代码打印参数a,b和c的最后三个值。
编辑:我对我的代码进行了一些更改,之后它成功运行。 以下是任何人都会感兴趣的编辑:
没有将z定义为z = x,而是将其定义为z = np.arange(1,101)。 结果是数组x不再改变,这是预期的。
将数组x和y的数据类型更改为float
x = np.array(x, dtype=np.float64)
plt.plot(x,y,'red',x,pval(x,plsq[0]),'blue')
plt.show()
不是你的问题的直接答案,但由于你正在使用取幂( **
),我强烈建议你事先将你的所有数字转换为Decimal
,以避免浮点运算对大值的固有精度损失。
例如:
import decimal
decimal.getcontext().prec = 100
A1_0=Decimal("1.0")
A2_0=Decimal("1.0")
A3_0=Decimal("0.5")
x = [Decimal(f) for f in x]
y = [Decimal(f) for f in y]
也许你的零点会“变成”一个接近于零的小值......
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.