简体   繁体   English

用两条线拟合曲线的二次算法

[英]Sub-quadratic algorithm for fitting a curve with two lines

The problem is to find the best fit of a real-valued 2D curve (given by the set of points) with a polyline consisting of two lines.问题是找到由两条线组成的折线的实值二维曲线(由点集给出)的最佳拟合。

A brute-force approach would be to find the "left" and "right" linear fits for each point of the curve and pick the pair with minimum error.蛮力方法是为曲线的每个点找到“左”和“右”线性拟合,并选择误差最小的对。 I can calculate the two linear fits incrementally while iterating through the points of the curve, but I can't find a way to incrementally calculate the error.我可以在遍历曲线点的同时增量计算两个线性拟合,但我找不到增量计算误差的方法。 Thus this approach yields to a quadratic complexity.因此,这种方法产生了二次复杂度。

The question is if there is an algorithm that will provide sub-quadratic complexity?问题是是否有一种算法可以提供次二次复杂度?

The second question is if there is a handy C++ library for such algorithms?第二个问题是这种算法是否有一个方便的 C++ 库?


EDIT For fitting with a single line, there are formulas:编辑对于单行拟合,有公式:

m = (Σxiyi - ΣxiΣyi/N) / (Σxi2 - (Σxi)2/N)
b = Σyi/N - m * Σxi/N

where m is the slope and b is the offset of the line.其中m是斜率, b是直线的偏移量。 Having such a formula for the fit error would solve the problem in the best way.拥有这样的拟合误差公式将以最好的方式解决问题。

Disclaimer: I don't feel like figuring out how to do this in C++, so I will use Python (numpy) notation.免责声明:我不想弄清楚如何在 C++ 中执行此操作,所以我将使用 Python(numpy)表示法。 The concepts are completely transferable, so you should have no trouble translating back to the language of your choice.这些概念是完全可以转移的,因此您可以毫不费力地将其翻译回您选择的语言。

Let's say that you have a pair of arrays, x and y , containing the data points, and that x is monotonically increasing.假设您有一对 arrays, xy ,包含数据点,并且x单调递增。 Let's also say that you will always select a partition point that leaves at least two elements in each partition, so the equations are solvable.还假设您将始终 select 分区点在每个分区中至少留下两个元素,因此方程是可解的。

Now you can compute some relevant quantities:现在您可以计算一些相关数量:

N = len(x)

sum_x_left = x[0]
sum_x2_left = x[0] * x[0]
sum_y_left = y[0]
sum_y2_left = y[0] * y[0]
sum_xy_left = x[0] * y[0]

sum_x_right = x[1:].sum()
sum_x2_right = (x[1:] * x[1:]).sum()
sum_y_right = y[1:].sum()
sum_y2_right = (y[1:] * y[1:]).sum()
sum_xy_right = (x[1:] * y[1:]).sum()

The reason that we need these quantities (which are O(N) to initialize) is that you can use them directly to compute some well known formulae for the parameters of a linear regression.我们需要这些量(初始化时间为O(N) )的原因是您可以直接使用它们来计算一些众所周知的线性回归参数公式。 For example, the optimal m and b for y = m * x + b is given by例如, y = m * x + b的最佳mb由下式给出

μx = Σxi/N
μy = Σyi/N
m = Σ(xi - μx)(yi - μy) / Σ(xi - μx)2
b = μy - m * μx

The sum of squared errors is given by误差平方和由下式给出

e = Σ(yi - m * xi - b)2

These can be expanded using simple algebra into the following:这些可以使用简单的代数扩展为以下内容:

m = (Σxiyi - ΣxiΣyi/N) / (Σxi2 - (Σxi)2/N)
b = Σyi/N - m * Σxi/N
e = Σyi2 + m2 * Σxi2 + N * b2 - m * Σxiyi - b * Σyi + m * b * Σxi

You can therefore loop over all the possibilities and record the minimal e :因此,您可以遍历所有可能性并记录最小的e

for p in range(1, N - 3):
    # shift sums: O(1)
    sum_x_left += x[p]
    sum_x2_left += x[p] * x[p]
    sum_y_left += y[p]
    sum_y2_left += y[p] * y[p]
    sum_xy_left += x[p] * y[p]

    sum_x_right -= x[p]
    sum_x2_right -= x[p] * x[p]
    sum_y_right -= y[p]
    sum_y2_right -= y[p] * y[p]
    sum_xy_right -= x[p] * y[p]

    # compute err: O(1)
    n_left = p + 1
    slope_left = (sum_xy_left - sum_x_left * sum_y_left * n_left) / (sum_x2_left - sum_x_left * sum_x_left / n_left)
    intercept_left = sum_y_left / n_left - slope_left * sum_x_left / n_left
    err_left = sum_y2_left + slope_left * slope_left * sum_x2_left + n_left * intercept_left * intercept_left - slope_left * sum_xy_left - intercept_left * sum_y_left + slope_left * intercept_left * sum_x_left

    n_right = N - n_left
    slope_right = (sum_xy_right - sum_x_right * sum_y_right * n_right) / (sum_x2_right - sum_x_right * sum_x_right / n_right)
    intercept_right = sum_y_right / n_right - slope_right * sum_x_right / n_right
    err_right = sum_y2_right + slope_right * slope_right * sum_x2_right + n_right * intercept_right * intercept_right - slope_right * sum_xy_right - intercept_right * sum_y_right + slope_right * intercept_right * sum_x_right

    err = err_left + err_right
    if p == 1 || err < err_min
        err_min = err
        n_min_left = n_left
        n_min_right = n_right
        slope_min_left = slope_left
        slope_min_right = slope_right
        intercept_min_left = intercept_left
        intercept_min_right = intercept_right

There are probably other simplifications you can make, but this is sufficient to have an O(n) algorithm.您可能还可以进行其他简化,但这足以拥有O(n)算法。

In case it helps here's some C code that I've used for this sort of thing.如果它有帮助,这里有一些我用于此类事情的 C 代码。 It adds little to what Mad Physicist said.它对疯狂物理学家所说的几乎没有什么帮助。

First off, a formula.首先,一个公式。 If you fit a line y^: x->a*x+b through some points, then the error is given by:如果您通过某些点拟合一条线 y^: x->a*x+b,则错误由下式给出:

E = Sum{ sqr(y[i]-y^(x[i])) }/ N = Vy - Cxy*Cxy/Vx
where 
Vx is the variance of the xs
Vy that of the ys 
Cxy the covariance of the as and the ys

The code below uses a structure that holds the means, the variances, the covariance and the count.下面的代码使用包含均值、方差、协方差和计数的结构。

The function moms_acc_pt() updates these when you add a new point.当您添加新点时,function moms_acc_pt() 会更新这些。 The function moms_line() returns a and b for the line, and the error as above. function moms_line() 为该行返回 a 和 b,错误如上。 The fmax(0,) on the return is in case of near perfect fits where rounding error could send the (mathematically non-negative) result negative.返回的 fmax(0,) 是在接近完美拟合的情况下,舍入误差可能会使(数学上非负)结果为负。

While it would be possible to have a function that removes a point from a momentsT, I find it easier deal with deciding which momentsT to add a point to by taking copies, accumulating the point in the copies, getting the lines and then keeping the copy for the side the point fits best in, and the original for the other虽然有可能有一个 function 从 momentT 中删除一个点,但我发现通过复制、在副本中累积点、获取行然后保留副本来决定添加点更容易点最适合的一侧,另一侧的原始

typedef struct
{   int n;      // number points
    double  xbar,ybar;  // means of x,y
    double  Vx, Vy;     // variances of x,y
    double  Cxy;        // covariance of x,y
}   momentsT;

// update the moments to include the point x,y
void    moms_acc_pt( momentsT* M, double x, double y)
{   M->n += 1;
double  f = 1.0/M->n;
double  dx = x-M->xbar;
double  dy = y-M->ybar;
    M->xbar += f*dx;
    M->ybar += f*dy;
double  g = 1.0 - f;
    M->Vx   = g*(M->Vx  + f*dx*dx);
    M->Cxy  = g*(M->Cxy + f*dx*dy);
    M->Vy   = g*(M->Vy  + f*dy*dy);
}

// return the moments for the combination of A and B (assumed disjoint)
momentsT    moms_combine( const momentsT* A, const momentsT* B)
{
momentsT    C;
    C.n = A->n + B->n;
double  alpha = (double)A->n/(double)C.n;
double  beta = (double)B->n/(double)C.n;
    C.xbar = alpha*A->xbar + beta*B->xbar;
    C.ybar = alpha*A->ybar + beta*B->ybar;
double  dx = A->xbar - B->xbar;
double  dy = A->ybar - B->ybar;
    C.Vx = alpha*A->Vx + beta*B->Vx + alpha*beta*dx*dx;
    C.Cxy= alpha*A->Cxy+ beta*B->Cxy+ alpha*beta*dx*dy;
    C.Vy = alpha*A->Vy + beta*B->Vy + alpha*beta*dy*dy;
    return C;
}

// line is y^ : x -> a*x + b; return Sum{ sqr( y[i] - y^(x[i])) }/N
double  moms_line( momentsT* M, double* a, double *b)
{   *a = M->Cxy/M->Vx;
    *b = M->ybar - *a*M->xbar;
    return fmax( 0.0, M->Vy - *a*M->Cxy);
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM