[英]Using scipy curve_fit to fit exponential function to data, but all y data is near 0 in plot
I've been using this tutorial ( https://swharden.com/blog/2020-09-24-python-exponential-fit/ ) to fit an exponential curve on my data (see pastebin for data: https://pastebin.com/DrEvJcRC ).我一直在使用本教程( https://swharden.com/blog/2020-09-24-python-exponential-fit/ )在我的数据上拟合指数曲线(有关数据,请参见 pastebin: https://pastebin .com/DrEvJcRC )。 I adapted the Input to use my own data and I changed the function from negative exponential
(m * np.exp(-t * x) + b)
to positive exponential (m * np.exp(t * x) + b)
.我调整了输入以使用我自己的数据,并将 function 从负指数
(m * np.exp(-t * x) + b)
更改为正指数(m * np.exp(t * x) + b)
。
This is my code (adapted from the tutorial):这是我的代码(改编自教程):
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy
# read data
df = pd.read_csv(r'C:/Users/HP/OneDrive/Bachelor/Segmentierung/Excel/1_new.csv', sep=";")
print(df)
x = df["x"]
y = df["y"]
# plot original
plt.plot(x, y, '.')
plt.title("Original Data")
plt.show()
def monoexp(x, m, t, b):
return m * np.exp(t * x) + b
# perform the fit
# start with values near those we expect
params, cv = scipy.optimize.curve_fit(monoexp, x, y)
m, t, b = params
sampleRate = 20_000 # Hz
tauSec = (1 / t) / sampleRate
# determine quality of the fit
squaredDiffs = np.square(y - monoexp(x, m, t, b))
squaredDiffsFromMean = np.square(y - np.mean(y))
rSquared = 1 - np.sum(squaredDiffs) / np.sum(squaredDiffsFromMean)
print(f"R² = {rSquared}")
# plot the results
plt.plot(x, y, '.', label="data")
plt.plot(x, monoexp(x, m, t, b), '--', label="fitted")
plt.title("Fitted Exponential Curve")
plt.show()
# inspect the parameters
print(f"Y = {m} * e^({t} * x) + {b}")
print(f"Tau = {tauSec * 1e6} µs")
However when I tried to plot the fitted function all the y data is concentrated near y = 0, despite the data not being centered near zero in the raw data.然而,当我尝试 plot 拟合 function 时,所有 y 数据都集中在 y = 0 附近,尽管原始数据中的数据未集中在零附近。
original data原始数据
fitted curve + original data拟合曲线+原始数据
I'm not sure if this is a problem with matplotlib or my data.我不确定这是否是 matplotlib 或我的数据的问题。 Any help would be appreciated.
任何帮助,将不胜感激。
Edit: Here the original arrays:编辑:这里是原来的 arrays:
y = np.array([[ 50.265654 50.481338 55.598281 57.875762 54.32182 58.760685
62.847534 64.28961 57.020572 72.1828 57.467019 62.230506
67.540995 64.496286 54.763321 70.058298 70.710515 58.604998
63.856038 71.711142 61.436699 69.918246 71.709434 72.019284
71.667271 64.667837 73.15604 72.78848 75.194899 73.362767
65.633833 76.527694 79.975514 66.31111 83.238201 86.12114
61.883045 83.874603 68.157062 91.044078 92.729386 71.74222
86.349847 82.8457 94.123932 89.738375 47.495947 78.255301
98.873243 94.74139 93.019812 100.313886 97.69019 57.750244
68.613443 37.156953 104.759545 85.397733 102.651581 99.332523
66.748191 100.283648 84.807827 83.60035 109.183196 85.638829
97.114549 99.870095 95.865177 95.974271 77.56365 95.365718
79.186443 85.08403 76.981884 92.026004 90.356532 97.741741
90.561301 95.034609 86.457017 93.39508 90.173374 92.576365
96.131347 97.231791 95.589212 96.561979 94.905649 96.565916
93.393454 99.992579 98.07305 93.475501 101.344676 98.577551
97.309664 99.832354 96.305865 107.564353 100.852842 102.705253
99.991087 100.500039 92.741113 105.655507 112.564399 113.391128
112.92131 108.758485 116.600566 119.832142 116.127415 96.076771
118.041184 121.801249 104.471811 119.362042 125.783554 99.582637
121.588023 127.001352 142.354073 129.378792 122.716748 123.295855
125.646691 117.353543 129.07801 133.94258 124.015839 125.661787
130.53109 144.816005 157.213145 135.520705 149.981018 145.649621
152.043158 151.030789 172.135762 157.441351 156.543511 136.303031
166.656986 160.05244 161.973895 163.219232 170.589712 167.537767
169.779851 194.229855 170.166431 174.551329 205.846669 188.975028
206.772085 212.799242 193.089462 220.674931 225.05487 222.929436
235.62132 257.330955 249.044577 231.147388 277.336486 257.839554
269.695485 263.06116 271.412341 265.298246 271.612072 275.421131
290.819824 267.474925 268.129235 261.022281 268.555814 269.902072
274.178234 278.662213 291.951716 274.587357 279.547121 273.842799
287.096126 301.781704 318.724333 320.13565 319.823382 305.842588
323.785279 312.432472 360.068566 340.44523 357.468329 397.867687
372.629412 373.711258 388.787463 401.839697 403.886104 407.850298
414.755803 410.143675 407.997144 418.170715 446.555237 436.598994
443.822001 453.202132 457.683222 473.140737 486.566587 507.351791
521.557097 491.379094 523.612526 526.795454 553.596441 556.339232
560.420248 564.853146 560.541646 574.501924 575.880492 580.008003
606.287027 593.064523 575.563008 630.483781 614.198263 650.491796
624.158124 666.845517 679.263829 663.266799 671.248458 655.239931
669.626968 695.610041 667.202116 706.391511 711.907568 717.474155
726.638 727.172115 758.956431 759.93291 764.690474 774.672464
757.418492 822.346932 788.690474 794.585579 815.187357 828.10081
810.168531 806.578016 833.373597 825.524955 895.056386 868.329793
865.217297 892.069431 901.155111 841.722724 912.566196 906.663263
911.161195 926.260885 891.846802 952.404458 968.076496 949.08048
983.608895 975.126401 1020.251247 991.357126 1030.692856 1076.355363
1057.721679 1100.199617 1068.971153 1109.204842 1123.111699 1136.594802
1110.202018 1186.288809 1171.616713 1174.268427 1172.328883 1209.430644
1213.274975 1226.917626 1243.912662 1258.744995 1284.971544 1286.015899
1327.096185 1331.784306 1371.597679 1370.215724 1431.747674 1468.234523
1476.391553 1466.01931 1444.764882 1472.96347 1511.162425 1551.551958
1593.354289 1554.588089 1542.742428 1586.901403 1573.936749 1726.079181
1619.650554 1657.632404 1761.770143 1719.361989 1703.11686 1747.136463
1773.522743 1778.767987 1783.531536 1741.582718 1850.804924 1808.78412
1790.434519])
x = np.array([[ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
325 326 327 328 329 330 331])
yours sincerely Cornelius你真诚的 科尼利厄斯
The documentation says that if no initial value p0
is provided, all parameters are initialized as 1
. 文档说如果没有提供初始值
p0
,所有参数都被初始化为1
。 This is orders of magnitudes off from the actual values in your example.这与您示例中的实际值相差几个数量级。 Also note that due to the huge range of exponential functions even on small intervals, this is generally only going to work with a good initializations of your parameters.
还要注意,由于指数函数的范围很广,即使在很小的时间间隔内,这通常也只能在参数初始化良好的情况下工作。
From the plot it looks like we could assume that m
is roughly 1
, and b
is roughly 0
.从 plot 看起来我们可以假设
m
大约是1
, b
大约是0
。 What is left is t
, and if you consider that the exponential function should be around y=2000
at x=300
, we could start out with t=0.1
or t=0.01
(as an approximation of log(2000)/300
).剩下的是
t
,如果您认为指数 function 在x=300
处应该在y=2000
左右,我们可以从t=0.1
或t=0.01
开始(作为log(2000)/300
的近似值)。 If we plug this into your curve_fit
as p0=[1, 0.01, 0]
we get a more reasonable result:如果我们将它作为
p0=[1, 0.01, 0]
插入你的curve_fit
中,我们会得到更合理的结果:
Finally I should add that I'd give it some more though if a least squares fit (which is what curve_fit
uses) is really what you want for an exponential function, or whether there might be other alternatives to find a working approximation.最后我应该补充一点,如果最小二乘拟合(这是
curve_fit
使用的)真的是指数 function 想要的,或者是否有其他替代方法可以找到工作近似值,我会给出更多。
The flawr's answer is very good.瑕疵者的回答非常好。 In fact there is no need for more answers.
事实上,不需要更多的答案。
Nevertheless I would like to add a comment (too long to be edited in the comments section).不过,我想添加一条评论(太长,无法在评论部分进行编辑)。
The method of nonlinear regression used in your sofware is iterative.您的软件中使用的非线性回归方法是迭代的。 It is well known that starting the iterative calculus requiers to set some initial values to the parameters.
众所周知,开始迭代演算需要为参数设置一些初始值。 That is judiciously mentioned in the flawr's answer with an empirical approach to guess initial values.
在缺陷者的答案中明智地提到了这一点,并采用经验方法来猜测初始值。
This is very important because if the initial values are too far from the unknown correct values the iterative calculus might fail.这非常重要,因为如果初始值离未知的正确值太远,迭代演算可能会失败。
Of course all would be easier with a not iterative method which doesn't require to guess initial values.当然,使用不需要猜测初始值的非迭代方法会更容易。 Such a method exists as shown below.
如下所示存在这样的方法。
Numerical example with your data:您的数据的数值示例:
Comparing to non linear regression (iterative method) one obtain:与非线性回归(迭代方法)相比,可以获得:
a=-2.91212; a=-2.91212; b=35.547425;
b=35.547425; c=0.012030;
c=0.012030; RMSE=30.050141 (Root Mean Square Error).
RMSE=30.050141(均方根误差)。
With this data the fitting is the same for the both methods.有了这些数据,两种方法的拟合都是相同的。 One cannot distinguish any difference between the two respective blue curves on the above graph.
人们无法区分上图中两条各自的蓝色曲线之间的任何差异。
FOR INFORMATION:信息:
The above method is a linear regression wrt a linear integral equation to which the exponential equation is solution:上述方法是线性回归 wrt 指数方程是解的线性积分方程:
In the numerical calculus the Sk are the values of the integral computed by numerical integration.在数值微积分中,Sk 是通过数值积分计算的积分值。
Ref.: https://fr.scribd.com/doc/14674814/Regressions-et-equations-integrales参考: https://fr.scribd.com/doc/14674814/Regressions-et-equations-integrales
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.