繁体   English   中英

sklearn 中的自定义指标

[英]Custom metric in sklearn

我应该设计一个自定义指标,当应用于具有不同算法的 MNIST 时,其性能优于 L2。

from sklearn import neighbors

import utils
import math


# Extraction du dataset
x_train, y_train = utils.get_train_data()
x_test,  y_test  = utils.get_test_data()

def EuclideanDistance(x, y):
    return math.sqrt((y[0] - x[0]) ** 2 + (y[1] - x[1]) ** 2)

test_range = 10
test_results = []  # tableau d'enregistrements {nn: [uniform, distance]}

for k in range(test_range):  # will test all 'k' values from 2 to 'test_range + 1'
    n_neighbors = k+2
    print("\nTesting  k =", n_neighbors)
    error_rate = []
    for weights in ['uniform', 'distance']:
        knn_clf = neighbors.KNeighborsClassifier(n_neighbors,
                                                 metric=EuclideanDistance,
                                                 weights=weights)
        knn_clf.fit(x_train, y_train)
        predictions = knn_clf.predict(x_test)

        error_rate.append(utils.count_error_rate(predictions, y_test))

    test_results.append({n_neighbors: error_rate})

print("\nResults:", test_results)

这样做,我得到以下结果:

Testing  k = 2
Error rate =      91.58316633266533 %
Error rate =      91.58316633266533 %

Testing  k = 3
Error rate =      91.58316633266533 %
Error rate =      91.58316633266533 %

Testing  k = 4
Error rate =      91.58316633266533 %
Error rate =      91.58316633266533 %

...

这显然是错误的。 为什么我的自定义指标应用于不同的上下文会得到相同的输出?

尝试使您的EuclideanDistance函数与输入数据的长度无关(您的函数仅查看 2 个组件,而不是 MNIST 中的 784 个维度):

def EuclideanDistance(x, y):
    if len(x) != len(y):
        raise ValueError("x and y need to have the same length")
    return math.sqrt(sum([(y[i] - x[i]) ** 2 for i in range(len(x))]))

*编辑您对效率的评论

如果你去 pythons lib 文件夹 (/site-packages/sklearn/metrics/pairwise.py) 你可以看到自己是如何编写函数的。 然而,函数内部的注释指出:

出于效率原因,一对行之间的欧几里德距离
向量 x 和 y 计算如​​下:

 dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y))

与其他计算距离的方法相比,此公式有两个优点。 首先,它在处理稀疏数据时计算效率高。 其次,如果一个参数发生变化而另一个参数保持不变,则可以预先计算dot(x, x)和/或dot(y, y)

然而,这不是进行此计算的最精确方法,并且此函数返回的距离矩阵可能不完全符合例如scipy.spatial.distance函数的要求。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM