[英]How do I get all Gini indices in my decision tree?
I have made a decision tree using sklearn, here, under the SciKit learn DL package, viz. 我使用sklearn创建了一个决策树,在这里,在SciKit学习DL包,即。
sklearn.tree.DecisionTreeClassifier().fit(x,y)
. sklearn.tree.DecisionTreeClassifier().fit(x,y)
。
How do I get the gini indices for all possible nodes at each step? 如何在每个步骤获取所有可能节点的gini索引?
graphviz
only gives me the gini index of the node with the lowest gini index, ie the node used for split. graphviz
只给出了具有最低gini索引的节点的gini索引,即用于拆分的节点。
For example, the image below (from graphviz
) tells me the gini score of the Pclass_lowVMid right index which is 0.408, but not the gini index of the Pclass_lower or Sex_male at that step. 例如,下面的图片(来自
graphviz
)告诉我Pclass_lowVMid权利索引的基尼评分为0.408,但不是该步骤中Pclass_lower或Sex_male的基尼指数。 I just know the Gini index of Pclass_lower and Sex_male must be greater than (0.408*0.7 + 0) but that's it. 我只知道Pclass_lower的Gini指数和Sex_male必须大于(0.408 * 0.7 + 0),但就是这样。
pclass node的gini索引=左节点的gini索引*(左节点的样本数/左节点的样本数+右节点的样本数)+右节点的gini索引*(左边的样本数)节点/没有左边节点的样本+右边节点的样本数量)所以这里它将是
Gini index of pclass = 0 + .408 *(7/10) = 0.2856
Using export_graphviz
shows impurity for all nodes, at least in version 0.20.1
. 使用
export_graphviz
显示所有节点的杂质,至少在版本0.20.1
。
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from graphviz import Source
data = load_iris()
X, y = data.data, data.target
clf = DecisionTreeClassifier(max_depth=2, random_state=42)
clf.fit(X, y)
graph = Source(export_graphviz(clf, out_file=None, feature_names=data.feature_names))
graph.format = 'png'
graph.render('dt', view=True);
The impurity values for all nodes are also accessible in the impurity
attribute of the tree
. 所有节点的杂质值也可以在
tree
的impurity
属性中访问。
clf.tree_.impurity
array([0.66666667, 0. , 0.5 , 0.16803841, 0.04253308])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.