繁体   English   中英

比较使用交叉验证使用的预测变量不同的线性回归模型的性能

[英]Compare performance of linear regression models that differ by predictors used using cross validation

我想使用 tidymodels 和交叉验证来比较 3 个可以指定如下的线性回归模型:

  • ( model_A ) y ~ a
  • (模型_B ) y ~ b
  • ( model_AB ) y ~ a + b

在下文中, y将表示目标变量,而ab将表示自变量。

在不使用交叉验证的情况下,(我希望)我很清楚我必须做什么:

  1. 将我的数据拆分为训练集和测试集
set.seed(1234)
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
  1. 我可以在一个 go 中指定、拟合和评估我的 model(例如model_AB
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a + b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

output 看起来像这样:

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard       x.xxx

我可以对其他两个模型重复步骤 2,并根据 RMSE 度量比较三个模型(因为这是本示例的选择)。

例如,我可以创建一个虚拟数据集并运行上述步骤。

library(tidyverse)
library(tidymodels)

set.seed(1234)
n <- 1e4
data <- tibble(a = rnorm(n),
               b = rnorm(n),
               y = 1 + 3*a - 2*b + rnorm(n))

set.seed(1234)
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
  • 型号_A
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

结果

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        2.23
  • 型号_B
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

结果

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        3.17
  • 型号_AB
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a + b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

结果

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        1.00

我的问题是:在对可能特征列表不同的三个模型执行交叉验证后,如何评估 RMSE?

在此视频中,Julia Silge 使用相同的预测变量集使用三个不同的模型(逻辑回归、knn 和决策树)来完成这项工作。 然而,我的目标是比较预测变量集不同的模型。

有什么建议和/或参考吗?

当您想要比较许多不同的模型时,一种处理方法是使用工作流集 package。

通过这种方式,您可以指定任意数量的模型和预处理器,它将运行所有模型并以整洁的格式返回结果。

请注意我们如何使用recipe()只是表示在每个 model 中使用了哪些变量。

library(tidymodels)
set.seed(1234)
n <- 1e4
data <- tibble(a = rnorm(n),
               b = rnorm(n),
               y = 1 + 3*a - 2*b + rnorm(n))

set.seed(1234)
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)

lm_spec <- linear_reg()

rec_a <- recipe(y ~ a, data = data_train)
rec_b <- recipe(y ~ b, data = data_train)
rec_ab <- recipe(y ~ a + b, data = data_train)

all_models_wfs <- workflow_set(
  preproc = list(a = rec_a, b = rec_b, c = rec_ab),
  models = list(lm = lm_spec),
  cross = TRUE
)

all_models_wfs
#> # A workflow set/tibble: 3 × 4
#>   wflow_id info             option    result    
#>   <chr>    <list>           <list>    <list>    
#> 1 a_lm     <tibble [1 × 4]> <opts[0]> <list [0]>
#> 2 b_lm     <tibble [1 × 4]> <opts[0]> <list [0]>
#> 3 c_lm     <tibble [1 × 4]> <opts[0]> <list [0]>

all_models_fit <- workflow_map(
  all_models_wfs, 
  resamples = vfold_cv(data_test)
)

all_models_fit %>%
  collect_metrics()
#> # A tibble: 6 × 9
#>   wflow_id .config             preproc model .metric .esti…¹  mean     n std_err
#>   <chr>    <chr>               <chr>   <chr> <chr>   <chr>   <dbl> <int>   <dbl>
#> 1 a_lm     Preprocessor1_Mode… recipe  line… rmse    standa… 2.26     10 0.0289 
#> 2 a_lm     Preprocessor1_Mode… recipe  line… rsq     standa… 0.627    10 0.00772
#> 3 b_lm     Preprocessor1_Mode… recipe  line… rmse    standa… 3.10     10 0.0213 
#> 4 b_lm     Preprocessor1_Mode… recipe  line… rsq     standa… 0.298    10 0.00761
#> 5 c_lm     Preprocessor1_Mode… recipe  line… rmse    standa… 1.01     10 0.00651
#> 6 c_lm     Preprocessor1_Mode… recipe  line… rsq     standa… 0.926    10 0.00206
#> # … with abbreviated variable name ¹​.estimator

代表 package (v2.0.1) 于 2022 年 9 月 12 日创建

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM