![](/img/trans.png)
[英]How to get the same thresholds values for both functions precision_recall_curve and roc_curve in sklearn.metrics
[英]Plotting Threshold (precision_recall curve) matplotlib/sklearn.metrics
我正在嘗試 plot 我的精度/召回曲線的閾值。 我只是使用 MNSIT 數據,其中示例來自使用 scikit-learn、keras 和 TensorFlow 進行機器學習一書中的示例。 嘗試訓練 model 來檢測 5 的圖像。 我不知道你需要看多少代碼。 我已經為訓練集制作了混淆矩陣,並計算了精度和召回值以及閾值。 我已經繪制了 pre/rec 曲線,書中的示例說要添加軸 label、壁架、網格並突出顯示閾值,但代碼在我在下面放置星號的書中被切斷。 除了如何讓閾值顯示在 plot 上之外,我能夠弄清楚所有問題。 我已經附上了一張書中圖表與我所擁有的圖表的圖片。 這就是這本書所展示的:
我無法顯示帶有兩個閾值點的紅色虛線。 有誰知道我會怎么做? 下面是我的代碼:
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
def plot_precision_recall_vs_thresholds(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g--", label="Recall")
plt.xlabel("Threshold")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
plt.grid(b=True, which="both", axis="both", color='gray', linestyle='-', linewidth=1)
plot_precision_recall_vs_thresholds(precisions, recalls, thresholds)
plt.show()
我知道這里有很多關於 sklearn 的問題,但似乎沒有一個涵蓋讓紅線出現的問題。 我將非常感謝您的幫助!
您可以使用以下代碼繪制水平線和垂直線:
plt.axhline(y_value, c='r', ls=':')
plt.axvline(x_value, c='r', ls=':')
這應該以確切的方式工作:
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
recall_80_precision = recalls[np.argmax(precisions >= 0.80)]
threshold_80_precision = thresholds[np.argmax(precisions >= 0.80)]
plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
plt.xlabel("Threshold")
plt.plot([threshold_80_precision, threshold_80_precision], [0., 0.8], "r:")
plt.axis([-4, 4, 0, 1])
plt.plot([-4, threshold_80_precision], [0.8, 0.8], "r:")
plt.plot([-4, threshold_80_precision], [recall_80_precision, recall_80_precision], "r:")
plt.plot([threshold_80_precision], [0.8], "ro")
plt.plot([threshold_80_precision], [recall_80_precision], "ro")
plt.grid(True)
plt.legend()
plt.show()
我在嘗試復制本書中的代碼時遇到了這段代碼。 原來@ageron將所有資源都放在了他的 github 頁面上。 你可以在這里查看
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.