繁体   English   中英

在partykit包中修改ctree()中的终端节点

[英]Modifying terminal node in ctree(), partykit package

我有一个因变量,可以根据决策树进行分类。 它由三类频率组成:738(19%),426(15%)和1800(66%)。 正如您所想象的那样,预测类别始终是第三类别,但是树的用途是描述性的,因此实际上并不重要。 问题是,当通过ctree()函数(package partykit )绘制树时,终端节点显示直方图,该直方图显示了这三种类别的出现概率。 我需要修改此输出:我想获得终端节点中每个类相对于类的绝对频率的出现比例。 例如,在class1的738名参与者中,哪个百分比属于某个终端节点? 每个终端节点将针对组成因变量的所有三个类显示此值。

下面是树的图,默认情况下,该图报告终端节点内每个类的流行程度。

您始终可以定义自己的面板功能以绘制进入每个终端面板窗口的内容。 如果您对grid图形有所了解,并了解如何定义当前的终端面板功能,您将了解其工作原理。

partykit程序包中的node_terminal()是应该执行所需操作的一个面板函数(对旧party程序包的重新实现有了很大的改进)。 但是,由于ctree()不会将其预测存储在每个终端节点中,因此node_terminal()函数目前无法立即执行此操作。 我将尝试在将来的版本中改进实现,以便于实现。 我希望下面是一个可以完成您想要的事的示例。

首先,我们使用iris数据拟合分类树(作为一个简单的可重现示例):

library("partykit")
(ct <- ctree(Species ~ ., data = iris))
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
## 
## Fitted party:
## [1] root
## |   [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
## |   [3] Petal.Length > 1.9
## |   |   [4] Petal.Width <= 1.7
## |   |   |   [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
## |   |   |   [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
## |   |   [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
## 
## Number of inner nodes:    3
## Number of terminal nodes: 4

然后,我们为每个终端节点计算一个预测概率表:

(pred <- aggregate(predict(ct, type = "prob"),
  list(predict(ct, type = "node")), FUN = mean))
##   Group.1 setosa versicolor  virginica
## 1       2      1 0.00000000 0.00000000
## 2       5      0 0.97826087 0.02173913
## 3       6      0 0.50000000 0.50000000
## 4       7      0 0.02173913 0.97826087

接下来是不太明显的部分:我们希望将这些预测的概率包括在树本身的终端节点中。 为此,我们将递归节点结构强制为一个平面列表,插入预测(适当格式化),然后将列表转换回节点结构:

ct_node <- as.list(ct$node)
for(i in 1:nrow(pred)) {
  ct_node[[pred[i,1]]]$info$prediction <- paste(
    format(names(pred)[-1]),
    format(round(pred[i, -1], digits = 3), nsmall = 3)
  )
}
ct$node <- as.partynode(ct_node)

然后,我们可以使用node_terminal面板函数轻松绘制树的图片,并插入预先格式化的预测:

plot(ct, terminal_panel = node_terminal, tp_args = list(
  FUN = function(node) c("Predictions", node$prediction)))

定制树

编辑: list和参与party之间的来回强制实际上已经在软件包中实现了...我只是忘了它;-)如果您这样做

st <- as.simpleparty(ct)

然后,结果party在每个节点中具有有关预测等的更多详细信息。例如, $distribution然后包含每个响应级别的绝对频率。 可以像以前一样轻松格式化

pred <- function(i) {
  tab <- i$distribution
  tab <- round(prop.table(tab), 3)
  tab <- paste0(names(tab), ":", format(tab, nsmall = 3))
  c("Predictions", tab)
}

可以将其传递给node_terminal以实质上创建上面的图。 如果希望所有终端节点都显示在底行中,则可能需要将drop = FALSE更改为drop = TRUE

plot(st, terminal_panel = node_terminal, tp_args = list(FUN = pred))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM