简体   繁体   中英

Difference between scikit-learn and caret GBM results?

I'm getting drastically different F1 scores with the same input data with scikit-learn and caret. Here's how I'm running a GBM model for each.

scikit-learn (F1 is default output)

est = GradientBoostingClassifier(n_estimators = 4000, learning_rate = 0.1, max_depth = 5, max_features = 'log2', random_state = 0)
cv = StratifiedKFold(y = labels, n_folds = 10, shuffle = True, random_state = 0)
scores = cross_val_score(est, data, labels, scoring = 'f1', cv, n_jobs = -1)

caret (F1 must be defined and called):

f1 <- function(data, lev = NULL, model = NULL) {
      f1_val <- F1_Score(y_pred = data$pred, y_true = data$obs, positive = lev[1])
      c("F1" = f1_val)
 }
set.seed(0)
gbm <- train(label ~ ., 
           data = data, 
           method = "gbm",
           trControl = trainControl(method = "repeatedcv", number = 10, repeats = 3, 
                                    summaryFunction = f1, classProbs = TRUE),
           metric = "F1",
           verbose = FALSE)

From the above code, I get an F1 score of ~0.8 using scikit-learn and ~0.25 using caret. A small difference might be attributed to algorithm differences, but I must be doing something wrong with the caret modeling to get the massive difference I'm seeing here. I'd prefer not to post my data set, so hopefully the issue can be diagnosed from the code. Any help would be much appreciated.

GBT is an ensemble of decision trees. The difference comes from:

  • The number of decision trees in the ensemble ( n_estimators = 4000 vs. n.trees = 100 ).
  • The shape (breadth, depth) of individual decision trees ( max_depth = 5 vs. interaction.depth = 1 ).

Currently, you're comparing the F1 score of a 100 MB GradientBoostingClassifier object with a 100 kB gbm object - one GBT model contains literally thousands of times more information than the other.

You may wish to export both models to the standardized PMML representation using sklearn2pmml and r2pmml packages, and look inside the resulting PMML files (plain text, so can be opened in any text editor) to better grasp their internal structure.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM