繁体   English   中英

通过样本获取特征重要性 - Python Scikit Learn

[英]Getting feature importance by sample - Python Scikit Learn

我有一个使用sklearn.ensemble.RandomForestClassifier的拟合模型( clf )。 我已经知道我可以使用clf.feature_importances_获得特征重要性。 无论如何,我想知道,如果可能的话,如何通过每个样本获取特征重要性。

例子:

from sklearn.ensemble import RandomForestClassifier

X = {"f1":[0,1,1,0,1], "f2":[1,1,1,0,1]}
y = [0,1,0,1,0]

clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

y_pred = clf.predict(X)

然后,我如何得到这样的东西:

y_pred f1_importance f2_importance
   1         0.57          0.43          
   1         0.26          0.74
   1         0.31          0.69
   0         0.62          0.38
   1         0.16          0.84

* y_pred值不是真实的。 我实际上是在 Python 3.8 中将pandas用于实际项目。

您可以使用treeinterpreter得到您的个人预测的功能重要性RandomForestClassifier

您可以在Github上找到treeinterpreter并通过以下方式安装它

pip install treeinterpreter

我使用了您的参考代码,但不得不对其进行调整,因为您不能使用字典作为输入来适应您的RandomForestClassifier

from sklearn.ensemble import RandomForestClassifier
from treeinterpreter import treeinterpreter as ti
import numpy as np
import pandas as pd

X = np.array([[0,1],[1,1],[1,1],[0,0],[1,1]])
y = [0,1,0,1,0]

clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

y_pred = clf.predict(X)
y_pred_probas = clf.predict_proba(X)

然后我使用带有分类器和数据的 treeinterpreter 来计算偏差、贡献以及预测值:

prediction, bias, contributions = ti.predict(clf, X)

df = pd.DataFrame(data=np.matrix([y_pred, prediction.transpose()[0], prediction.transpose()[1], np.sum(contributions, axis=1).transpose()[0], bias.transpose()[0], np.sum(contributions, axis=1).transpose()[1], bias.transpose()[1]]).transpose(), columns=["Prediction", "Prediction value 0", "Prediction value 1", "f1_contribution", "f1_bias", "f2_contribution","f2_bias"])

df

输出

在此处输入图片说明

您可以查看作者的这篇博文,以更好地了解它的工作原理。

在表中,0 和 1 的预测值是指两个类的概率,您也可以使用RandomForestClassifier的现有predict_proba()方法RandomForestClassifier

您可以验证,偏差和贡献加起来像这样的预测值/概率:

bias + np.sum(contributions, axis=1)

输出

array([[0.744 , 0.256 ],
       [0.6565, 0.3435],
       [0.6565, 0.3435],
       [0.214 , 0.786 ],
       [0.6565, 0.3435]])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM