繁体   English   中英

曲线拟合 scipy

[英]curve fitting with scipy

我正在尝试使用来自 SciPy 的 curve_fit 拟合曲线。但它没有按预期工作,我不知道为什么。 这是我的代码:


xdata = np.asarray(std_ex_90degree[5050:5150,0])
ydata = np.asarray(std_ex_90degree[5050:5150,1])
print(xdata,ydata)
  

def Gauss(x, A, B):
    y = A*np.exp(-1*B*x**2)
    return y

popt, covariance = curve_fit(Gauss, xdata, ydata)

fit_A, fit_B = popt
  
fit_y = Gauss(xdata, fit_A, fit_B)

plt.scatter(xdata, ydata, label='data',s=5)
plt.plot(xdata, fit_y, '-', label='fit')
plt.legend()

这是数据图和拟合

如您所见,Gaussian Fit 不起作用,我只得到一条直线。

这是数据:

[2834.486 2834.968 2835.45  2835.932 2836.414 2836.896 2837.378 2837.861
 2838.343 2838.825 2839.307 2839.789 2840.271 2840.753 2841.235 2841.718
 2842.2   2842.682 2843.164 2843.646 2844.128 2844.61  2845.093 2845.575
 2846.057 2846.539 2847.021 2847.503 2847.985 2848.468 2848.95  2849.432
 2849.914 2850.396 2850.878 2851.36  2851.843 2852.325 2852.807 2853.289
 2853.771 2854.253 2854.735 2855.218 2855.699 2856.182 2856.664 2857.146
 2857.628 2858.11  2858.592 2859.074 2859.557 2860.039 2860.521 2861.003
 2861.485 2861.967 2862.449 2862.932 2863.414 2863.896 2864.378 2864.86
 2865.342 2865.824 2866.307 2866.789 2867.271 2867.753 2868.235 2868.717
 2869.199 2869.682 2870.164 2870.646 2871.128 2871.61  2872.092 2872.574
 2873.056 2873.539 2874.021 2874.503 2874.985 2875.467 2875.949 2876.431
 2876.914 2877.396 2877.878 2878.36  2878.842 2879.324 2879.806 2880.289
 2880.771 2881.253 2881.735 2882.217] 
[0.5027119 0.5155925 0.5296563 0.5450429 0.5619112 0.5804411 0.6008373
 0.6233361 0.6482099 0.67577   0.7063611 0.7403504 0.7781109 0.8200049
 0.8663718 0.9175249 0.9737514 1.035319  1.102472  1.175419  1.254304
 1.339163  1.429889  1.526202  1.627649  1.733603  1.84322   1.955248
 2.067605  2.176702  2.276757  2.359875  2.417753  2.445059  2.441798
 2.41245   2.362954  2.298523  2.223243  2.14052   2.05336   1.964326
 1.87539   1.787885  1.702644  1.620191  1.540921  1.465193  1.393333
 1.325607  1.262171  1.203057  1.148185  1.097403  1.050529  1.007382
 0.9678    0.9316369 0.8987471 0.8689752 0.8421496 0.8180863 0.7965991
 0.7775094 0.76065   0.7458642 0.732995  0.7218768 0.7123291 0.7041584
 0.6971676 0.6911709 0.6860058 0.6815417 0.6776828 0.674363  0.6715436
 0.6692089 0.6673671 0.6660498 0.6653103 0.6652156 0.6658351 0.6672268
 0.6694273 0.6724483 0.676279  0.6808962 0.686272  0.6923797 0.699192
 0.7066767 0.7147906 0.7234787 0.7326793 0.7423348 0.7524015 0.7628553
 0.7736901 0.7849081]

不合适的 model。对 model 进行微调会产生粗略的拟合:

import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit

xdata = np.array((
    2834.486, 2834.968, 2835.45 , 2835.932, 2836.414, 2836.896, 2837.378, 2837.861,
    2838.343, 2838.825, 2839.307, 2839.789, 2840.271, 2840.753, 2841.235, 2841.718,
    2842.2  , 2842.682, 2843.164, 2843.646, 2844.128, 2844.61 , 2845.093, 2845.575,
    2846.057, 2846.539, 2847.021, 2847.503, 2847.985, 2848.468, 2848.95 , 2849.432,
    2849.914, 2850.396, 2850.878, 2851.36 , 2851.843, 2852.325, 2852.807, 2853.289,
    2853.771, 2854.253, 2854.735, 2855.218, 2855.699, 2856.182, 2856.664, 2857.146,
    2857.628, 2858.11 , 2858.592, 2859.074, 2859.557, 2860.039, 2860.521, 2861.003,
    2861.485, 2861.967, 2862.449, 2862.932, 2863.414, 2863.896, 2864.378, 2864.86 ,
    2865.342, 2865.824, 2866.307, 2866.789, 2867.271, 2867.753, 2868.235, 2868.717,
    2869.199, 2869.682, 2870.164, 2870.646, 2871.128, 2871.61 , 2872.092, 2872.574,
    2873.056, 2873.539, 2874.021, 2874.503, 2874.985, 2875.467, 2875.949, 2876.431,
    2876.914, 2877.396, 2877.878, 2878.36 , 2878.842, 2879.324, 2879.806, 2880.289,
    2880.771, 2881.253, 2881.735, 2882.217,
))
ydata = np.array((
    0.5027119, 0.5155925, 0.5296563, 0.5450429, 0.5619112, 0.5804411, 0.6008373,
    0.6233361, 0.6482099, 0.67577  , 0.7063611, 0.7403504, 0.7781109, 0.8200049,
    0.8663718, 0.9175249, 0.9737514, 1.035319 , 1.102472 , 1.175419 , 1.254304 ,
    1.339163 , 1.429889 , 1.526202 , 1.627649 , 1.733603 , 1.84322  , 1.955248 ,
    2.067605 , 2.176702 , 2.276757 , 2.359875 , 2.417753 , 2.445059 , 2.441798 ,
    2.41245  , 2.362954 , 2.298523 , 2.223243 , 2.14052  , 2.05336  , 1.964326 ,
    1.87539  , 1.787885 , 1.702644 , 1.620191 , 1.540921 , 1.465193 , 1.393333 ,
    1.325607 , 1.262171 , 1.203057 , 1.148185 , 1.097403 , 1.050529 , 1.007382 ,
    0.9678   , 0.9316369, 0.8987471, 0.8689752, 0.8421496, 0.8180863, 0.7965991,
    0.7775094, 0.76065  , 0.7458642, 0.732995 , 0.7218768, 0.7123291, 0.7041584,
    0.6971676, 0.6911709, 0.6860058, 0.6815417, 0.6776828, 0.674363 , 0.6715436,
    0.6692089, 0.6673671, 0.6660498, 0.6653103, 0.6652156, 0.6658351, 0.6672268,
    0.6694273, 0.6724483, 0.676279 , 0.6808962, 0.686272 , 0.6923797, 0.699192 ,
    0.7066767, 0.7147906, 0.7234787, 0.7326793, 0.7423348, 0.7524015, 0.7628553,
    0.7736901, 0.7849081,
))


def gauss(x: np.ndarray, *args: float) -> np.ndarray:
    a, b, c, d = args
    return a*np.exp(-b*(x - c)**2) + d


popt, _ = curve_fit(
    gauss, xdata, ydata,
    p0=(1.7, 0.02, 2851, 0.7),
    maxfev=100_000,
)
print(popt)
fit_y = gauss(xdata, *popt)

plt.scatter(xdata, ydata, label='data', s=5)
plt.plot(xdata, fit_y, '-', label='fit')
plt.legend()
plt.show()
[1.68927347e+00 2.10977276e-02 2.85117456e+03 6.81806648e-01]

为了做得更好,您的 model 需要进行更多更改。

合身

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM