How do I do the breadth first search traversal of the sklearn decision tree?
In my code i have tried sklearn.tree_ library and used various function such as tree_.feature and tree_.threshold to understand the structure of the tree. But these functions do the dfs traversal of the tree if I want to do bfs how should i do it?
Suppose
clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)
this is my classifier and the decision tree produced is
Then I have traversed the tree using following function
def encoding(clf, features):
l1 = list()
l2 = list()
for i in range(len(clf.tree_.feature)):
if(clf.tree_.feature[i]>=0):
l1.append( features[clf.tree_.feature[i]])
l2.append(clf.tree_.threshold[i])
else:
l1.append(None)
print(np.max(clf.tree_.value))
l2.append(np.argmax(clf.tree_.value[i]))
l = [l1 , l2]
return np.array(l)
and the output produced is
array([['address', 'age', None, None, 'age', None, None],
[0.5, 17.5, 2, 1, 15.5, 1, 1]], dtype=object)
where 1st array is feature of node or if it leaf noed then it is labelled as none and 2nd array is threshold for feature node and for class node it is class but this is dfs traversal of tree i want to do bfs traversal what should i do?
As I am new to stack overflow kindly suggest how to improve the question description and what other information should i add if any to explain my problem further.
This should do it:
from collections import deque
tree = clf.tree_
stack = deque()
stack.append(0) # push tree root to stack
while stack:
current_node = stack.popleft()
# do whatever you want with current node
# ...
left_child = tree.children_left[current_node]
if left_child >= 0:
stack.append(left_child)
right_child = tree.children_right[current_node]
if right_child >= 0:
stack.append(right_child)
This uses a deque
to keep a stack of the nodes to process next. Since we remove elements from the left and add them to the right, this should represent a breadth-first traversal.
For actual use, I suggest you turn this into a generator:
from collections import deque
def breadth_first_traversal(tree):
stack = deque()
stack.append(0)
while stack:
current_node = stack.popleft()
yield current_node
left_child = tree.children_left[current_node]
if left_child >= 0:
stack.append(left_child)
right_child = tree.children_right[current_node]
if right_child >= 0:
stack.append(right_child)
Then, you only need minimal changes to your original function:
def encoding(clf, features):
l1 = list()
l2 = list()
for i in breadth_first_traversal(clf.tree_):
if(clf.tree_.feature[i]>=0):
l1.append( features[clf.tree_.feature[i]])
l2.append(clf.tree_.threshold[i])
else:
l1.append(None)
print(np.max(clf.tree_.value))
l2.append(np.argmax(clf.tree_.value[i]))
l = [l1 , l2]
return np.array(l)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.