簡體   English   中英

如何為多類分類提取隨機森林樹規則?

[英]How to extract random forest tree rules for a Multiclass Classification?

嗨,我想在多類別分類的情況下從一棵樹中提取規則

from sklearn.tree import _tree 
from sklearn.tree import DecisionTreeClassifier

#creat a gaussian classifier
clf=RandomForestClassifier(n_estimators=100)

#train the model using the training sets y_pred=clf.predict(X_test)

clf.fit(X_train,y_train)

#extract one tree from the forest
model = clf.estimators_[0]


def find_rules(tree,features): 
    dt=tree.tree_
    def visitor(node,depth):
        indent= ' ' * depth
        if dt.feature[node] != _tree.TREE_UNDEFINED:
            print('{} if <{}> <= {}:'.format(indent,features[node],round(dt.threshold[node],100)))
            visitor(dt.children_left[node],depth+1)
            print('{}else:'.format(indent))
            visitor(dt.children_right[node],depth+1)
        else:
            print('{} return {}'.format(indent,dt.value[node]))
    visitor(0,1)


find_rules(model, iris.feature_names)


在此處輸入圖片說明

請檢查以下代碼。 它似乎有效。 只有一個小變化

def find_rules(tree,features): 
    dt=tree.tree_
    def visitor(node,depth):
        indent= ' ' * depth
        if dt.feature[node] != _tree.TREE_UNDEFINED:
            print('{} if <{}> <= {}:'.format(indent,features[dt.feature[node]],round(dt.threshold[node],100)))
            # in the previous line i added a backward-mapping
            # for the feature id
            visitor(dt.children_left[node],depth+1)
            print('{} else:'.format(indent))
            visitor(dt.children_right[node],depth+1)
        else:
            print('{} return {}'.format(indent,dt.value[node]))
    visitor(0,1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM