简体   繁体   中英

SciPy Curve_fit() doesn't fit curve

I made a random graph, and tried to use SciPy curve_fit to fit the best curve to the plot, but it fails.

First, I generated a random exponential decay graph, where A, w, T2 are randomly generated using numpy:

def expDec(t, A, w, T2):
    return A * np.cos(w * t) * (2.718**(-t / T2))

Now I have SciPy guess the best fit curve:

t = x['Input'].values
hr = x['Output'].values
c, cov = curve_fit(bpm, t, hr)

Then I plot the curve

for i in range(n):
    y[i] = bpm(x['Input'][i], c[0], c[1], c[2])
plt.plot(x['Input'], x['Output'])
plt.plot(x['Input'], y)

That's it. Here's how bad the fit looks:

.

If anyone can help, that would be great.

MWE (Also available interactively here )

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit

inputs = []
outputs = []

# THIS GIVES THE DOMAIN
dom = np.linspace(-5, 5, 100)

# FUNCTION & PARAMETERS (RANDOMLY SELECTED)
A = np.random.uniform(3, 6)
w = np.random.uniform(3, 6)
T2 = np.random.uniform(3, 6)
y = A * np.cos(w * dom) * (2.718**(-dom / T2))

# DEFINES EXPONENTIAL DECAY FUNCTION
def expDec(t, A, w, T2):
    return A * np.cos(w * t) * (2.718**(-t / T2))

# SETS UP FIGURE FOR PLOTTING
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

# PLOTS THE FUNCTION
plt.plot(dom, y, 'r')

# SHOW THE PLOT
plt.show()

for i in range(-9, 10): 
    inputs.append(i)
    outputs.append(expDec(i, A, w, T2))
    
# PUT IT DIRECTLY IN A PANDAS DATAFRAME
points = {'Input': inputs, 'Output': outputs}

x = pd.DataFrame(points, columns = ['Input', 'Output'])
   
# FUNCTION WHOSE PARAMETERS PROGRAM SHOULD BE GUESSING
def bpm(t, A, w, T2):
    return A * np.cos(w * t) * (2.718**(-t / T2))

# INPUT & OUTPUTS
t = x['Input'].values
hr = x['Output'].values

# USE SCIPY CURVE FIT TO USE NONLINEAR LEAST SQUARES TO FIND BEST PARAMETERS. TRY 1000 TIMES BEFORE STOPPING.
constants = curve_fit(bpm, t, hr, maxfev=1000)

# GET CONSTANTS FROM CURVE_FIT
A_fit = constants[0][0]
w_fit = constants[0][1]
T2_fit = constants[0][2]

# CREATE ARRAY TO HOLD FITTED OUTPUT
fit = []

# APPEND OUTPUT TO FIT=[] ARRAY
for i in range(-9,10):
    fit.append(bpm(i, A_fit, w_fit, T2_fit))
    
# PLOTS BEST PARAMETERS
plt.plot(x['Input'], x['Output'])
plt.plot(x['Input'], fit, "ro-")

As a first step, I would like to rewrite your MCVE to use vectorized operations and only a single instance of the function computation. This will reduce everything to just a couple of lines. I recommend using a seed for repeatability when you do your testing as well:

def exp_dec(t, A, w, T2):
    return A * np.cos(w * t) * np.exp(-t / T2)

np.random.seed(42)
A, w, T2 = np.random.uniform(3, 6, size=3)
dom = np.linspace(-9, 9, 1000)

t = np.arange(-9., 10.)
hr = exp_dec(t, A, w, T2)

fit, _ = curve_fit(exp_dec, t, hr)

fig, ax = plt.subplots()
ax.plot(dom, exp_dec(dom, A, w, T2), 'g', label='target')
ax.scatter(t, hr, c='r', label='samples')
ax.plot(dom, exp_dec(dom, *fit), 'b', label='fit')
ax.plot(dom, exp_dec(dom, 1, 1, 1), 'k:', label='start')
ax.legend()

To explain the last plotted item, take a look at the docs for curve_fit . Notice that there is a parameter p0 , which defaults to all ones if you do not supply it. That is the initial guess from which your fit starts to guess values.

在此处输入图片说明

Looking at this picture, you can pretty much see what the problem is. The starting guess has a much lower frequency than your data. Because the sampling frequency is so close to the oscillation frequency, the fit hits a local minimum before it is able to increase the frequency sufficiently to get the right function. You can fix this in a couple of different ways.

One way is to give curve_fit a better initial guess. If you know bounds on the amplitude, frequency and decay rate, use them. The amplitude will generally be a straightforward linear fit. The toughest one is usually the frequency, and as you can see here, it is better to over-estimate it. But if you over-estimate it too much, you might end up with a harmonic of the original data.

Here are a couple of sample fits that show different local minima in the optimization. The second one shows a harmonic case from over-estimating the oscillation frequency:

在此处输入图片说明

在此处输入图片说明

A decent set of starting parameters is the upper bound of your random range:

fit, _ = curve_fit(exp_dec, t, hr, p0=[6, 6, 6])

在此处输入图片说明

The green curve matches the blue so closely, you can not see it:

>>> A, w, T2
(4.123620356542087, 5.852142919229749, 5.195981825434215)
>>> tuple(fit)
(4.123620356542086, 5.852142919229749, 5.195981825434215)

Another way to to fix the problem is to sample the data more frequently. More data will generally mean a lower chance of hitting a false local minimum in the optimization. However, when dealing with sinusoidal functions, this does not always help because of how the matching works. Here is an example with 10x the number of samples (a fit with just 2x and the default guess fails entirely):

...
t = np.arange(-9., 10., 0.1)
...

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM