繁体   English   中英

R partykit计算子街道上的分类概率

[英]R partykit calculate classification probabilities on sub stree

我已经训练了partykit包ctree分类决策树,并且我需要计算子树(不仅是叶节点)的分类概率。 因此,例如,如果一个子树由具有以下概率的3个叶节点组成:叶1(120个观测值):0.45叶2(160个观测值):0.49叶3(190个观测值):0.83

对于此假设子树,加权平均概率将为120 * 0.42 + 160 * 0.49 + 190 * 0.83 /(120 + 160 + 190)= 0.507

依此类推,我需要遍历ctree对象并递归计算每个节点的所有加权概率。

我有以下代码:

data(airquality)
airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq,
                 controls = ctree_control(maxsurrogate = 3))
traverse <- function(treenode){
    if(treenode$terminal){
      bas=paste("Current node is terminal node with",treenode$nodeID,'prediction',treenode$prediction)
      print(bas)
      return(0)
    } else {
      bas=paste("Current node",treenode$nodeID,"Split var. ID:",treenode$psplit$variableName,"split value:",treenode$psplit$splitpoint,'prediction',treenode$prediction)
      print(bas)
    }
    traverse(treenode$left)
    traverse(treenode$right)
  }

在tree上遍历的对象不适用于partykit对象。 另一方面,我有这段代码,其中仅列出了叶节点的所有可能性:

preds.ls <- list(predict(airct , type = "prob"))[1]
pred.probs.df <- unique(as.data.frame((preds.ls[[1]])))

任何建议将这两个代码段合并到将在PARTYKIT对象上遍历并计算该加权平均值的代码的建议

我对partykit并不熟悉,但是这个简单的函数走了一个ctree并提取了每个内部和终端节点的概率:

   library(party)

    set.seed(100)
    dt <- ctree(factor(mpg > 20)~., data = mtcars,
                control = ctree_control(minsplit=2, minbucket=1, mincriterion=0))

    traverse <- function(node) {
      if (node$terminal) {
        return(node$prediction[2])
      }
      return(c(node$prediction[2],
               traverse(node$left), traverse(node$right)))
    }

在此处输入图片说明

调用该函数将产生以下概率向量:

> traverse(dt@tree)
[1] 0.4375000 1.0000000 0.1428571 0.4285714 0.7500000 0.0000000 0.0000000

最左边的值是通过以下方式验证的人口值:

> mean(mtcars$mpg > 20)
[1] 0.4375

其余的值将按从左到右的顺序排列。 您会看到1和0排列在预期的位置。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM