![](/img/trans.png)
[英]Non linear Least Squares: Reproducing Matlabs lsqnonlin with Scipy.optimize.least_squares using Levenberg-Marquardt
[英]Program to fit a hyperbola to linear data using least squares (Levenberg-Marquardt algorithm) not working as expected
我有一個1D數組數據,我試圖使用三個參數建模為雙曲線。 我正在嘗試使用scipy.optimize庫中的leastsq函數實現Levenberg Marquardt算法。 但是,我的程序陷入了一個迭代,其中一個數字除以零,我不明白為什么。
一些背景:1D陣列數據基本上是不同盒子尺寸的空隙值。 我已經從一些聲音文件生成了空隙數據,其中的上下文可以在這里找到。
在算法中,最小二乘函數需要三個輸入:
(a)對三個參數的初步猜測
(b)最小二乘問題的x坐標 - 在我的問題中基本上是從1到100的1D整數數組
(c)最小二乘問題的y坐標 - 這是存儲空隙度值的一維數組。 因此,空隙度值是x的函數,其中x從1到100變化。
雙曲線使用三個參數a,b和c建模
該代碼給出以下錯誤:
“OverflowError:無法將浮點無窮大轉換為整數”
編碼:
#import
from scipy import *
from scipy.optimize import leastsq
import matplotlib.pylab as plt
import numpy as np
import codecs, json
from math import *
# Define your function to calculate the residuals.
#The fitting function holds your parameter values.
def residuals(p, y, x):
err = y-pval(x,p)
return err
def pval(x, p):
z = x
for i in range(100):
print(x)
print(x[i]**p[1])
z[i] = p[0]/(x[i]**p[1])+p[2]
return z
#read in your data
obj_text = codecs.open('textfiles\CC1.json', 'r', encoding='utf-8').read()
b_new = json.loads(obj_text)
data = np.array(b_new)
x = np.arange(1,101)
y = data[1:101]
#guess at initial parameters
A1_0=1.0
A2_0=1.0
A3_0=0.5
#leastsq package calls the Levenberg-Marquardt algorithm
pname = (['A1','A2','A3'])
p0 = array([A1_0 , A2_0, A3_0])
plsq = leastsq(residuals, p0, args=(y, x), maxfev=2000)
# Now, plot your data
plt.plot(x,y,'xo',x,pval(x,plsq[0]),'x')
title('Least-squares fit to data')
xlabel('x')
ylabel('y')
legend(['Data', 'Fit'],loc=4)
# Your best-fit paramters are kept within plsq[0].
print(plsq[0])
根據該錯誤,x的值在迭代中的某個點處變為0,並且第一個參數a最終被除以零,這給出了錯誤。
要進行故障排除,我在執行代碼時打印了值x [i] ^ b和數組x, 您可以在此處查看值 。 我看到數組x正在被修改,這不應該發生。 x應該保持從1到100的1D自然數組,並且不會在迭代中被修改。 我無法確定修改數組x的代碼究竟在哪里。
我希望數組x保持不變,並且代碼打印參數a,b和c的最后三個值。
編輯:我對我的代碼進行了一些更改,之后它成功運行。 以下是任何人都會感興趣的編輯:
沒有將z定義為z = x,而是將其定義為z = np.arange(1,101)。 結果是數組x不再改變,這是預期的。
將數組x和y的數據類型更改為float
x = np.array(x, dtype=np.float64)
plt.plot(x,y,'red',x,pval(x,plsq[0]),'blue')
plt.show()
不是你的問題的直接答案,但由於你正在使用取冪( **
),我強烈建議你事先將你的所有數字轉換為Decimal
,以避免浮點運算對大值的固有精度損失。
例如:
import decimal
decimal.getcontext().prec = 100
A1_0=Decimal("1.0")
A2_0=Decimal("1.0")
A3_0=Decimal("0.5")
x = [Decimal(f) for f in x]
y = [Decimal(f) for f in y]
也許你的零點會“變成”一個接近於零的小值......
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.