简体   繁体   中英

randomForest Categorical Predictor Limits

I understand and appreciate that R's randomForest function can only handle categorical predictors with less than 54 categories. However, when I trim my categorical predictor down to less than 54 categories, I still get the error. The only questions I've seen around categorical predictor limits on stackoverflow is how to get around this category limit, but I'm trying to trim my number of categories to follow the function's limitations and I am still get the error.

The following script creates a data frame so we can predict 'profession'. Understandably, I get the "Can not handle categorical predictors with more than 53 categories" error when trying to run randomForest() on 'df' due to the 'college_id' variable.

But when I trim my data set to only include the top 40 college IDs, I get the same error. Am I missing some basic data frame concept that retains all of the categories even though only 40 are now populated in the 'df2' data frame? What is a workaround option that I can use?

library(dplyr)
library(randomForest)

# create data frame
df <- data.frame(profession = sample(c("accountant", "lawyer", "dentist"), 10000, replace = TRUE),
             zip = sample(c("32801", "32807", "32827", "32828"), 10000, replace = TRUE),
             salary = sample(c(50000:150000), 10000, replace = TRUE),
             college_id = as.factor(c(sample(c(1001:1040), 9200, replace = TRUE),
                                      sample(c(1050:9999), 800, replace = TRUE))))


# results in error, as expected
rfm <- randomForest(profession ~ ., data = df)


# arrange college_ids by count and retain the top 40 in the 'df' data frame
sdf <- df %>% 
  dplyr::group_by(college_id) %>% 
  dplyr::summarise(n = n()) %>% 
  dplyr::arrange(desc(n))
sdf <- sdf[1:40, ]
df2 <- dplyr::inner_join(df, sdf, by = "college_id")
df2$n <- NULL


# confirm that df2 only contains 40 categories of 'college_id'
nrow(df2[which(!duplicated(df2$college_id)), ])


# THIS IS WHAT I WANT TO RUN, BUT STILL RESULTS IN ERROR
rfm2 <- randomForest(profession ~ ., data = df2)

I think you still had all the factor levels in your variable. Try adding this line before you fit the forest again:

df2$college_id <- factor(df2$college_id)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM