简体   繁体   中英

R Extracting inner node information and splits from ctree (partykit)

Hi I'm currently trying to extract some of the inner node information stored in the constant partying object in R using ctree in partykit but I'm finding navigating the objects a bit difficult, I'm able to display the information on a plot but I'm not sure how to extract the information - I think it requires nodeapply or another function in the partykit?

library(partykit)
irisct <- ctree(Species ~ .,data = iris)
plot(irisct, inner_panel = node_barplot(irisct))

Plot with inner node details

All the information is accessible by the functions to plot, but I'm after a text output similar to: Example output

The main trick (as previously pointed out by @G5W) is to take the [id] subset of the party object and then extract the data (by either $data or using the data_party() function) which contains the response. I would recommend to build a table with absolute frequencies first and then compute the relative and marginal frequencies from that. Using the irisct object the plain table can be obtained by

tab <- sapply(1:length(irisct), function(id) {
  y <- data_party(irisct[id])
  y <- y[["(response)"]]
  table(y)
})
tab
##            [,1] [,2] [,3] [,4] [,5] [,6] [,7]
## setosa       50   50    0    0    0    0    0
## versicolor   50    0   50   49   45    4    1
## virginica    50    0   50    5    1    4   45

Then we can add a little bit of formatting to a nice table object:

colnames(tab) <- 1:length(irisct)
tab <- as.table(tab)
names(dimnames(tab)) <- c("Species", "Node")

And then use prop.table() and margin.table() to compute the frequencies we are interested in. The as.data.frame() method transform from the table layout to a "long" data.frame :

as.data.frame(prop.table(tab, 1))
##       Species Node        Freq
## 1      setosa    1 0.500000000
## 2  versicolor    1 0.251256281
## 3   virginica    1 0.322580645
## 4      setosa    2 0.500000000
## 5  versicolor    2 0.000000000
## 6   virginica    2 0.000000000
## 7      setosa    3 0.000000000
## 8  versicolor    3 0.251256281
## 9   virginica    3 0.322580645
## 10     setosa    4 0.000000000
## 11 versicolor    4 0.246231156
## 12  virginica    4 0.032258065
## 13     setosa    5 0.000000000
## 14 versicolor    5 0.226130653
## 15  virginica    5 0.006451613
## 16     setosa    6 0.000000000
## 17 versicolor    6 0.020100503
## 18  virginica    6 0.025806452
## 19     setosa    7 0.000000000
## 20 versicolor    7 0.005025126
## 21  virginica    7 0.290322581

as.data.frame(margin.table(tab, 2))
##   Node Freq
## 1    1  150
## 2    2   50
## 3    3  100
## 4    4   54
## 5    5   46
## 6    6    8
## 7    7   46

And the split information can be obtained with the (still unexported) .list.rules.party() function. You just need to ask for all node IDs (the default is to use just the terminal node IDs):

partykit:::.list.rules.party(irisct, i = nodeids(irisct))
##                                                               1 
##                                                              "" 
##                                                               2 
##                                           "Petal.Length <= 1.9" 
##                                                               3 
##                                            "Petal.Length > 1.9" 
##                                                               4 
##                       "Petal.Length > 1.9 & Petal.Width <= 1.7" 
##                                                               5 
## "Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8" 
##                                                               6 
##  "Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length > 4.8" 
##                                                               7 
##                        "Petal.Length > 1.9 & Petal.Width > 1.7" 

Most of the information that you want is accessible without much work. I will show how to get the information, but leave you to format the information into a pretty table.

Notice that your tree structure irisct is just a list of each of the nodes.

length(irisct)
[1] 7

Each node has a field data that contains the points that have made it down this far in the tree, so you can get the number of observations at the node by counting the rows.

dim(irisct[4]$data)
[1] 54  5
nrow(irisct[4]$data)
[1] 54

Or doing them all at once to get your table 2

NObs = sapply(1:7, function(n) { nrow(irisct[n]$data) })
NObs
[1] 150  50 100  54  46   8  46

The first column of the data at a node is the class (Species), so you can get the count of each class and the probability of each class at a node

table(irisct[4]$data[1])
setosa versicolor  virginica 
     0         49          5 
table(irisct[4]$data[1]) / NObs[4]
setosa versicolor  virginica 
0.00000000 0.90740741 0.09259259 

The split information in your table 3 is a bit more awkward. Still, you can get a text version of what you need just by printing out the top level node

irisct[1]
Model formula:
Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
Fitted party:
[1] root
|   [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
|   [3] Petal.Length > 1.9
|   |   [4] Petal.Width <= 1.7
|   |   |   [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
|   |   |   [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
|   |   [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
Number of inner nodes:    3
Number of terminal nodes: 4

To save the output for parsing and display

TreeSplits = capture.output(print(irisct[1]))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM