[英]Variable Importance Dummy Variables R
如何確定分類預測變量的變量重要性(r 中的 vip package)? 當 model 建立在虛擬變量而不是原始分類預測變量上時,r 似乎不可能做到這一點。
我將用 Ames Housing 數據集演示我的意思。 我將使用兩個分類預測變量。 Street(兩層)和 Sale.Type(十層)。 我將它們從字符轉換為因子。
library(AmesHousing)
df <- data.frame(ames_raw)
# convert characters to factors
df <- df%>%mutate_if(is.character, as.factor)
# train and split code from caret datacamp
# Get the number of observations
n_obs <- nrow(df)
# Shuffle row indices: permuted_rows
permuted_rows <- sample(n_obs)
# Randomly order data:
df_shuffled <- df[permuted_rows, ]
# Identify row to split on: split
split <- round(n_obs * 0.7)
# Create train
train <- df_shuffled[1:split, ]
# Create test
test <- df_shuffled[(split + 1):n_obs, ]
mod_lm <- train(SalePrice ~ Street + Sale.Type,
data = df,
method = "lm")
vip(mod_lm)
變量重要性按每個級別而不是原始預測變量對它們進行排名。 我可以看到 StreetPave 很重要,但我看不出 Street 是否重要。
從caret
文檔中,我們看到線性模型中的變量重要性對應於每個協變量的 t 統計量的絕對值。 所以,我們可以手動計算它,就像我在下面的代碼中所做的那樣。
lm()
自動將分類變量轉換為虛擬變量。 因此,為了獲得每個協變量的重要性,我們必須對虛擬變量求和。 我沒有找到自動化的方法,所以如果你想將我的解決方案應用於不同的變量集,你需要小心選擇t.stats
的項目來求和。
最后,我們可以使用結果進行繪圖。 I just used the baseline function for a bar plot, but you can customize it as you want (maybe also using the ggplot2
package for better visualization).
Ps 當您提供可重現的示例時,請記住加載所有需要的包。
Pps 對虛擬對象求和可能對我們正在使用的虛擬對象的基本水平(即,我們從回歸中省略的水平)敏感。 我不知道這是否是個問題。
library(AmesHousing)
library(caret)
library(dplyr)
df = data.frame(ames_raw)
# convert characters to factors
df = df%>%mutate_if(is.character, as.factor)
# train and split code from caret datacamp
# Get the number of observations
n_obs <- nrow(df)
# Shuffle row indices: permuted_rows
permuted_rows <- sample(n_obs)
# Randomly order data:
df_shuffled <- df[permuted_rows, ]
# Identify row to split on: split
split <- round(n_obs * 0.7)
# Create train
train <- df_shuffled[1:split, ]
# Create test
test <- df_shuffled[(split + 1):n_obs, ]
mod_lm <- train(SalePrice ~ Street + Sale.Type,
data = df,
method = "lm")
# Manually computing variable importance from t-statistics of the model.
t.stats = coef(summary(mod_lm))[, "t value"]
imp.sale = sum(t.stats[-(1:2)])
imp.street = t.stats[2]
# Plotting.
barplot(c(imp.sale, imp.street), names.arg = c("Sale", "Street"))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.