简体   繁体   English

有没有更快的方法在 R 中制作这个混淆矩阵表?

[英]Is there a quicker way to make this confusion matrix table in R?

I am trying to make a confusion matrix table in R, using the following dataframe:我正在尝试使用以下 dataframe 在 R 中制作混淆矩阵表:

mydf <- structure(list(pred_class = c("dog", "dog", "fish", "cat", "cat", 
"dog", "fish", "cat", "dog", "fish"), true_class = c("cat", "cat", 
"dog", "cat", "cat", "dog", "dog", "cat", "dog", "fish")), row.names = c(NA, 
10L), class = "data.frame")

  pred_class true_class
1        dog        cat
2        dog        cat
3       fish        dog
4        cat        cat
5        cat        cat
6        dog        dog

I have produced code to do what I want - for every class (dog, cat or fish), say if each row is a true positive, a false positive, a true negative or a false negative.我已经生成了代码来做我想做的事——对于每个 class(狗、猫或鱼),说每一行是真阳性、假阳性、真阴性还是假阴性。

conf_mat <- mydf %>%
    mutate(
        dog_conf = case_when(
            true_class == "dog" &  pred_class == "dog" ~ "TP",
            true_class == "dog" &  pred_class != "dog" ~ "FN",
            true_class != "dog" &  pred_class == "dog" ~ "FP",
            true_class != "dog" &  pred_class != "dog" ~ "TN"
        ),
        cat_conf = case_when(
            true_class == "cat" &  pred_class == "cat" ~ "TP",
            true_class == "cat" &  pred_class != "cat" ~ "FN",
            true_class != "cat" &  pred_class == "cat" ~ "FP",
            true_class != "cat" &  pred_class != "cat" ~ "TN"
        ),
        fish_conf = case_when(
            true_class == "fish" &  pred_class == "fish" ~ "TP",
            true_class == "fish" &  pred_class != "fish" ~ "FN",
            true_class != "fish" &  pred_class == "fish" ~ "FP",
            true_class != "fish" &  pred_class != "fish" ~ "TN"
        )
    )

However this code is very repetitive and bulky.但是,此代码非常重复且庞大。 I'm not sure how to streamline this.我不确定如何简化这一点。 Does anyone have any suggestions?有没有人有什么建议? Thank you.谢谢你。

Here is one option with map where we loop over the unique elements of the dataset, create columns with transmute in the loop based on the conditions specified in OP's post and bind those columns with the original data这是map的一个选项,我们在其中循环数据集的唯一元素,根据 OP 帖子中指定的条件在循环中创建带有transmute的列,并将这些列与原始数据绑定

library(dplyr)
library(purrr)
library(stringr)

map_dfc(unique(unlist(mydf)), ~ 
      mydf %>% 
           transmute(!! str_c(.x, '_conf') := 
        case_when(true_class == .x &  pred_class == .x ~ "TP",
            true_class == .x &  pred_class != .x ~ "FN",
            true_class != .x &  pred_class == .x ~ "FP",
            true_class != .x &  pred_class != .x ~ "TN"
        ))) %>% 
   bind_cols(mydf, .)

-output -输出

#     pred_class true_class dog_conf cat_conf fish_conf
#1         dog        cat       FP       FN        TN
#2         dog        cat       FP       FN        TN
#3        fish        dog       FN       TN        FP
#4         cat        cat       TN       TP        TN
#5         cat        cat       TN       TP        TN
#6         dog        dog       TP       TN        TN
#7        fish        dog       FN       TN        FP
#8         cat        cat       TN       TP        TN
#9         dog        dog       TP       TN        TN
#10       fish       fish       TN       TN        TP

Or using merge on a key val dataset或者在 key val 数据集上使用merge

keydat <- data.frame(pred_class = c(TRUE, TRUE, FALSE, FALSE), 
   true_class = c(TRUE, FALSE, TRUE, FALSE), 
  conf = c("TP", "FN", "FP", "TN"))

un1 <- unique(unlist(mydf))
mydf[paste0(un1, "_conf")] <- lapply(un1, function(x)
             merge(mydf == x, keydat, all.x = TRUE)$conf)

In addition to the excellent answer by @akrun, if you wish to identify the status of each prediction (TP/TN/FP/FN) in order to calculate additional statistics/metrics, many of these can be provided by the caret package , eg除了@akrun 的出色回答,如果您希望确定每个预测的状态(TP/TN/FP/FN)以计算其他统计/指标,其中许多可以由插入符号 package提供,例如

library(caret)
mydf <- structure(list(pred_class = c("dog", "dog", "fish", "cat", "cat", 
                                      "dog", "fish", "cat", "dog", "fish"), true_class = c("cat", "cat", 
                                                                                           "dog", "cat", "cat", "dog", "dog", "cat", "dog", "fish")), row.names = c(NA, 
                                                                                                                                                                    10L), class = "data.frame")

conf_matrix <- confusionMatrix(factor(mydf$pred_class),
                               reference = factor(mydf$true_class),
                               mode = "everything")
conf_matrix
#> Confusion Matrix and Statistics
#> 
#>           Reference
#> Prediction cat dog fish
#>       cat    3   0    0
#>       dog    2   2    0
#>      fish    0   2    1
#> 
#> Overall Statistics
#>                                          
#>                Accuracy : 0.6             
#>                  95% CI : (0.2624, 0.8784)
#>     No Information Rate : 0.5             
#>     P-Value [Acc > NIR] : 0.377           
#>                                          
#>                   Kappa : 0.3939          
#>                                          
#>  Mcnemar's Test P-Value : NA              
#>
#> Statistics by Class:
#>
#>                      Class: cat Class: dog Class: fish
#> Sensitivity              0.6000     0.5000      1.0000
#> Specificity              1.0000     0.6667      0.7778
#> Pos Pred Value           1.0000     0.5000      0.3333
#> Neg Pred Value           0.7143     0.6667      1.0000
#> Precision                1.0000     0.5000      0.3333
#> Recall                   0.6000     0.5000      1.0000
#> F1                       0.7500     0.5000      0.5000
#> Prevalence               0.5000     0.4000      0.1000
#> Detection Rate           0.3000     0.2000      0.1000
#> Detection Prevalence     0.3000     0.4000      0.3000
#> Balanced Accuracy        0.8000     0.5833      0.8889

Further explanation:进一步说明:

For a 2x2 table with notation对于带有符号的 2x2 表

            Reference   
Predicted   Event   No Event
Event           A        B
No Event        C        D

When "A" = TP, "B" = FP, "C" = FN, "D" = TN, the formulas used by the package/function are:当“A”=TP、“B”=FP、“C”=FN、“D”=TN时,包/函数使用的公式为:

  • Sensitivity = A/(A+C)灵敏度 = A/(A+C)
  • Specificity = D/(B+D)特异性 = D/(B+D)
  • Prevalence = (A+C)/(A+B+C+D)患病率 = (A+C)/(A+B+C+D)
  • PPV = (sensitivity * prevalence)/((sensitivity * prevalence) + ((1-specificity) * (1-prevalence))) PPV = (敏感性 * 患病率)/((敏感性 * 患病率) + ((1-特异性) * (1-患病率)))
  • NPV = (specificity * (1-prevalence))/(((1-sensitivity) * prevalence) + ((specificity) * (1-prevalence))) Detection Rate = A/(A+B+C+D) NPV = (特异性 * (1-患病率))/(((1-敏感性) * 患病率) + ((特异性) * (1-患病率))) 检出率 = A/(A+B+C+D)
  • Detection Prevalence = (A+B)/(A+B+C+D)检出率 = (A+B)/(A+B+C+D)
  • Balanced Accuracy = (sensitivity+specificity)/2 Precision = A/(A+B) Recall = A/(A+C)平衡准确度 =(灵敏度+特异性)/2 精确度 = A/(A+B) 召回率 = A/(A+C)
  • F1 = (1+beta^2) * precision * recall/((beta^2 * precision)+recall) F1 = (1+beta^2) * 精度 * 召回率/((beta^2 * 精度)+召回率)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM