[英]How do we calculate multiclass probabilities in yardstick?
我有一个多类分类问题,想使用pr_curve
中的 yardstick 库中的 pr_curve 构建精确召回曲线。这个 function 需要将每个 class 的概率的 tibble 提供给它,就像这样(这是data(hpc_cv)
)。 我如何从我的分类结果中到达那里,存储为 tibble 中的列?
library(yardstick)
data <- tibble(predicted = as.factor(c("A", "A", "B", "B", "C", "C")),
expected = as.factor(c("A", "B", "B", "C", "A", "C")))
data %>% conf_mat(truth = expected, estimate = predicted)
我还没有找到 function 的标准(或其他地方)来计算这些。
我不确定 class 概率是如何计算的,我在想这些:
data %>% filter(predicted == "A") %>% summarise(n = n() / 6)
这个对吗? 如果是这样,我想知道是否有一种很好的方法可以在每次折叠的每个 class 上不使用 for 循环,并在上图中接收像 hpc_cv 这样的 tibble。
我不确定 class 概率是如何计算的
Class 概率由特定的 model 为每个单独的数据点生成。
PR 曲线(以及精度和召回率)是结果有两个类别的数据集的指标。 不过,您可以进行多类平均以获得整体 PR 曲线 AUC。
下面有一个例子,但我建议在继续之前阅读一下tidymodels 书。
library(nnet) # <- for mutlinom_fit
library(tidymodels)
tidymodels_prefer()
data(hpc_data, package = "modeldata")
set.seed(1)
hpc_split <- initial_split(hpc_data)
hpc_train <- training(hpc_split)
hpc_test <- testing(hpc_split)
set.seed(2)
mutlinom_fit <-
multinom_reg() %>%
fit(class ~ iterations + compounds, data = hpc_train)
test_predictions <- augment(mutlinom_fit, new_data = hpc_test)
# examples of the hard class predictions and the
# predicted probabilities:
test_predictions %>% select(starts_with(".pred")) %>% head()
#> # A tibble: 6 × 5
#> .pred_class .pred_VF .pred_F .pred_M .pred_L
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 VF 0.641 0.279 0.0670 0.0128
#> 2 VF 0.640 0.280 0.0671 0.0128
#> 3 VF 0.628 0.287 0.0711 0.0138
#> 4 VF 0.628 0.287 0.0711 0.0138
#> 5 VF 0.626 0.288 0.0716 0.0139
#> 6 VF 0.626 0.288 0.0719 0.0140
# a confusion matrix
test_predictions %>% conf_mat(class, .pred_class)
#> Truth
#> Prediction VF F M L
#> VF 516 278 74 16
#> F 18 46 36 4
#> M 2 7 19 21
#> L 0 11 7 28
# create some metrics:
cls_metrics <- metric_set(accuracy, precision, recall, pr_auc)
# precision, recal, and the PR AUC are caluclated using macro weighting of 4
# different 1 vs all results.
# See https://yardstick.tidymodels.org/articles/multiclass.html
# evaluate them
test_predictions %>%
# See ?metric_set for more information. We pass the truth (class), all of the
# predicted probability columns (.pred_VF:.pred_L), and the named hard class
# predictions.
cls_metrics(class, .pred_VF:.pred_L, estimate = .pred_class)
#> # A tibble: 4 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.562
#> 2 precision macro 0.506
#> 3 recall macro 0.411
#> 4 pr_auc macro 0.481
创建于 2022-12-09,使用reprex v2.0.2
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.