简体   繁体   中英

Tidymodels: How to extra importance from training data

I have the following code, where I do some grid search for different mtry and min_n. I know how to extract the parameters that give the highest accuracy (see second code box). How can I extract the importance of each feature in the training dataset? The guides I found online show how to do it only in the test dataset using "last_fit". Eg of guide: https://www.tidymodels.org/start/case-study/#data-split

set.seed(seed_number)
    data_split <- initial_split(node_strength,prop = 0.8,strata = Group)
    
    train <- training(data_split)
    test <- testing(data_split)
    train_folds <- vfold_cv(train,v = 10)
    
    
    rfc <- rand_forest(mode = "classification", mtry = tune(),
                       min_n = tune(), trees = 1500) %>%
        set_engine("ranger", num.threads = 48, importance = "impurity")
    
    rfc_recipe <- recipe(data = train, Group~.)
    
    rfc_workflow <- workflow() %>% add_model(rfc) %>%
        add_recipe(rfc_recipe)
    
    rfc_result <- rfc_workflow %>%
        tune_grid(train_folds, grid = 40, control = control_grid(save_pred = TRUE),
                  metrics = metric_set(accuracy))

.

best <- 
        rfc_result %>% 
        select_best(metric = "accuracy")

To do this, you will want to create a custom extract function, as outlined in this documentation .

For random forest variable importance, your function will look something like this:

get_rf_imp <- function(x) {
    x %>% 
        extract_fit_parsnip() %>% 
        vip::vi()
}

And then you can apply it to your resamples like so (notice that you get a new .extracts column):

library(tidymodels)
data(cells, package = "modeldata")

set.seed(123)
cell_split <- cells %>% select(-case) %>%
    initial_split(strata = class)
cell_train <- training(cell_split)
cell_test  <- testing(cell_split)
folds <- vfold_cv(cell_train)            

rf_spec <- rand_forest(mode = "classification") %>%
    set_engine("ranger", importance = "impurity")

ctrl_imp <- control_grid(extract = get_rf_imp)

cells_res <-
    workflow(class ~ ., rf_spec) %>%
    fit_resamples(folds, control = ctrl_imp)
cells_res
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits             id     .metrics         .notes           .extracts       
#>    <list>             <chr>  <list>           <list>           <list>          
#>  1 <split [1362/152]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  2 <split [1362/152]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  3 <split [1362/152]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  4 <split [1362/152]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  5 <split [1363/151]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  6 <split [1363/151]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  7 <split [1363/151]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  8 <split [1363/151]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  9 <split [1363/151]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 10 <split [1363/151]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>

Created on 2022-06-19 by the reprex package (v2.0.1)

Once you have those variable importance score extracts, you can unnest() them (right now, you have to do this twice because it is deeply nested) and then you can summarize and visualize as you prefer:

cells_res %>%
    select(id, .extracts) %>%
    unnest(.extracts) %>%
    unnest(.extracts) %>%
    group_by(Variable) %>%
    summarise(Mean = mean(Importance),
              Variance = sd(Importance)) %>%
    slice_max(Mean, n = 15) %>%
    ggplot(aes(Mean, reorder(Variable, Mean))) +
    geom_crossbar(aes(xmin = Mean - Variance, xmax = Mean + Variance)) +
    labs(x = "Variable importance", y = NULL)

Created on 2022-06-19 by the reprex package (v2.0.1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM