Is there a smooth way to filter a workflowsets
object? In my case I want for instance keep only those rows were the mean
for the roc_auc
is >= 0.8
. I guess I could get this result by applying the rank_results
function with some joins
but maybe there is a "cleaner" way to do this?
Thanks in advance! M.
library(titanic)
library(tidyverse)
library(tidymodels)
library(finetune)
library(themis)
#> Registered S3 methods overwritten by 'themis':
#> method from
#> bake.step_downsample recipes
#> bake.step_upsample recipes
#> prep.step_downsample recipes
#> prep.step_upsample recipes
#> tidy.step_downsample recipes
#> tidy.step_upsample recipes
#> tunable.step_downsample recipes
#> tunable.step_upsample recipes
#>
#> Attaching package: 'themis'
#> The following objects are masked from 'package:recipes':
#>
#> step_downsample, step_upsample
library(discrim)
#>
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#>
#> smoothness
options(tidymodels.dark = TRUE)
## Splitting Train / Test data
titanic_train <- as_tibble(titanic_train) %>%
mutate(Survived = factor(Survived),
Pclass = factor(Pclass, ordered = TRUE),
Sex = factor(Sex)) %>%
select(!c(Name, Ticket, Cabin, Embarked))
titanic_folds <- vfold_cv(titanic_train, v = 5, repeats = 5)
## Model Definition
rf_model <- rand_forest(mtry = tune(),
trees = 200,
min_n = tune()) %>%
set_engine("ranger") %>%
set_mode("classification")
xgb_model <- boost_tree(
trees = 200,
tree_depth = tune(),
min_n = tune(),
loss_reduction = tune(),
sample_size = tune(),
mtry = tune(),
learn_rate = tune()
) %>%
set_engine("xgboost") %>%
set_mode("classification")
naive_bayes_model <- naive_Bayes() %>%
set_mode("classification") %>%
set_engine("naivebayes")
base_rec <- recipe(Survived ~ ., data = titanic_train) %>%
update_role(PassengerId, new_role = 'id') %>%
step_impute_knn(all_predictors(), neighbors = 5) %>%
step_dummy(all_nominal_predictors()) %>%
step_downsample(Survived, seed = 123)
another_rec <- recipe(Survived ~ ., data = titanic_train) %>%
update_role(PassengerId, new_role = 'ID') %>%
step_impute_knn(all_predictors(), neighbors = 5) %>%
step_dummy(all_nominal_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_downsample(Survived, seed = 123)
titanic_models <- workflow_set(
preproc = list(
base = base_rec,
another = another_rec
),
models = list(
rf = rf_model,
xgb = xgb_model,
bayes = naive_bayes_model
),
cross = TRUE
)
titanic_models
#> [38;5;246m# A workflow set/tibble: 6 x 4[39m
#> wflow_id info option result
#> <chr> <list> <list> <list>
#> [base_rf [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [base_xgb [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [base_bayes [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_rf [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_xgb [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_bayes [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
num_cores <- parallel::detectCores() - 2
cl <- parallel::makeCluster(num_cores)
doParallel::registerDoParallel(cl = cl)
titanic_models_result <- titanic_models %>%
workflow_map(
"tune_race_anova",
resamples = titanic_folds,
grid = 4,
metrics = metric_set(accuracy, roc_auc),
verbose = TRUE,
control = control_race(
verbose = TRUE,
save_pred = TRUE,
save_workflow = TRUE
)
)
#> i 1 of 6 tuning: base_rf
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 1 of 6 tuning: base_rf (38.3s)
#> i 2 of 6 tuning: base_xgb
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 2 of 6 tuning: base_xgb (7s)
#> i No tuning parameters. `fit_resamples()` will be attempted
#> i 3 of 6 resampling: base_bayes
#> ? 3 of 6 resampling: base_bayes (3.3s)
#> i 4 of 6 tuning: another_rf
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 4 of 6 tuning: another_rf (29.8s)
#> i 5 of 6 tuning: another_xgb
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 5 of 6 tuning: another_xgb (6s)
#> i No tuning parameters. `fit_resamples()` will be attempted
#> i 6 of 6 resampling: another_bayes
#> ? 6 of 6 resampling: another_bayes (2.4s)
parallel::stopCluster(cl)
titanic_models_result
wflow_id info option result
<chr> <list> <list> <list>
1 base_rf <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
2 base_xgb <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
3 base_bayes <tibble[,4] [1 × 4]> <opts[4]> <rsmp[+]>
4 another_rf <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
5 another_xgb <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
6 another_bayes <tibble[,4] [1 × 4]> <opts[4]> <rsmp[+]>
rank_results(titanic_models_result, rank_metric = "roc_auc")
# A tibble: 20 x 9
wflow_id .config .metric mean std_err n preprocessor model
<chr> <fct> <chr> <dbl> <dbl> <int> <chr> <chr>
1 base_rf Preproces? accura? 0.816 0.00579 25 recipe rand_...
2 base_rf Preproces? roc_auc 0.870 0.00626 25 recipe rand_...
3 another_? Preproces? accura? 0.816 0.00569 25 recipe rand_...
4 another_? Preproces? roc_auc 0.870 0.00628 25 recipe rand_...
5 another_? Preproces? accura? 0.816 0.00611 25 recipe rand_...
6 another_? Preproces? roc_auc 0.869 0.00633 25 recipe rand_...
7 another_? Preproces? accura? 0.817 0.00461 25 recipe rand_...
8 another_? Preproces? roc_auc 0.869 0.00619 25 recipe rand_...
9 base_rf Preproces? accura? 0.816 0.00610 25 recipe rand_...
10 base_rf Preproces? roc_auc 0.869 0.00625 25 recipe rand_...
11 base_rf Preproces? accura? 0.817 0.00470 25 recipe rand_...
12 base_rf Preproces? roc_auc 0.869 0.00614 25 recipe rand_...
13 base_xgb Preproces? accura? 0.766 0.00737 25 recipe boost...
14 base_xgb Preproces? roc_auc 0.836 0.00646 25 recipe boost...
15 another_? Preproces? accura? 0.766 0.00737 25 recipe boost...
16 another_? Preproces? roc_auc 0.836 0.00646 25 recipe boost...
17 base_bay? Preproces? accura? 0.757 0.00692 25 recipe naive...
18 base_bay? Preproces? roc_auc 0.812 0.00640 25 recipe naive...
19 another_? Preproces? accura? 0.756 0.00704 25 recipe naive...
20 another_? Preproces? roc_auc 0.811 0.00636 25 recipe naive...
# ... with 1 more variable: rank <int>
Created on 2021-10-13 by the reprex package (v2.0.0)
There are a couple ways you could approach this, but certainly using rank_results()
would be a good way to go.
You could use collect_metrics()
, then filter()
to find the workflows that fit your conditions.
Don't forget that you can extract_*()
various components of the workflowset(s).
Check out what's currently implemented in workflowsets to see if anything else fits your particular use case.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.