简体   繁体   中英

SHAP Summary Plot for XGBoost model in R without displaying Mean Absolute SHAP value on the plot

I don't want to display the Mean Absolute Values on my SHAP Summary Plot in R. I want an output similar to the one produced in python. What line of code will help remove the mean absolute values from the summary plot in R?

I'm currently using this line of code:

shap.plot.summary.wrap1(xgb_model, X = x, top_n = 10)

You can do this by sligtly modifying the source code of shap.plot.summary() as below:

shap.plot.summary.edited <- function(data_long,
                          x_bound = NULL,
                          dilute = FALSE,
                          scientific = FALSE,
                          my_format = NULL){
  if (scientific){label_format = "%.1e"} else {label_format = "%.3f"}



if (!is.null(my_format)) label_format <- my_format
  # check number of observations
  N_features <- setDT(data_long)[,uniqueN(variable)]
  if (is.null(dilute)) dilute = FALSE
  nrow_X <- nrow(data_long)/N_features # n per feature
  if (dilute!=0){
    # if nrow_X <= 10, no dilute happens
    dilute <- ceiling(min(nrow_X/10, abs(as.numeric(dilute)))) # not allowed to dilute to fewer than 10 obs/feature
    set.seed(1234)
    data_long <- data_long[sample(nrow(data_long),
                                  min(nrow(data_long)/dilute, nrow(data_long)/2))] # dilute
  }
  x_bound <- if (is.null(x_bound)) max(abs(data_long$value))*1.1 else as.numeric(abs(x_bound))
  plot1 <- ggplot(data = data_long) +
    coord_flip(ylim = c(-x_bound, x_bound)) +
    geom_hline(yintercept = 0) + # the y-axis beneath
    # sina plot:
    ggforce::geom_sina(aes(x = variable, y = value, color = stdfvalue),
                       method = "counts", maxwidth = 0.7, alpha = 0.7) +
    # print the mean absolute value:
    #geom_text(data = unique(data_long[, c("variable", "mean_value")]),
    #          aes(x = variable, y=-Inf, label = sprintf(label_format, mean_value)),
    #          size = 3, alpha = 0.7,
    #          hjust = -0.2,
    #          fontface = "bold") + # bold
    # # add a "SHAP" bar notation
    # annotate("text", x = -Inf, y = -Inf, vjust = -0.2, hjust = 0, size = 3,
    #          label = expression(group("|", bar(SHAP), "|"))) +
    scale_color_gradient(low="#FFCC33", high="#6600CC",
                         breaks=c(0,1), labels=c(" Low","High "),
                         guide = guide_colorbar(barwidth = 12, barheight = 0.3)) +
    theme_bw() +
    theme(axis.line.y = element_blank(),
          axis.ticks.y = element_blank(), # remove axis line
          legend.position="bottom",
          legend.title=element_text(size=10),
          legend.text=element_text(size=8),
          axis.title.x= element_text(size = 10)) +
    # reverse the order of features, from high to low
    # also relabel the feature using `label.feature`
    scale_x_discrete(limits = rev(levels(data_long$variable))#,
                     #labels = label.feature(rev(levels(data_long$variable)))
                     )+
    labs(y = "SHAP value (impact on model output)", x = "", color = "Feature value  ")
  return(plot1)
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM