Similar questions have been asked, for example here and here but none of the other questions can be applied to my issue. Im trying to determine and count which observations are in each node in a decision tree. However, the tree structure is coming from a data frame of trees that Im creating myself from the BART
package. Im extracting tree information from BART
package and turning it into a data frame that resembles the one shown below (ie, df
). But I need to work with the data frame structure provided. Aside: I believe the method im using, in relation to how the trees are drawn/ordered in my data frame, is called 'depth first'.
For example, my data frame of trees looks like this:
library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))
Visually, these trees would look like:
The trees are being drawn left-first when traversing down df
. Additionally, all splits are binary splits. So each node will have 2 children.
So, if we create some data that looks like this:
set.seed(100)
dat <- data.frame( x1 = runif(10),
x2 = runif(10),
x3 = runif(10),
x4 = runif(10),
x5 = runif(10)
)
Im trying to find which of the observations of dat
fall into which node?
Attempt at an answer: This isn't really helpful, but for clarity (as I am still trying to solve this), hardcoding it for tree number three would look like this:
lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]
namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]
I have been trying to maybe turn this into a loop. But it's proving to be challenging to work out. And I (obviously) cant hardcode it for every tree. Any suggestions as to how I could solve this??
I was thinking about doing something like this:
lists <- df %>% group_by(treeNo) %>% group_split()
sideFunction <- function(df, data, index){
if (!is.na(data[index, ]$variableName)){
namesDf <- names(df[grepl(data[index, ]$variableName, names(df))])
#get left side
dataLeft <- df[df[, namesDf] <= data[index,]$splitValue, ]
leftValues <- dataLeft[, namesDf]
#get right side
dataRight <- df[df[, namesDf] > data[index,]$splitValue, ]
rightValues <- dataRight[, namesDf]
}else{
leftValues <- "No values"
rightValues <- "No values"
}
return(list(
dataLeft = dataLeft,
leftValues = leftValues,
dataRight = dataRight,
rightValues = rightValues
))
}
# checking for one list
data <- lists[[1]]
numbRow <- nrow(data)
dataLeft <- dat
dataRight <- dat
leftValues <- list()
rightValues <- list()
for (i in 1:numbRow) {
Left <- sideFunction(df = dataLeft, data = data, index = i)
Right <- sideFunction(df = dataRight, data = data,index = i)
dataLeft <- Left$dataLeft
dataRight <- Right$dataRight
leftValues[[i]] <- list(Left$leftValues, Left$rightValues)
rightValues[[i]] <- list(Right$leftValues, Right$rightValues)
#print(Left)
#print(Right)
}
leftValues
rightValues
This gives the values for one tree, but the output is not organized very well, it's giving unnecessary information. Additionally, I'm not providing the indices, but it's a start.
Here is a vectorized solution.
First load the dplyr
and runner
packages; and also generate your df
:
library(dplyr)
library(runner)
# ...
# Code to generate 'df'.
# ...
Then take an intermediate step that provides rich information about each node...
df_info <- df %>%
# Work separately within each tree.
group_by(treeNo) %>%
mutate(
# Uniquely identify each node.
node_id = row_number(),
# Determine if each node is a terminus.
is_terminus = is.na(variableName),
# Track how the level shifts with each new node:
termini_streak = streak_run(cumsum(!is_terminus)) - 1,
level_shift = case_when(
# Encountering a new node: the next node is its child; down one level.
termini_streak < 1 ~ 1,
# Encountering yet another terminus: the next node has a different parent; up one
# level.
termini_streak > 1 ~ -1,
# Encountering the first terminus: the next node has the same parent; same level.
TRUE ~ 0
),
# Determine the level of each node from the tracking; starting at 0 for the root.
level_id = cumsum(lag(level_shift, default = 0)),
level_id = as.integer(level_id)
) %>%
mutate(
# Find the parent of each node: the nearest preceding node that is one level up.
parent_index = runner(level_id, function(x) {
length(x) - match(last(x) - 1, rev(x)) + 1
}),
# Identify that parent by its ID.
parent_id = node_id[parent_index]
) %>%
# Discard the helper columns.
select(!c(is_terminus, termini_streak, level_shift, parent_index))
...as illustrated below. Here node_id
uniquely identifies a node within its tree; while level_id
indicates how many levels (ie. "generations") a node is descended from the root ( 0
); and parent_id
gives the node_id
of the parent to a node.
#> df_info
# A tibble: 15 x 6
# Groups: treeNo [3]
variableName splitValue treeNo node_id level_id parent_id
<chr> <dbl> <dbl> <int> <int> <int>
1 x2 0.542 1 1 0 NA
2 x1 0.126 1 2 1 1
3 NA NA 1 3 2 2
4 NA NA 1 4 2 2
5 NA NA 1 5 1 1
6 x2 0.655 2 1 0 NA
7 NA NA 2 2 1 1
8 NA NA 2 3 1 1
9 x5 0.418 3 1 0 NA
10 x4 0.234 3 2 1 1
11 NA NA 3 3 2 2
12 NA NA 3 4 2 2
13 x3 0.747 3 5 1 1
14 NA NA 3 6 2 5
15 NA NA 3 7 2 5
Next, put that df_info
through this workflow to link each node to its children...
df_linked <- df_info %>%
# Link each node to its children.
left_join(
df_info,
by = c(
"treeNo",
node_id = "parent_id"
),
suffix = c(".parent", ".child")
) %>%
# Keep only the desired columns.
transmute(
treeNo,
variableName = variableName.parent,
splitValue = splitValue.parent,
node_id = node_id,
level_id = level_id.parent,
child_id = node_id.child,
child_var = variableName.child
# , child_level = level_id.child
# , child_val = splitValue.child
) %>%
# OPTIONALLY ignore unnamed nodes.
filter(!is.na(variableName))
...as illustrated below:
#> df_linked
# A tibble: 12 x 7
# Groups: treeNo [3]
treeNo variableName splitValue node_id level_id child_id child_var
<dbl> <chr> <dbl> <int> <int> <int> <chr>
1 1 x2 0.542 1 0 2 x1
2 1 x2 0.542 1 0 5 NA
3 1 x1 0.126 2 1 3 NA
4 1 x1 0.126 2 1 4 NA
5 2 x2 0.655 1 0 2 NA
6 2 x2 0.655 1 0 3 NA
7 3 x5 0.418 1 0 2 x4
8 3 x5 0.418 1 0 5 x3
9 3 x4 0.234 2 1 3 NA
10 3 x4 0.234 2 1 4 NA
11 3 x3 0.747 5 1 6 NA
12 3 x3 0.747 5 1 7 NA
Finally, you can summarize the number of children for each node:
df_summary <- df_linked %>%
# Summarize each node in each tree:
group_by(node_id, .add = TRUE) %>%
summarize(
# Count the number of children (not missing)...
total_children = sum(!is.na(child_id)),
# ..and the number of named children.
named_children = sum(!is.na(child_var)),
# Preserve the other useful info.
variableName = first(variableName),
splitValue = first(splitValue)
# , level_id = first(level_id)
) %>%
# Reformat the dataset as before.
select(
variableName, splitValue, treeNo, named_children, total_children
# , node_id, level_id
)
Given a df
like the one reproduced here
df <- structure(
list(
variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3)
),
class = c("tbl_df", "tbl", "data.frame"),
row.names = c(NA, -15L)
)
this solution should yield the following result for df_summary
:
#> df_summary
# A tibble: 6 x 5
# Groups: treeNo [3]
variableName splitValue treeNo named_children total_children
<chr> <dbl> <dbl> <int> <int>
1 x2 0.542 1 1 2
2 x1 0.126 1 0 2
3 x2 0.655 2 0 2
4 x5 0.418 3 2 2
5 x4 0.234 3 0 2
6 x3 0.747 3 0 2
If you allow the format of df
to deviate even slightly from your convention, or if you perform something other than a depth-first search (from the left) on something other than a binary tree, then this workflow will fail.
If you choose to keep unnamed nodes in df_linked
, rather than filter()
them out, then you should include node_id
in the select()
for df_summary
, so as to distinguish unnamed nodes from each other.
There is still much room for optimization, however, given your data:
library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))
and
set.seed(100)
dat <- data.frame( x1 = runif(10),
x2 = runif(10),
x3 = runif(10),
x4 = runif(10),
x5 = runif(10)
)
dat
##> x1 x2 x3 x4 x5
##>1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
##>2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
##>3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
##>4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
##>5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
##>6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
##>7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
##>8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
##>9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
##>10 0.17026205 0.6902905 0.2777238 0.1302889 0.3070859
Here's a function that for a single tree, returns a function that will map a row of values to a node:
makeTree <- function(dat, r = 1) {
stopifnot(r <= nrow(dat))
vname <- pull(dat,variableName)[r]
splitVal <- pull(dat, splitValue)[r]
if (is.na(vname)) {
## terminal node
## print(sprintf("terminal node: %i", r))
res <- list(size = 1, # offset to access right node
fn = function(z) {
pull(dat, "id")[r]
})
return(res)
} else {
##print(sprintf("node: %i, varName: %s, splitVal: %f", r, vname, splitVal ))
## caclulate the left and right functions
fnleft <- makeTree(dat, r + 1) #fnleft is always positoned next to the
#caller
fnright <- makeTree(dat, r + fnleft$size + 1 )
return(list(size = fnleft$size + fnright$size + 1,
fn = function(z) {
if (z[vname] <= splitVal)
fnleft$fn(z)
else
fnright$fn(z)
}))
}
}
Now this function is applied to each tree to produce a list of matching functions.
treefns <- df |>
mutate(id = row_number()) %>%
group_by(treeNo) |>
group_split() |>
purrr::map(makeTree) |>
purrr::map("fn")
Finally, each row of your dataframe is matched to a node of the tree.
apply(dat,1, function(z) sapply(treefns, function(fn) fn(z))) |>
t() |>
data.frame() |>
rename_with(function(z) paste0("TREE", gsub("X", "", z))) |>
cbind(dat) |>
pivot_longer(cols = starts_with("TREE"),
names_to = "TREE",
values_to = "NODE") |>
sample_n(10)
##> A tibble: 10 x 7
##> x1 x2 x3 x4 x5 TREE NODE
##> <dbl> <dbl> <dbl> <dbl> <dbl> <chr> <int>
##> 1 0.170 0.690 0.278 0.130 0.307 TREE3 11
##> 2 0.170 0.690 0.278 0.130 0.307 TREE2 8
##> 3 0.370 0.358 0.882 0.629 0.884 TREE2 7
##> 4 0.308 0.625 0.536 0.488 0.331 TREE1 5
##> 5 0.370 0.358 0.882 0.629 0.884 TREE1 4
##> 6 0.552 0.280 0.538 0.349 0.778 TREE3 14
##> 7 0.547 0.359 0.549 0.990 0.208 TREE1 4
##> 8 0.370 0.358 0.882 0.629 0.884 TREE3 15
##> 9 0.547 0.359 0.549 0.990 0.208 TREE2 7
##>10 0.0564 0.398 0.749 0.954 0.827 TREE2 7
It is possible to identify each node by a set of filters. We can use a stack to keep track of all the states. The logic is as follows.
stk
to keep track of the nodes, a stack counter cnt
to keep track of branches, and a position index pos
to record current position in stk
and cnt
.variableName
, output variables and split values that are currently in stk
. If current cnt > 0
, we are on the right branch and would thus use a >
for split; otherwise use <=
.NA
, increment current cnt
by 1 to show that we have just finished a branch; Otherwise, put one more node onto stk
to show that we are moving one level deeper down the branch.cnt > 1
, then we have fully traversed all branches of a node. Thus, we discharge that node from stk
, reinitialize the cnt
, move backward the branch, and then increment branch cnt
by 1. Repeat 4) if needed since we may recursively finish one or more branches. Here is the code. Note that this kind of state-dependent computation is hard to vectorize. It's thus not what R is good at. If you have a lot of trees and the code performance becomes a serious concern, I'd suggest rewriting the code below using Rcpp
.
read_node <- function(x, v) {
stk <- integer(length(x))
cnt <- integer(length(x))
pos <- 1L
out <- character(length(x))
for (i in seq_along(x)) {
out[[i]] <- paste0(
x[stk], c("<=", ">")[(cnt[stk != 0L] > 0L) + 1L], v[stk],
collapse = "&"
)
if (!is.na(x[[i]]))
stk[[pos <- pos + 1L]] <- i
else
cnt[[pos]] <- cnt[[pos]] + 1L
while (cnt[[pos]] > 1L) {
stk[[pos]] <- 0L
cnt[[pos]] <- 0L
pos <- pos - 1L
cnt[[pos]] <- cnt[[pos]] + 1L
}
}
out
}
Then you can apply the filters to your dat
like this.
library(dplyr)
df %>%
group_by(treeNo) %>%
mutate(
node = read_node(variableName, splitValue),
filter = if_else(node == "", "dat", sprintf("dat[%s, ]", node)),
obs = eval(parse(text = sprintf("list(%s)", paste0(filter, collapse = ","))), dat)
)
Output
# A tibble: 15 x 6
# Groups: treeNo [3]
variableName splitValue treeNo node filter obs
<chr> <dbl> <dbl> <chr> <chr> <list>
1 x2 0.542 1 "" dat <df [10 x 5]>
2 x1 0.126 1 "x2<=0.542" dat[x2<=0.542, ] <df [5 x 5]>
3 NA NA 1 "x2<=0.542&x1<=0.126" dat[x2<=0.542&x1<=0.126, ] <df [1 x 5]>
4 NA NA 1 "x2<=0.542&x1>0.126" dat[x2<=0.542&x1>0.126, ] <df [4 x 5]>
5 NA NA 1 "x2>0.542" dat[x2>0.542, ] <df [5 x 5]>
6 x2 0.655 2 "" dat <df [10 x 5]>
7 NA NA 2 "x2<=0.6547" dat[x2<=0.6547, ] <df [6 x 5]>
8 NA NA 2 "x2>0.6547" dat[x2>0.6547, ] <df [4 x 5]>
9 x5 0.418 3 "" dat <df [10 x 5]>
10 x4 0.234 3 "x5<=0.418" dat[x5<=0.418, ] <df [3 x 5]>
11 NA NA 3 "x5<=0.418&x4<=0.234" dat[x5<=0.418&x4<=0.234, ] <df [1 x 5]>
12 NA NA 3 "x5<=0.418&x4>0.234" dat[x5<=0.418&x4>0.234, ] <df [2 x 5]>
13 x3 0.747 3 "x5>0.418" dat[x5>0.418, ] <df [7 x 5]>
14 NA NA 3 "x5>0.418&x3<=0.747" dat[x5>0.418&x3<=0.747, ] <df [4 x 5]>
15 NA NA 3 "x5>0.418&x3>0.747" dat[x5>0.418&x3>0.747, ] <df [3 x 5]>
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.