簡體   English   中英

使用 tidymodels 預測 GAM model 時出錯

[英]Error while predicting a GAM model using tidymodels

我想要什么:我正在嘗試使用tidymodels在給定數據上擬合 GAM model 進行分類。

到目前為止我能夠安裝 logit model。

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip


df_split <- initial_split(df, prop = 0.75, strata = class)
df_train <- training(df_split)
df_test <- testing(df_split)

log_model <- logistic_reg(mode = "classification",
                          engine = "glm") %>%
  fit(class~duration, data = df_train)

predict(log_model, df_test)
#> # A tibble: 26 × 1
#>    .pred_class
#>    <fct>      
#>  1 good       
#>  2 good       
#>  3 good       
#>  4 bad        
#>  5 good       
#>  6 good       
#>  7 bad        
#>  8 bad        
#>  9 good       
#> 10 bad        
#> # … with 16 more rows

我的問題:令人驚訝的是,當我嘗試 GAM 時出現錯誤。

gen_model <- gen_additive_mod(mode = "classification",
                              engine = "mgcv") %>%
  fit(class~duration, data = df_train)

predict(gen_model, df_test)
#> Error: $ operator is invalid for atomic vectors

數據:這是df dput的輸入數據:

df <- structure(list(class = structure(c(2L, 1L, 2L, 2L, 2L, 2L, 2L, 
                                         2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 1L, 2L, 
                                         1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 2L, 2L, 1L, 
                                         1L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 
                                         2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 
                                         1L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L, 
                                         2L, 1L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 1L), 
                                       .Label = c("bad", 
                                                  "good"), class = "factor"), 
                     duration = c(42, 31.7869911119342, 
                                  18, 24, 12, 18, 10, 9, 12, 24, 10, 27, 14.4910072591156, 12, 
                                  48, 24, 30, 18, 6, 6, 12, 48, 10, 18, 6, 12, 24.4157173759304, 
                                  18, 48, 60, 18, 15, 9, 60, 24, 24, 9, 21, 26.4959116294049, 12, 
                                  5, 12, 12, 48, 18, 48, 12, 17.4877766738646, 36, 9, 15, 39.2811119947582, 
                                  27, 21, 24, 10, 6, 12, 12, 24, 39, 18, 24, 15, 48, 12, 24, 26.7659258879721, 
                                  36, 24, 27, 9, 12, 48, 28, 21, 6, 24, 24, 24, 18, 36, 36, 30, 
                                  8.19771710922942, 36, 18, 12, 13.8241796996444, 26.0928970947862, 
                                  10, 36, 12, 12, 24, 21.3157193372026, 18, 21, 24, 24)), 
                class = c("tbl_df", 
                          "tbl", "data.frame"), 
                row.names = c(NA, -100L))

代表 package (v2.0.1) 於 2022 年 1 月 12 日創建

Session 信息
sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.1.2 (2021-11-01) #> os macOS Big Sur 10.16 #> system x86_64, darwin17.0 #> ui X11 #> language (EN) #> collate es_ES.UTF-8 #> ctype es_ES.UTF-8 #> tz Europe/Madrid #> date 2022-01-12 #> pandoc 2.14.0.3 @ /Applications/RStudio.app/Contents/MacOS/pandoc/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.1.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.1.0) #> broom * 0.7.11 2022-01-03 [1] CRAN (R 4.1.2) #> class 7.3-19 2021-05-03 [1] CRAN (R 4.1.2) #> cli 3.1.0 2021-10-27 [1] CRAN (R 4.1.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.1.2) #> colorspace 2.0-2 2021-06-24 [1] CRAN (R 4.1.0) #> crayon 1.4.2 2021-10-29 [1] CRAN (R 4.1.0) #> DBI 1.1.2 2021-12-20 [1] CRAN (R 4.1.2) #> dials * 0.0.10 2021-09-10 [1] CRAN (R 4.1.0) #> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.1.0) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.1.0) #> dplyr * 1.0.7 2021-06-18 [1] CRAN (R 4.1.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.1.0) #> evaluate 0.14 2019-05-28 [1] CRAN (R 4.1.0) #> fansi 1.0.0 2022-01-10 [1] CRAN (R 4.1.2) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.1.0) #> foreach 1.5.1 2020-10-15 [1] CRAN (R 4.1.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.1.0) #> furrr 0.2.3 2021-06-25 [1] CRAN (R 4.1.0) #> future 1.23.0 2021-10-31 [1] CRAN (R 4.1.0) #> future.apply 1.8.1 2021-08-10 [1] CRAN (R 4.1.0) #> generics 0.1.1 2021-10-25 [1] CRAN (R 4.1.0) #> ggplot2 * 3.3.5 2021-06-25 [1] CRAN (R 4.1.0) #> globals 0.14.0 2020-11-22 [1] CRAN (R 4.1.0) #> glue 1.6.0 2021-12-17 [1] CRAN (R 4.1.0) #> gower 0.2.2 2020-06-23 [1] CRAN (R 4.1.0) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.1.0) #> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.1.0) #> hardhat 0.1.6 2021-07-14 [1] CRAN (R 4.1.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.1.0) #> htmltools 0.5.2 2021-08-25 [1] CRAN (R 4.1.0) #> infer * 1.0.0 2021-08-13 [1] CRAN (R 4.1.0) #> ipred 0.9-12 2021-09-15 [1] CRAN (R 4.1.0) #> iterators 1.0.13 2020-10-15 [1] CRAN (R 4.1.0) #> knitr 1.37 2021-12-16 [1] CRAN (R 4.1.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.1.2) #> lava 1.6.10 2021-09-02 [1] CRAN (R 4.1.0) #> lhs 1.1.3 2021-09-08 [1] CRAN (R 4.1.0) #> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.1.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.1.0) #> lubridate 1.8.0 2021-10-07 [1] CRAN (R 4.1.0) #> magrittr 2.0.1 2020-11-17 [1] CRAN (R 4.1.0) #> MASS 7.3-54 2021-05-03 [1] CRAN (R 4.1.2) #> Matrix 1.4-0 2021-12-08 [1] CRAN (R 4.1.0) #> mgcv 1.8-38 2021-10-06 [1] CRAN (R 4.1.2) #> modeldata * 0.1.1 2021-07-14 [1] CRAN (R 4.1.0) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.1.0) #> nlme 3.1-153 2021-09-07 [1] CRAN (R 4.1.2) #> nnet 7.3-16 2021-05-03 [1] CRAN (R 4.1.2) #> parallelly 1.30.0 2021-12-17 [1] CRAN (R 4.1.0) #> parsnip * 0.1.7 2021-07-21 [1] CRAN (R 4.1.0) #> pillar 1.6.4 2021-10-18 [1] CRAN (R 4.1.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.1.0) #> plyr 1.8.6 2020-03-03 [1] CRAN (R 4.1.0) #> pROC 1.18.0 2021-09-03 [1] CRAN (R 4.1.0) #> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.1.0) #> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.1.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.1.0) #> Rcpp 1.0.7 2021-07-07 [1] CRAN (R 4.1.0) #> recipes * 0.1.17 2021-09-27 [1] CRAN (R 4.1.0) #> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.1.0) #> rlang 0.4.12 2021-10-18 [1] CRAN (R 4.1.0) #> rmarkdown 2.11 2021-09-14 [1] CRAN (R 4.1.0) #> rpart 4.1-15 2019-04-12 [1] CRAN (R 4.1.2) #> rsample * 0.1.1 2021-11-08 [1] CRAN (R 4.1.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.1.0) #> scales * 1.1.1 2020-05-11 [1] CRAN (R 4.1.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.1.0) #> stringi 1.7.6 2021-11-29 [1] CRAN (R 4.1.0) #> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.1.0) #> survival 3.2-13 2021-08-24 [1] CRAN (R 4.1.2) #> tibble * 3.1.6 2021-11-07 [1] CRAN (R 4.1.0) #> tidymodels * 0.1.4.9000 2022-01-12 [1] Github (tidymodels/tidymodels@8486957) #> tidyr * 1.1.4 2021-09-27 [1] CRAN (R 4.1.0) #> tidyselect 1.1.1 2021-04-30 [1] CRAN (R 4.1.0) #> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.1.0) #> tune * 0.1.6 2021-07-21 [1] CRAN (R 4.1.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.1.0) #> vctrs 0.3.8 2021-04-29 [1] CRAN (R 4.1.0) #> withr 2.4.3 2021-11-30 [1] CRAN (R 4.1.0) #> workflows * 0.2.4 2021-10-12 [1] CRAN (R 4.1.0) #> workflowsets * 0.1.0 2021-07-22 [1] CRAN (R 4.1.0) #> xfun 0.29 2021-12-14 [1] CRAN (R 4.1.0) #> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.1.0) #> yardstick * 0.0.9 2021-11-22 [1] CRAN (R 4.1.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library #> #> ──────────────────────────────────────────────────────────────────────────────

此問題已在 {parsnip} (>0.1.7) 的開發版本中得到修復。 您可以通過運行remotes::install_github("tidymodels/parsnip")來安裝它。

library(parsnip)
library(rsample)

df <- structure(list(class = structure(c(2L, 1L, 2L, 2L, 2L, 2L, 2L, 
                                         2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 1L, 2L, 
                                         1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 2L, 2L, 1L, 
                                         1L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 
                                         2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 
                                         1L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L, 
                                         2L, 1L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 1L), 
                                       .Label = c("bad", 
                                                  "good"), class = "factor"), 
                     duration = c(42, 31.7869911119342, 
                                  18, 24, 12, 18, 10, 9, 12, 24, 10, 27, 14.4910072591156, 12, 
                                  48, 24, 30, 18, 6, 6, 12, 48, 10, 18, 6, 12, 24.4157173759304, 
                                  18, 48, 60, 18, 15, 9, 60, 24, 24, 9, 21, 26.4959116294049, 12, 
                                  5, 12, 12, 48, 18, 48, 12, 17.4877766738646, 36, 9, 15, 39.2811119947582, 
                                  27, 21, 24, 10, 6, 12, 12, 24, 39, 18, 24, 15, 48, 12, 24, 26.7659258879721, 
                                  36, 24, 27, 9, 12, 48, 28, 21, 6, 24, 24, 24, 18, 36, 36, 30, 
                                  8.19771710922942, 36, 18, 12, 13.8241796996444, 26.0928970947862, 
                                  10, 36, 12, 12, 24, 21.3157193372026, 18, 21, 24, 24)), 
                class = c("tbl_df", 
                          "tbl", "data.frame"), 
                row.names = c(NA, -100L))

df_split <- initial_split(df, prop = 0.75, strata = class)
df_train <- training(df_split)
df_test <- testing(df_split)

gen_model <- gen_additive_mod(mode = "classification",
                              engine = "mgcv") %>%
  fit(class~duration, data = df_train)

predict(gen_model, df_test)
#> # A tibble: 26 × 1
#>    .pred_class
#>    <fct>      
#>  1 bad        
#>  2 good       
#>  3 good       
#>  4 good       
#>  5 bad        
#>  6 good       
#>  7 good       
#>  8 good       
#>  9 good       
#> 10 good       
#> # … with 16 more rows

代表 package (v2.0.1) 於 2022 年 1 月 12 日創建

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM