简体   繁体   English

带有 mlr3proba 模型的 survxai 解释器

[英]survxai explainer with an mlr3proba model

I am trying to build a survxai explainer from a survival model built with mlr3proba.我正在尝试从使用 mlr3proba 构建的生存模型构建一个 survxai 解释器。 I'm having trouble creating the predict_function necessary for the explainer.我在创建解释器所需的 predict_function 时遇到问题。 Has anyone ever tried to build something like this?有没有人尝试过构建这样的东西?

So far, my code is the following:到目前为止,我的代码如下:

require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)

create_pipeops <- function(learner) {
  GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}

fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)

data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)

predict_function<-function(model, newdata, times=NULL){
  if(!is.data.frame(newdata)){
    newdata <- data.frame(newdata)
  }
  surv_task<-TaskSurv$new("task", newdata, time = "time", 
                          event = "status")
  pred<-model$predict(surv_task)
  mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
  colnames(mat)<-colnames(pred$data$distr)
  return(mat)
}

explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
                            y = Surv(veteran$time, veteran$status),
                            predict_function = predict_function)

pred_breakdown<-prediction_breakdown(explainer, veteran[1,])

It throws the following error: Error in [.data.table (r6_private(backend)$.data, , event, with = FALSE) : column(s) not found: status, but I suspect that once that one is solved there may be more.它引发以下错误: [.data.table (r6_private(backend)$.data, , event, with = FALSE) 中的错误:未找到列:状态,但我怀疑一旦解决了该问题,可能更多。 I don't fully understand the structure of the object that the function returns.我不完全理解函数返回的对象的结构。

In the predict_function, I included the times argument because according to the R help page, the function must take the three arguments.在 predict_function 中,我包含了times参数,因为根据 R 帮助页面,该函数必须采用三个参数。

Working example with randomForestSRC here, you can just change surv.rfsrc to surv.deepsurv for your example.此处使用 randomForestSRC 的工作示例,您可以将surv.rfsrc更改为surv.deepsurv作为您的示例。 BTW we are planning on implementing this within mlr3proba soon, or I might just add it directly to survivalmodels, still deciding!顺便说一句,我们计划很快在 mlr3proba 中实现它,或者我可能只是将它直接添加到生存模型中,仍在决定!

library(mlr3proba)
#> Loading required package: mlr3
#> Warning: package 'mlr3' was built under R version 4.1.3
library(mlr3extralearners)
#> 
#> Attaching package: 'mlr3extralearners'
#> The following objects are masked from 'package:mlr3':
#> 
#>     lrn, lrns
library(survxai)
#> Loading required package: prodlim
#> Welcome to survxai (version: 0.2.1).
#> Information about the package can be found in the GitHub repository: https://github.com/MI2DataLab/survxai
library(survival)
data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc), ]
task <- as_task_surv(pbc, event = "status", time = "days")
split <- partition(task)
predict_times <- function(model, data, times) {
  t(model$predict_newdata(data)$distr$survival(times))
}
model <- lrn("surv.rfsrc")$train(task, row_ids = split$train)
surve_cph <- explain(
  model = model, data = pbc[, -c(1, 2)],
  y = Surv(pbc$days, pbc$status),
  predict_function = predict_times
)
prediction_breakdown(surve_cph, pbc[1, -c(1, 2)])
#>             contribution
#> bili            -35.079%
#> edema           -10.278%
#> ascites          -5.505%
#> copper           -1.084%
#> stage            -0.773%
#> prothrombin      -0.421%
#> albumin          -0.247%
#> sgot             -0.143%
#> hepatom          -0.098%
#> spiders          -0.086%
#> alk              -0.043%
#> trig             -0.041%
#> age              -0.035%

Created on 2022-06-07 by the reprex package (v2.0.1)reprex 包(v2.0.1) 创建于 2022-06-07

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM