简体   繁体   English

GBM R函数:为每个类分别获取变量重要性

[英]GBM R function: get variable importance separately for each class

I am using the gbm function in R (gbm package) to fit stochastic gradient boosting models for multiclass classification. 我在R(gbm包)中使用gbm函数来拟合用于多类分类的随机梯度增强模型。 I am simply trying to obtain the importance of each predictor separately for each class, like in this picture from the Hastie book (the Elements of Statistical Learning) (p. 382). 我只是试图分别为每个班级获得每个预测因子的重要性,就像哈斯蒂(Hastie)书(统计学习要素) (第382页中的这张图片一样。

在此输入图像描述

However, the function summary.gbm only returns the overall importance of the predictors (their importance averaged over all classes). 但是,函数summary.gbm仅返回预测变量的总体重要性(它们对所有类的平均重要性)。

Does anyone know how to get the relative importance values? 有谁知道如何获得相对重要性值?

I think the short answer is that on page 379, Hastie mentions that he uses MART , which appears to only be available for Splus. 我认为简短的回答是,在第379页,Hastie提到他使用的是MART ,它似乎只适用于Splus。

I agree that the gbm package doesn't seem to allow for seeing the separate relative influence. 我同意gbm包似乎不允许看到单独的相对影响。 If that's something you're interested in for a mutliclass problem, you could probably get something pretty similar by building a one-vs-all gbm for each of your classes and then getting the importance measures from each of those models. 如果你对mutliclass问题感兴趣,你可能会通过为每个类构建一个vs-all-gbm然后从每个模型中获取重要性度量来获得非常相似的东西。

So say your classes are a, b, c, & d. 所以说你的课程是a,b,c和d。 You model a vs. the rest and get the importance from that model. 您对其他模型进行建模,并从该模型中获得重要性。 Then you model b vs. the rest and get the importance from that model. 然后你模拟b与其余的模型,并从该模型中获得重要性。 Etc. 等等。

Hopefully this function helps you. 希望这个功能可以帮到你。 For the example I used data from the ElemStatLearn package. 对于示例,我使用了ElemStatLearn包中的数据。 The function figures out what the classes for a column are, splits the data into these classes, runs the gbm() function on each class and plots the bar plots for these models. 该函数计算出列的类是什么,将数据拆分为这些类,在每个类上运行gbm()函数并绘制这些模型的条形图。

# install.packages("ElemStatLearn"); install.packages("gbm")
library(ElemStatLearn)
library(gbm)

set.seed(137531)

# formula: the formula to pass to gbm()
# data: the data set to use
# column: the class column to use
classPlots <- function (formula, data, column) {

    class_column <- as.character(data[,column])
    class_values <- names(table(class_column))
    class_indexes <- sapply(class_values, function(x) which(class_column == x))
    split_data <- lapply(class_indexes, function(x) marketing[x,])
    object <- lapply(split_data, function(x) gbm(formula, data = x))
    rel.inf <- lapply(object, function(x) summary.gbm(x, plotit=FALSE))

    nobjs <- length(class_values)
    for( i in 1:nobjs ) {
        tmp <- rel.inf[[i]]
        tmp.names <- row.names(tmp)
        tmp <- tmp$rel.inf
        names(tmp) <- tmp.names

        barplot(tmp, horiz=TRUE, col='red',
                xlab="Relative importance", main=paste0("Class = ", class_values[i]))
    }
    rel.inf
}

par(mfrow=c(1,2))
classPlots(Income ~ Marital + Age, data = marketing, column = 2)

` `

产量

I did some digging into how the gbm package calculates importance and it is based on the ErrorReduction which is contained in the trees element of the result and can be accessed with pretty.gbm.trees() . 我做了一些深入研究gbm包如何计算重要性,它基于ErrorReduction,它包含在结果的树元素中,可以使用pretty.gbm.trees()访问。 Relative influence is obtained by taking the sum of this ErrorReduction over all trees for each variable. 通过在每个变量的所有树上获取此ErrorReduction的总和来获得相对影响。 For a multiclass problem there are actually n.trees*num.classes trees in the model. 对于多类问题,模型中实际上有n.trees*num.classes树。 So if there are 3 classes you can calculate the sum of the ErrorReduction for each variable over every third tree to get the importance for one class. 因此,如果有3个类,您可以计算每个变量在每个第三个树上的ErrorReduction的总和,以获得一个类的重要性。 I have written the following functions to implement this and then plot the results: 我编写了以下函数来实现它,然后绘制结果:

Get Variable Importance By Class 按类获取变量重要性

RelInf_ByClass <- function(object, n.trees, n.classes, Scale = TRUE){
  library(dplyr)
  library(purrr)
  library(gbm)
  Ext_ErrRed<- function(ptree){
    ErrRed <- ptree %>% filter(SplitVar != -1) %>% group_by(SplitVar) %>% 
      summarise(Sum_ErrRed = sum(ErrorReduction))
  }
  trees_ErrRed <- map(1:n.trees, ~pretty.gbm.tree(object, .)) %>% 
    map(Ext_ErrRed)

  trees_by_class <- split(trees_ErrRed, rep(1:n.classes, n.trees/n.classes)) %>% 
    map(~bind_rows(.) %>% group_by(SplitVar) %>% 
          summarise(rel_inf = sum(Sum_ErrRed)))
  varnames <- data.frame(Num = 0:(length(object$var.names)-1),
                         Name = object$var.names)
  classnames <- data.frame(Num = 1:object$num.classes, 
                           Name = object$classes)
  out <- trees_by_class %>% bind_rows(.id = "Class") %>%  
    mutate(Class = classnames$Name[match(Class,classnames$Num)],
    SplitVar = varnames$Name[match(SplitVar,varnames$Num)]) %>%
    group_by(Class) 
  if(Scale == FALSE){
    return(out)
    } else {
    out <- out %>% mutate(Scaled_inf = rel_inf/max(rel_inf)*100)
    }
}

Plot Variable Importance By Class 按类别绘制变量重要性

In my real use for this I have over 40 features so I give an option to specify the number of features to plot. 在我的实际使用中,我有40多个功能,所以我给出了一个选项来指定要绘制的特征的数量。 I also couldn't use faceting if I wanted the plots to be sorted separately for each class, which is why I used gridExtra . 如果我想要为每个类分别对图进行排序,我也无法使用faceting,这就是我使用gridExtra

plot_imp_byclass <- function(df, n) {
  library(ggplot2)
  library(gridExtra)
  plot_imp_class <- function(df){
    df %>% arrange(rel_inf) %>% 
      mutate(SplitVar = factor(SplitVar, levels = .$SplitVar)) %>% 
      ggplot(aes(SplitVar, rel_inf))+
      geom_segment(aes(x = SplitVar, 
                       xend = SplitVar, 
                       y = 0, 
                       yend = rel_inf))+
      geom_point(size=3, col = "cyan") + 
      coord_flip()+
      labs(title = df$Class[[1]], x = "Variable", y = "Importance")+
      theme_classic()+
      theme(plot.title = element_text(hjust = 0.5))
  }

  df %>% top_n(n, rel_inf) %>% split(.$Class) %>% 
    map(plot_imp_class) %>% map(ggplotGrob) %>% 
    {grid.arrange(grobs = .)}
}

Try It 试试吧

gbm_iris <- gbm(Species~., data = iris)
imp_byclass <- RelInf_ByClass(gbm_iris, length(gbm_iris$trees), 
                              gbm_iris$num.classes, Scale = F)
plot_imp_byclass(imp_byclass, 4)

Seems to give the same results as the built in relative.influence function if you sum the results over all the classes. 如果对所有类的结果求和,似乎给出与内置relative.influence函数相同的结果。

relative.influence(gbm_iris)
# n.trees not given. Using 100 trees.
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
# 0.00000     51.88684   2226.88017    868.71085 

imp_byclass %>% group_by(SplitVar) %>% summarise(Overall_rel_inf = sum(rel_inf))
# A tibble: 3 x 2
# SplitVar     Overall_rel_inf
# <fct>                  <dbl>
#   1 Petal.Length          2227. 
# 2 Petal.Width            869. 
# 3 Sepal.Width             51.9

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM