簡體   English   中英

繪制閾值(precision_recall 曲線)matplotlib/sklearn.metrics

[英]Plotting Threshold (precision_recall curve) matplotlib/sklearn.metrics

我正在嘗試 plot 我的精度/召回曲線的閾值。 我只是使用 MNSIT 數據,其中示例來自使用 scikit-learn、keras 和 TensorFlow 進行機器學習一書中的示例。 嘗試訓練 model 來檢測 5 的圖像。 我不知道你需要看多少代碼。 我已經為訓練集制作了混淆矩陣,並計算了精度和召回值以及閾值。 我已經繪制了 pre/rec 曲線,書中的示例說要添加軸 label、壁架、網格並突出顯示閾值,但代碼在我在下面放置星號的書中被切斷。 除了如何讓閾值顯示在 plot 上之外,我能夠弄清楚所有問題。 我已經附上了一張書中圖表與我所擁有的圖表的圖片。 這就是這本書所展示的:

在此處輸入圖像描述 與我的圖表:

在此處輸入圖像描述

我無法顯示帶有兩個閾值點的紅色虛線。 有誰知道我會怎么做? 下面是我的代碼:

from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

def plot_precision_recall_vs_thresholds(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recalls[:-1], "g--", label="Recall")
    plt.xlabel("Threshold")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.grid(b=True, which="both", axis="both", color='gray', linestyle='-', linewidth=1)

plot_precision_recall_vs_thresholds(precisions, recalls, thresholds)
plt.show()

我知道這里有很多關於 sklearn 的問題,但似乎沒有一個涵蓋讓紅線出現的問題。 我將非常感謝您的幫助!

您可以使用以下代碼繪制水平線和垂直線:

plt.axhline(y_value, c='r', ls=':')
plt.axvline(x_value, c='r', ls=':')

這應該以確切的方式工作:

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    recall_80_precision = recalls[np.argmax(precisions >= 0.80)]
    threshold_80_precision = thresholds[np.argmax(precisions >= 0.80)]
    
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
    plt.xlabel("Threshold")
    plt.plot([threshold_80_precision, threshold_80_precision], [0., 0.8], "r:")
    plt.axis([-4, 4, 0, 1])
    plt.plot([-4, threshold_80_precision], [0.8, 0.8], "r:")
    plt.plot([-4, threshold_80_precision], [recall_80_precision, recall_80_precision], "r:")
    plt.plot([threshold_80_precision], [0.8], "ro") 
    plt.plot([threshold_80_precision], [recall_80_precision], "ro")
    plt.grid(True)
    plt.legend()
    plt.show()

我在嘗試復制本書中的代碼時遇到了這段代碼。 原來@ageron將所有資源都放在了他的 github 頁面上。 你可以在這里查看

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM