![](/img/trans.png)
[英]How can I use k-fold cross-validation in scikit-learn to get precision-recall per fold?
[英]Plotting Precision-Recall curve when using cross-validation in scikit-learn
我正在使用交叉驗證來評估帶有scikit-learn
的分類器的性能,我想繪制Precision-Recall曲線。 我在scikit-learn網站上找到了一個繪制PR曲線的示例 ,但它沒有使用交叉驗證進行評估。
在使用交叉驗證時,如何繪制scikit中的Precision-Recall曲線?
我做了以下但我不確定這是否是正確的方法(psudo代碼):
for each k-fold:
precision, recall, _ = precision_recall_curve(y_test, probs)
mean_precision += precision
mean_recall += recall
mean_precision /= num_folds
mean_recall /= num_folds
plt.plot(recall, precision)
你怎么看?
編輯:
它不起作用,因為每次折疊后precision
和recall
陣列的大小不同。
任何人?
而不是在每次折疊后記錄精確度和召回值,而是在每次折疊后將預測存儲在測試樣本上。 接下來, 收集所有測試(即袋外)預測並計算精度和召回率。
## let test_samples[k] = test samples for the kth fold (list of list)
## let train_samples[k] = test samples for the kth fold (list of list)
for k in range(0, k):
model = train(parameters, train_samples[k])
predictions_fold[k] = predict(model, test_samples[k])
# collect predictions
predictions_combined = [p for preds in predictions_fold for p in preds]
## let predictions = rearranged predictions s.t. they are in the original order
## use predictions and labels to compute lists of TP, FP, FN
## use TP, FP, FN to compute precisions and recalls for one run of k-fold cross-validation
在單次,完整的k-fold交叉驗證運行中,預測器對每個樣本進行一次且僅一次預測。 給定n個樣本,您應該有n個測試預測。
(注意:這些預測與訓練預測不同,因為預測器會對每個樣本進行預測,而不會事先看到它。)
除非您使用留一法交叉驗證 ,否則k折交叉驗證通常需要對數據進行隨機分區。 理想情況下,您將進行重復 (和分層 )k折交叉驗證。 然而,組合來自不同輪次的精確回憶曲線並不是直截了當的,因為與ROC不同,您不能在精確回憶點之間使用簡單的線性插值(參見Davis和Goadrich 2006 )。
我親自使用Davis-Goadrich方法計算AUC-PR用於PR空間中的插值(隨后進行數值積分),並使用來自重復分層10倍交叉驗證的AUC-PR估計來比較分類器。
對於一個不錯的情節,我展示了一個交叉驗證輪次的代表性PR曲線。
當然,還有許多其他評估分類器性能的方法,具體取決於數據集的性質。
例如,如果數據集中(二進制)標簽的比例沒有偏差(即大約為50-50),則可以使用更簡單的ROC分析和交叉驗證:
收集每個折疊的預測並構建ROC曲線(如前所述),收集所有TPR-FPR點(即采用所有TPR-FPR元組的並集),然后繪制可能平滑的組合點集。 可選地,使用簡單線性插值和用於數值積分的復合梯形方法計算AUC-ROC。
這是使用交叉驗證繪制sklearn分類器的Precision Recall曲線的最佳方法。 最好的部分是,它繪制了所有類的PR曲線,因此您也可以獲得多條整齊的曲線
from scikitplot.classifiers import plot_precision_recall_curve
import matplotlib.pyplot as plt
clf = LogisticRegression()
plot_precision_recall_curve(clf, X, y)
plt.show()
該功能自動負責交叉驗證給定數據集,連接所有折疊預測,並計算每個類別的PR曲線+平均PR曲線。 它是一個單行功能,可以為您完成所有這些功能。
免責聲明:請注意,這使用我構建的scikit-plot庫。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.