簡體   English   中英

我們如何計算標准中的多類概率?

[英]How do we calculate multiclass probabilities in yardstick?

我有一個多類分類問題,想使用pr_curve中的 yardstick 庫中的 pr_curve 構建精確召回曲線。這個 function 需要將每個 class 的概率的 tibble 提供給它,就像這樣(這是data(hpc_cv) )。 在此處輸入圖像描述 我如何從我的分類結果中到達那里,存儲為 tibble 中的列?

library(yardstick)
data <- tibble(predicted = as.factor(c("A", "A", "B", "B", "C", "C")), 
               expected = as.factor(c("A", "B", "B", "C", "A", "C")))
data %>% conf_mat(truth = expected, estimate = predicted)

我還沒有找到 function 的標准(或其他地方)來計算這些。

我不確定 class 概率是如何計算的,我在想這些:

data %>% filter(predicted == "A") %>% summarise(n = n() / 6)

這個對嗎? 如果是這樣,我想知道是否有一種很好的方法可以在每次折疊的每個 class 上不使用 for 循環,並在上圖中接收像 hpc_cv 這樣的 tibble。

我不確定 class 概率是如何計算的

Class 概率由特定的 model 為每個單獨的數據點生成。

PR 曲線(以及精度和召回率)是結果有兩個類別的數據集的指標。 不過,您可以進行多類平均以獲得整體 PR 曲線 AUC。

下面有一個例子,但我建議在繼續之前閱讀一下tidymodels 書

library(nnet) # <- for mutlinom_fit
library(tidymodels)

tidymodels_prefer()

data(hpc_data, package = "modeldata")

set.seed(1)
hpc_split <- initial_split(hpc_data)
hpc_train <- training(hpc_split)
hpc_test  <- testing(hpc_split)

set.seed(2)
mutlinom_fit <- 
  multinom_reg() %>% 
  fit(class ~ iterations + compounds, data = hpc_train)

test_predictions <- augment(mutlinom_fit, new_data = hpc_test)

# examples of the hard class predictions and the 
# predicted probabilities: 
test_predictions %>% select(starts_with(".pred")) %>% head()
#> # A tibble: 6 × 5
#>   .pred_class .pred_VF .pred_F .pred_M .pred_L
#>   <fct>          <dbl>   <dbl>   <dbl>   <dbl>
#> 1 VF             0.641   0.279  0.0670  0.0128
#> 2 VF             0.640   0.280  0.0671  0.0128
#> 3 VF             0.628   0.287  0.0711  0.0138
#> 4 VF             0.628   0.287  0.0711  0.0138
#> 5 VF             0.626   0.288  0.0716  0.0139
#> 6 VF             0.626   0.288  0.0719  0.0140

# a confusion matrix
test_predictions %>% conf_mat(class, .pred_class)
#>           Truth
#> Prediction  VF   F   M   L
#>         VF 516 278  74  16
#>         F   18  46  36   4
#>         M    2   7  19  21
#>         L    0  11   7  28

# create some metrics:
cls_metrics <- metric_set(accuracy, precision, recall, pr_auc)
# precision, recal, and the PR AUC are caluclated using macro weighting of 4 
# different 1 vs all results. 
# See https://yardstick.tidymodels.org/articles/multiclass.html

# evaluate them
test_predictions %>% 
  # See ?metric_set for more information. We pass the truth (class), all of the
  # predicted probability columns (.pred_VF:.pred_L), and the named hard class
  # predictions. 
  cls_metrics(class, .pred_VF:.pred_L, estimate = .pred_class)
#> # A tibble: 4 × 3
#>   .metric   .estimator .estimate
#>   <chr>     <chr>          <dbl>
#> 1 accuracy  multiclass     0.562
#> 2 precision macro          0.506
#> 3 recall    macro          0.411
#> 4 pr_auc    macro          0.481

創建於 2022-12-09,使用reprex v2.0.2

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM