简体   繁体   English

sklearn LogisticRegression 并更改分类的默认阈值

[英]sklearn LogisticRegression and changing the default threshold for classification

I am using LogisticRegression from the sklearn package, and have a quick question about classification.我正在使用 sklearn 包中的 LogisticRegression,并且有一个关于分类的快速问题。 I built a ROC curve for my classifier, and it turns out that the optimal threshold for my training data is around 0.25.我为我的分类器构建了一条 ROC 曲线,结果证明我的训练数据的最佳阈值约为 0.25。 I'm assuming that the default threshold when creating predictions is 0.5.我假设创建预测时的默认阈值是 0.5。 How can I change this default setting to find out what the accuracy is in my model when doing a 10-fold cross-validation?在进行 10 倍交叉验证时,如何更改此默认设置以找出模型中的准确度? Basically, I want my model to predict a '1' for anyone greater than 0.25, not 0.5.基本上,我希望我的模型为大于 0.25 而不是 0.5 的任何人预测“1”。 I've been looking through all the documentation, and I can't seem to get anywhere.我一直在查看所有文档,但似乎无处可寻。

I would like to give a practical answer我想给出一个实际的答案

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, recall_score, roc_auc_score, precision_score

X, y = make_classification(
    n_classes=2, class_sep=1.5, weights=[0.9, 0.1],
    n_features=20, n_samples=1000, random_state=10
)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

clf = LogisticRegression(class_weight="balanced")
clf.fit(X_train, y_train)
THRESHOLD = 0.25
preds = np.where(clf.predict_proba(X_test)[:,1] > THRESHOLD, 1, 0)

pd.DataFrame(data=[accuracy_score(y_test, preds), recall_score(y_test, preds),
                   precision_score(y_test, preds), roc_auc_score(y_test, preds)], 
             index=["accuracy", "recall", "precision", "roc_auc_score"])

By changing the THRESHOLD to 0.25 , one can find that recall and precision scores are decreasing.通过将THRESHOLD更改为0.25 ,可以发现recallprecision分数正在下降。 However, by removing the class_weight argument, the accuracy increases but the recall score falls down.但是,通过删除class_weight参数, accuracy提高了,但recall下降了。 Refer to the @accepted answer请参阅@accepted 答案

That is not a built-in feature.这不是内置功能。 You can "add" it by wrapping the LogisticRegression class in your own class, and adding a threshold attribute which you use inside a custom predict() method.您可以通过将 LogisticRegression 类包装在您自己的类中并添加您在自定义predict()方法中使用的threshold属性来“添加”它。

However, some cautions:但是,一些警告:

  1. The default threshold is actually 0. LogisticRegression.decision_function() returns a signed distance to the selected separation hyperplane.默认阈值实际上是 0。 LogisticRegression.decision_function()返回到所选分离超平面的有符号距离。 If you are looking at predict_proba() , then you are looking at logit() of the hyperplane distance with a threshold of 0.5.如果您正在查看predict_proba() ,那么您正在查看阈值为 0.5 的超平面距离的logit() But that's more expensive to compute.但这计算起来更昂贵。
  2. By selecting the "optimal" threshold like this, you are utilizing information post-learning, which spoils your test set (ie, your test or validation set no longer provides an unbiased estimate of out-of-sample error).通过像这样选择“最佳”阈值,您正在利用学习后的信息,这会破坏您的测试集(即,您的测试或验证集不再提供对样本外错误的无偏估计)。 You may therefore be inducing additional over-fitting unless you choose the threshold inside a cross-validation loop on your training set only, then use it and the trained classifier with your test set.因此,除非您仅在训练集上选择交叉验证循环内的阈值,然后将它和经过训练的分类器与测试集一起使用,否则您可能会引起额外的过度拟合。
  3. Consider using class_weight if you have an unbalanced problem rather than manually setting the threshold.如果您有不平衡的问题,请考虑使用class_weight而不是手动设置阈值。 This should force the classifier to choose a hyperplane farther away from the class of serious interest.这应该迫使分类器选择一个离真正感兴趣的类更远的超平面。

You can change the threshold, but it's at 0.5 so that the calculations are correct.您可以更改阈值,但它是 0.5,因此计算是正确的。 If you have an unbalanced set, the classification looks like the figure below.如果您有一个不平衡的集合,分类如下图所示。 在此处输入图片说明

You can see that category 1 was very poorly anticipated.您可以看到类别 1 的预期非常差。 Class 1 accounted for 2% of the population. 1类占人口的2%。 After balancing the result variable at 50% to 50% (using oversamplig) the 0.5 threshold went to the center of the chart.在将结果变量平衡为 50% 到 50%(使用过采样)后,0.5 阈值到达图表的中心。

在此处输入图片说明

Special case: one-dimensional logistic regression特例:一维逻辑回归

The value separating the regions where a sample X is labeled as 1 and where it is labeled 0 is calculated using the formula:使用以下公式计算将样本X标记为1和标记为0的区域分开的值:

from scipy.special import logit
thresh = 0.1
val = (logit(thresh)-clf.intercept_)/clf.coef_[0]

Thus, the predictions can be calculated more directly with因此,可以更直接地计算预测

preds = np.where(X>val, 1, 0)

For the sake of completeness, I would like to mention another way to elegantly generate predictions based on scikit's probability computations using binarize :为了完整起见,我想提到另一种使用binarize 基于 scikit 的概率计算优雅地生成预测的方法:

import numpy as np
from sklearn.preprocessing import binarize

THRESHOLD = 0.25

# This probabilities would come from logistic_regression.predict_proba()
y_logistic_prob =  np.random.uniform(size=10)

predictions = binarize(y_logistic_prob.reshape(-1, 1), THRESHOLD).ravel()

Furthermore, I agree with the considerations that Andreus makes , specially 2 and 3. Be sure to keep an eye for them.此外,我同意Andreus 所做的考虑,特别是 2 和 3。一定要留意它们。

def find_best_threshold(threshould, fpr, tpr):
   t = threshould[np.argmax(tpr*(1-fpr))]
   # (tpr*(1-fpr)) will be maximum if your fpr is very low and tpr is very high
   print("the maximum value of tpr*(1-fpr)", max(tpr*(1-fpr)), "for threshold", np.round(t,3))
   return t

this function can be used if you want find the best True positive rate and nagatuve rate如果您想找到最佳的真阳性率和自然率,可以使用此功能

Ok as far as my alghoritm:好的,就我的算法而言:

threshold = 0.1
LR_Grid_ytest_THR = ((model.predict_proba(Xtest)[:, 1])>= threshold).astype(int)

and:和:

print('Valuation for test data only:')
    print(classification_report(ytest, model.predict(Xtest)))
    print("----------------------------------------------------------------------")
    print('Valuation for test data only  (new_threshold):')
    print(classification_report(ytest, LR_Grid_ytest_THR))

在此处输入图片说明

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM