簡體   English   中英

在R Tidyverse的函數中編碼多個模型

[英]Coding Multiple Models in a Function in R Tidyverse

我正在嘗試使用一些公式擬合幾個機器學習模型,並將它們作為list_column對象以小標題形式存儲。

我試圖修改“ R for Datascience”(第25章:許多模型)一書中引用的代碼,但它只給我最后的輸出。 請參閱下面的代碼以獲取更多詳細信息。 我們以gapminder包中的gapminder數據集為例。

lab_formula <- as.formula("pop ~ lifeExp ")

temp_formula <- as.formula("gdppercap ~ year")

formula_list <- list(lab_formula,temp_formula)
library(gapminder)

by_country <- gapminder %>% 
  dplyr :: group_by(country, continent) %>% 
  nest()

country_model <- function(df) {
for (i in formula_list) {
  lm(formula=formula[i], data = df)
  randomForest(formula=formula[i], data = df)
  gbm(formula=formula[i], data = df, n.minobsinnode = 2)
}
}

by_country <- by_country %>% 
  mutate(model = map(data, country_model))

by_country
# A tibble: 142 x 4
   country     continent data              model    
   <fct>       <fct>     <list>            <list>   
 1 Afghanistan Asia      <tibble [12 x 4]> <S3: gbm>
 2 Albania     Europe    <tibble [12 x 4]> <S3: gbm>
 3 Algeria     Africa    <tibble [12 x 4]> <S3: gbm>
 4 Angola      Africa    <tibble [12 x 4]> <S3: gbm>
 5 Argentina   Americas  <tibble [12 x 4]> <S3: gbm>
 6 Australia   Oceania   <tibble [12 x 4]> <S3: gbm>
 7 Austria     Europe    <tibble [12 x 4]> <S3: gbm>
 8 Bahrain     Asia      <tibble [12 x 4]> <S3: gbm>
 9 Bangladesh  Asia      <tibble [12 x 4]> <S3: gbm>
10 Belgium     Europe    <tibble [12 x 4]> <S3: gbm>
# ... with 132 more rows

There is no error code but it does not achieve my objective of training the 3 machine learning models (LM, RF, GBM) with the different variables.

您需要考慮如何存儲結果。 這是一種方法。 首先創建要應用的公式列表

library(randomForest)
library(gbm)
library(tidyverse)

lab_formula <- as.formula("pop ~ lifeExp ")
temp_formula <- as.formula("gdpPercap ~ year")
formula_list <- list(lab_formula,temp_formula)

創建一個函數,該函數一次返回僅應用於一個公式的模型列表。

country_model <- function(df, formula_list, index) {
    list(lm(formula = formula_list[[index]] , data = df), 
         randomForest(formula=formula_list[[index]], data = df),
         gbm(formula=formula_list[[index]], data = df, n.minobsinnode = 2))
}

現在它適用於每個data傳遞formula_list要應用到您的數據和公式數從列表中,

df1 <- by_country %>% 
  mutate(model1 = map(data, ~country_model(., formula_list, 1)), 
         model2 = map(data, ~country_model(., formula_list, 2)))
df1

# A tibble: 142 x 5
#   country     continent data              model1     model2    
#   <fct>       <fct>     <list>            <list>     <list>    
# 1 Afghanistan Asia      <tibble [12 × 4]> <list [3]> <list [3]>
# 2 Albania     Europe    <tibble [12 × 4]> <list [3]> <list [3]>
# 3 Algeria     Africa    <tibble [12 × 4]> <list [3]> <list [3]>
# 4 Angola      Africa    <tibble [12 × 4]> <list [3]> <list [3]>
# 5 Argentina   Americas  <tibble [12 × 4]> <list [3]> <list [3]>
# 6 Australia   Oceania   <tibble [12 × 4]> <list [3]> <list [3]>
# 7 Austria     Europe    <tibble [12 × 4]> <list [3]> <list [3]>
# 8 Bahrain     Asia      <tibble [12 × 4]> <list [3]> <list [3]>
# 9 Bangladesh  Asia      <tibble [12 × 4]> <list [3]> <list [3]>
#10 Belgium     Europe    <tibble [12 × 4]> <list [3]> <list [3]>
# … with 132 more rows

現在, model1中的每一行都有一個使用公式formula_list[[1]]的三個模型的列表,類似地,對於model2您也使用了公式formula_list[[2]]


為了使用這些模型進行預測,我們需要對randomForest模型進行不同的處理,因為它需要n.trees參數,並且當我們從函數中返回這些模型時,我們知道它是列表中的第三個模型,我們可以根據索引對其進行區分。

df1 %>%
   mutate(pred= map2(data,model1, function(x, y) 
     map(seq_along(y), function(i) 
        if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
        else as.numeric(predict(y[[i]])))))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM