[英]R extract terminal node info from partykit decision tree
我创建了一个constparty决策树(自定义拆分规则)并打印出树结果。 结果看起来像这样:
Fitted party:
[1] root
| [2] value.a < 1651: 0.067 (n = 1419, err = 88.6)
| [3] value.a >= 1651: 0.571 (n = 7, err = 1.7)
我正在尝试提取终端节点信息(yval:0.067和0.571;每个节点上的n:1419和7;以及err:88.6和1.7),并将它们放入列表中,同时具有相应的节点ID(节点ID 2和3),以便以后可以使用这些信息。
我一直在研究partykit函数有一段时间,但找不到可以帮助我提取刚才列出的信息的函数。
有人可以帮我吗? 谢谢!
与往常一样,有几种方法可以获取您要查找的信息。 提取存储在特定node
的info
的技术方法是使用nodeapply(object, ids, info_node)
,其中info_node
返回存储在相应节点中的信息列表。
但是,在constparty
对象的终端节点中,没有存储任何内容。 取而代之的是,存储已拟合节点的响应的整个分布,并可以由fitted(object)
提取。 它包含一个数据帧,其中包含观察到的(response)
(fitted)
节点和观察到的(weights)
(如果有)。 然后,您可以轻松地使用tapply()
或aggregate()
或类似的东西来计算节点均值等。
可替代地,可以将转换constparty
对象到simpleparty
其存储在节点中的印刷信息和对象提取它。
这两种策略的一个有效示例是cars
数据的简单回归树:
library("partykit")
data("cars", package = "datasets")
ct <- ctree(dist ~ speed, data = cars)
然后,您可以通过以下方式轻松计算节点mean
with(fitted(ct), tapply(`(response)`, `(fitted)`, mean))
## 3 4 5
## 18.20000 39.75000 65.26316
当然,您可以将mean
替换为您感兴趣的任何其他摘要统计量。
可以通过以下方式获取simpleparty
的nodeapply()
:
nodeapply(as.simpleparty(ct), ids = nodeids(ct, terminal = TRUE), info_node)
## $`3`
## $`3`$prediction
## [1] 18.2
##
## $`3`$n
## n
## 15
##
## $`3`$error
## [1] 1176.4
##
## $`3`$distribution
## NULL
##
## $`3`$p.value
## NULL
##
##
## $`4`
## $`4`$prediction
## [1] 39.75
## ...
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.