[英]Get AUC on training data from a fitted workflow in Tidymodels?
我正在努力解决如何使用 tidymodels 从逻辑回归 model 获得 AUC。
这是一个使用内置mpg
数据集的示例。
library(tidymodels)
library(tidyverse)
# Use mpg dataset
df <- mpg
# Create an indicator variable for class="suv"
df$is_suv <- as.factor(df$class == "suv")
# Create the split object
df_split <- initial_split(df, prop=1/2)
# Create the training and testing sets
df_train <- training(df_split)
df_test <- testing(df_split)
# Create workflow
rec <-
recipe(is_suv ~ cty + hwy + cyl, data=df_train)
glm_spec <-
logistic_reg() %>%
set_engine(engine = "glm")
glm_wflow <-
workflow() %>%
add_recipe(rec) %>%
add_model(glm_spec)
# Fit the model
model1 <- fit(glm_wflow, df_train)
# Attach predictions to training dataset
training_results <- bind_cols(df_train, predict(model1, df_train))
# Calculate accuracy
accuracy(training_results, truth = is_suv, estimate = .pred_class)
# Calculate AUC??
roc_auc(training_results, truth = is_suv, estimate = .pred_class)
最后一行返回此错误:
> roc_auc(training_results, truth = is_suv, estimate = .pred_class)
Error in metric_summarizer(metric_nm = "roc_auc", metric_fn = roc_auc_vec, :
formal argument "estimate" matched by multiple actual arguments
由于您正在执行二进制分类, roc_auc()
期望与“相关”class 相对应的 class 概率向量,而不是预测的 class。
您可以使用predict(model1, df_train, type = "prob")
获得此信息。 或者,如果您使用的是 0.2.2 或更高版本的工作流,则可以使用augment()
来获取 class 预测和概率,而无需使用bind_cols()
。
library(tidymodels)
library(tidyverse)
# Use mpg dataset
df <- mpg
# Create an indicator variable for class="suv"
df$is_suv <- as.factor(df$class == "suv")
# Create the split object
df_split <- initial_split(df, prop=1/2)
# Create the training and testing sets
df_train <- training(df_split)
df_test <- testing(df_split)
# Create workflow
rec <-
recipe(is_suv ~ cty + hwy + cyl, data=df_train)
glm_spec <-
logistic_reg() %>%
set_engine(engine = "glm")
glm_wflow <-
workflow() %>%
add_recipe(rec) %>%
add_model(glm_spec)
# Fit the model
model1 <- fit(glm_wflow, df_train)
# Attach predictions to training dataset
training_results <- augment(model1, df_train)
# Calculate accuracy
accuracy(training_results, truth = is_suv, estimate = .pred_class)
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy binary 0.795
# Calculate AUC
roc_auc(training_results, truth = is_suv, estimate = .pred_FALSE)
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.879
由代表 package (v1.0.0) 于 2021 年 4 月 12 日创建
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.