簡體   English   中英

通過樣本獲取特征重要性 - Python Scikit Learn

[英]Getting feature importance by sample - Python Scikit Learn

我有一個使用sklearn.ensemble.RandomForestClassifier的擬合模型( clf )。 我已經知道我可以使用clf.feature_importances_獲得特征重要性。 無論如何,我想知道,如果可能的話,如何通過每個樣本獲取特征重要性。

例子:

from sklearn.ensemble import RandomForestClassifier

X = {"f1":[0,1,1,0,1], "f2":[1,1,1,0,1]}
y = [0,1,0,1,0]

clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

y_pred = clf.predict(X)

然后,我如何得到這樣的東西:

y_pred f1_importance f2_importance
   1         0.57          0.43          
   1         0.26          0.74
   1         0.31          0.69
   0         0.62          0.38
   1         0.16          0.84

* y_pred值不是真實的。 我實際上是在 Python 3.8 中將pandas用於實際項目。

您可以使用treeinterpreter得到您的個人預測的功能重要性RandomForestClassifier

您可以在Github上找到treeinterpreter並通過以下方式安裝它

pip install treeinterpreter

我使用了您的參考代碼,但不得不對其進行調整,因為您不能使用字典作為輸入來適應您的RandomForestClassifier

from sklearn.ensemble import RandomForestClassifier
from treeinterpreter import treeinterpreter as ti
import numpy as np
import pandas as pd

X = np.array([[0,1],[1,1],[1,1],[0,0],[1,1]])
y = [0,1,0,1,0]

clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

y_pred = clf.predict(X)
y_pred_probas = clf.predict_proba(X)

然后我使用帶有分類器和數據的 treeinterpreter 來計算偏差、貢獻以及預測值:

prediction, bias, contributions = ti.predict(clf, X)

df = pd.DataFrame(data=np.matrix([y_pred, prediction.transpose()[0], prediction.transpose()[1], np.sum(contributions, axis=1).transpose()[0], bias.transpose()[0], np.sum(contributions, axis=1).transpose()[1], bias.transpose()[1]]).transpose(), columns=["Prediction", "Prediction value 0", "Prediction value 1", "f1_contribution", "f1_bias", "f2_contribution","f2_bias"])

df

輸出

在此處輸入圖片說明

您可以查看作者的這篇博文,以更好地了解它的工作原理。

在表中,0 和 1 的預測值是指兩個類的概率,您也可以使用RandomForestClassifier的現有predict_proba()方法RandomForestClassifier

您可以驗證,偏差和貢獻加起來像這樣的預測值/概率:

bias + np.sum(contributions, axis=1)

輸出

array([[0.744 , 0.256 ],
       [0.6565, 0.3435],
       [0.6565, 0.3435],
       [0.214 , 0.786 ],
       [0.6565, 0.3435]])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM