[英]Bounds in scipy curve_fit
我正在嘗試擬合兩分量高斯擬合:
mu0 = sum(velo_peak * spec_peak) / sum(spec_peak)
sigma = np.sqrt(sum(spec_peak * (velo_peak - mu0)**2) / sum(spec_peak))
def Gauss(velo_peak, a, mu0, sigma):
res = a * np.exp(-(velo_peak - mu0)**2 / (2 * sigma**2))
return res
p0 = [max(spec_peak) - RMS, mu0, sigma] # a = max(spec_peak)
popt,pcov = curve_fit(Gauss, velo_peak, spec_peak, p0,maxfev=10000, bounds=((0, 0, +np.inf, +np.inf), (0, 0, +np.inf, +np.inf)))
#____________________two component gaussian fit_______________________#
def double_gaussian(velo_peak,a1, mu1, sigma1, a2, mu2, sigma2):
res_two = a1 * np.exp(-(velo_peak - mu1)**2/(2 * sigma1**2)) \
+ a2 * np.exp(-(velo_peak - mu2)**2/(2 * sigma2**2))
return res_two
##_____________________Initial guess values__________________________##
sigma1 = 0.7 * sigma
sigma2 = 0.7 * sigma
mu1 = mu0 + sigma
mu2 = mu0 - sigma
a1 = 3
a2 = 1
guess = [a1, mu1, sigma1, a2, mu2, sigma2]
popt_2,pcov_2 = curve_fit(double_gaussian, velo_peak, spec_peak, guess,maxfev=10000, bounds=((0, 0, +np.inf, +np.inf), (0, 0, +np.inf, +np.inf)))
但是我得到了一個我想避免的負面部分,但我不知道如何正確實現邊界,因為我不太了解文檔。 我收到以下錯誤:
ValueError: Inconsistent shapes between bounds and `x0`.
誰能指導我如何正確使用邊界?
它期待"2-tuple of array_like, optional"
,因此看起來像:
((lower_bound0, lower_bound1, ..., lower_boundn), (upper_bound0, upper_bound1, ..., upper_boundn))
在我看來,如果你想避免負值,那么在雙高斯中你想將a1
和a2
約束為正值。
根據您的guess
:
[a1, mu1, sigma1, a2, mu2, sigma2]
那將是:
... bounds=[(0, -np.inf, -np.inf, 0, -np.inf, -np.inf), (np.inf, np.inf, np.inf, np.inf, np.inf, np.inf)], ...
演示:
import matplotlib.pyplot as plt
def double_gaussian(velo_peak,a1, mu1, sigma1, a2, mu2, sigma2):
res_two = a1 * np.exp(-(velo_peak - mu1)**2/(2 * sigma1**2)) \
+ a2 * np.exp(-(velo_peak - mu2)**2/(2 * sigma2**2))
return res_two
x = np.linspace(0, 10, 1000)
y = double_gaussian(x, 1, 3, 1, 1, 7, 0.5) + 0.4*(np.random.random(x.shape) - 0.5)
popt, _ = curve_fit(double_gaussian, x, y, bounds=[(0, -np.inf, -np.inf, 0, -np.inf, -np.inf), (np.inf, np.inf, np.inf, np.inf, np.inf, np.inf)])
plt.plot(x, y)
plt.plot(x, double_gaussian(x, *popt))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.