I'm new to cuML
and I have a decision tree classifier using scikit learn. I would like to perform some hyperparameter search using the GPU, so I started looking cuML
. There is no DecisionTreeClassifier
in cuML, but it can be reproduced by using RandomForestClassifier
with 1 tree and no bootstrap, as far as I read on other SO posts.
My problem is how can I extract the tree and all the rules (the leafs and nodes) using cuML
RandomForestClassifier
? Or I should be looking to other algorithms like XGBoost
?
Access to the underlying decision trees or information isn't necessary to do hyperparameter optimization.
With that said, you can access summary information about the underlying trees and leaf predictions like this:
from cuml.ensemble import RandomForestClassifier
from cuml.datasets import make_classification
N = 100
K = 10
X, y = make_classification(
n_samples=N,
n_features=K,
n_informative=K,
n_redundant=0
)
clf = RandomForestClassifier(n_estimators=2)
clf.fit(X, y)
print(clf.get_summary_text())
print(clf.get_detailed_text())
print(clf.get_json())
Forest has 2 trees, max_depth 16, and max_leaves -1
Tree #0
Decision Tree depth --> 9 and n_leaves --> 18
Tree Fitting - Overall time --> 1.216 milliseconds
Tree #1
Decision Tree depth --> 7 and n_leaves --> 16
Tree Fitting - Overall time --> 1.919 milliseconds
Forest has 2 trees, max_depth 16, and max_leaves -1
Tree #0
Decision Tree depth --> 9 and n_leaves --> 18
Tree Fitting - Overall time --> 1.216 milliseconds
└(colid: 7, quesval: 2.73323, best_metric_val: 0.0407427)
├(colid: 9, quesval: -0.233239, best_metric_val: 0.116631)
│ ├(colid: 2, quesval: -1.48028, best_metric_val: 0.045858)
│ │ ├(colid: 8, quesval: -1.14041, best_metric_val: 0.28125)
│ │ │ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ │ └(colid: 1, quesval: 0.720062, best_metric_val: 0.375)
│ │ │ ├(leaf, prediction: [1, 0], best_metric_val: 0)
│ │ │ └(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ └(leaf, prediction: [0, 1], best_metric_val: 0)
│ └(colid: 3, quesval: -1.01601, best_metric_val: 0.313368)
│ ├(colid: 8, quesval: 1.68195, best_metric_val: 0.0131944)
│ │ ├(leaf, prediction: [1, 0], best_metric_val: 0)
│ │ └(colid: 6, quesval: -0.458985, best_metric_val: 0.32)
│ │ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ └(leaf, prediction: [1, 0], best_metric_val: 0)
│ └(colid: 7, quesval: -2.86422, best_metric_val: 0.126263)
│ ├(leaf, prediction: [1, 0], best_metric_val: 0)
│ └(colid: 8, quesval: 1.3618, best_metric_val: 0.0198347)
│ ├(colid: 9, quesval: 1.96266, best_metric_val: 0.142222)
│ │ ├(colid: 5, quesval: -0.427346, best_metric_val: 0.0308642)
│ │ │ ├(colid: 8, quesval: -0.295362, best_metric_val: 0.125)
│ │ │ │ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ │ │ └(colid: 6, quesval: 1.99819, best_metric_val: 0.5)
│ │ │ │ ├(leaf, prediction: [1, 0], best_metric_val: 0)
│ │ │ │ └(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ │ └(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ └(leaf, prediction: [1, 0], best_metric_val: 0)
│ └(leaf, prediction: [0, 1], best_metric_val: 0)
└(colid: 3, quesval: 1.4614, best_metric_val: 0.239645)
├(leaf, prediction: [1, 0], best_metric_val: 0)
└(colid: 7, quesval: 3.80204, best_metric_val: 0.125)
├(leaf, prediction: [0, 1], best_metric_val: 0)
└(colid: 8, quesval: 0.637938, best_metric_val: 0.5)
├(leaf, prediction: [0, 1], best_metric_val: 0)
└(leaf, prediction: [1, 0], best_metric_val: 0)
Tree #1
Decision Tree depth --> 7 and n_leaves --> 16
Tree Fitting - Overall time --> 1.919 milliseconds
└(colid: 8, quesval: -1.19294, best_metric_val: 0.111478)
├(colid: 7, quesval: -2.32102, best_metric_val: 0.0867768)
│ ├(leaf, prediction: [1, 0], best_metric_val: 0)
│ └(leaf, prediction: [0, 1], best_metric_val: 0)
└(colid: 3, quesval: 0.590359, best_metric_val: 0.180291)
├(colid: 6, quesval: -2.11692, best_metric_val: 0.126613)
│ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ └(colid: 5, quesval: -1.94796, best_metric_val: 0.0655193)
│ ├(colid: 6, quesval: 1.18255, best_metric_val: 0.489796)
│ │ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ └(leaf, prediction: [1, 0], best_metric_val: 0)
│ └(colid: 8, quesval: 3.48108, best_metric_val: 0.0196773)
│ ├(colid: 5, quesval: 0.71779, best_metric_val: 0.00283446)
│ │ ├(colid: 4, quesval: 1.85633, best_metric_val: 1.19209e-07)
│ │ │ ├(leaf, prediction: [1, 0], best_metric_val: 0)
│ │ │ └(leaf, prediction: [1, 0], best_metric_val: 0)
│ │ └(colid: 5, quesval: 0.815552, best_metric_val: 0.152778)
│ │ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ └(leaf, prediction: [1, 0], best_metric_val: 0)
│ └(colid: 9, quesval: 0.690919, best_metric_val: 0.5)
│ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ └(leaf, prediction: [1, 0], best_metric_val: 0)
└(colid: 6, quesval: 2.16413, best_metric_val: 0.071035)
├(colid: 7, quesval: 3.80204, best_metric_val: 0.0818594)
│ ├(colid: 9, quesval: 1.33454, best_metric_val: 0.02)
│ │ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ └(colid: 5, quesval: 0.0840077, best_metric_val: 0.375)
│ │ ├(leaf, prediction: [0, 1], best_metric_val: 0)
│ │ └(leaf, prediction: [1, 0], best_metric_val: 0)
│ └(leaf, prediction: [1, 0], best_metric_val: 0)
└(leaf, prediction: [1, 0], best_metric_val: 0)
[
{"nodeid": 0, "split_feature": 7, "split_threshold": 2.7332263, "gain": 0.0407427214, "instance_count": 100, "yes": 1, "no": 2, "children": [
{"nodeid": 1, "split_feature": 9, "split_threshold": -0.233238578, "gain": 0.116630867, "instance_count": 87, "yes": 3, "no": 4, "children": [
{"nodeid": 3, "split_feature": 2, "split_threshold": -1.48028064, "gain": 0.0458579995, "instance_count": 39, "yes": 7, "no": 8, "children": [
{"nodeid": 7, "split_feature": 8, "split_threshold": -1.1404053, "gain": 0.28125, "instance_count": 8, "yes": 13, "no": 14, "children": [
{"nodeid": 13, "leaf_value": [0, 1], "instance_count": 4},
{"nodeid": 14, "split_feature": 1, "split_threshold": 0.720061541, "gain": 0.375, "instance_count": 4, "yes": 21, "no": 22, "children": [
{"nodeid": 21, "leaf_value": [1, 0], "instance_count": 3},
{"nodeid": 22, "leaf_value": [0, 1], "instance_count": 1}
]}
]},
{"nodeid": 8, "leaf_value": [0, 1], "instance_count": 31}
]},
{"nodeid": 4, "split_feature": 3, "split_threshold": -1.01600909, "gain": 0.313368142, "instance_count": 48, "yes": 9, "no": 10, "children": [
{"nodeid": 9, "split_feature": 8, "split_threshold": 1.68195295, "gain": 0.0131943803, "instance_count": 24, "yes": 15, "no": 16, "children": [
{"nodeid": 15, "leaf_value": [1, 0], "instance_count": 19},
{"nodeid": 16, "split_feature": 6, "split_threshold": -0.458984971, "gain": 0.320000023, "instance_count": 5, "yes": 23, "no": 24, "children": [
{"nodeid": 23, "leaf_value": [0, 1], "instance_count": 1},
{"nodeid": 24, "leaf_value": [1, 0], "instance_count": 4}
]}
]},
{"nodeid": 10, "split_feature": 7, "split_threshold": -2.86421776, "gain": 0.126262575, "instance_count": 24, "yes": 17, "no": 18, "children": [
{"nodeid": 17, "leaf_value": [1, 0], "instance_count": 2},
{"nodeid": 18, "split_feature": 8, "split_threshold": 1.36179876, "gain": 0.0198347215, "instance_count": 22, "yes": 25, "no": 26, "children": [
{"nodeid": 25, "split_feature": 9, "split_threshold": 1.96266103, "gain": 0.142222196, "instance_count": 10, "yes": 27, "no": 28, "children": [
{"nodeid": 27, "split_feature": 5, "split_threshold": -0.427345634, "gain": 0.0308641735, "instance_count": 9, "yes": 29, "no": 30, "children": [
{"nodeid": 29, "split_feature": 8, "split_threshold": -0.295361876, "gain": 0.125, "instance_count": 4, "yes": 31, "no": 32, "children": [
{"nodeid": 31, "leaf_value": [0, 1], "instance_count": 2},
{"nodeid": 32, "split_feature": 6, "split_threshold": 1.99819326, "gain": 0.5, "instance_count": 2, "yes": 33, "no": 34, "children": [
{"nodeid": 33, "leaf_value": [1, 0], "instance_count": 1},
{"nodeid": 34, "leaf_value": [0, 1], "instance_count": 1}
]}
]},
{"nodeid": 30, "leaf_value": [0, 1], "instance_count": 5}
]},
{"nodeid": 28, "leaf_value": [1, 0], "instance_count": 1}
]},
{"nodeid": 26, "leaf_value": [0, 1], "instance_count": 12}
]}
]}
]}
]},
{"nodeid": 2, "split_feature": 3, "split_threshold": 1.46139979, "gain": 0.239644989, "instance_count": 13, "yes": 5, "no": 6, "children": [
{"nodeid": 5, "leaf_value": [1, 0], "instance_count": 9},
{"nodeid": 6, "split_feature": 7, "split_threshold": 3.8020432, "gain": 0.125, "instance_count": 4, "yes": 11, "no": 12, "children": [
{"nodeid": 11, "leaf_value": [0, 1], "instance_count": 2},
{"nodeid": 12, "split_feature": 8, "split_threshold": 0.637937546, "gain": 0.5, "instance_count": 2, "yes": 19, "no": 20, "children": [
{"nodeid": 19, "leaf_value": [0, 1], "instance_count": 1},
{"nodeid": 20, "leaf_value": [1, 0], "instance_count": 1}
]}
]}
]}
]},
{"nodeid": 0, "split_feature": 8, "split_threshold": -1.19294095, "gain": 0.111478344, "instance_count": 100, "yes": 1, "no": 2, "children": [
{"nodeid": 1, "split_feature": 7, "split_threshold": -2.3210218, "gain": 0.0867768154, "instance_count": 22, "yes": 3, "no": 4, "children": [
{"nodeid": 3, "leaf_value": [1, 0], "instance_count": 1},
{"nodeid": 4, "leaf_value": [0, 1], "instance_count": 21}
]},
{"nodeid": 2, "split_feature": 3, "split_threshold": 0.590358853, "gain": 0.180290893, "instance_count": 78, "yes": 5, "no": 6, "children": [
{"nodeid": 5, "split_feature": 6, "split_threshold": -2.1169188, "gain": 0.12661314, "instance_count": 56, "yes": 7, "no": 8, "children": [
{"nodeid": 7, "leaf_value": [0, 1], "instance_count": 5},
{"nodeid": 8, "split_feature": 5, "split_threshold": -1.94796324, "gain": 0.065519318, "instance_count": 51, "yes": 11, "no": 12, "children": [
{"nodeid": 11, "split_feature": 6, "split_threshold": 1.18254995, "gain": 0.489795923, "instance_count": 7, "yes": 15, "no": 16, "children": [
{"nodeid": 15, "leaf_value": [0, 1], "instance_count": 4},
{"nodeid": 16, "leaf_value": [1, 0], "instance_count": 3}
]},
{"nodeid": 12, "split_feature": 8, "split_threshold": 3.48108315, "gain": 0.0196772516, "instance_count": 44, "yes": 17, "no": 18, "children": [
{"nodeid": 17, "split_feature": 5, "split_threshold": 0.717789888, "gain": 0.00283446093, "instance_count": 42, "yes": 21, "no": 22, "children": [
{"nodeid": 21, "split_feature": 4, "split_threshold": 1.85632861, "gain": 1.1920929e-07, "instance_count": 30, "yes": 27, "no": 28, "children": [
{"nodeid": 27, "leaf_value": [1, 0], "instance_count": 19},
{"nodeid": 28, "leaf_value": [1, 0], "instance_count": 11}
]},
{"nodeid": 22, "split_feature": 5, "split_threshold": 0.815551639, "gain": 0.152777761, "instance_count": 12, "yes": 29, "no": 30, "children": [
{"nodeid": 29, "leaf_value": [0, 1], "instance_count": 1},
{"nodeid": 30, "leaf_value": [1, 0], "instance_count": 11}
]}
]},
{"nodeid": 18, "split_feature": 9, "split_threshold": 0.690918803, "gain": 0.5, "instance_count": 2, "yes": 23, "no": 24, "children": [
{"nodeid": 23, "leaf_value": [0, 1], "instance_count": 1},
{"nodeid": 24, "leaf_value": [1, 0], "instance_count": 1}
]}
]}
]}
]},
{"nodeid": 6, "split_feature": 6, "split_threshold": 2.1641295, "gain": 0.0710349679, "instance_count": 22, "yes": 9, "no": 10, "children": [
{"nodeid": 9, "split_feature": 7, "split_threshold": 3.8020432, "gain": 0.0818593949, "instance_count": 21, "yes": 13, "no": 14, "children": [
{"nodeid": 13, "split_feature": 9, "split_threshold": 1.33453584, "gain": 0.0200000368, "instance_count": 20, "yes": 19, "no": 20, "children": [
{"nodeid": 19, "leaf_value": [0, 1], "instance_count": 16},
{"nodeid": 20, "split_feature": 5, "split_threshold": 0.08400774, "gain": 0.375, "instance_count": 4, "yes": 25, "no": 26, "children": [
{"nodeid": 25, "leaf_value": [0, 1], "instance_count": 3},
{"nodeid": 26, "leaf_value": [1, 0], "instance_count": 1}
]}
]},
{"nodeid": 14, "leaf_value": [1, 0], "instance_count": 1}
]},
{"nodeid": 10, "leaf_value": [1, 0], "instance_count": 1}
]}
]}
]}
]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.