繁体   English   中英

我们如何计算标准中的多类概率?

[英]How do we calculate multiclass probabilities in yardstick?

我有一个多类分类问题,想使用pr_curve中的 yardstick 库中的 pr_curve 构建精确召回曲线。这个 function 需要将每个 class 的概率的 tibble 提供给它,就像这样(这是data(hpc_cv) )。 在此处输入图像描述 我如何从我的分类结果中到达那里,存储为 tibble 中的列?

library(yardstick)
data <- tibble(predicted = as.factor(c("A", "A", "B", "B", "C", "C")), 
               expected = as.factor(c("A", "B", "B", "C", "A", "C")))
data %>% conf_mat(truth = expected, estimate = predicted)

我还没有找到 function 的标准(或其他地方)来计算这些。

我不确定 class 概率是如何计算的,我在想这些:

data %>% filter(predicted == "A") %>% summarise(n = n() / 6)

这个对吗? 如果是这样,我想知道是否有一种很好的方法可以在每次折叠的每个 class 上不使用 for 循环,并在上图中接收像 hpc_cv 这样的 tibble。

我不确定 class 概率是如何计算的

Class 概率由特定的 model 为每个单独的数据点生成。

PR 曲线(以及精度和召回率)是结果有两个类别的数据集的指标。 不过,您可以进行多类平均以获得整体 PR 曲线 AUC。

下面有一个例子,但我建议在继续之前阅读一下tidymodels 书

library(nnet) # <- for mutlinom_fit
library(tidymodels)

tidymodels_prefer()

data(hpc_data, package = "modeldata")

set.seed(1)
hpc_split <- initial_split(hpc_data)
hpc_train <- training(hpc_split)
hpc_test  <- testing(hpc_split)

set.seed(2)
mutlinom_fit <- 
  multinom_reg() %>% 
  fit(class ~ iterations + compounds, data = hpc_train)

test_predictions <- augment(mutlinom_fit, new_data = hpc_test)

# examples of the hard class predictions and the 
# predicted probabilities: 
test_predictions %>% select(starts_with(".pred")) %>% head()
#> # A tibble: 6 × 5
#>   .pred_class .pred_VF .pred_F .pred_M .pred_L
#>   <fct>          <dbl>   <dbl>   <dbl>   <dbl>
#> 1 VF             0.641   0.279  0.0670  0.0128
#> 2 VF             0.640   0.280  0.0671  0.0128
#> 3 VF             0.628   0.287  0.0711  0.0138
#> 4 VF             0.628   0.287  0.0711  0.0138
#> 5 VF             0.626   0.288  0.0716  0.0139
#> 6 VF             0.626   0.288  0.0719  0.0140

# a confusion matrix
test_predictions %>% conf_mat(class, .pred_class)
#>           Truth
#> Prediction  VF   F   M   L
#>         VF 516 278  74  16
#>         F   18  46  36   4
#>         M    2   7  19  21
#>         L    0  11   7  28

# create some metrics:
cls_metrics <- metric_set(accuracy, precision, recall, pr_auc)
# precision, recal, and the PR AUC are caluclated using macro weighting of 4 
# different 1 vs all results. 
# See https://yardstick.tidymodels.org/articles/multiclass.html

# evaluate them
test_predictions %>% 
  # See ?metric_set for more information. We pass the truth (class), all of the
  # predicted probability columns (.pred_VF:.pred_L), and the named hard class
  # predictions. 
  cls_metrics(class, .pred_VF:.pred_L, estimate = .pred_class)
#> # A tibble: 4 × 3
#>   .metric   .estimator .estimate
#>   <chr>     <chr>          <dbl>
#> 1 accuracy  multiclass     0.562
#> 2 precision macro          0.506
#> 3 recall    macro          0.411
#> 4 pr_auc    macro          0.481

创建于 2022-12-09,使用reprex v2.0.2

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM