简体   繁体   中英

Calculate number of observations in each node in a decision tree in R?

Similar questions have been asked, for example here and here but none of the other questions can be applied to my issue. Im trying to determine and count which observations are in each node in a decision tree. However, the tree structure is coming from a data frame of trees that Im creating myself from the BART package. Im extracting tree information from BART package and turning it into a data frame that resembles the one shown below (ie, df ). But I need to work with the data frame structure provided. Aside: I believe the method im using, in relation to how the trees are drawn/ordered in my data frame, is called 'depth first'.

For example, my data frame of trees looks like this:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

Visually, these trees would look like:

决策树

The trees are being drawn left-first when traversing down df . Additionally, all splits are binary splits. So each node will have 2 children.

So, if we create some data that looks like this:

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)

Im trying to find which of the observations of dat fall into which node?

Attempt at an answer: This isn't really helpful, but for clarity (as I am still trying to solve this), hardcoding it for tree number three would look like this:

lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]

 namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
    dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
    dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
    dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
    dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
    dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
    dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]

I have been trying to maybe turn this into a loop. But it's proving to be challenging to work out. And I (obviously) cant hardcode it for every tree. Any suggestions as to how I could solve this??

I was thinking about doing something like this:

lists <- df %>% group_by(treeNo) %>% group_split()

sideFunction <- function(df, data, index){
  if (!is.na(data[index, ]$variableName)){
    namesDf <- names(df[grepl(data[index, ]$variableName, names(df))])
    #get left side
    dataLeft <- df[df[, namesDf] <= data[index,]$splitValue, ]
    leftValues <- dataLeft[, namesDf]
    #get right side
    dataRight <- df[df[, namesDf] > data[index,]$splitValue, ]
    rightValues <- dataRight[, namesDf] 
  }else{
    leftValues <- "No values"
    rightValues <- "No values"
  }
  return(list(
    dataLeft = dataLeft, 
    leftValues = leftValues,
    dataRight = dataRight, 
    rightValues = rightValues
  ))
}


# checking for one list
data <- lists[[1]]
numbRow <- nrow(data)

dataLeft <- dat
dataRight <- dat

leftValues <- list()
rightValues <- list()

for (i in 1:numbRow) {
  Left <- sideFunction(df = dataLeft, data = data, index = i)
  Right <- sideFunction(df = dataRight, data = data,index = i)
  dataLeft <- Left$dataLeft
  dataRight <- Right$dataRight
  leftValues[[i]] <- list(Left$leftValues, Left$rightValues)
  rightValues[[i]] <- list(Right$leftValues, Right$rightValues)
  #print(Left)
  #print(Right)
}

leftValues
rightValues

This gives the values for one tree, but the output is not organized very well, it's giving unnecessary information. Additionally, I'm not providing the indices, but it's a start.

Here is a vectorized solution.

Solution

First load the dplyr and runner packages; and also generate your df :

library(dplyr)
library(runner)

# ...
# Code to generate 'df'.
# ...

Then take an intermediate step that provides rich information about each node...

df_info <- df %>%
  
  # Work separately within each tree.
  group_by(treeNo) %>%
  
  mutate(
    # Uniquely identify each node.  
    node_id = row_number(),
         
    # Determine if each node is a terminus.
    is_terminus = is.na(variableName),
    
    # Track how the level shifts with each new node:
    termini_streak = streak_run(cumsum(!is_terminus)) - 1,
    level_shift = case_when(
      # Encountering a new node: the next node is its child; down one level.
      termini_streak  < 1 ~        1,
      
      # Encountering yet another terminus: the next node has a different parent; up one
      # level.
      termini_streak  > 1 ~       -1,
      
      # Encountering the first terminus: the next node has the same parent; same level.
      TRUE                ~        0
    ),
    
    # Determine the level of each node from the tracking; starting at 0 for the root.
    level_id = cumsum(lag(level_shift, default = 0)),
    level_id = as.integer(level_id)
  ) %>%
  
  mutate(
    # Find the parent of each node: the nearest preceding node that is one level up.
    parent_index = runner(level_id, function(x) {
      length(x) - match(last(x) - 1, rev(x)) + 1
    }),
    
    # Identify that parent by its ID.
    parent_id = node_id[parent_index]
  ) %>%
  
  # Discard the helper columns.
  select(!c(is_terminus, termini_streak, level_shift, parent_index))

...as illustrated below. Here node_id uniquely identifies a node within its tree; while level_id indicates how many levels (ie. "generations") a node is descended from the root ( 0 ); and parent_id gives the node_id of the parent to a node.

#> df_info

# A tibble: 15 x 6
# Groups:   treeNo [3]
   variableName splitValue treeNo node_id level_id parent_id
   <chr>             <dbl>  <dbl>   <int>    <int>     <int>
 1 x2                0.542      1       1        0        NA
 2 x1                0.126      1       2        1         1
 3 NA               NA          1       3        2         2
 4 NA               NA          1       4        2         2
 5 NA               NA          1       5        1         1
 6 x2                0.655      2       1        0        NA
 7 NA               NA          2       2        1         1
 8 NA               NA          2       3        1         1
 9 x5                0.418      3       1        0        NA
10 x4                0.234      3       2        1         1
11 NA               NA          3       3        2         2
12 NA               NA          3       4        2         2
13 x3                0.747      3       5        1         1
14 NA               NA          3       6        2         5
15 NA               NA          3       7        2         5

Next, put that df_info through this workflow to link each node to its children...

df_linked <- df_info %>%
  
  # Link each node to its children.
  left_join(
    df_info,
    by = c(
      "treeNo",
      node_id = "parent_id"
    ),
    suffix = c(".parent", ".child")
  ) %>%
  
  # Keep only the desired columns.
  transmute(
    treeNo,
    variableName = variableName.parent,
    splitValue = splitValue.parent,
    
    node_id = node_id,
    level_id = level_id.parent,
    
    child_id = node_id.child,
    child_var = variableName.child
    
    # , child_level = level_id.child
    # , child_val = splitValue.child
  ) %>%
  
  # OPTIONALLY ignore unnamed nodes.
  filter(!is.na(variableName))

...as illustrated below:

#> df_linked

# A tibble: 12 x 7
# Groups:   treeNo [3]
   treeNo variableName splitValue node_id level_id child_id child_var
    <dbl> <chr>             <dbl>   <int>    <int>    <int> <chr>    
 1      1 x2                0.542       1        0        2 x1       
 2      1 x2                0.542       1        0        5 NA       
 3      1 x1                0.126       2        1        3 NA       
 4      1 x1                0.126       2        1        4 NA       
 5      2 x2                0.655       1        0        2 NA       
 6      2 x2                0.655       1        0        3 NA       
 7      3 x5                0.418       1        0        2 x4       
 8      3 x5                0.418       1        0        5 x3       
 9      3 x4                0.234       2        1        3 NA       
10      3 x4                0.234       2        1        4 NA       
11      3 x3                0.747       5        1        6 NA       
12      3 x3                0.747       5        1        7 NA       

Finally, you can summarize the number of children for each node:

df_summary <- df_linked %>%

  # Summarize each node in each tree:
  group_by(node_id, .add = TRUE) %>%
  summarize(
    # Count the number of children (not missing)...
    total_children = sum(!is.na(child_id)),
    # ..and the number of named children.
    named_children = sum(!is.na(child_var)),
    
    
    # Preserve the other useful info.
    variableName = first(variableName),
    splitValue = first(splitValue)
    # , level_id = first(level_id)
  ) %>%
  
  # Reformat the dataset as before.
  select(
    variableName, splitValue, treeNo, named_children, total_children
    # , node_id, level_id
  )

Results

Given a df like the one reproduced here

df <- structure(
  list(
    variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
    splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
    treeNo = c(1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3)
  ),
  class = c("tbl_df", "tbl", "data.frame"),
  row.names = c(NA, -15L)
)

this solution should yield the following result for df_summary :

#> df_summary

# A tibble: 6 x 5
# Groups:   treeNo [3]
  variableName splitValue treeNo named_children total_children
  <chr>             <dbl>  <dbl>          <int>          <int>
1 x2                0.542      1              1              2
2 x1                0.126      1              0              2
3 x2                0.655      2              0              2
4 x5                0.418      3              2              2
5 x4                0.234      3              0              2
6 x3                0.747      3              0              2

Warning

If you allow the format of df to deviate even slightly from your convention, or if you perform something other than a depth-first search (from the left) on something other than a binary tree, then this workflow will fail.

Note

If you choose to keep unnamed nodes in df_linked , rather than filter() them out, then you should include node_id in the select() for df_summary , so as to distinguish unnamed nodes from each other.

There is still much room for optimization, however, given your data:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

and

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)
dat
##>           x1        x2        x3        x4        x5
##>1  0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
##>2  0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
##>3  0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
##>4  0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
##>5  0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
##>6  0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
##>7  0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
##>8  0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
##>9  0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
##>10 0.17026205 0.6902905 0.2777238 0.1302889 0.3070859

Here's a function that for a single tree, returns a function that will map a row of values to a node:

makeTree <- function(dat, r = 1) {
  stopifnot(r <= nrow(dat))
  vname <- pull(dat,variableName)[r]
  splitVal <- pull(dat, splitValue)[r]
  if (is.na(vname)) {
    ## terminal node
    ## print(sprintf("terminal node: %i", r))
    res <- list(size = 1, # offset to access right node
                fn = function(z) {
                  pull(dat, "id")[r]
                })
    return(res)
  } else {
    ##print(sprintf("node: %i, varName: %s, splitVal: %f", r, vname, splitVal ))
    ## caclulate the left and right functions
    fnleft <- makeTree(dat, r + 1) #fnleft is always positoned next to the
                                   #caller
    fnright <- makeTree(dat, r + fnleft$size + 1 )
    return(list(size = fnleft$size + fnright$size + 1,
                fn = function(z) {
                  if (z[vname] <= splitVal)
                    fnleft$fn(z)
                  else
                    fnright$fn(z)
                }))
  }
}

Now this function is applied to each tree to produce a list of matching functions.

treefns <- df |>
  mutate(id = row_number()) %>%
  group_by(treeNo) |>
  group_split()    |>
  purrr::map(makeTree) |>
  purrr::map("fn")

Finally, each row of your dataframe is matched to a node of the tree.

apply(dat,1, function(z) sapply(treefns, function(fn) fn(z))) |>
  t() |>
  data.frame() |>
  rename_with(function(z) paste0("TREE", gsub("X", "", z))) |>
  cbind(dat) |>
  pivot_longer(cols = starts_with("TREE"),
               names_to = "TREE",
               values_to = "NODE")  |>
  sample_n(10)

##> A tibble: 10 x 7
##>       x1    x2    x3    x4    x5 TREE   NODE
##>    <dbl> <dbl> <dbl> <dbl> <dbl> <chr> <int>
##> 1 0.170  0.690 0.278 0.130 0.307 TREE3    11
##> 2 0.170  0.690 0.278 0.130 0.307 TREE2     8
##> 3 0.370  0.358 0.882 0.629 0.884 TREE2     7
##> 4 0.308  0.625 0.536 0.488 0.331 TREE1     5
##> 5 0.370  0.358 0.882 0.629 0.884 TREE1     4
##> 6 0.552  0.280 0.538 0.349 0.778 TREE3    14
##> 7 0.547  0.359 0.549 0.990 0.208 TREE1     4
##> 8 0.370  0.358 0.882 0.629 0.884 TREE3    15
##> 9 0.547  0.359 0.549 0.990 0.208 TREE2     7
##>10 0.0564 0.398 0.749 0.954 0.827 TREE2     7

It is possible to identify each node by a set of filters. We can use a stack to keep track of all the states. The logic is as follows.

  1. Initialize a stack stk to keep track of the nodes, a stack counter cnt to keep track of branches, and a position index pos to record current position in stk and cnt .
  2. For each value in variableName , output variables and split values that are currently in stk . If current cnt > 0 , we are on the right branch and would thus use a > for split; otherwise use <= .
  3. If we observe an NA , increment current cnt by 1 to show that we have just finished a branch; Otherwise, put one more node onto stk to show that we are moving one level deeper down the branch.
  4. If current cnt > 1 , then we have fully traversed all branches of a node. Thus, we discharge that node from stk , reinitialize the cnt , move backward the branch, and then increment branch cnt by 1. Repeat 4) if needed since we may recursively finish one or more branches.

Here is the code. Note that this kind of state-dependent computation is hard to vectorize. It's thus not what R is good at. If you have a lot of trees and the code performance becomes a serious concern, I'd suggest rewriting the code below using Rcpp .

read_node <- function(x, v) {
  stk <- integer(length(x))
  cnt <- integer(length(x))
  pos <- 1L
  out <- character(length(x))
  for (i in seq_along(x)) {
    out[[i]] <- paste0(
      x[stk], c("<=", ">")[(cnt[stk != 0L] > 0L) + 1L], v[stk], 
      collapse = "&"
    )
    if (!is.na(x[[i]]))
      stk[[pos <- pos + 1L]] <- i
    else
      cnt[[pos]] <- cnt[[pos]] + 1L
    while (cnt[[pos]] > 1L) {
      stk[[pos]] <- 0L
      cnt[[pos]] <- 0L
      pos <- pos - 1L
      cnt[[pos]] <- cnt[[pos]] + 1L
    }
  }
  out
}

Then you can apply the filters to your dat like this.

library(dplyr)

df %>% 
  group_by(treeNo) %>% 
  mutate(
    node = read_node(variableName, splitValue), 
    filter = if_else(node == "", "dat", sprintf("dat[%s, ]", node)), 
    obs = eval(parse(text = sprintf("list(%s)", paste0(filter, collapse = ","))), dat)
  )

Output

# A tibble: 15 x 6
# Groups:   treeNo [3]
   variableName splitValue treeNo node                  filter                     obs          
   <chr>             <dbl>  <dbl> <chr>                 <chr>                      <list>       
 1 x2                0.542      1 ""                    dat                        <df [10 x 5]>
 2 x1                0.126      1 "x2<=0.542"           dat[x2<=0.542, ]           <df [5 x 5]> 
 3 NA               NA          1 "x2<=0.542&x1<=0.126" dat[x2<=0.542&x1<=0.126, ] <df [1 x 5]> 
 4 NA               NA          1 "x2<=0.542&x1>0.126"  dat[x2<=0.542&x1>0.126, ]  <df [4 x 5]> 
 5 NA               NA          1 "x2>0.542"            dat[x2>0.542, ]            <df [5 x 5]> 
 6 x2                0.655      2 ""                    dat                        <df [10 x 5]>
 7 NA               NA          2 "x2<=0.6547"          dat[x2<=0.6547, ]          <df [6 x 5]> 
 8 NA               NA          2 "x2>0.6547"           dat[x2>0.6547, ]           <df [4 x 5]> 
 9 x5                0.418      3 ""                    dat                        <df [10 x 5]>
10 x4                0.234      3 "x5<=0.418"           dat[x5<=0.418, ]           <df [3 x 5]> 
11 NA               NA          3 "x5<=0.418&x4<=0.234" dat[x5<=0.418&x4<=0.234, ] <df [1 x 5]> 
12 NA               NA          3 "x5<=0.418&x4>0.234"  dat[x5<=0.418&x4>0.234, ]  <df [2 x 5]> 
13 x3                0.747      3 "x5>0.418"            dat[x5>0.418, ]            <df [7 x 5]> 
14 NA               NA          3 "x5>0.418&x3<=0.747"  dat[x5>0.418&x3<=0.747, ]  <df [4 x 5]> 
15 NA               NA          3 "x5>0.418&x3>0.747"   dat[x5>0.418&x3>0.747, ]   <df [3 x 5]> 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM