简体   繁体   English

如何更改决策树分类器模型的阈值?

[英]How to change the threshold on decision tree classifier model?

Is it possible to change the threshold of a decisiontreeclassifier?是否可以更改决策树分类器的阈值? I'm studying the precision/recall trade-off and want to change the threshold to favor recall.我正在研究精确度/召回率的权衡,并希望更改阈值以支持召回率。 I'm studying the hand's on ML, but there it uses the SGDClassifier, at some point it uses the cross_val_predict() with the method="decision_function" attribute, but this does not exist for the decisiontreeclassifier.我正在研究 ML 上的手,但它使用 SGDClassifier,在某些时候它使用带有 method="decision_function" 属性的cross_val_predict() ,但这对于 decisiontreeclassifier 不存在。 I'm using a pipeline and a cross-validation.我正在使用管道和交叉验证。 My study is with this dataset: https://www.kaggle.com/datasets/imnikhilanand/heart-attack-prediction我的研究是用这个数据集: https ://www.kaggle.com/datasets/imnikhilanand/heart-attack-prediction

features = df_heart.drop(['output'], axis=1).copy()
labels = df_heart.output

#split
X_train, X_test, y_train, y_test= train_test_split(features, labels,
                                train_size=0.7,
                                random_state=42,
                                stratify=features["sex"]
                               )
# categorical features
cat = ['sex', 'tipo_de_dor', 'ang_indz_exerc', 'num_vasos', 'acuc_sang_jejum', 'eletrc_desc', 'pico_ST_exerc', 'talassemia']

# treatment of categorical variables
t = [('cat', OneHotEncoder(handle_unknown='ignore'), cat)]

preprocessor = ColumnTransformer(transformers=t, remainder='passthrough')

#pipeline
pipe = Pipeline(steps=[('preprocessor', preprocessor),
                       ('clf', DecisionTreeClassifier(min_samples_leaf=8, random_state=42),)
                       ]
                )

pipe.fit(X_train, y_train)

valid_cruz_strat = StratifiedKFold(n_splits=14, shuffle=True, random_state=42)

y_train_pred = cross_val_predict(pipe['clf'], X_train, y_train, cv=valid_cruz_strat)

conf_mat = confusion_matrix(y_train, y_train_pred)


ConfusionMatrixDisplay(confusion_matrix=conf_mat, 
                       display_labels=pipe['clf'].classes_).plot()
plt.grid(False)
plt.show()

混淆矩阵

threshold = 0 #this is only for support the graph
idx = (thresholds >= threshold).argmax()  # first index ≥ threshold

plt.plot(thresholds, precisions[:-1], 'b--', label = 'Precisão')
plt.plot(thresholds, recalls[:-1], 'g-', label = 'Recall')
plt.vlines(threshold, 0, 1.0, "k", "dotted", label="threshold")
plt.title('Precisão x Recall', fontsize = 14)

plt.plot(thresholds[idx], precisions[idx], "bo")
plt.plot(thresholds[idx], recalls[idx], "go")
plt.axis([-.5, 1.5, 0, 1.1])
plt.grid()
plt.xlabel("Threshold")
plt.legend(loc="lower left")

plt.show()

精确率/召回率 - 阈值

valid_cruz_strat = StratifiedKFold(n_splits=14, shuffle=True, random_state=42)

y_score = cross_val_predict(pipe['clf'], X_train, y_train, cv=valid_cruz_strat)

precisions, recalls, thresholds = precision_recall_curve(y_train, y_score)

threshold = 0.75 #this is only for support the graph
idx = (thresholds >= threshold).argmax() 

plt.figure(figsize=(6, 5))  

plt.plot(recalls, precisions, linewidth=2, label="Precision/Recall curve")

plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], "k:")
plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], "k:")
plt.plot([recalls[idx]], [precisions[idx]], "ko",
         label="Point at threshold "+str(threshold))

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.axis([0, 1, 0, 1])
plt.grid()
plt.legend(loc="lower left")

plt.show()

精确率 x 召回率

When I check the arrays generated by the precision_recall_curve() function I see that it only contains 3 elements.当我检查由precision_recall_curve()函数生成的数组时,我发现它只包含 3 个元素。 Is this correct behavior?这是正确的行为吗? When I do the cross_val_predict() function for an SGDClassifier, for example, as it is in the book, without the method='decision_function' attribute and I use the output in precision_recall_curve() and it generates arrays with 3 elements and if I use the method='decision_function ' it generates arrays with several elements.当我为 SGDClassifier 执行cross_val_predict()函数时,例如,如书中所示,没有 method='decision_function' 属性,我使用precision_recall_curve()中的输出,它生成包含 3 个元素的数组,如果我使用method='decision_function ' 它生成包含多个元素的数组。

My main question is how to choose the threshold for the DecisionTreeClassifier, and if there is a way to generate the Precision x Recall curve with several points, I only manage with these three points and I am not able to assimilate how to improve the recall.我的主要问题是如何选择 DecisionTreeClassifier 的阈值,如果有办法生成具有多个点的 Precision x Recall 曲线,我只用这三个点来管理,我无法吸收如何提高召回率。

Move the threshold to improve recall, and understand how to do it with Decision tree classifier移动阈值以提高召回率,并了解如何使用决策树分类器做到这一点

This topic usually falls under the name " model calibration ."该主题通常属于“模型校准”的名称。 scikit-learn supports a few kinds of probability calibration which could be informative to read about as well. scikit-learn支持几种概率校准,这些也可以提供信息。

One way to "change the threshold" in a DecisionTreeClassifier would involve invoking .predict_proba(X) and observing a metric(s) over possible thresholds:DecisionTreeClassifier中“更改阈值”的一种方法是调用.predict_proba(X)并观察指标是否超过可能的阈值:

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import recall_score
import numpy as np
import matplotlib.pyplot as plt

X, y = make_classification(n_samples=10000, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

clf = DecisionTreeClassifier(max_depth=5)
clf.fit(X_train, y_train)

prob_pred = clf.predict_proba(X_test)[:, 1]

thresholds = np.arange(0.0, 1.0, step=0.01)
recall_scores = [recall_score(y_test, prob_pred > t) for t in thresholds]
precis_scores = [precision_score(y_test, prob_pred > t) for t in thresholds]

Now we have a set of thresholds between 0.0 and 1.0 , and we've computed precision and recall over each threshold ( Side note : this problem is less-well-defined for multilabel or multiclass prediction—usually these metrics are averaged over each class or similar).现在我们有一组介于0.01.0之间的阈值,并且我们已经计算了每个阈值的精度和召回率(旁注:这个问题对于多标签或多类预测的定义不太明确——通常这些指标是每个类的平均值或相似的)。

Then we'll plot:然后我们将绘制:

fig, ax = plt.subplots(1, 1)
ax.plot(thresholds, recall_scores, label="Recall @ t")
ax.plot(thresholds, precis_scores, label="Precision @ t")
ax.axvline(0.5, c="gray", linestyle="--", label="Default Threshold")
ax.set_xlabel("Threshold")
ax.set_ylabel("Metric @ Threshold")
ax.set_box_aspect(1)
ax.legend()
plt.show()

Which results in a figure like this:结果如下图:

线图在 x 轴上显示从 0 到 1 的阈值,在 y 轴上显示召回率。该线是抛物线,低阈值具有高召回率,高阈值具有低召回率。

This shows us that the default threshold used by .predict() at 0.5 may not be the best in all circumstances.这向我们表明.predict()使用的默认阈值0.5可能并非在所有情况下都是最佳的。 In fact, there are a range of thresholds where precision and recall are fairly close, but favors one over the other.事实上,有一系列的阈值,精确率和召回率相当接近,但其中一个优于另一个。 In this case: lowering the threshold slightly will tend to favor recall, while increasing the threshold will tend to favor precision.在这种情况下:稍微降低阈值将倾向于有利于召回,而增加阈值将倾向于有利于精确度。

In practice: the threshold appropriate for the problem comes down to domain knowledge since there's always a trade-off between precision and recall.在实践中:适用于问题的阈值归结为领域知识,因为在精度和召回率之间总是存在权衡。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM