[英]How do we calculate multiclass probabilities in yardstick?
我有一個多類分類問題,想使用pr_curve
中的 yardstick 庫中的 pr_curve 構建精確召回曲線。這個 function 需要將每個 class 的概率的 tibble 提供給它,就像這樣(這是data(hpc_cv)
)。 我如何從我的分類結果中到達那里,存儲為 tibble 中的列?
library(yardstick)
data <- tibble(predicted = as.factor(c("A", "A", "B", "B", "C", "C")),
expected = as.factor(c("A", "B", "B", "C", "A", "C")))
data %>% conf_mat(truth = expected, estimate = predicted)
我還沒有找到 function 的標准(或其他地方)來計算這些。
我不確定 class 概率是如何計算的,我在想這些:
data %>% filter(predicted == "A") %>% summarise(n = n() / 6)
這個對嗎? 如果是這樣,我想知道是否有一種很好的方法可以在每次折疊的每個 class 上不使用 for 循環,並在上圖中接收像 hpc_cv 這樣的 tibble。
我不確定 class 概率是如何計算的
Class 概率由特定的 model 為每個單獨的數據點生成。
PR 曲線(以及精度和召回率)是結果有兩個類別的數據集的指標。 不過,您可以進行多類平均以獲得整體 PR 曲線 AUC。
下面有一個例子,但我建議在繼續之前閱讀一下tidymodels 書。
library(nnet) # <- for mutlinom_fit
library(tidymodels)
tidymodels_prefer()
data(hpc_data, package = "modeldata")
set.seed(1)
hpc_split <- initial_split(hpc_data)
hpc_train <- training(hpc_split)
hpc_test <- testing(hpc_split)
set.seed(2)
mutlinom_fit <-
multinom_reg() %>%
fit(class ~ iterations + compounds, data = hpc_train)
test_predictions <- augment(mutlinom_fit, new_data = hpc_test)
# examples of the hard class predictions and the
# predicted probabilities:
test_predictions %>% select(starts_with(".pred")) %>% head()
#> # A tibble: 6 × 5
#> .pred_class .pred_VF .pred_F .pred_M .pred_L
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 VF 0.641 0.279 0.0670 0.0128
#> 2 VF 0.640 0.280 0.0671 0.0128
#> 3 VF 0.628 0.287 0.0711 0.0138
#> 4 VF 0.628 0.287 0.0711 0.0138
#> 5 VF 0.626 0.288 0.0716 0.0139
#> 6 VF 0.626 0.288 0.0719 0.0140
# a confusion matrix
test_predictions %>% conf_mat(class, .pred_class)
#> Truth
#> Prediction VF F M L
#> VF 516 278 74 16
#> F 18 46 36 4
#> M 2 7 19 21
#> L 0 11 7 28
# create some metrics:
cls_metrics <- metric_set(accuracy, precision, recall, pr_auc)
# precision, recal, and the PR AUC are caluclated using macro weighting of 4
# different 1 vs all results.
# See https://yardstick.tidymodels.org/articles/multiclass.html
# evaluate them
test_predictions %>%
# See ?metric_set for more information. We pass the truth (class), all of the
# predicted probability columns (.pred_VF:.pred_L), and the named hard class
# predictions.
cls_metrics(class, .pred_VF:.pred_L, estimate = .pred_class)
#> # A tibble: 4 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.562
#> 2 precision macro 0.506
#> 3 recall macro 0.411
#> 4 pr_auc macro 0.481
創建於 2022-12-09,使用reprex v2.0.2
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.