简体   繁体   English

scipy curve_fit 无法拟合高帽函数

[英]scipy curve_fit cannot fit a tophat function

I am trying to fit a top hat function to some data, ie.我正在尝试将顶帽函数拟合到某些数据,即。 f(x) is constant for the entire real line, except for one segment of finite length which is equal to another constant. f(x) 对于整条实线是常数,除了一个有限长度的线段等于另一个常数。 My parameters are the two constants of the tophat function, the midpoint, and the width and I'm trying to use scipy.optimize.curve_fit to get all of these.我的参数是 tophat 函数的两个常量、中点和宽度,我正在尝试使用 scipy.optimize.curve_fit 来获得所有这些。 Unfortunately, curve_fit is having trouble obtaining the width of the hat.不幸的是,curve_fit 无法获得帽子的宽度。 No matter what I do, it refuses to test any value of the width other than the one I start with, and fits the rest of the data very badly.无论我做什么,它都拒绝测试除我开始使用的宽度值之外的任何宽度值,并且非常不适合其余数据。 The following code snippet illustrates the problem:下面的代码片段说明了这个问题:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def tophat(x, base_level, hat_level, hat_mid, hat_width):
    ret=[]
    for xx in x:
        if hat_mid-hat_width/2. < xx < hat_mid+hat_width/2.:
            ret.append(hat_level)
        else:
            ret.append(base_level)
    return np.array(ret)

x = np.arange(-10., 10., 0.01)
y = tophat(x, 1.0, 5.0, 0.0, 1.0)+np.random.rand(len(x))*0.2-0.1

guesses = [ [1.0, 5.0, 0.0, 1.0],
            [1.0, 5.0, 0.0, 0.1],
            [1.0, 5.0, 0.0, 2.0] ]

plt.plot(x,y)

for guess in guesses:
    popt, pcov = curve_fit( tophat, x, y, p0=guess )
    print popt
    plt.plot( x, tophat(x, popt[0], popt[1], popt[2], popt[3]) )

plt.show()

Why is curve_fit so extremely terrible at getting this right, and how can I fix it?为什么curve_fit 在正确处理这个问题上非常糟糕,我该如何解决?

First, the definition of tophat could use numpy.where instead of a loop:首先, tophat的定义可以使用numpy.where而不是循环:

def tophat(x, base_level, hat_level, hat_mid, hat_width):
    return np.where((hat_mid-hat_width/2. < x) & (x < hat_mid+hat_width/2.), hat_level, base_level)

Second, the tricky discontinuous objective function resists the optimization algorithms that curve_fit calls.其次,棘手的不连续目标函数抵抗了curve_fit调用的优化算法。 The Nelder-Mead method is usually preferable for rough functions, but it looks like curve_fit cannot use it. Nelder-Mead 方法通常更适合用于粗糙函数,但看起来curve_fit不能使用它。 So I set up an objective function (just the sum of absolute values of deviations) and minimize that:所以我设置了一个目标函数(只是偏差的绝对值的总和)并将其最小化:

def objective(params, x, y):
    return np.sum(np.abs(tophat(x, *params) - y))

plt.plot(x,y)

for guess in guesses:
    res = minimize(objective, guess, args=(x, y), method='Nelder-Mead')
    print(res.x)
    plt.plot(x, tophat(x, *(res.x)))

The results are better, in that starting with a too-wide hat of width 2 makes it shrink down to the correct size (see the last of three guesses).结果更好,因为从宽度为 2 的太宽帽子开始使其缩小到正确的大小(请参阅三个猜测中的最后一个)。

[9.96041297e-01 5.00035502e+00 2.39462103e-04 9.99759984e-01]
[ 1.00115808e+00  4.94088711e+00 -2.21340843e-05  1.04924153e-01]
[9.95947108e-01 4.99871040e+00 1.26575116e-03 9.97908018e-01]

Unfortunately, when the starting guess is a too-narrow hat, the optimizer is still stuck.不幸的是,当起始猜测太窄时,优化器仍然卡住了。

拟合函数

You can try other optimization method / objective function combinations but I haven't found one that makes the hat reliably expand.您可以尝试其他优化方法/目标函数组合,但我还没有找到使帽子可靠扩展的方法。

One thing to try is not to use the parameters that are too close to the true levels;要尝试的一件事是不要使用接近真实水平的参数; this sometimes might hurt.这有时可能会造成伤害。 With

guesses = [ [1.0, 1.0, 0.0, 1.0],
            [1.0, 1.0, 0.0, 0.1],
            [1.0, 1.0, 0.0, 2.0] ]

I once managed to get我曾经设法得到

[ 1.00131181  4.99156649 -0.01109271  0.96822019]
[ 1.00137925  4.97879423 -0.05091561  1.096166  ]
[ 1.00130568  4.98679988 -0.01133717  0.99339777]

which is correct for all three widths.这对于所有三个宽度都是正确的。 However, this was only on some of several tries (there is some randomness in the initialization of the optimizing procedure).然而,这只是几次尝试中的一些(优化过程的初始化存在一些随机性)。 Some other attempts with the same initial points failed;其他一些具有相同初始点的尝试失败了; the process is not robust enough.该过程不够稳健。

By its nature, non-linear least-squares fitting as with curve_fit() works with real, floating-point numbers and is not good at dealing with discrete variables.就其性质而言,与curve_fit()一样的非线性最小二乘curve_fit()合适用于实数、浮点数,并且不擅长处理离散变量。 In the fit process, small changes (like, at the 1e-7 level) are made to each variable, and the effect of that small change on the fit result is used to decide how to change that variable to improve the fit.在拟合过程中,对每个变量进行小的更改(例如,在 1e-7 级别),并且使用该小的更改对拟合结果的影响来决定如何更改该变量以改进拟合。 With discretely sampled data, small changes to your hat_mid and/or hat_width could easily be smaller than the spacing of data points and so have no effect at all on the fit.对于离散采样的数据,您的hat_mid和/或hat_width微小变化很容易小于数据点的间距,因此对拟合完全没有影响。 That is why curve_fit is "extremely terrible" at this problem.这就是为什么curve_fit在这个问题上“非常糟糕”。

You may find that giving a finite width (that is, comparable to the step size of your discrete data) to the steps helps to better find where the edges of you hat are.您可能会发现,为步骤提供有限宽度(即,与离散数据的步长相当)有助于更好地找到您帽子的边缘所在的位置。

You could also try fitting to f = A0 (-Erf[A4*(u - A1) - A3] + Erf[A4*(u + A1) - A3]) + A2您也可以尝试拟合 f = A0 (-Erf[A4*(u - A1) - A3] + Erf[A4*(u + A1) - A3]) + A2

Here A0 will be proportional to step height, A1 will be proportional to step width, A2 will be a vertical offset of the baseline, A3 should be the horizontal distance of the middle of the step from 0 and the slope of the step will be proportional to A4.这里A0将与步高成正比,A1将与步宽成正比,A2将是基线的垂直偏移量,A3应该是步中间距0的水平距离,步的斜率将成正比到 A4。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM