[英]survxai explainer with an mlr3proba model
I am trying to build a survxai explainer from a survival model built with mlr3proba.我正在尝试从使用 mlr3proba 构建的生存模型构建一个 survxai 解释器。 I'm having trouble creating the predict_function necessary for the explainer.
我在创建解释器所需的 predict_function 时遇到问题。 Has anyone ever tried to build something like this?
有没有人尝试过构建这样的东西?
So far, my code is the following:到目前为止,我的代码如下:
require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)
create_pipeops <- function(learner) {
GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}
fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)
data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)
predict_function<-function(model, newdata, times=NULL){
if(!is.data.frame(newdata)){
newdata <- data.frame(newdata)
}
surv_task<-TaskSurv$new("task", newdata, time = "time",
event = "status")
pred<-model$predict(surv_task)
mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
colnames(mat)<-colnames(pred$data$distr)
return(mat)
}
explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
y = Surv(veteran$time, veteran$status),
predict_function = predict_function)
pred_breakdown<-prediction_breakdown(explainer, veteran[1,])
It throws the following error: Error in [.data.table
(r6_private(backend)$.data, , event, with = FALSE) : column(s) not found: status, but I suspect that once that one is solved there may be more.它引发以下错误:
[.data.table
(r6_private(backend)$.data, , event, with = FALSE) 中的错误:未找到列:状态,但我怀疑一旦解决了该问题,可能更多。 I don't fully understand the structure of the object that the function returns.我不完全理解函数返回的对象的结构。
In the predict_function, I included the times
argument because according to the R help page, the function must take the three arguments.在 predict_function 中,我包含了
times
参数,因为根据 R 帮助页面,该函数必须采用三个参数。
Working example with randomForestSRC here, you can just change surv.rfsrc
to surv.deepsurv
for your example.此处使用 randomForestSRC 的工作示例,您可以将
surv.rfsrc
更改为surv.deepsurv
作为您的示例。 BTW we are planning on implementing this within mlr3proba soon, or I might just add it directly to survivalmodels, still deciding!顺便说一句,我们计划很快在 mlr3proba 中实现它,或者我可能只是将它直接添加到生存模型中,仍在决定!
library(mlr3proba)
#> Loading required package: mlr3
#> Warning: package 'mlr3' was built under R version 4.1.3
library(mlr3extralearners)
#>
#> Attaching package: 'mlr3extralearners'
#> The following objects are masked from 'package:mlr3':
#>
#> lrn, lrns
library(survxai)
#> Loading required package: prodlim
#> Welcome to survxai (version: 0.2.1).
#> Information about the package can be found in the GitHub repository: https://github.com/MI2DataLab/survxai
library(survival)
data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc), ]
task <- as_task_surv(pbc, event = "status", time = "days")
split <- partition(task)
predict_times <- function(model, data, times) {
t(model$predict_newdata(data)$distr$survival(times))
}
model <- lrn("surv.rfsrc")$train(task, row_ids = split$train)
surve_cph <- explain(
model = model, data = pbc[, -c(1, 2)],
y = Surv(pbc$days, pbc$status),
predict_function = predict_times
)
prediction_breakdown(surve_cph, pbc[1, -c(1, 2)])
#> contribution
#> bili -35.079%
#> edema -10.278%
#> ascites -5.505%
#> copper -1.084%
#> stage -0.773%
#> prothrombin -0.421%
#> albumin -0.247%
#> sgot -0.143%
#> hepatom -0.098%
#> spiders -0.086%
#> alk -0.043%
#> trig -0.041%
#> age -0.035%
Created on 2022-06-07 by the reprex package (v2.0.1)由reprex 包(v2.0.1) 创建于 2022-06-07
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.