简体   繁体   中英

Linear Regression model building and prediction by group in R

I'm trying to build several models based on subsets (groups) and generate their fits. In other words, taking my attempts below into consideration, I'm trying to build models that are country specific. Unfortunately in my attempts I'm only able to take the entire dataset into consideration to build the models instead of restricting it to the groups of countries in the datasets. Could you please help me resolve this problem?

In the first case I'm doing some sort of cross validation to generate the predictions. In the second case I'm not. Both my attempts seem to fail.



library(modelr)
install.packages("gapminder")
library(gapminder)                           
data(gapminder) 

#CASE 1
model1 <- lm(lifeExp ~ pop, data = gapminder)
model2 <- lm(lifeExp ~ pop + gdpPercap, data = gapminder)

models <- list(fit_model1 = model1,fit_model2 = model2)

gapminder %>% group_by(continent, country) %>%
  bind_cols(
    map(1:nrow(gapminder), function(i) {
      map_dfc(models, function(model) {
        training <- gapminder[-i, ] 
        fit <- lm(model, data = training)
        
        validation <- gapminder[i, ]
        predict(fit, newdata = validation)
        
      })
    }) %>%
      bind_rows()
  )


#CASE 2
model1 <- lm(lifeExp ~ pop, data = gapminder)
model2 <- lm(lifeExp ~ pop + gdpPercap, data = gapminder)

models <- list(fit_model1 = model1,fit_model2 = model2)


for (m in names(models)) {
  gapminder[[m]] <- predict(models[[m]], gapminder %>% group_by(continent, country) )
  
}

The tidyverse solution to modeling by group is to use:

  • tidyr::nest() to group the variables
  • dplyr::mutate() together with purrr::map() to create models by group
  • broom::tidy() or broom::augment() to generate model summaries and predictions
  • tidyr::unnest() and dplyr::filter() to get summaries and predictions by group

Here's an example. It doesn't do the same as the code in your question, but I think it will be helpful nevertheless.

This code generates the linear model lifeExp ~ pop by country and the fitted (predicted) values for each model.

library(tidyverse)
library(broom)
library(gapminder)

gapminder_lm <- gapminder %>% 
  nest(data = c(year, lifeExp, pop, gdpPercap)) %>% 
  mutate(model = map(data, ~lm(lifeExp ~ pop, .)), 
         fitted = map(model, augment)) %>% 
  unnest(fitted)

gapminder_lm

# A tibble: 1,704 x 12
   country     continent data              model  lifeExp      pop .fitted .resid .std.resid   .hat .sigma  .cooksd
   <fct>       <fct>     <list>            <list>   <dbl>    <int>   <dbl>  <dbl>      <dbl>  <dbl>  <dbl>    <dbl>
 1 Afghanistan Asia      <tibble [12 x 4]> <lm>      28.8  8425333    33.2 -4.41     -1.54   0.182    2.92 0.262   
 2 Afghanistan Asia      <tibble [12 x 4]> <lm>      30.3  9240934    33.7 -3.35     -1.15   0.161    3.11 0.128   
 3 Afghanistan Asia      <tibble [12 x 4]> <lm>      32.0 10267083    34.3 -2.27     -0.773  0.139    3.24 0.0482  
 4 Afghanistan Asia      <tibble [12 x 4]> <lm>      34.0 11537966    35.0 -0.985    -0.331  0.116    3.32 0.00720 
 5 Afghanistan Asia      <tibble [12 x 4]> <lm>      36.1 13079460    35.9  0.193     0.0641 0.0969   3.34 0.000220
 6 Afghanistan Asia      <tibble [12 x 4]> <lm>      38.4 14880372    36.9  1.50      0.496  0.0849   3.30 0.0114  
 7 Afghanistan Asia      <tibble [12 x 4]> <lm>      39.9 12881816    35.8  4.07      1.35   0.0989   3.02 0.101   
 8 Afghanistan Asia      <tibble [12 x 4]> <lm>      40.8 13867957    36.4  4.47      1.48   0.0902   2.95 0.108   
 9 Afghanistan Asia      <tibble [12 x 4]> <lm>      41.7 16317921    37.8  3.91      1.29   0.0838   3.05 0.0759  
10 Afghanistan Asia      <tibble [12 x 4]> <lm>      41.8 22227415    41.2  0.588     0.202  0.157    3.33 0.00380 
# ... with 1,694 more rows

This has the advantage of keeping everything in a tidy data frame, which can be filtered for the data of interest.

For example, filter for Egypt and plot real versus predicted values:

gapminder_lm %>% 
  filter(country == "Egypt") %>% 
  ggplot(aes(lifeExp, .fitted)) + 
  geom_point() + 
  labs(title = "Egypt")

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM