简体   繁体   中英

Why do DALEX and tidymodels provide different GOF?

I wonder why DALEX model_performance and collect_metrics do not provide the same accuracy. Do they use different measures or different methods? I've compiled the following example code:

library(tidymodels)
library(parsnip)
library(DALEXtra)

set.seed(1)
x1 <- rbinom(1000, 5, .1)
x2 <- rbinom(1000, 5, .4)
x3 <- rbinom(1000, 5, .9)
x4 <- rbinom(1000, 5, .6)
id <- c(1:1000)
y <- as.factor(rbinom(1000, 5, .5))
df <- tibble(y, x1, x2, x3, x4, id)


# create training and test set
set.seed(20)
split_dat <- initial_split(df, prop = 0.8)
train <- training(split_dat)
test <- testing(split_dat)
# use cross-validation
kfolds <- vfold_cv(df)

# recipe
rec_pca <- recipe(y ~ ., data = train) %>%
  update_role(id, new_role = "id variable") %>%
  step_center(all_predictors()) %>%
  step_scale(all_predictors()) %>%
  step_pca(x1, x2, x3, threshold = 0.9, num_comp = 1)

# parsnip engine
boost_model <- boost_tree() %>% 
  set_mode("classification") %>% 
  set_engine("xgboost")

# create wf
boosted_wf <- 
  workflow() %>% 
  add_model(boost_model) %>% 
  add_recipe(rec_pca)

boosted_res <- last_fit(boosted_wf, split_dat)
collect_metrics(boosted_res)

Output of collect_metrics is 0.31

# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy multiclass     0.31  Preprocessor1_Model1
2 roc_auc  hand_till      0.512 Preprocessor1_Model1

Continuing to prepare for DALEX model explanation.

final_boosted <- generics::fit(boosted_wf, df) 

# create an explanation object
explainer_xgb <- DALEXtra::explain_tidymodels(final_boosted, 
                                              data = df[,-1], 
                                              y = df$y) 

perf <- model_performance(explainer_xgb)
perf

Now this provides the following output for the overall fit:

Measures for:  multiclass
micro_F1   : 0.43 
macro_F1   : 0.5743392 
w_macro_F1 : 0.4775901 
accuracy   : 0.43 
w_macro_auc: 0.7064296

Note that accuracy is 0.43 using model_performance and 0.31 using collect_metrics . Does anyone know why this is the case?

I believe it is because different resampling indicies/schemes are being used. In other words, different data are being used to compute the performance statistics.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM