简体   繁体   English

从 rpart 决策树中提取变量标签

[英]Extract variable labels from rpart decision tree

I've used part to build a decision tree on a dataset that has categorical variables with hundreds of levels.我已经使用 part 在具有数百个级别的分类变量的数据集上构建决策树。 The tree splits these variables based on select values of the variable.树根据变量的选择值拆分这些变量。 I would like to examine the labels on which the split is made.我想检查进行拆分的标签。

If I just run the decision tree result, the display listing the splits in the console gets truncated and either way, it is not in an easily-interpretable format (separated by commas).如果我只运行决策树结果,控制台中列出拆分的显示会被截断,无论哪种方式,它都不是易于解释的格式(用逗号分隔)。 Is there a way to access this as an R object?有没有办法将它作为 R 对象访问? I'm open to using another package to build the tree.我愿意使用另一个包来构建树。

One issue here is that some of the functions in the rpart package are not exported.这里的一个问题是rpart包中的一些函数没有导出。 It appears you're looking to capture the output of the function rpart:::print.rpart .看来您要捕获函数rpart:::print.rpart的输出。 So, beginning with a reproducible example:所以,从一个可重现的例子开始:

set.seed(1)
df1 <- data.frame(y=rbinom(n=100, size=1, prob=0.5),
                  x1=rbinom(n=100, size=1, prob=0.25),
                  x2=rbinom(n=100, size=1, prob=0.75))
(r1 <- rpart(y ~ ., data=df1))

giving给予

n= 100 

node), split, n, deviance, yval
      * denotes terminal node

1) root 100 24.960000 0.4800000  
  2) x1< 0.5 78 19.179490 0.4358974  
    4) x2>=0.5 66 15.954550 0.4090909 *
    5) x2< 0.5 12  2.916667 0.5833333 *
  3) x1>=0.5 22  5.090909 0.6363636  
    6) x2< 0.5 7  1.714286 0.4285714 *
    7) x2>=0.5 15  2.933333 0.7333333 *

Now, looking at rpart:::print.rpart , we see a call to rpart:::labels.rpart , giving us the splits (or names of the 'rows' in the output above).现在,查看rpart:::print.rpart ,我们看到对rpart:::labels.rpart的调用,为我们提供了拆分(或上面输出中“行”的名称)。 The value of n, deviance, yval and more are stored in r1$frame , which can be seen by inspecting the output from unclass(r1) . n, deviance, yval等的值存储在r1$frame ,可以通过检查unclass(r1)的输出来unclass(r1)

Thus we could extract the above with因此我们可以提取上述内容

(df2 <- data.frame(split=rpart:::labels.rpart(r1), n=r1$frame$n))

giving给予

    split   n
1    root 100
2 x1< 0.5  78
3 x2>=0.5  66
4 x2< 0.5  12
5 x1>=0.5  22
6 x2< 0.5   7
7 x2>=0.5  15

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM