简体   繁体   中英

Tidymodels prediction methods giving different results

I'm a bit confused about getting metrics from resamples using tidymodels.

I seem to be getting 3 different metrics from the same set of resamples, depending on if I use collect_predictions() %>% metrics() or simply collect_metrics()

Here is a simple example...

library(tidyverse)
library(tidymodels)

starwars_df <- starwars %>% select(name:sex) %>% drop_na()

lasso_linear_reg_glmnet_spec <-
  linear_reg(penalty = .1, mixture = 1) %>%
  set_engine('glmnet')

basic_rec <-
  recipe(mass ~ height  + sex + skin_color,
         data = starwars_df) %>% 
  step_novel(all_nominal_predictors()) %>%
  step_other(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_nzv(all_predictors())

sw_wf <- workflow() %>% 
  add_recipe(basic_rec) %>% 
  add_model(lasso_linear_reg_glmnet_spec)

sw_boots <-  bootstraps(starwars_df, times = 50)

resampd <- fit_resamples(
  sw_wf,
  sw_boots,
  control = control_resamples(save_pred = TRUE)
)

The following three lines give different results

resampd %>% collect_predictions(resampd, summarize = T) %>% metrics(mass, .pred)
resampd %>% collect_predictions(resampd, summarize = F) %>% metrics(mass, .pred)
resampd %>% collect_metrics()
 

As an additional question, what would be the best/correct way to get confidence intervals for the rmse in the above example. Here is one way...

individ_metrics <- resampd %>% collect_predictions() %>% group_by(id) %>% rmse(mass, .pred) 
confintr::ci_mean(individ_metrics$.estimate)
mean(individ_metrics$.estimate)

Thanks!

The reason that none of those are the same is they are not aggregated in the same way. It turns that taking a mean of a set of means doesn't give you the same (right) result as taking the mean of the whole underlying set. If you were to do something like resampd %>% collect_predictions(summarize = TRUE) %>% metrics(mass, .pred) , that is like taking a mean of a set of means.

It turns out that these two things are the same:

## these are the same:
resampd %>% 
    collect_predictions(summarize = FALSE) %>% 
    group_by(id) %>% 
    metrics(mass, .pred)
#> # A tibble: 150 × 4
#>    id          .metric .estimator .estimate
#>    <chr>       <chr>   <chr>          <dbl>
#>  1 Bootstrap01 rmse    standard       16.4 
#>  2 Bootstrap02 rmse    standard       23.1 
#>  3 Bootstrap03 rmse    standard       31.6 
#>  4 Bootstrap04 rmse    standard       17.6 
#>  5 Bootstrap05 rmse    standard        9.59
#>  6 Bootstrap06 rmse    standard       25.0 
#>  7 Bootstrap07 rmse    standard       16.3 
#>  8 Bootstrap08 rmse    standard       35.1 
#>  9 Bootstrap09 rmse    standard       25.7 
#> 10 Bootstrap10 rmse    standard       25.3 
#> # … with 140 more rows
resampd %>% collect_metrics(summarize = FALSE)
#> # A tibble: 100 × 5
#>    id          .metric .estimator .estimate .config             
#>    <chr>       <chr>   <chr>          <dbl> <chr>               
#>  1 Bootstrap01 rmse    standard      16.4   Preprocessor1_Model1
#>  2 Bootstrap01 rsq     standard       0.799 Preprocessor1_Model1
#>  3 Bootstrap02 rmse    standard      23.1   Preprocessor1_Model1
#>  4 Bootstrap02 rsq     standard       0.193 Preprocessor1_Model1
#>  5 Bootstrap03 rmse    standard      31.6   Preprocessor1_Model1
#>  6 Bootstrap03 rsq     standard       0.608 Preprocessor1_Model1
#>  7 Bootstrap04 rmse    standard      17.6   Preprocessor1_Model1
#>  8 Bootstrap04 rsq     standard       0.836 Preprocessor1_Model1
#>  9 Bootstrap05 rmse    standard       9.59  Preprocessor1_Model1
#> 10 Bootstrap05 rsq     standard       0.860 Preprocessor1_Model1
#> # … with 90 more rows

Created on 2022-08-23 with reprex v2.0.2

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM