[英]Tidymodels prediction methods giving different results
我對使用 tidymodels 從重新采樣中獲取指標有點困惑。
我似乎從同一組重新采樣中獲得了 3 個不同的指標,具體取決於我是使用 collect_predictions() %>% metrics() 還是只是 collect_metrics()
這是一個簡單的例子......
library(tidyverse)
library(tidymodels)
starwars_df <- starwars %>% select(name:sex) %>% drop_na()
lasso_linear_reg_glmnet_spec <-
linear_reg(penalty = .1, mixture = 1) %>%
set_engine('glmnet')
basic_rec <-
recipe(mass ~ height + sex + skin_color,
data = starwars_df) %>%
step_novel(all_nominal_predictors()) %>%
step_other(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_nzv(all_predictors())
sw_wf <- workflow() %>%
add_recipe(basic_rec) %>%
add_model(lasso_linear_reg_glmnet_spec)
sw_boots <- bootstraps(starwars_df, times = 50)
resampd <- fit_resamples(
sw_wf,
sw_boots,
control = control_resamples(save_pred = TRUE)
)
以下三行給出不同的結果
resampd %>% collect_predictions(resampd, summarize = T) %>% metrics(mass, .pred)
resampd %>% collect_predictions(resampd, summarize = F) %>% metrics(mass, .pred)
resampd %>% collect_metrics()
作為一個附加問題,在上面的示例中,獲得 rmse 置信區間的最佳/正確方法是什么。 這是一種方法...
individ_metrics <- resampd %>% collect_predictions() %>% group_by(id) %>% rmse(mass, .pred)
confintr::ci_mean(individ_metrics$.estimate)
mean(individ_metrics$.estimate)
謝謝!
這些都不相同的原因是它們沒有以相同的方式聚合。 事實證明, 取一組平均值的平均值並不會給你與取整個基礎集合的平均值相同的(正確的)結果。 如果您要執行類似resampd %>% collect_predictions(summarize = TRUE) %>% metrics(mass, .pred)
之類的操作,這就像取一組均值的平均值。
事實證明,這兩件事是一樣的:
## these are the same:
resampd %>%
collect_predictions(summarize = FALSE) %>%
group_by(id) %>%
metrics(mass, .pred)
#> # A tibble: 150 × 4
#> id .metric .estimator .estimate
#> <chr> <chr> <chr> <dbl>
#> 1 Bootstrap01 rmse standard 16.4
#> 2 Bootstrap02 rmse standard 23.1
#> 3 Bootstrap03 rmse standard 31.6
#> 4 Bootstrap04 rmse standard 17.6
#> 5 Bootstrap05 rmse standard 9.59
#> 6 Bootstrap06 rmse standard 25.0
#> 7 Bootstrap07 rmse standard 16.3
#> 8 Bootstrap08 rmse standard 35.1
#> 9 Bootstrap09 rmse standard 25.7
#> 10 Bootstrap10 rmse standard 25.3
#> # … with 140 more rows
resampd %>% collect_metrics(summarize = FALSE)
#> # A tibble: 100 × 5
#> id .metric .estimator .estimate .config
#> <chr> <chr> <chr> <dbl> <chr>
#> 1 Bootstrap01 rmse standard 16.4 Preprocessor1_Model1
#> 2 Bootstrap01 rsq standard 0.799 Preprocessor1_Model1
#> 3 Bootstrap02 rmse standard 23.1 Preprocessor1_Model1
#> 4 Bootstrap02 rsq standard 0.193 Preprocessor1_Model1
#> 5 Bootstrap03 rmse standard 31.6 Preprocessor1_Model1
#> 6 Bootstrap03 rsq standard 0.608 Preprocessor1_Model1
#> 7 Bootstrap04 rmse standard 17.6 Preprocessor1_Model1
#> 8 Bootstrap04 rsq standard 0.836 Preprocessor1_Model1
#> 9 Bootstrap05 rmse standard 9.59 Preprocessor1_Model1
#> 10 Bootstrap05 rsq standard 0.860 Preprocessor1_Model1
#> # … with 90 more rows
使用reprex v2.0.2創建於 2022-08-23
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.