简体   繁体   English

R 中有没有办法确定变量中的哪些水平在 GBM 预测模型中最重要?

[英]Is there a way in R to determine which levels within the variables are most important in the GBM predictive model?

I constructed a predictive model using the GBM package in R. I have good results and I am able to see the feature importance list to see which variables are most important to the model.我使用 R 中的 GBM 包构建了一个预测模型。我得到了很好的结果,并且我能够查看特征重要性列表以查看哪些变量对模型最重要。 I am struggling with an editor's question asking for direction of the variables.我正在努力解决编辑器询问变量方向的问题。

For instance: age variable: which age group is most important, rather than age overall?例如:年龄变量:哪个年龄组最重要,而不是整体年龄?
region: which specific region, rather than region as a variable overall? region:哪个具体的区域,而不是作为整体的区域变量?

I see some implementation of this with LIME, however the GBM package is not compatible with LIME and I am stuggling with implementing it otherwise.我在 LIME 中看到了一些实现,但是 GBM 包与 LIME 不兼容,我正在努力以其他方式实现它。 Is there a manual way to see this?有没有手动方法可以看到这个?

My current idea to run the GBM model one by one and compare results.我目前的想法是一一运行 GBM 模型并比较结果。 For instance, run with region A and all others the same, then region B, C, D, E, etc. Compare the final results and see more information about the level of each variable.例如,在区域 A 和所有其他区域相同的情况下运行,然后区域 B、C、D、E 等。比较最终结果并查看有关每个变量水平的更多信息。

Does anyone have further advice or a quicker solution?有没有人有进一步的建议或更快的解决方案? Thanks谢谢

I suppose you are using gbm and not xgboost, but in any case you can always convert data into the necessary format.我想您使用的是 gbm 而不是 xgboost,但无论如何您始终可以将数据转换为必要的格式。

You can try onehot encoding, and this is a bit better than testing the variables one by one because the model is exposed to all the variables.您可以尝试使用 onehot 编码,这比一个一个测试变量好一点,因为模型暴露于所有变量。 Below is not a very good example because I cut up a continuous variable, but hopefully in your model the categorization makes more sense:下面不是一个很好的例子,因为我分割了一个连续变量,但希望在你的模型中分类更有意义:

library(MASS)
library(gbm)
library(highcharter)

data = Pima.te
age_cat = cut(data$age,4,labels = paste0("age",1:4))
onehot_bp = model.matrix(~0+age_cat)
data$type = as.numeric(data$type)-1
fit = gbm(type ~ .,data=cbind(data[,-grep("age",colnames(data))],onehot_bp))

res = summary(fit,plotit=FALSE)

hchart(res,"bar",hcaes(x=var,y=rel.inf,color=rel.inf))

在此处输入图片说明

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM