繁体   English   中英

我如何获得 cuML RandomForestClassifier 叶子?

[英]How can I get cuML RandomForestClassifier leafs?

我是cuML的新手,我有一个使用 scikit 学习的决策树分类器。 我想使用 GPU 执行一些超参数搜索,所以我开始寻找cuML cuML 中没有DecisionTreeClassifier ,但据我在其他 SO 帖子中所读,它可以通过使用具有 1 棵树且没有引导程序的RandomForestClassifier来重现。

我的问题是如何使用cuML RandomForestClassifier提取树和所有规则(叶子和节点)? 或者我应该寻找像XGBoost这样的其他算法?

进行超参数优化不需要访问底层决策树或信息。

话虽如此,您可以像这样访问有关底层树和叶预测的摘要信息:

from cuml.ensemble import RandomForestClassifier
from cuml.datasets import make_classification

N = 100
K = 10

X, y = make_classification(
    n_samples=N,
    n_features=K,
    n_informative=K,
    n_redundant=0
)

clf = RandomForestClassifier(n_estimators=2)
clf.fit(X, y)

print(clf.get_summary_text())
print(clf.get_detailed_text())
print(clf.get_json())
Forest has 2 trees, max_depth 16, and max_leaves -1
Tree #0
 Decision Tree depth --> 9 and n_leaves --> 18
 Tree Fitting - Overall time --> 1.216 milliseconds

Tree #1
 Decision Tree depth --> 7 and n_leaves --> 16
 Tree Fitting - Overall time --> 1.919 milliseconds


Forest has 2 trees, max_depth 16, and max_leaves -1
Tree #0
 Decision Tree depth --> 9 and n_leaves --> 18
 Tree Fitting - Overall time --> 1.216 milliseconds

└(colid: 7, quesval: 2.73323, best_metric_val: 0.0407427)
    ├(colid: 9, quesval: -0.233239, best_metric_val: 0.116631)
    │   ├(colid: 2, quesval: -1.48028, best_metric_val: 0.045858)
    │   │   ├(colid: 8, quesval: -1.14041, best_metric_val: 0.28125)
    │   │   │   ├(leaf, prediction: [0, 1], best_metric_val: 0)
    │   │   │   └(colid: 1, quesval: 0.720062, best_metric_val: 0.375)
    │   │   │       ├(leaf, prediction: [1, 0], best_metric_val: 0)
    │   │   │       └(leaf, prediction: [0, 1], best_metric_val: 0)
    │   │   └(leaf, prediction: [0, 1], best_metric_val: 0)
    │   └(colid: 3, quesval: -1.01601, best_metric_val: 0.313368)
    │       ├(colid: 8, quesval: 1.68195, best_metric_val: 0.0131944)
    │       │   ├(leaf, prediction: [1, 0], best_metric_val: 0)
    │       │   └(colid: 6, quesval: -0.458985, best_metric_val: 0.32)
    │       │       ├(leaf, prediction: [0, 1], best_metric_val: 0)
    │       │       └(leaf, prediction: [1, 0], best_metric_val: 0)
    │       └(colid: 7, quesval: -2.86422, best_metric_val: 0.126263)
    │           ├(leaf, prediction: [1, 0], best_metric_val: 0)
    │           └(colid: 8, quesval: 1.3618, best_metric_val: 0.0198347)
    │               ├(colid: 9, quesval: 1.96266, best_metric_val: 0.142222)
    │               │   ├(colid: 5, quesval: -0.427346, best_metric_val: 0.0308642)
    │               │   │   ├(colid: 8, quesval: -0.295362, best_metric_val: 0.125)
    │               │   │   │   ├(leaf, prediction: [0, 1], best_metric_val: 0)
    │               │   │   │   └(colid: 6, quesval: 1.99819, best_metric_val: 0.5)
    │               │   │   │       ├(leaf, prediction: [1, 0], best_metric_val: 0)
    │               │   │   │       └(leaf, prediction: [0, 1], best_metric_val: 0)
    │               │   │   └(leaf, prediction: [0, 1], best_metric_val: 0)
    │               │   └(leaf, prediction: [1, 0], best_metric_val: 0)
    │               └(leaf, prediction: [0, 1], best_metric_val: 0)
    └(colid: 3, quesval: 1.4614, best_metric_val: 0.239645)
        ├(leaf, prediction: [1, 0], best_metric_val: 0)
        └(colid: 7, quesval: 3.80204, best_metric_val: 0.125)
            ├(leaf, prediction: [0, 1], best_metric_val: 0)
            └(colid: 8, quesval: 0.637938, best_metric_val: 0.5)
                ├(leaf, prediction: [0, 1], best_metric_val: 0)
                └(leaf, prediction: [1, 0], best_metric_val: 0)
Tree #1
 Decision Tree depth --> 7 and n_leaves --> 16
 Tree Fitting - Overall time --> 1.919 milliseconds

└(colid: 8, quesval: -1.19294, best_metric_val: 0.111478)
    ├(colid: 7, quesval: -2.32102, best_metric_val: 0.0867768)
    │   ├(leaf, prediction: [1, 0], best_metric_val: 0)
    │   └(leaf, prediction: [0, 1], best_metric_val: 0)
    └(colid: 3, quesval: 0.590359, best_metric_val: 0.180291)
        ├(colid: 6, quesval: -2.11692, best_metric_val: 0.126613)
        │   ├(leaf, prediction: [0, 1], best_metric_val: 0)
        │   └(colid: 5, quesval: -1.94796, best_metric_val: 0.0655193)
        │       ├(colid: 6, quesval: 1.18255, best_metric_val: 0.489796)
        │       │   ├(leaf, prediction: [0, 1], best_metric_val: 0)
        │       │   └(leaf, prediction: [1, 0], best_metric_val: 0)
        │       └(colid: 8, quesval: 3.48108, best_metric_val: 0.0196773)
        │           ├(colid: 5, quesval: 0.71779, best_metric_val: 0.00283446)
        │           │   ├(colid: 4, quesval: 1.85633, best_metric_val: 1.19209e-07)
        │           │   │   ├(leaf, prediction: [1, 0], best_metric_val: 0)
        │           │   │   └(leaf, prediction: [1, 0], best_metric_val: 0)
        │           │   └(colid: 5, quesval: 0.815552, best_metric_val: 0.152778)
        │           │       ├(leaf, prediction: [0, 1], best_metric_val: 0)
        │           │       └(leaf, prediction: [1, 0], best_metric_val: 0)
        │           └(colid: 9, quesval: 0.690919, best_metric_val: 0.5)
        │               ├(leaf, prediction: [0, 1], best_metric_val: 0)
        │               └(leaf, prediction: [1, 0], best_metric_val: 0)
        └(colid: 6, quesval: 2.16413, best_metric_val: 0.071035)
            ├(colid: 7, quesval: 3.80204, best_metric_val: 0.0818594)
            │   ├(colid: 9, quesval: 1.33454, best_metric_val: 0.02)
            │   │   ├(leaf, prediction: [0, 1], best_metric_val: 0)
            │   │   └(colid: 5, quesval: 0.0840077, best_metric_val: 0.375)
            │   │       ├(leaf, prediction: [0, 1], best_metric_val: 0)
            │   │       └(leaf, prediction: [1, 0], best_metric_val: 0)
            │   └(leaf, prediction: [1, 0], best_metric_val: 0)
            └(leaf, prediction: [1, 0], best_metric_val: 0)

[
{"nodeid": 0, "split_feature": 7, "split_threshold": 2.7332263, "gain": 0.0407427214, "instance_count": 100, "yes": 1, "no": 2, "children": [
  {"nodeid": 1, "split_feature": 9, "split_threshold": -0.233238578, "gain": 0.116630867, "instance_count": 87, "yes": 3, "no": 4, "children": [
    {"nodeid": 3, "split_feature": 2, "split_threshold": -1.48028064, "gain": 0.0458579995, "instance_count": 39, "yes": 7, "no": 8, "children": [
      {"nodeid": 7, "split_feature": 8, "split_threshold": -1.1404053, "gain": 0.28125, "instance_count": 8, "yes": 13, "no": 14, "children": [
        {"nodeid": 13, "leaf_value": [0, 1], "instance_count": 4},
        {"nodeid": 14, "split_feature": 1, "split_threshold": 0.720061541, "gain": 0.375, "instance_count": 4, "yes": 21, "no": 22, "children": [
          {"nodeid": 21, "leaf_value": [1, 0], "instance_count": 3},
          {"nodeid": 22, "leaf_value": [0, 1], "instance_count": 1}
        ]}
      ]},
      {"nodeid": 8, "leaf_value": [0, 1], "instance_count": 31}
    ]},
    {"nodeid": 4, "split_feature": 3, "split_threshold": -1.01600909, "gain": 0.313368142, "instance_count": 48, "yes": 9, "no": 10, "children": [
      {"nodeid": 9, "split_feature": 8, "split_threshold": 1.68195295, "gain": 0.0131943803, "instance_count": 24, "yes": 15, "no": 16, "children": [
        {"nodeid": 15, "leaf_value": [1, 0], "instance_count": 19},
        {"nodeid": 16, "split_feature": 6, "split_threshold": -0.458984971, "gain": 0.320000023, "instance_count": 5, "yes": 23, "no": 24, "children": [
          {"nodeid": 23, "leaf_value": [0, 1], "instance_count": 1},
          {"nodeid": 24, "leaf_value": [1, 0], "instance_count": 4}
        ]}
      ]},
      {"nodeid": 10, "split_feature": 7, "split_threshold": -2.86421776, "gain": 0.126262575, "instance_count": 24, "yes": 17, "no": 18, "children": [
        {"nodeid": 17, "leaf_value": [1, 0], "instance_count": 2},
        {"nodeid": 18, "split_feature": 8, "split_threshold": 1.36179876, "gain": 0.0198347215, "instance_count": 22, "yes": 25, "no": 26, "children": [
          {"nodeid": 25, "split_feature": 9, "split_threshold": 1.96266103, "gain": 0.142222196, "instance_count": 10, "yes": 27, "no": 28, "children": [
            {"nodeid": 27, "split_feature": 5, "split_threshold": -0.427345634, "gain": 0.0308641735, "instance_count": 9, "yes": 29, "no": 30, "children": [
              {"nodeid": 29, "split_feature": 8, "split_threshold": -0.295361876, "gain": 0.125, "instance_count": 4, "yes": 31, "no": 32, "children": [
                {"nodeid": 31, "leaf_value": [0, 1], "instance_count": 2},
                {"nodeid": 32, "split_feature": 6, "split_threshold": 1.99819326, "gain": 0.5, "instance_count": 2, "yes": 33, "no": 34, "children": [
                  {"nodeid": 33, "leaf_value": [1, 0], "instance_count": 1},
                  {"nodeid": 34, "leaf_value": [0, 1], "instance_count": 1}
                ]}
              ]},
              {"nodeid": 30, "leaf_value": [0, 1], "instance_count": 5}
            ]},
            {"nodeid": 28, "leaf_value": [1, 0], "instance_count": 1}
          ]},
          {"nodeid": 26, "leaf_value": [0, 1], "instance_count": 12}
        ]}
      ]}
    ]}
  ]},
  {"nodeid": 2, "split_feature": 3, "split_threshold": 1.46139979, "gain": 0.239644989, "instance_count": 13, "yes": 5, "no": 6, "children": [
    {"nodeid": 5, "leaf_value": [1, 0], "instance_count": 9},
    {"nodeid": 6, "split_feature": 7, "split_threshold": 3.8020432, "gain": 0.125, "instance_count": 4, "yes": 11, "no": 12, "children": [
      {"nodeid": 11, "leaf_value": [0, 1], "instance_count": 2},
      {"nodeid": 12, "split_feature": 8, "split_threshold": 0.637937546, "gain": 0.5, "instance_count": 2, "yes": 19, "no": 20, "children": [
        {"nodeid": 19, "leaf_value": [0, 1], "instance_count": 1},
        {"nodeid": 20, "leaf_value": [1, 0], "instance_count": 1}
      ]}
    ]}
  ]}
]},
{"nodeid": 0, "split_feature": 8, "split_threshold": -1.19294095, "gain": 0.111478344, "instance_count": 100, "yes": 1, "no": 2, "children": [
  {"nodeid": 1, "split_feature": 7, "split_threshold": -2.3210218, "gain": 0.0867768154, "instance_count": 22, "yes": 3, "no": 4, "children": [
    {"nodeid": 3, "leaf_value": [1, 0], "instance_count": 1},
    {"nodeid": 4, "leaf_value": [0, 1], "instance_count": 21}
  ]},
  {"nodeid": 2, "split_feature": 3, "split_threshold": 0.590358853, "gain": 0.180290893, "instance_count": 78, "yes": 5, "no": 6, "children": [
    {"nodeid": 5, "split_feature": 6, "split_threshold": -2.1169188, "gain": 0.12661314, "instance_count": 56, "yes": 7, "no": 8, "children": [
      {"nodeid": 7, "leaf_value": [0, 1], "instance_count": 5},
      {"nodeid": 8, "split_feature": 5, "split_threshold": -1.94796324, "gain": 0.065519318, "instance_count": 51, "yes": 11, "no": 12, "children": [
        {"nodeid": 11, "split_feature": 6, "split_threshold": 1.18254995, "gain": 0.489795923, "instance_count": 7, "yes": 15, "no": 16, "children": [
          {"nodeid": 15, "leaf_value": [0, 1], "instance_count": 4},
          {"nodeid": 16, "leaf_value": [1, 0], "instance_count": 3}
        ]},
        {"nodeid": 12, "split_feature": 8, "split_threshold": 3.48108315, "gain": 0.0196772516, "instance_count": 44, "yes": 17, "no": 18, "children": [
          {"nodeid": 17, "split_feature": 5, "split_threshold": 0.717789888, "gain": 0.00283446093, "instance_count": 42, "yes": 21, "no": 22, "children": [
            {"nodeid": 21, "split_feature": 4, "split_threshold": 1.85632861, "gain": 1.1920929e-07, "instance_count": 30, "yes": 27, "no": 28, "children": [
              {"nodeid": 27, "leaf_value": [1, 0], "instance_count": 19},
              {"nodeid": 28, "leaf_value": [1, 0], "instance_count": 11}
            ]},
            {"nodeid": 22, "split_feature": 5, "split_threshold": 0.815551639, "gain": 0.152777761, "instance_count": 12, "yes": 29, "no": 30, "children": [
              {"nodeid": 29, "leaf_value": [0, 1], "instance_count": 1},
              {"nodeid": 30, "leaf_value": [1, 0], "instance_count": 11}
            ]}
          ]},
          {"nodeid": 18, "split_feature": 9, "split_threshold": 0.690918803, "gain": 0.5, "instance_count": 2, "yes": 23, "no": 24, "children": [
            {"nodeid": 23, "leaf_value": [0, 1], "instance_count": 1},
            {"nodeid": 24, "leaf_value": [1, 0], "instance_count": 1}
          ]}
        ]}
      ]}
    ]},
    {"nodeid": 6, "split_feature": 6, "split_threshold": 2.1641295, "gain": 0.0710349679, "instance_count": 22, "yes": 9, "no": 10, "children": [
      {"nodeid": 9, "split_feature": 7, "split_threshold": 3.8020432, "gain": 0.0818593949, "instance_count": 21, "yes": 13, "no": 14, "children": [
        {"nodeid": 13, "split_feature": 9, "split_threshold": 1.33453584, "gain": 0.0200000368, "instance_count": 20, "yes": 19, "no": 20, "children": [
          {"nodeid": 19, "leaf_value": [0, 1], "instance_count": 16},
          {"nodeid": 20, "split_feature": 5, "split_threshold": 0.08400774, "gain": 0.375, "instance_count": 4, "yes": 25, "no": 26, "children": [
            {"nodeid": 25, "leaf_value": [0, 1], "instance_count": 3},
            {"nodeid": 26, "leaf_value": [1, 0], "instance_count": 1}
          ]}
        ]},
        {"nodeid": 14, "leaf_value": [1, 0], "instance_count": 1}
      ]},
      {"nodeid": 10, "leaf_value": [1, 0], "instance_count": 1}
    ]}
  ]}
]}
]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM