![](/img/trans.png)
[英]gaussian fit with scipy.optimize.curve_fit in python with wrong results
[英]scipy.optimize.curve_fit unable to fit shifted skewed gaussian curve
我正在尝试使用scipy的curve_fit函数拟合偏斜和偏移的高斯曲线,但是我发现在某些条件下拟合度很差,通常使我接近或恰好是一条直线。
以下代码来自curve_fit
文档。 提供的代码是用于测试目的的任意数据集,但显示得很好。
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import math as math
import scipy.special as sp
#def func(x, a, b, c):
# return a*np.exp(-b*x) + c
def func(x, sigmag, mu, alpha, c,a):
#normal distribution
normpdf = (1/(sigmag*np.sqrt(2*math.pi)))*np.exp(-(np.power((x-mu),2)/(2*np.power(sigmag,2))))
normcdf = (0.5*(1+sp.erf((alpha*((x-mu)/sigmag))/(np.sqrt(2)))))
return 2*a*normpdf*normcdf + c
x = np.linspace(0,100,100)
y = func(x, 10,30, 0,0,1)
yn = y + 0.001*np.random.normal(size=len(x))
popt, pcov = curve_fit(func, x, yn,) #p0=(9,35,0,9,1))
y_fit= func(x,popt[0],popt[1],popt[2],popt[3],popt[4])
plt.plot(x,yn)
plt.plot(x,y_fit)
当我将高斯距离零(使用mu
)移得太远时,该问题似乎弹出。 我尝试给出初始值,甚至那些与原始函数相同的值,但它不能解决问题。 对于mu=10
的值, curve_fit
可以完美地工作,但是如果我使用mu>=30
它将不再适合数据。
提供最小化的起点通常会产生奇迹。 尝试给最小化器一些关于最大值位置和曲线宽度的信息:
popt, pcov = curve_fit(func, x, yn, p0=(1./np.std(yn), np.argmax(yn) ,0,0,1))
用sigma=10
和mu=50
更改代码中的这一行会产生
您可以通过随机初始猜测多次调用curve_fit
,并选择误差最小的参数。
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import math as math
import scipy.special as sp
def func(x, sigmag, mu, alpha, c,a):
#normal distribution
normpdf = (1/(sigmag*np.sqrt(2*math.pi)))*np.exp(-(np.power((x-mu),2)/(2*np.power(sigmag,2))))
normcdf = (0.5*(1+sp.erf((alpha*((x-mu)/sigmag))/(np.sqrt(2)))))
return 2*a*normpdf*normcdf + c
x = np.linspace(0,100,100)
y = func(x, 10,30, 0,0,1)
yn = y + 0.001*np.random.normal(size=len(x))
results = []
for i in xrange(50):
p = np.random.randn(5)*10
try:
popt, pcov = curve_fit(func, x, yn, p)
except:
pass
err = np.sum(np.abs(func(x, *popt) - yn))
results.append((err, popt))
if err < 0.1:
break
err, popt = min(results, key=lambda x:x[0])
y_fit= func(x, *popt)
plt.plot(x,yn)
plt.plot(x,y_fit)
print len(results)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.