简体   繁体   English

numpy 带异常值的最佳拟合线

[英]numpy best fit line with outliers

I have a scatter plot of data that mostly fits a line, but with some outliers.我有一个分散的 plot 数据,这些数据大多符合一条线,但有一些异常值。 I've been using numpy polyfit to fit a line to the data, but it will pick up the outliers and give me the wrong line output:我一直在使用 numpy polyfit 来拟合数据线,但它会拾取异常值并给我错误的线 output:

线拟合误差

Is there a function out there that will give me the line that has the best fit, not a line fitted to all data points?那里有没有 function 会给我最适合的线,而不是适合所有数据点的线?

Code to reproduce:重现代码:

from numpy.polynomial.polynomial import polyfit
import numpy as np
from matplotlib import pyplot as plt


y = np.array([72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 31, 31, 32, 32, 32, 32, 32, 39, 33, 33, 40, 41, 41, 41, 42, 42, 42, 42, 42, 43, 44, 44, 45, 46, 46, 46, 47, 47, 48, 48, 48, 49, 49, 49, 50, 51, 51, 52, 54, 54, 55, 55, 55, 56, 56, 56, 56, 56, 56, 56, 57, 56, 56, 56, 56, 58, 59, 59, 61, 64, 63, 64, 64, 64, 64, 64, 64, 65, 65, 65, 66, 73, 73, 69, 72, 72, 71, 71, 71, 72, 72, 72, 72, 72, 72, 72, 74, 74, 73, 77, 78, 78, 78, 78, 78, 79, 79, 79, 80, 80, 80, 80, 80, 80, 81, 81, 82, 84, 85, 85, 86, 86, 88, 88, 88, 88, 88, 88, 88, 88, 88, 89, 90, 90, 90, 90, 91, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 97, 97, 97, 97, 98, 99, 100, 103, 103, 104, 104, 104, 104, 104, 104, 104, 104, 104, 105, 105, 105, 106, 106, 106, 108, 107, 110, 111, 111, 111, 112, 112, 112, 112, 113, 113, 113, 113, 114, 114, 114, 115, 116, 119, 119, 119, 119, 119, 120, 119, 120, 120, 120, 120, 120, 120, 121, 122, 123, 124, 126, 126, 127, 127, 127, 127, 128, 128, 128, 129, 129, 129, 129, 129, 130, 130, 131, 133, 134, 135, 133, 135, 135, 136, 136, 136, 136, 136, 136, 136, 137, 136, 137, 138, 138, 138, 140, 141, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 145, 145, 146, 147, 147, 148, 150, 151, 150, 151, 151, 152, 152, 152, 152, 152, 152, 152, 153, 153, 153, 154, 155, 157, 158, 158, 159, 159, 159, 159])

x = np.array([25, 26, 28, 29, 35, 36, 38, 39, 42, 43, 44, 45, 46, 50, 79, 223, 224, 226, 227, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507])

# Fit with polyfit
b, m = polyfit(x, y, 1)


_ = plt.plot(x, y, 'o', label='Original data', markersize=2)
_ = plt.plot(x, m*x + b, 'r', label='Fitted line')
_ = plt.legend()
plt.show()

For the curious, I'm attempting ground plane estimation with disparity maps .出于好奇,我正在尝试使用视差图进行地平面估计。

You can fit a linear model with the Huber loss , which is robust towards outliers.您可以使用Huber 损失拟合线性 model ,该损失对异常值具有鲁棒性。

Full example using scikit learn:使用 scikit learn 的完整示例:

from sklearn.linear_model import HuberRegressor
from sklearn.preprocessing import StandardScaler

y = np.array([72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 31, 31, 32, 32, 32, 32, 32, 39, 33, 33, 40, 41, 41, 41, 42, 42, 42, 42, 42, 43, 44, 44, 45, 46, 46, 46, 47, 47, 48, 48, 48, 49, 49, 49, 50, 51, 51, 52, 54, 54, 55, 55, 55, 56, 56, 56, 56, 56, 56, 56, 57, 56, 56, 56, 56, 58, 59, 59, 61, 64, 63, 64, 64, 64, 64, 64, 64, 65, 65, 65, 66, 73, 73, 69, 72, 72, 71, 71, 71, 72, 72, 72, 72, 72, 72, 72, 74, 74, 73, 77, 78, 78, 78, 78, 78, 79, 79, 79, 80, 80, 80, 80, 80, 80, 81, 81, 82, 84, 85, 85, 86, 86, 88, 88, 88, 88, 88, 88, 88, 88, 88, 89, 90, 90, 90, 90, 91, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 97, 97, 97, 97, 98, 99, 100, 103, 103, 104, 104, 104, 104, 104, 104, 104, 104, 104, 105, 105, 105, 106, 106, 106, 108, 107, 110, 111, 111, 111, 112, 112, 112, 112, 113, 113, 113, 113, 114, 114, 114, 115, 116, 119, 119, 119, 119, 119, 120, 119, 120, 120, 120, 120, 120, 120, 121, 122, 123, 124, 126, 126, 127, 127, 127, 127, 128, 128, 128, 129, 129, 129, 129, 129, 130, 130, 131, 133, 134, 135, 133, 135, 135, 136, 136, 136, 136, 136, 136, 136, 137, 136, 137, 138, 138, 138, 140, 141, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 145, 145, 146, 147, 147, 148, 150, 151, 150, 151, 151, 152, 152, 152, 152, 152, 152, 152, 153, 153, 153, 154, 155, 157, 158, 158, 159, 159, 159, 159])
x = np.array([25, 26, 28, 29, 35, 36, 38, 39, 42, 43, 44, 45, 46, 50, 79, 223, 224, 226, 227, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507])

# standardize    
x_scaler, y_scaler = StandardScaler(), StandardScaler()
x_train = x_scaler.fit_transform(x[..., None])
y_train = y_scaler.fit_transform(y[..., None])

# fit model
model = HuberRegressor(epsilon=1)
model.fit(x_train, y_train.ravel())

# do some predictions
test_x = np.array([25, 600])
predictions = y_scaler.inverse_transform(
    model.predict(x_scaler.transform(test_x[..., None]))
)

# plot
plt.scatter(x, y)
plt.plot(test_x, predictions, 'r')
plt.ylim(0, 200)
plt.xlim(0, 550)
plt.savefig('aa.png')

Result:结果:

在此处输入图像描述

I also suggest you not to follow the other answer, as it does not always work.我还建议您不要遵循其他答案,因为它并不总是有效。 In the following example, it would not remove any points, and result in the green line.在以下示例中,它不会删除任何点,并导致绿线。 The solution above returns the red line, as expected.正如预期的那样,上面的解决方案返回了红线。

在此处输入图像描述

If the residuals are approximately normally distributed, you can filter outliers based on the Z-Score, which is defined as:如果残差近似正态分布,您可以根据 Z-Score 过滤异常值,其定义为:

z = (x - mean)/std

For example:例如:
Convert your data to a DataFrame将您的数据转换为 DataFrame

import pandas as pd
from scipy import stats
df = pd.DataFrame(zip(y, x))

Then you filter the outliers, based on the column mean and standard deviation然后根据列均值和标准差过滤异常值

df = df[(np.abs(stats.zscore(df)) < 2.5).all(axis=1)]

Usually a point is considered an outlier when the absolute value of its Z-Score > 3, but here you keep only the points with abs(Z-Score) < 2.5通常,当 Z-Score 的绝对值 > 3 时,一个点被认为是异常值,但这里只保留 abs(Z-Score) < 2.5 的点

# Fit with polyfit
b, m = polyfit(df[1], df[0], 1)


_ = plt.plot(df[1], df[0], 'o', label='Original data', markersize=2)
_ = plt.plot(df[1], m*df[1] + b, 'r', label='Fitted line')
_ = plt.legend()
plt.show()

Result:结果:
在此处输入图像描述

I found this Z-Score filtering method here: Detect and exclude outliers in Pandas data frame我在这里找到了这个 Z-Score 过滤方法: Detect and exclude outliers in Pandas data frame
Edit: Please note that this approach has limitations, since it is a univariate outlier detection method, that is, it only considers one variable a time.编辑:请注意,这种方法有局限性,因为它是一种单变量异常值检测方法,也就是说,它一次只考虑一个变量。 Besides, it is very sensitive to extreme outliers, because they shift the mean of the sample and, consequently, the Z-Score.此外,它对极端异常值非常敏感,因为它们会改变样本的平均值,从而改变 Z-Score。 A work-around could be using the Robust Z-Score method, which incorporates the Median Absolute Deviation (MAD) Z-Score.一种解决方法是使用稳健的 Z 分数方法,该方法结合了中值绝对偏差 (MAD) Z 分数。
Articles:文章:
https://medium.com/james-blogs/outliers-make-us-go-mad-univariate-outlier-detection-b3a72f1ea8c7 https://medium.com/james-blogs/outliers-make-us-go-mad-univariate-outlier-detection-b3a72f1ea8c7
https://www.itl.nist.gov/div898/handbook/eda/section3/eda35h.htm https://www.itl.nist.gov/div898/handbook/eda/section3/eda35h.htm

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM