簡體   English   中英

如何在scikit-learn中獲取與predict_proba一起使用的cross_val_predict中的類標簽

[英]How to get classes labels from cross_val_predict used with predict_proba in scikit-learn

我需要使用3倍交叉驗證來訓練隨機森林分類器。 對於每個樣本,我需要在它恰好位於測試集中時檢索預測概率。

我正在使用scikit-learn版本0.18.dev0。

此新版本添加了使用方法cross_val_predict()和附加參數method來定義估算器需要哪種預測的功能。

在我的例子中,我想使用predict_proba()方法,該方法在多類場景中返回每個類的概率。

但是,當我運行該方法時,我得到預測概率矩陣,其中每行代表一個樣本,每列代表特定類的預測概率。

問題是該方法沒有指出哪個類對應於每列。

我需要的值與屬性classes_中返回的相同(在我的情況下使用RandomForestClassifier )定義為:

classes_:shape of array = [n_classes]或此類數組的列表類標簽(單輸出問題)或類標簽數組列表(多輸出問題)。

這是predict_proba()所需要的,因為在其文檔中寫道:

類的順序對應於屬性classes_中的順序。

最小的例子如下:

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict

clf = RandomForestClassifier()

X = np.random.randn(10, 10)
y = y = np.array([1] * 4 + [0] * 3 + [2] * 3)

# how to get classes from here?
proba = cross_val_predict(estimator=clf, X=X, y=y, method="predict_proba")

# using the classifier without cross-validation
# it is possible to get the classes in this way:
clf.fit(X, y)
proba = clf.predict_proba(X)
classes = clf.classes_

是的,它們將按順序排列; 這是因為DecisionTreeClassifier (這是默認base_estimatorRandomForestClassifier使用np.unique構建classes_屬性返回輸入數組的排序后的唯一值。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM