简体   繁体   English

如何使用 mlr3 重新采样存储模型获得对新数据的样本外预测?

[英]How to obtain out of sample predictions on new data with mlr3 resample stored models?

I want to use mlr3 for cross-fitting of nuisance parameters in a semi-parametric model such as TMLE or AIPW.我想使用 mlr3 交叉拟合半参数 model(如 TMLE 或 AIPW)中的有害参数。 The cross-fitting procedure is similar to k-fold cross-validation;交叉拟合过程类似于 k 折交叉验证; split the data into K sets of somewhat equal size, obtain predictions for each group using the data in the remaining groups for model training.将数据分成大小相等的 K 组,使用剩余组中的数据获得每个组的预测,以进行 model 训练。 However, with cross-fitting, I'm not interested in model evaluation.但是,通过交叉拟合,我对 model 评估不感兴趣。 Instead, I need to reuse the K models to produce out-of-sample predictions to relax certain assumptions necessary for valid statistical inference with machine learning estimators.相反,我需要重用 K 个模型来生成样本外预测,以放宽使用机器学习估计器进行有效统计推断所必需的某些假设。

I'd like to use resample from mlr3 for this.我想为此使用来自 mlr3 的resample

require(mlr3verse)

# Create some data
set.seed(5434)
n <- 250
W <- matrix(rnorm(n*3), ncol=3)
A <- rbinom(n,1, 1/(1+exp(-(.2*W[,1] - .1*W[,2] + .4*W[,3]))))
Y <- A + 2*W[,1] + W[,3] + W[,2]^2 + rnorm(n)

dat <- data.frame(W, A, Y)

# Creating a Task with 2 pre-defined folds
K <- 2
folds <- sample(rep(1:K, length.out = n),
                size = n,
                replace = FALSE)
dat[, "fold_id"] <- folds

task <- as_task_regr(dat, "Y", "foo_task")
task$col_roles$group <- "fold_id"
task$col_roles$feature <- setdiff(task$col_roles$feature, "fold_id")

# Create a light gbm learner object
learn_gbm <- lrn("regr.lightgbm")

# Repeatedely train the learner K times and store the models
cv <- rsmp("cv", folds = K)
rr <- resample(task, learn_gbm, cv, store_models = TRUE)

From here, I'd like to use the stored models to predict on modified versions of dat (ie, A is set to 1) of the K test-sets:从这里开始,我想使用存储的模型来预测 K 测试集的dat的修改版本(即 A 设置为 1):

# Creating a copy of the dat where A is always 1
# Want to obtain out-of-sample predictions of Y on this data, dat_1
dat_1 <- dat
dat_1$A <- 1

# Using the first fold as an example
predict(rr$learners[[1]], newdata = dat_1[rr$resampling$test_set(1), ])

It seems like I can't use the stored models to predict on new data and I get this error:似乎我无法使用存储的模型来预测新数据,并且出现此错误:

Error: No task stored, and no task provided

How can I get these predictions with resample() ?如何使用resample()获得这些预测?

Session info Session 信息

Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.4

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] mlr3verse_0.2.5 mlr3_0.14.0    

loaded via a namespace (and not attached):
 [1] tidyselect_1.1.2              clusterCrit_1.2.8             purrr_0.3.4                  
 [4] listenv_0.8.0                 lattice_0.20-45               mlr3cluster_0.1.4            
 [7] colorspace_2.0-3              vctrs_0.4.1                   generics_0.1.3               
[10] bbotk_0.5.4                   paradox_0.10.0                utf8_1.2.2                   
[13] rlang_1.0.4                   pillar_1.8.0                  glue_1.6.2                   
[16] withr_2.5.0                   DBI_1.1.3                     palmerpenguins_0.1.1         
[19] uuid_1.1-0                    prompt_1.0.1                  mlr3fselect_0.7.2            
[22] lifecycle_1.0.1               mlr3learners_0.5.4            munsell_0.5.0                
[25] gtable_0.3.0                  progressr_0.10.1              future_1.27.0                
[28] codetools_0.2-18              mlr3data_0.6.1                parallel_4.2.1               
[31] fansi_1.0.3                   mlr3tuningspaces_0.3.0        scales_1.2.0                 
[34] backports_1.4.1               checkmate_2.1.0               mlr3filters_0.5.0            
[37] mlr3viz_0.5.10                mlr3tuning_0.14.0             jsonlite_1.8.0               
[40] lightgbm_3.3.2                parallelly_1.32.1             ggplot2_3.3.6                
[43] digest_0.6.29                 dplyr_1.0.9                   mlr3extralearners_0.5.46-9000
[46] grid_4.2.1                    clue_0.3-61                   cli_3.3.0                    
[49] tools_4.2.1                   magrittr_2.0.3                tibble_3.1.7                 
[52] cluster_2.1.3                 mlr3misc_0.10.0               future.apply_1.9.0           
[55] crayon_1.5.1                  pkgconfig_2.0.3               Matrix_1.4-1                 
[58] ellipsis_0.3.2                data.table_1.14.2             mlr3pipelines_0.4.1          
[61] assertthat_0.2.1              rstudioapi_0.13               lgr_0.4.3                    
[64] R6_2.5.1                      globals_0.16.1                compiler_4.2.1 

You need to pass the model to the predict() function, ie您需要将 model 传递给predict() function,即

predict(rr$learners[[1]]$model, newdata = dat_1[rr$resampling$test_set(1), ])

mlr3 "models" (ie trained learners) are wrapped versions of the underlying trained models, so you have to pull them out explicitly. mlr3 “模型”(即训练有素的学习者)是底层训练模型的包装版本,因此您必须明确地将它们拉出来。

Reading https://mlr3.mlr-org.com/reference/Learner.html#method-predict-newdata- more closely, I found更仔细地阅读https://mlr3.mlr-org.com/reference/Learner.html#method-predict-newdata- ,我发现

If the learner has been fitted via resample() or benchmark(), you need to pass the corresponding task stored in the ResampleResult or BenchmarkResult, respectively.如果已经通过 resample() 或 benchmark() 拟合了学习器,则需要分别传递存储在 ResampleResult 或 BenchmarkResult 中的相应任务。

Trying,试,

rr$learners[[1]]$predict_newdata(dat_1[rr$resampling$test_set(1), ], task = rr$task)

produced,产生,

<PredictionRegr> for 125 observations:
    row_ids     truth  response
          1  6.360351 5.4017536
          2  1.104371 2.1685862
          3  1.445641 1.6610209
---                            
        123  5.533508 2.0075193
        124  1.927375 4.5808117
        125 -1.634550 0.7129312

Which looks like what I needed.这看起来像我需要的。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM