简体   繁体   English

如何 map pdp::partial 嵌套随机森林模型?

[英]How to map pdp::partial to nested randomForest models?

I would like to map the function pdp::partial to nested randomForest models.我想 map function pdp::partial 到嵌套的 randomForest 模型。 I'll then use the output to plot the 3d partial dependency plots for each group via facet_wrap().然后,我将通过 facet_wrap() 为每个组使用 output 到 plot 3d 部分依赖图。 When mapping the function to the models I get an error that the predictor variables can not be found in the training data -- but they are there when I check the tibble so I'm at a loss for what to do.将 function 映射到模型时,我得到一个错误,即无法在训练数据中找到预测变量——但是当我检查 tibble 时它们就在那里,所以我不知道该怎么做。

library(tidyverse)
library(pdp)
library(randomForest)
data(boston)
glimpse(boston)

#Make groups, nest data by groups, apply random forest model to nested data
boston %>%
  mutate(grp=ifelse(age<80, "young", "old"))%>%
  nest(data= -grp)%>%
  mutate(fit = map(data, ~ randomForest(cmedv ~ ., data = boston, importance = TRUE)))%>%
  {.->>GrpModels}

#Map pdp::partial to fitted models for two predictor variables
GrpModels%>%
  mutate(p=map2(fit,data, ~pdp::partial(fit,train=data, pred.var=c("lstat", "rm"))))%>%
  unnest(p)%>%{.->>checkpdp}

Error: Problem with mutate() column p .错误: mutate()p有问题。 i p = map2(...) .p = map2(...) x lstat, rm not found in the training data. x lstat, rm 未在训练数据中找到。

This seems to work, although I'm not sure why plotting with geom_tile() does not quite do what I thought it would.这似乎可行,尽管我不确定为什么使用geom_tile()进行绘图并不能完全按照我的预期进行。 I used geom_point() instead.我改用geom_point() In short, I needed to get pred.var as a list and then pass the three inputs (fit, data, and predictor variables) to pmap .简而言之,我需要将 pred.var 作为列表获取,然后将三个输入(拟合、数据和预测变量)传递给pmap

GrpModels %>% 
  mutate(preds = data.table::transpose(as.list(c('lstat','rm')))) %>%
  mutate(p = pmap(list(fit, data, preds), 
                      .f = ~pdp::partial(object=..1, train = ..2, 
                                   pred.var = ..3)))%>%
  select(-data,-fit,-preds)%>%
  unnest_wider(p)%>%
  unnest(c(yhat,lstat,rm))%>%{.->>checkpdp}%>%
  ggplot(.,aes(x=lstat,y=rm,color=yhat))+
  #geom_tile()+
  geom_point(shape=15, size=2)+
  facet_wrap(~grp, scales='free')

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM