简体   繁体   中英

Scipy curve_fit fails for data with sine function

I'm trying to fit a curve through some data. The function I'm trying to fit is as follows:

def f(x,a,b,c):
    return a +b*x**c

When using scipy.optimize.curve_fit I do not get any results: It returns the (default) initial parameters:

(array([ 1.,  1.,  1.]),
 array([[ inf,  inf,  inf],
        [ inf,  inf,  inf],
        [ inf,  inf,  inf]]))

I've tried reproducing the data, and found that a sine function was causing the problem (the data contains daily variation):

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

xdata=np.random.rand(1000) + 0.002 *np.sin(np.arange(1000)/(1.5*np.pi))
ydata=0.1 + 23.4*xdata**0.56 + np.random.normal(0,2,1000)

def f(x,a,b,c):
    return a +b*x**c

fit=curve_fit(f,xdata,ydata)

fig,ax=plt.subplots(1,1)
ax.plot(xdata,ydata,'k.',markersize=3)
ax.plot(np.arange(0,1,.01), f(np.arange(0,1,.01),*fit[0]))
fig.show()

I would obviously expect curve_fit to return something close to [0.1, 23.4, .56].

Note that the sine function does not really seem to affect the data ('xdata') in value, as the first term of xdata ranges between 0 and 1 and I'm adding something between -0.002 and +0.002, but it does cause the fitting procedure to fail. I found the value 0.002 to be close to the 'critical' value for failure; if it is smaller the procedure is less likely to fail, and vice versa. At 0.002 the procedure fails about as often as not.

I have tried solving this problem by shuffling the 'xdata' and 'ydata' simultaneously, to no effect. I thought (for no particular reason) that perhaps removing the autocorrelation of the data would solve the problem.

So my question is: how can I fix/bypass this problem? I can change the sine contribution in the synthetic data in the snippet above, but for my real data I obviously cannot.

You can eliminate the NaNs generated by negative x-values within in the model function:

def f(x,a,b,c):
    y = a +b*x**c
    y[np.isnan(y)] = 0.0
    return y

Replacing all NaNs by 0 might not be the best choice. You could try neighbour values or do some kind of extrapolation.

If you feed in generated test data you have to make sure that there are no NaNs in there either. So directly after data generation put something like:

if xdata.min() < 0:
    print 'expecting NaNs'
    ydata[np.isnan(ydata)] = 0.0

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM