簡體   English   中英

使用 R 中“rpart”包中的生存樹來預測新的觀察結果

[英]Using a survival tree from the 'rpart' package in R to predict new observations

我正在嘗試使用 R 中的“rpart”包來構建生存樹,我希望使用這棵樹來預測其他觀察結果。

我知道有很多涉及 rpart 和預測的 SO 問題; 但是,我無法找到任何解決(我認為)特定於將 rpart 與“Surv”對象一起使用的問題。

我的特殊問題涉及解釋“預測”函數的結果。 一個例子很有幫助:

library(rpart)
library(OIsurv)

# Make Data:
set.seed(4)
dat = data.frame(X1 = sample(x = c(1,2,3,4,5), size = 1000, replace=T))
dat$t = rexp(1000, rate=dat$X1)
dat$t = dat$t / max(dat$t)
dat$e = rbinom(n = 1000, size = 1, prob = 1-dat$t )

# Survival Fit:
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
plot(sfit)

# Tree Fit:
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)

# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )

到現在為止還挺好。 我對這里發生的事情的理解是 rpart 試圖將指數生存曲線擬合到我的數據子集。 基於這種理解,我相信當我調用predict(tfit) ,對於每個觀察,我都會得到一個對應於該觀察指數曲線參數的數字。 因此,例如,如果predict(fit)[1]是 .46,那么這意味着對於我原始數據集中的第一個觀察,曲線由方程P(s) = exp(−λt) ,其中λ=.46 .

這似乎正是我想要的。 對於每個觀察(或任何新觀察),我可以獲得在給定時間點該觀察將存活/死亡的預測概率。 (編輯:我意識到這可能是一種誤解——這些曲線沒有給出存活/死亡的概率,而是給出一個間隔的存活概率。不過,這並沒有改變下面描述的問題。)

但是,當我嘗試使用指數公式時......

# Predict:
# an attempt to use the rates extracted from the tree to
# capture the survival curve formula in each tree node.
rates = unique(predict(tfit))
for (rate in rates) {
  grid= seq(0,1,length.out = 100)
  lines(x= grid, y= exp(-rate*(grid)), col=2)
}

陰謀

我在這里所做的是以與生存樹相同的方式分割數據集,然后使用survfit為這些分區中的每一個繪制非參數曲線。 那是黑線。 我還繪制了與將“速率”參數(我認為是)插入(我認為是)生存指數公式的結果相對應的線。

我知道非參數和參數擬合不一定是相同的,但這似乎不止於此:似乎我需要縮放 X 變量或其他東西。

基本上,我似乎不明白 rpart/survival 在幕后使用的公式。 任何人都可以幫助我從(1)rpart 模型到(2)任何任意觀察的生存方程?

生存數據以指數方式內部縮放,以便根節點中的預測速率始終固定為1.000 然后, predict()方法報告的predict()總是相對於根節點中的生存,即,某個因子更高或更低。 有關更多詳細信息vignette("longintro", package = "rpart")請參見vignette("longintro", package = "rpart")第8.4節vignette("longintro", package = "rpart") 在任何情況下,報告的Kaplan-Meier曲線都與rpart插圖中報告的曲線完全一致。

如果要直接獲取樹中Kaplan-Meier曲線的圖並獲得預測的中值生存時間,可以將rpart樹強制constpartypartykit包提供的constparty樹:

library("partykit")
(tfit2 <- as.party(tfit))
## Model formula:
## Surv(t, event = e) ~ X1
## 
## Fitted party:
## [1] root
## |   [2] X1 < 2.5
## |   |   [3] X1 < 1.5: 0.192 (n = 213)
## |   |   [4] X1 >= 1.5: 0.082 (n = 213)
## |   [5] X1 >= 2.5: 0.037 (n = 574)
## 
## Number of inner nodes:    2
## Number of terminal nodes: 3
##
plot(tfit2)

生存樹

打印輸出顯示中位存活時間和可視化相應的Kaplan-Meier曲線。 兩者也可以使用predict()方法獲得,分別將type參數設置為"response""prob"

predict(tfit2, type = "response")[1]
##          5 
## 0.03671885 
predict(tfit2, type = "prob")[[1]]
## Call: survfit(formula = y ~ 1, weights = w, subset = w > 0)
## 
##  records    n.max  n.start   events   median  0.95LCL  0.95UCL 
## 574.0000 574.0000 574.0000 542.0000   0.0367   0.0323   0.0408 

作為rpart生存樹的替代方案,您還可以考慮使用ctree ctree()條件推理(使用logrank得分)或使用來自partykit包的一般mob()基礎結構的完全參數生存樹的非參數生存樹。

@Achim Zeileis 的回答非常有幫助,但似乎沒有回答確切的 @jwdink 問題。 我將其理解為“如果 RPart 樹按最佳指數生存擬合分裂,那么這些擬合的絕對值的 Lambda 是多少,因此我們可以使用這些指數生存函數進行預測”。 RPart 摘要確實顯示了估計速率,但僅在假設整個群體的速率為 1 的情況下相對而言。要克服這一問題,可以擬合指數 survreg,從那里獲取引用的 lambda,然后將 RPart 預測速率乘以該數字(見下面的代碼)。

也就是說,這不是從樹中預測 RPart 中的存活率的方式。 我沒有直接在 RPart 中找到生存預測函數,但是正如 Achim 上面指出的那樣,partykit 使用 Kaplan-Meier 估計,即那些最終出現在各自最終葉子中的非參數生存。 我認為在生存隨機森林樹中也是如此,在最終葉子中使用 KM 曲線。

這個問題中的模擬數據使用指數分布,因此 KM 和指數生存曲線在設計上是相似的,但是對於不同的模擬或現實分布,通過 RPart 樹估計指數率並在最終葉子中使用 KM 曲線(相同的樹)將給出不同的存活率。

sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)

# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
table(dat$node)
s0 = survreg(Surv(t,e)~ 1, data =  dat, dist = "exponential") #-0.6175
e0 = exp(-summary(s0)$coefficients[1]); e0 #1.854
rates = unique(predict(tfit))
#1) plot K-M curves by node (black):
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )

#2) plot exponential survival with rates = e0 * RPart rates (red):
for (rate in rates) {
  grid= seq(0,1,length.out = 100)
  lines(x= grid, y= exp(-e0*rate*(grid)), col=2)
}
#3) plot partykit survival curves based on RPart tree (green)
library(partykit)
tfit2 <- as.party(tfit)
col_n = 1
for (node in names(table(dat$node))){
  predict_curve = predict(tfit2, newdata = dat[dat$node == node, ], type = "prob")  
  surv_esitmated = approxfun(predict_curve[[1]]$time, predict_curve[[1]]$surv)
  lines(x= grid, y= surv_esitmated(grid), col = 2+col_n)
  col_n=+1
}

在此處輸入圖片說明

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM