简体   繁体   中英

Tidymodels: Filter workflowsets based on results

Is there a smooth way to filter a workflowsets object? In my case I want for instance keep only those rows were the mean for the roc_auc is >= 0.8 . I guess I could get this result by applying the rank_results function with some joins but maybe there is a "cleaner" way to do this?

Thanks in advance! M.

library(titanic)
library(tidyverse)
library(tidymodels)
library(finetune)
library(themis)
#> Registered S3 methods overwritten by 'themis':
#>   method                  from   
#>   bake.step_downsample    recipes
#>   bake.step_upsample      recipes
#>   prep.step_downsample    recipes
#>   prep.step_upsample      recipes
#>   tidy.step_downsample    recipes
#>   tidy.step_upsample      recipes
#>   tunable.step_downsample recipes
#>   tunable.step_upsample   recipes
#> 
#> Attaching package: 'themis'
#> The following objects are masked from 'package:recipes':
#> 
#>     step_downsample, step_upsample
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness

options(tidymodels.dark = TRUE)

## Splitting Train / Test data

titanic_train <- as_tibble(titanic_train) %>% 
  mutate(Survived = factor(Survived),
         Pclass = factor(Pclass, ordered = TRUE),
         Sex = factor(Sex)) %>% 
  select(!c(Name, Ticket, Cabin, Embarked))

titanic_folds <- vfold_cv(titanic_train, v = 5, repeats = 5)

## Model Definition

rf_model <- rand_forest(mtry = tune(),
                        trees = 200,
                        min_n = tune()) %>%
  set_engine("ranger") %>%
  set_mode("classification")

xgb_model <- boost_tree(
  trees = 200,
  tree_depth = tune(),
  min_n = tune(),
  loss_reduction = tune(),
  sample_size = tune(),
  mtry = tune(),
  learn_rate = tune()
) %>%
  set_engine("xgboost") %>%
  set_mode("classification")

naive_bayes_model <- naive_Bayes() %>%
  set_mode("classification") %>%
  set_engine("naivebayes")

base_rec <-  recipe(Survived ~ ., data = titanic_train) %>%
  update_role(PassengerId, new_role = 'id') %>%
  step_impute_knn(all_predictors(), neighbors = 5) %>%
  step_dummy(all_nominal_predictors()) %>% 
  step_downsample(Survived, seed = 123)

another_rec <- recipe(Survived ~ ., data = titanic_train) %>%
  update_role(PassengerId, new_role = 'ID') %>%
  step_impute_knn(all_predictors(), neighbors = 5) %>%
  step_dummy(all_nominal_predictors()) %>% 
  step_normalize(all_numeric_predictors()) %>% 
  step_downsample(Survived, seed = 123) 

titanic_models <- workflow_set(
  preproc = list(
    base = base_rec,
    another = another_rec
  ),
  models = list(
    rf = rf_model,
    xgb = xgb_model,
    bayes = naive_bayes_model
  ),
  cross = TRUE
)

titanic_models
#> [38;5;246m# A workflow set/tibble: 6 x 4[39m
#>   wflow_id      info                 option    result    
#>  <chr>         <list>               <list>    <list>    
#> [base_rf       [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [base_xgb      [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [base_bayes    [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_rf    [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_xgb   [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_bayes [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>

num_cores <- parallel::detectCores() - 2

cl <- parallel::makeCluster(num_cores)
doParallel::registerDoParallel(cl = cl)

titanic_models_result <- titanic_models %>%
  workflow_map(
    "tune_race_anova",
    resamples = titanic_folds,
    grid = 4,
    metrics = metric_set(accuracy, roc_auc),
    verbose = TRUE,
    control = control_race(
      verbose = TRUE,
      save_pred = TRUE,
      save_workflow = TRUE
    )
  )
#> i 1 of 6 tuning:     base_rf
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 1 of 6 tuning:     base_rf (38.3s)
#> i 2 of 6 tuning:     base_xgb
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 2 of 6 tuning:     base_xgb (7s)
#> i    No tuning parameters. `fit_resamples()` will be attempted
#> i 3 of 6 resampling: base_bayes
#> ? 3 of 6 resampling: base_bayes (3.3s)
#> i 4 of 6 tuning:     another_rf
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 4 of 6 tuning:     another_rf (29.8s)
#> i 5 of 6 tuning:     another_xgb
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 5 of 6 tuning:     another_xgb (6s)
#> i    No tuning parameters. `fit_resamples()` will be attempted
#> i 6 of 6 resampling: another_bayes
#> ? 6 of 6 resampling: another_bayes (2.4s)

parallel::stopCluster(cl)

titanic_models_result
 wflow_id      info                 option    result   
  <chr>         <list>               <list>    <list>   
1 base_rf       <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
2 base_xgb      <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
3 base_bayes    <tibble[,4] [1 × 4]> <opts[4]> <rsmp[+]>
4 another_rf    <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
5 another_xgb   <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
6 another_bayes <tibble[,4] [1 × 4]> <opts[4]> <rsmp[+]>

rank_results(titanic_models_result, rank_metric = "roc_auc")
# A tibble: 20 x 9
   wflow_id  .config    .metric  mean std_err     n preprocessor model 
   <chr>     <fct>      <chr>   <dbl>   <dbl> <int> <chr>        <chr> 
 1 base_rf   Preproces? accura? 0.816 0.00579    25 recipe       rand_...
 2 base_rf   Preproces? roc_auc 0.870 0.00626    25 recipe       rand_...
 3 another_? Preproces? accura? 0.816 0.00569    25 recipe       rand_...
 4 another_? Preproces? roc_auc 0.870 0.00628    25 recipe       rand_...
 5 another_? Preproces? accura? 0.816 0.00611    25 recipe       rand_...
 6 another_? Preproces? roc_auc 0.869 0.00633    25 recipe       rand_...
 7 another_? Preproces? accura? 0.817 0.00461    25 recipe       rand_...
 8 another_? Preproces? roc_auc 0.869 0.00619    25 recipe       rand_...
 9 base_rf   Preproces? accura? 0.816 0.00610    25 recipe       rand_...
10 base_rf   Preproces? roc_auc 0.869 0.00625    25 recipe       rand_...
11 base_rf   Preproces? accura? 0.817 0.00470    25 recipe       rand_...
12 base_rf   Preproces? roc_auc 0.869 0.00614    25 recipe       rand_...
13 base_xgb  Preproces? accura? 0.766 0.00737    25 recipe       boost...
14 base_xgb  Preproces? roc_auc 0.836 0.00646    25 recipe       boost...
15 another_? Preproces? accura? 0.766 0.00737    25 recipe       boost...
16 another_? Preproces? roc_auc 0.836 0.00646    25 recipe       boost...
17 base_bay? Preproces? accura? 0.757 0.00692    25 recipe       naive...
18 base_bay? Preproces? roc_auc 0.812 0.00640    25 recipe       naive...
19 another_? Preproces? accura? 0.756 0.00704    25 recipe       naive...
20 another_? Preproces? roc_auc 0.811 0.00636    25 recipe       naive...
# ... with 1 more variable: rank <int>

Created on 2021-10-13 by the reprex package (v2.0.0)

There are a couple ways you could approach this, but certainly using rank_results() would be a good way to go.

  • You could use collect_metrics() , then filter() to find the workflows that fit your conditions.

  • Don't forget that you can extract_*() various components of the workflowset(s).

Check out what's currently implemented in workflowsets to see if anything else fits your particular use case.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM