简体   繁体   English

具有 n 个断点的分段线性拟合

[英]Piecewise linear fit with n breakpoints

I have used some code found in the question How to apply piecewise linear fit in Python?我使用了问题如何在 Python 中应用分段线性拟合中的一些代码 , to perform segmented linear approximation with a single breakpoint. , 使用单个断点执行分段线性逼近。

The code is as follows:代码如下:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03])

def piecewise_linear(x, x0, y0, k1, k2):
    return np.piecewise(x, 
                       [x < x0], 
                       [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 15, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))

I am trying to figure out how I can extend this to handle n breakpoints.我想弄清楚如何扩展它以处理 n 个断点。

I tried the following code for the piecewise_linear() method to handle 2 breakpoints, but it does not alter the values of the breakpoints in any way.我为piecewise_linear() 方法尝试了以下代码来处理2 个断点,但它不会以任何方式改变断点的值。

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03, 150, 152, 154, 156, 158])

def piecewise_linear(x, x0, x1, a1, b1, a2, b2, a3, b3):
    return np.piecewise(x,
                       [x < x0, np.logical_and(x >= x0, x < x1), x >= x1 ], 
                       [lambda x:a1*x + b1, lambda x:a2*x+b2, lambda x: a3*x + b3])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 20, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))

Any input would be greatly appreciated任何投入将不胜感激

NumPy has a polyfit function which makes it very easy to find the best fit line through a set of points: NumPy 有一个polyfit函数,可以很容易地通过一组点找到最佳拟合线:

coefs = npoly.polyfit(xi, yi, 1)

So really the only difficulty is finding the breakpoints.所以真正唯一的困难是找到断点。 For a given set of breakpoints it's trivial to find the best fit lines through the given data.对于给定的断点集,通过给定数据找到最佳拟合线是微不足道的。

So instead of trying to find location of the breakpoints and the coefficients of the linear parts all at once, it suffices to minimize over a parameter space of breakpoints.因此,与其试图一次找到断点的位置线性部分的系数,不如在断点的参数空间上最小化就足够了。

Since the breakpoints can be specified by their integer index values into the x array, the parameter space can be thought of as points on an integer grid of N dimensions, where N is the number of breakpoints.由于断点可以通过它们的整数索引值指定到x数组中,因此可以将参数空间视为N维整数网格上的点,其中N是断点的数量。

optimize.curve_fit is not a good choice as the minimizer for this problem because the parameter space is integer-valued. optimize.curve_fit作为这个问题的最小化optimize.curve_fit不是一个好的选择,因为参数空间是整数值。 If you were to use curve_fit , the algorithm would tweak the parameters to determine in which direction to move.如果您要使用curve_fit ,算法将调整参数以确定移动的方向。 If the tweak is less than 1 unit, the x-values of the breakpoints do not change, so the error does not change, so the algorithm gains no information about the correct direction in which to shift the parameters.如果调整小于 1 个单位,则断点的 x 值不会改变,因此误差不会改变,因此算法不会获得有关移动参数的正确方向的信息。 Hence curve_fit tends to fail when the parameter space is essentially integer-valued.因此,当参数空间本质上是整数值时, curve_fit往往会失败。

A better, but not very fast, minimizer would be a brute-force grid search.一个更好但不是非常快的最小化器将是蛮力网格搜索。 If the number of breakpoints is small (and the parameter space of x -values is small) this might suffice.如果断点的数量很小(并且x的参数空间很小),这可能就足够了。 If the number of breakpoints is large and/or the parameter space is large, then perhaps set up a multi-stage coarse/fine (brute-force) grid search.如果断点的数量很大和/或参数空间很大,那么也许可以设置多级粗/精(蛮力)网格搜索。 Or, perhaps someone will suggest a smarter minimizer than brute-force...或者,也许有人会建议一个比蛮力更聪明的最小化器......


import numpy as np
import numpy.polynomial.polynomial as npoly
from scipy import optimize
import matplotlib.pyplot as plt
np.random.seed(2017)

def f(breakpoints, x, y, fcache):
    breakpoints = tuple(map(int, sorted(breakpoints)))
    if breakpoints not in fcache:
        total_error = 0
        for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y):
            total_error += ((f(xi) - yi)**2).sum()
        fcache[breakpoints] = total_error
    # print('{} --> {}'.format(breakpoints, fcache[breakpoints]))
    return fcache[breakpoints]

def find_best_piecewise_polynomial(breakpoints, x, y):
    breakpoints = tuple(map(int, sorted(breakpoints)))
    xs = np.split(x, breakpoints)
    ys = np.split(y, breakpoints)
    result = []
    for xi, yi in zip(xs, ys):
        if len(xi) < 2: continue
        coefs = npoly.polyfit(xi, yi, 1)
        f = npoly.Polynomial(coefs)
        result.append([f, xi, yi])
    return result

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 
              18, 19, 20], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 
              126.14, 140.03, 150, 152, 154, 156, 158])
# Add some noise to make it exciting :)
y += np.random.random(len(y))*10

num_breakpoints = 2
breakpoints = optimize.brute(
    f, [slice(1, len(x), 1)]*num_breakpoints, args=(x, y, {}), finish=None)

plt.scatter(x, y, c='blue', s=50)
for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y):
    x_interval = np.array([xi.min(), xi.max()])
    print('y = {:35s}, if x in [{}, {}]'.format(str(f), *x_interval))
    plt.plot(x_interval, f(x_interval), 'ro-')


plt.show()

prints印刷

y = poly([ 4.58801083  2.94476604])    , if x in [1.0, 6.0]
y = poly([-70.36472935  14.37305793])  , if x in [7.0, 15.0]
y = poly([ 123.24565235    1.94982153]), if x in [16.0, 20.0]

and plots和情节

在此处输入图片说明

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM