簡體   English   中英

在scikit-learn中使用交叉驗證時繪制Precision-Recall曲線

[英]Plotting Precision-Recall curve when using cross-validation in scikit-learn

我正在使用交叉驗證來評估帶有scikit-learn的分類器的性能,我想繪制Precision-Recall曲線。 我在scikit-learn網站上找到了一個繪制PR曲線的示例 ,但它沒有使用交叉驗證進行評估。

在使用交叉驗證時,如何繪制scikit中的Precision-Recall曲線?

我做了以下但我不確定這是否是正確的方法(psudo代碼):

for each k-fold:

   precision, recall, _ =  precision_recall_curve(y_test, probs)
   mean_precision += precision
   mean_recall += recall

mean_precision /= num_folds
mean_recall /= num_folds

plt.plot(recall, precision)

你怎么看?

編輯:

它不起作用,因為每次折疊后precisionrecall陣列的大小不同。

任何人?

而不是在每次折疊后記錄精確度和召回值,而是在每次折疊后將預測存儲在測試樣本上。 接下來, 收集所有測試(即袋外)預測並計算精度和召回率。

 ## let test_samples[k] = test samples for the kth fold (list of list)
 ## let train_samples[k] = test samples for the kth fold (list of list)

 for k in range(0, k):
      model = train(parameters, train_samples[k])
      predictions_fold[k] = predict(model, test_samples[k])

 # collect predictions
 predictions_combined = [p for preds in predictions_fold for p in preds]

 ## let predictions = rearranged predictions s.t. they are in the original order

 ## use predictions and labels to compute lists of TP, FP, FN
 ## use TP, FP, FN to compute precisions and recalls for one run of k-fold cross-validation

在單次,完整的k-fold交叉驗證運行中,預測器對每個樣本進行一次且僅一次預測。 給定n個樣本,您應該有n個測試預測。

(注意:這些預測與訓練預測不同,因為預測器會對每個樣本進行預測,而不會事先看到它。)

除非您使用留一法交叉驗證 ,否則k折交叉驗證通常需要對數據進行隨機分區。 理想情況下,您將進行重復 (和分層 )k折交叉驗證。 然而,組合來自不同輪次的精確回憶曲線並不是直截了當的,因為與ROC不同,您不能在精確回憶點之間使用簡單的線性插值(參見Davis和Goadrich 2006 )。

我親自使用Davis-Goadrich方法計算AUC-PR用於PR空間中的插值(隨后進行數值積分),並使用來自重復分層10倍交叉驗證的AUC-PR估計來比較分類器。

對於一個不錯的情節,我展示了一個交叉驗證輪次的代表性PR曲線。

當然,還有許多其他評估分類器性能的方法,具體取決於數據集的性質。

例如,如果數據集中(二進制)標簽的比例沒有偏差(即大約為50-50),則可以使用更簡單的ROC分析和交叉驗證:

收集每個折疊的預測並構建ROC曲線(如前所述),收集所有TPR-FPR點(即采用所有TPR-FPR元組的並集),然后繪制可能平滑的組合點集。 可選地,使用簡單線性插值和用於數值積分的復合梯形方法計算AUC-ROC。

這是使用交叉驗證繪制sklearn分類器的Precision Recall曲線的最佳方法。 最好的部分是,它繪制了所有類的PR曲線,因此您也可以獲得多條整齊的曲線

from scikitplot.classifiers import plot_precision_recall_curve
import matplotlib.pyplot as plt

clf = LogisticRegression()
plot_precision_recall_curve(clf, X, y)
plt.show()

該功能自動負責交叉驗證給定數據集,連接所有折疊預測,並計算每個類別的PR曲線+平均PR曲線。 它是一個單行功能,可以為您完成所有這些功能。

精確回憶曲線

免責聲明:請注意,這使用我構建的scikit-plot庫。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM