简体   繁体   中英

Traversal of sklearn decision tree

How do I do the breadth first search traversal of the sklearn decision tree?

In my code i have tried sklearn.tree_ library and used various function such as tree_.feature and tree_.threshold to understand the structure of the tree. But these functions do the dfs traversal of the tree if I want to do bfs how should i do it?

Suppose

clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)

this is my classifier and the decision tree produced is

在此处输入图像描述

Then I have traversed the tree using following function

def encoding(clf, features):
l1 = list()
l2 = list()

for i in range(len(clf.tree_.feature)):
    if(clf.tree_.feature[i]>=0):
        l1.append( features[clf.tree_.feature[i]])
        l2.append(clf.tree_.threshold[i])
    else:
        l1.append(None)
        print(np.max(clf.tree_.value))
        l2.append(np.argmax(clf.tree_.value[i]))

l = [l1 , l2]

return np.array(l)

and the output produced is

array([['address', 'age', None, None, 'age', None, None],
       [0.5, 17.5, 2, 1, 15.5, 1, 1]], dtype=object)

where 1st array is feature of node or if it leaf noed then it is labelled as none and 2nd array is threshold for feature node and for class node it is class but this is dfs traversal of tree i want to do bfs traversal what should i do?

As I am new to stack overflow kindly suggest how to improve the question description and what other information should i add if any to explain my problem further.

X_train (sample) X_train

y_train (sample) y_train

This should do it:

from collections import deque

tree = clf.tree_

stack = deque()
stack.append(0)  # push tree root to stack

while stack:
    current_node = stack.popleft()

    # do whatever you want with current node
    # ...

    left_child = tree.children_left[current_node]
    if left_child >= 0:
        stack.append(left_child)

    right_child = tree.children_right[current_node]
    if right_child >= 0:
        stack.append(right_child)

This uses a deque to keep a stack of the nodes to process next. Since we remove elements from the left and add them to the right, this should represent a breadth-first traversal.


For actual use, I suggest you turn this into a generator:

from collections import deque

def breadth_first_traversal(tree):
    stack = deque()
    stack.append(0)

    while stack:
        current_node = stack.popleft()

        yield current_node

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append(left_child)

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append(right_child)

Then, you only need minimal changes to your original function:

def encoding(clf, features):
    l1 = list()
    l2 = list()

    for i in breadth_first_traversal(clf.tree_):
        if(clf.tree_.feature[i]>=0):
            l1.append( features[clf.tree_.feature[i]])
            l2.append(clf.tree_.threshold[i])
        else:
            l1.append(None)
            print(np.max(clf.tree_.value))
            l2.append(np.argmax(clf.tree_.value[i]))

    l = [l1 , l2]

    return np.array(l)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM