簡體   English   中英

SHAP 摘要 Plot 用於 R 中的 XGBoost model 而不在 Z322FA6E1BE7056D 上顯示平均絕對 SHAP 值

[英]SHAP Summary Plot for XGBoost model in R without displaying Mean Absolute SHAP value on the plot

我不想在 R 中的 SHAP 摘要 Plot 上顯示平均絕對值。 我想要一個類似於 python 中生產的 output。 哪一行代碼將有助於從 R 中的摘要 plot 中刪除平均絕對值?

我目前正在使用這行代碼:

shap.plot.summary.wrap1(xgb_model, X = x, top_n = 10)

您可以通過修改 shap.plot.summary() 的源代碼來做到這一點,如下所示:

shap.plot.summary.edited <- function(data_long,
                          x_bound = NULL,
                          dilute = FALSE,
                          scientific = FALSE,
                          my_format = NULL){
  if (scientific){label_format = "%.1e"} else {label_format = "%.3f"}



if (!is.null(my_format)) label_format <- my_format
  # check number of observations
  N_features <- setDT(data_long)[,uniqueN(variable)]
  if (is.null(dilute)) dilute = FALSE
  nrow_X <- nrow(data_long)/N_features # n per feature
  if (dilute!=0){
    # if nrow_X <= 10, no dilute happens
    dilute <- ceiling(min(nrow_X/10, abs(as.numeric(dilute)))) # not allowed to dilute to fewer than 10 obs/feature
    set.seed(1234)
    data_long <- data_long[sample(nrow(data_long),
                                  min(nrow(data_long)/dilute, nrow(data_long)/2))] # dilute
  }
  x_bound <- if (is.null(x_bound)) max(abs(data_long$value))*1.1 else as.numeric(abs(x_bound))
  plot1 <- ggplot(data = data_long) +
    coord_flip(ylim = c(-x_bound, x_bound)) +
    geom_hline(yintercept = 0) + # the y-axis beneath
    # sina plot:
    ggforce::geom_sina(aes(x = variable, y = value, color = stdfvalue),
                       method = "counts", maxwidth = 0.7, alpha = 0.7) +
    # print the mean absolute value:
    #geom_text(data = unique(data_long[, c("variable", "mean_value")]),
    #          aes(x = variable, y=-Inf, label = sprintf(label_format, mean_value)),
    #          size = 3, alpha = 0.7,
    #          hjust = -0.2,
    #          fontface = "bold") + # bold
    # # add a "SHAP" bar notation
    # annotate("text", x = -Inf, y = -Inf, vjust = -0.2, hjust = 0, size = 3,
    #          label = expression(group("|", bar(SHAP), "|"))) +
    scale_color_gradient(low="#FFCC33", high="#6600CC",
                         breaks=c(0,1), labels=c(" Low","High "),
                         guide = guide_colorbar(barwidth = 12, barheight = 0.3)) +
    theme_bw() +
    theme(axis.line.y = element_blank(),
          axis.ticks.y = element_blank(), # remove axis line
          legend.position="bottom",
          legend.title=element_text(size=10),
          legend.text=element_text(size=8),
          axis.title.x= element_text(size = 10)) +
    # reverse the order of features, from high to low
    # also relabel the feature using `label.feature`
    scale_x_discrete(limits = rev(levels(data_long$variable))#,
                     #labels = label.feature(rev(levels(data_long$variable)))
                     )+
    labs(y = "SHAP value (impact on model output)", x = "", color = "Feature value  ")
  return(plot1)
}

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM