简体   繁体   中英

How to define the classification threshold as a (hyper)parameter of a learner for tuning in mlr3 package in R?

there is a function to tune threshold for say a binary classification described here: https://mlr3pipelines.mlr-org.com/reference/mlr_pipeops_t.nethreshold.html

Here's my failed attempt:

  RF_lrn <- lrn("classif.rfsrc", id = "rf", predict_type = "prob")
  RF_lrn$param_set$values = list(na.action = "na.impute", seed = -123)
  single_pred_rf = po("subsample", frac = 1, id = "rf_ss") %>>%
    po("learner", RF_lrn) %>>% po("tunethreshold")

That did not work in my mlr3 pipeline and I did not find any solution explained anywhere so here is my code:

   xgb_lrn <-
    lrn("classif.xgboost", id = "xgb", predict_type = "prob")
  single_pred_xgb = po("subsample", frac = 1, id = "xgb_ss") %>>%
    po("learner", xgb_lrn)
  
    lrnrs <- list(
      RF_lrn,
      xgb_lrn)
    
    lrnrs <- lapply(lrnrs, function(x) {
      GraphLearner$new(po("learner_cv", x) %>>% po("tunethreshold",
                                                   param_vals = list(
                                                     measure = "classif.prauc"
                                                   )
      ))
    })
    library("GenSA")
    lrnrs$RF_lrn <- auto_tuner(
      learner =  RF_lrn,
      search_space = ps(
        ntree = p_int(lower = 20, upper = 300),
        mtry = p_int(lower = 2, upper = 5),
        nodesize = p_int(lower = 1, upper = 7)
      ),
      resampling = rsmp("bootstrap", repeats = 2, ratio = 0.8),
      measure = msr("classif.prauc"),
      term_evals = 100,
      method = "random_search"
    )

which somehow works but I want the decision threshold to be tuned as a parameter the same way I tune other hyperparameters using the random search in benchmarking/cross validation. Any solution? Thanks in advance

the solution is to use po("threshold") instead of po("t.nethreshold") as suggested in the comments and this mlr gallery example

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM