简体   繁体   中英

SciPy.optimize.least_squares() 5PL Curve Optimization problems

I am trying to write a script that will take an input array of x and y values and fit them to a 5-PL curve (defined by the equation F(x) = D+(AD)/((1+(x/C)^B)^E)). I then want to be able to use the predicted curve to take a given y value and extrapolate an x value from the curve, given by the equation F(y) = C(((AD)/(-D+y))^(1/E)-1)^(1/B).

The answer below fixed the previous error, but the fit is still really bad. I've introduced a print function with a handful of y values across the range fed into curve_fit, and it yields almost the exact same x value across the range. Any ideas what may be going on here?

Edit: For anyone looking now, the problem appears to have been my estimate for B. The hill slope should be between -1 and 1 in most cases, not in the thousands. That made it too far to estimate.

import numpy as np
import scipy.optimize as sp


def logistic5(x, A, B, C, D, E):
    '''5PL logistic equation'''
    log = D + (A-D)/(np.power((1 + np.power((x/C), B)), E))
    return log


def residuals(p, y, x):
    '''Deviations of data from fitted 5PL curve'''
    A, B, C, D, E = p
    err = y - logistic5(x, A, B, C, D, E)
    print(err)
    return err


def log_solve_for_x(curve, y):
    '''Returns the estimated x value for the provided y value'''
    A, B, C, D, E = curve
    return C*(np.power((np.power(((A-D)/(-D+y)), (1/E))-1), (1/B)))


# Toy data set
x = np.array([130, 38, 15, 4.63, 1.41])
y = np.array([9121, 1987, 1017, 343, 117])

# Set initial guess for parameters
A = np.amin(y)  # Min asymptote
D = np.amax(y)  # Max asymptote
B = (D-A)/(np.amax(x)-np.amin(x))  # Steepness
C = (np.amax(x)-np.amin(x))/2  # inflection point
E = 1  # Asymmetry factor

# Optimize curve for initial parameters
p0 = [A, B, C, D, E]
# set bounds for each parameter
pu = []
pl = []
for p in p0:
    pu.append(p*1.5)
    pl.append(p*0.5)
print(pu)
print(pl)
print("Initial guess of parameters is: ", p0)
curve = sp.least_squares(fun=residuals, x0=p0, args=(y, x), bounds=(pl, pu))
curve = curve.x.tolist()
print("Optimized curve parameters are: ", curve)

# Predict x values based on given y
y = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000]
for sample in y:
    solve = log_solve_for_x(curve, sample)
    print("Predicted X value for y =", sample, " is: ", solve)

Your curve is defined not for any values of parameters. But you did not provide that information for least_squares . At some point, the solver goes to an inadmissible zone and gets stuck there obtaining nans from residuals and you get messages about invalid power. You have trivial powers and may just set the E>=0, B>=0 . But your base is non-trivial. You either need to switch to a solver that supports generic constraints (eg scipy.optimize.minimize ) and add constraints that base >=0 or somehow else restrict the search to the admissible domain, eg:

pu = []
pl = []
for p in p0:
    pu.append(p*1.5)
    pl.append(p*.5)

curve = sp.least_squares(fun=residuals, x0=p0, args=(y, x), bounds=(pl, pu))

You also may try to fix your residual that it works for any parameters, eg replace nan with distance to the initial guess. But it may work inefficiently.


To improve fitting results you may try a better initial point or multistart or both.

A = np.amin(y)  # Min asymptote
D = np.amax(y)  # Max asymptote
B = (D-A)/np.amax(x)*10  # Steepness
C = np.amax(x)/10  # inflection point
E = 0.001  # Asymmetry factor

p0 = [A, B, C, D, E]
print("Initial guess of parameters is: ", p0)
pu = []
pl = []
for p in p0:
    pu.append(p*1.5)
    pl.append(p*.5)

best_cost = np.inf
for i in range(100):
    for i in range(5):
        p0[i] = np.random.uniform(pl[i], pu[i])

    curve = sp.least_squares(fun=residuals, x0=p0, args=(y, x), bounds=(pl, pu))
    print(p0, curve.cost)
    if best_cost > curve.cost:
        best_cost = curve.cost
        curve_out = curve.x.tolist()
print("Optimized curve parameters are: ", curve_out)

plt.plot(x, y, '.')

xx = np.linspace(0, 150, 100)
yy = []
for x in xx:
    yy.append(logistic5(x, *curve_out))

plt.plot(xx, yy)
plt.show()

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM