繁体   English   中英

R从partykit决策树中提取终端节点信息

[英]R extract terminal node info from partykit decision tree

我创建了一个constparty决策树(自定义拆分规则)并打印出树结果。 结果看起来像这样:

Fitted party:
[1] root
|   [2] value.a < 1651: 0.067 (n = 1419, err = 88.6)
|   [3] value.a >= 1651: 0.571 (n = 7, err = 1.7)

我正在尝试提取终端节点信息(yval:0.067和0.571;每个节点上的n:1419和7;以及err:88.6和1.7),并将它们放入列表中,同时具有相应的节点ID(节点ID 2和3),以便以后可以使用这些信息。

我一直在研究partykit函数有一段时间,但找不到可以帮助我提取刚才列出的信息的函数。

有人可以帮我吗? 谢谢!

与往常一样,有几种方法可以获取您要查找的信息。 提取存储在特定nodeinfo的技术方法是使用nodeapply(object, ids, info_node) ,其中info_node返回存储在相应节点中的信息列表。

但是,在constparty对象的终端节点中,没有存储任何内容。 取而代之的是,存储已拟合节点的响应的整个分布,并可以由fitted(object)提取。 它包含一个数据帧,其中包含观察到的(response) (fitted)节点和观察到的(weights) (如果有)。 然后,您可以轻松地使用tapply()aggregate()或类似的东西来计算节点均值等。

可替代地,可以将转换constparty对象到simpleparty其存储在节点中的印刷信息和对象提取它。

这两种策略的一个有效示例是cars数据的简单回归树:

library("partykit")
data("cars", package = "datasets")
ct <- ctree(dist ~ speed, data = cars)

然后,您可以通过以下方式轻松计算节点mean

with(fitted(ct), tapply(`(response)`, `(fitted)`, mean))
##        3        4        5 
## 18.20000 39.75000 65.26316 

当然,您可以将mean替换为您感兴趣的任何其他摘要统计量。

可以通过以下方式获取simplepartynodeapply()

nodeapply(as.simpleparty(ct), ids = nodeids(ct, terminal = TRUE), info_node)
## $`3`
## $`3`$prediction
## [1] 18.2
## 
## $`3`$n
##  n 
## 15 
## 
## $`3`$error
## [1] 1176.4
## 
## $`3`$distribution
## NULL
## 
## $`3`$p.value
## NULL
## 
## 
## $`4`
## $`4`$prediction
## [1] 39.75
## ...

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM