[英]How to fit log(a-x) type functions with scipy. curve_fit?
I am trying to fit a function which looks like log(y)=a*log(bx)+c
, where a
, b
and c
are the parameters that need to be fitted. 我试图拟合一个看起来像
log(y)=a*log(bx)+c
的函数,其中a
, b
和c
是需要拟合的参数。 The relevant bit of code is 代码的相关位是
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
def logfunc(T, a, b, c):
v=(a*np.log(b-T))+c
return v
popt, pcov=curve_fit(logfunc, T, np.log(Energy), check_finite=False, bounds=([0.1, 1.8, 0.1], [1.0, 2.6, 1.0]))
plt.plot(T, logfunc(T, *popt))
plt.show
Where T
and Energy
is some data that was generated (I use it to plot other things so the data should be fine). 其中
T
和Energy
是生成的一些数据(我用它来绘制其他东西,所以数据应该没问题)。 T
is between 0.3 and 3.2. T
介于0.3和3.2之间。 I am pretty sure that the problem is the fact that there is a point where b=T
because I keep getting the error ValueError: Residuals are not finite in the initial point
. 我很确定问题是有一个点
b=T
的事实,因为我不断得到错误ValueError: Residuals are not finite in the initial point
。 but I am not sure how to solve this. 但我不知道如何解决这个问题。
Residuals are not finite in the initial point
残差在初始点不是有限的
means the initial point is bad, where some logarithms are infinite or undefined. 意味着初始点是坏的,其中一些对数是无限的或未定义的。 You need a better initial point.
你需要一个更好的初始点。
By the nature of the model, b
has to be greater than any of the points in T. The bounds on b
that you have at present do not guarantee that. 根据模型的性质,
b
必须大于T中的任何点。目前b
上的界限不能保证。 Tighten them up. 收紧它们。
When you do not provide p0
parameter, SciPy will take a guess within the provided bounds. 当您不提供
p0
参数时,SciPy将在提供的范围内进行猜测。 So if the bounds guarantee finiteness, the error will not occur. 因此,如果边界保证有限,则不会发生错误。 Still, it is generally better to prescribe
p0
yourself, because you have better a priori understanding of the problem than SciPy does. 不过,通常最好规定
p0
自己,因为你有更好的比SciPy的究竟问题的先验认识。
A working example with adjusted bounds: 一个调整边界的工作示例:
popt, pcov=curve_fit(logfunc, np.linspace(0.3, 3.2, 6), [8, 7, 6, 5, 4, 3], bounds=([0.1, 3.2, 0.1], [1.0, 3.6, 1.0]))
You may find the lmfit
package ( http://lmfit.github.io/lmfit-py/ ) useful for this sort of problem. 您可能会发现
lmfit
包( http://lmfit.github.io/lmfit-py/ )对此类问题很有用。 This provides a higher-level approach to curve fitting problems and a better abstraction of Parameters and Models than scipy.optimize
package or curve_fit()
function. 与
scipy.optimize
package或curve_fit()
函数相比,这为曲线拟合问题提供了更高级的方法,并且更好地抽象了参数和模型。
For the problem here, two important features of lmfit
are 对于这里的问题,
lmfit
两个重要特征是
curve_fit()
can do this as well, but only by working with ordered lists of min/max bounds. curve_fit()
也可以这样做,但只能使用最小/最大边界的有序列表。 With lmfit
, the bounds belong to Parameter objects. lmfit
,边界属于Parameter对象。 With lmfit, your script would be written approximately as 使用lmfit,您的脚本将大致写为
import numpy as np
import matplotlib.pyplot as plt
from lmfit import Model
def logfunc(T, a, b, c):
return (a*np.log(b-T))+c
log_model = Model(logfunc, nan_policy='raise') # raise error on NaNs
params = log_model.make_params(a=0.5, b=2.0, c=0.5) # initial values
params['b'].min = 1.8 # set min/max values
params['b'].max = 2.6
params['c'].min = 0.1 # and so forth
result = log_model.fit(np.log(Energy), params, T=T)
print(result.fit_report())
plt.plot(T, Energy, 'bo', label='data')
plt.plot(T, np.exp(result.best_fit), 'r--', label='fit')
plt.legend()
plt.xlabel('T')
plt.ylabel('Energy')
plt.gca().set_yscale('log', basey=10)
plt.show()
This is slightly more verbose than your starting script because it gives a labeled plot and because using Parameter objects instead of scalars gives more flexibility and clarity. 这比起始脚本稍微冗长一些,因为它给出了标记图,因为使用Parameter对象而不是标量提供了更多的灵活性和清晰度。
For your fit, you might consider setting the nan_policy
to 'omit', which will omit NaNs as they occur -- never a great idea, but sometimes helpful to get you started on finding where log(bT)
is valid. 为了您的健康,您可以考虑将
nan_policy
设置为'省略',这将在它们出现时省略NaN - 从来不是一个好主意,但有时有助于您开始寻找log(bT)
有效的位置。 You could also alter your model function to do something like 你也可以改变模型函数来做类似的事情
def logfunc(T, a, b, c):
arg = b - T
arg[np.where(arg < 1.e-16)] = 1.e-16
return a*np.log(arg) + c
To explicitly prevent one obvious cause of NaNs. 明确地防止一个明显的NaNs原因。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.