简体   繁体   English

线性回归 model 在 R 中按组构建和预测

[英]Linear Regression model building and prediction by group in R

I'm trying to build several models based on subsets (groups) and generate their fits.我正在尝试基于子集(组)构建几个模型并生成它们的拟合。 In other words, taking my attempts below into consideration, I'm trying to build models that are country specific.换句话说,考虑到我在下面的尝试,我正在尝试构建特定于国家/地区的模型。 Unfortunately in my attempts I'm only able to take the entire dataset into consideration to build the models instead of restricting it to the groups of countries in the datasets.不幸的是,在我的尝试中,我只能考虑整个数据集来构建模型,而不是将其限制在数据集中的国家组中。 Could you please help me resolve this problem?你能帮我解决这个问题吗?

In the first case I'm doing some sort of cross validation to generate the predictions.在第一种情况下,我正在做某种交叉验证来生成预测。 In the second case I'm not.在第二种情况下,我不是。 Both my attempts seem to fail.我的两次尝试似乎都失败了。



library(modelr)
install.packages("gapminder")
library(gapminder)                           
data(gapminder) 

#CASE 1
model1 <- lm(lifeExp ~ pop, data = gapminder)
model2 <- lm(lifeExp ~ pop + gdpPercap, data = gapminder)

models <- list(fit_model1 = model1,fit_model2 = model2)

gapminder %>% group_by(continent, country) %>%
  bind_cols(
    map(1:nrow(gapminder), function(i) {
      map_dfc(models, function(model) {
        training <- gapminder[-i, ] 
        fit <- lm(model, data = training)
        
        validation <- gapminder[i, ]
        predict(fit, newdata = validation)
        
      })
    }) %>%
      bind_rows()
  )


#CASE 2
model1 <- lm(lifeExp ~ pop, data = gapminder)
model2 <- lm(lifeExp ~ pop + gdpPercap, data = gapminder)

models <- list(fit_model1 = model1,fit_model2 = model2)


for (m in names(models)) {
  gapminder[[m]] <- predict(models[[m]], gapminder %>% group_by(continent, country) )
  
}

The tidyverse solution to modeling by group is to use:按组建模的 tidyverse 解决方案是使用:

  • tidyr::nest() to group the variables tidyr::nest()对变量进行分组
  • dplyr::mutate() together with purrr::map() to create models by group dplyr::mutate()purrr::map()一起按组创建模型
  • broom::tidy() or broom::augment() to generate model summaries and predictions broom::tidy()broom::augment()生成 model 总结和预测
  • tidyr::unnest() and dplyr::filter() to get summaries and predictions by group tidyr::unnest()dplyr::filter()按组获取摘要和预测

Here's an example.这是一个例子。 It doesn't do the same as the code in your question, but I think it will be helpful nevertheless.它与您问题中的代码不同,但我认为它仍然会有所帮助。

This code generates the linear model lifeExp ~ pop by country and the fitted (predicted) values for each model.此代码生成线性 model lifeExp ~ pop by country 和每个 model 的拟合(预测)值。

library(tidyverse)
library(broom)
library(gapminder)

gapminder_lm <- gapminder %>% 
  nest(data = c(year, lifeExp, pop, gdpPercap)) %>% 
  mutate(model = map(data, ~lm(lifeExp ~ pop, .)), 
         fitted = map(model, augment)) %>% 
  unnest(fitted)

gapminder_lm

# A tibble: 1,704 x 12
   country     continent data              model  lifeExp      pop .fitted .resid .std.resid   .hat .sigma  .cooksd
   <fct>       <fct>     <list>            <list>   <dbl>    <int>   <dbl>  <dbl>      <dbl>  <dbl>  <dbl>    <dbl>
 1 Afghanistan Asia      <tibble [12 x 4]> <lm>      28.8  8425333    33.2 -4.41     -1.54   0.182    2.92 0.262   
 2 Afghanistan Asia      <tibble [12 x 4]> <lm>      30.3  9240934    33.7 -3.35     -1.15   0.161    3.11 0.128   
 3 Afghanistan Asia      <tibble [12 x 4]> <lm>      32.0 10267083    34.3 -2.27     -0.773  0.139    3.24 0.0482  
 4 Afghanistan Asia      <tibble [12 x 4]> <lm>      34.0 11537966    35.0 -0.985    -0.331  0.116    3.32 0.00720 
 5 Afghanistan Asia      <tibble [12 x 4]> <lm>      36.1 13079460    35.9  0.193     0.0641 0.0969   3.34 0.000220
 6 Afghanistan Asia      <tibble [12 x 4]> <lm>      38.4 14880372    36.9  1.50      0.496  0.0849   3.30 0.0114  
 7 Afghanistan Asia      <tibble [12 x 4]> <lm>      39.9 12881816    35.8  4.07      1.35   0.0989   3.02 0.101   
 8 Afghanistan Asia      <tibble [12 x 4]> <lm>      40.8 13867957    36.4  4.47      1.48   0.0902   2.95 0.108   
 9 Afghanistan Asia      <tibble [12 x 4]> <lm>      41.7 16317921    37.8  3.91      1.29   0.0838   3.05 0.0759  
10 Afghanistan Asia      <tibble [12 x 4]> <lm>      41.8 22227415    41.2  0.588     0.202  0.157    3.33 0.00380 
# ... with 1,694 more rows

This has the advantage of keeping everything in a tidy data frame, which can be filtered for the data of interest.这样做的好处是将所有内容保存在一个整洁的数据框中,可以针对感兴趣的数据进行过滤。

For example, filter for Egypt and plot real versus predicted values:例如,过滤埃及和 plot 实际值与预测值:

gapminder_lm %>% 
  filter(country == "Egypt") %>% 
  ggplot(aes(lifeExp, .fitted)) + 
  geom_point() + 
  labs(title = "Egypt")

在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM