简体   繁体   English

Python:使用通用 scipy.optimize.curve_fit function

[英]Python: Use general scipy.optimize.curve_fit function

I want to curve fit some data in python.我想曲线拟合 python 中的一些数据。 My program looks like this:我的程序如下所示:

from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

def lin(x, a, b,c):
    return a*x+b

def exp(x, a, b, c):
    return a*np.exp(b*x)+c

def ln(x, a, b, c):
    return a*np.log(b+x)+c

x_dummy = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
y_dummy = np.array([9.2, 9.9, 10.0, 11.2, 10.2, 12.6, 10.0, 11.6, 12.2])



popt, _ = curve_fit(lin, x_dummy[:-2], y_dummy[:-1])

y_approx = lin(x_dummy, popt[0], popt[1], popt[2])

print(y_approx[-1])



print(popt)
print(mean_squared_error(y_dummy[:-1], y_approx[:-2]))


plt.plot(x_dummy[:-1], y_dummy, color='blue')
plt.plot(x_dummy, y_approx, color='green')
plt.show()

My aim now is a general function call it fn which can have some parameter eg as string in the sense, that the call我现在的目标是一个通用的 function 调用它 fn 它可以有一些参数,例如在某种意义上作为字符串,调用

popt, _ = curve_fit(fn('lin' or 'exp' or 'ln'), x_dummy[:-2], y_dummy[:-1])

means the same as意思是一样的

popt, _ = curve_fit(lin or exp or ln, x_dummy[:-2], y_dummy[:-1])

Background: I want to generate some array = ['lin', 'exp', 'ln'] and loop through all three kinds of possible curve fits and calculate the minimum of the reproduced squared errors.背景:我想生成一些数组 = ['lin', 'exp', 'ln'] 并遍历所有三种可能的曲线拟合并计算再现平方误差的最小值。

found some method, but maybe its an easier way:找到了一些方法,但也许它是一种更简单的方法:

from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

class FunctionCollector():
    def __init__(self):
        self.name = 'lin'

    def setFunc(self, name):
        self.name = name

    def lin(self, x, a, b, c):
        return a*x+b

    def exp(self, x, a, b, c):
        return a*np.exp(b*x)+c

    def ln(self, x, a, b, c):
        return a*np.log(b+x)+c

    def fn(self, x, a, b, c):
        if self.name == 'lin':
            return self.lin(x, a,b,c)
        elif self.name == 'exp':
            return self.exp(x,a,b,c)
        elif self.name == 'ln':
            return self.ln(x,a,b,c)
        return 0



def l(x,a,b,c):
    return a * x + b
x_dummy = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
y_dummy = np.array([9.2, 9.9, 10.0, 11.2, 10.2, 12.6, 10.0, 11.6, 12.2])


#noise = 5*np.random.normal(size=y_dummy.size)
#y_dummy = y_dummy + noise

f = FunctionCollector()

popt, _ = curve_fit(f.fn, x_dummy[:-2], y_dummy[:-1])
y_approx = f.fn(x_dummy, popt[0], popt[1], popt[2])

print(y_approx[-1])



print(popt)
print(mean_squared_error(y_dummy[:-1], y_approx[:-2]))


plt.plot(x_dummy[:-1], y_dummy, color='blue')
plt.plot(x_dummy, y_approx, color='green')
plt.show()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM