简体   繁体   中英

Is there a quicker way to make this confusion matrix table in R?

I am trying to make a confusion matrix table in R, using the following dataframe:

mydf <- structure(list(pred_class = c("dog", "dog", "fish", "cat", "cat", 
"dog", "fish", "cat", "dog", "fish"), true_class = c("cat", "cat", 
"dog", "cat", "cat", "dog", "dog", "cat", "dog", "fish")), row.names = c(NA, 
10L), class = "data.frame")

  pred_class true_class
1        dog        cat
2        dog        cat
3       fish        dog
4        cat        cat
5        cat        cat
6        dog        dog

I have produced code to do what I want - for every class (dog, cat or fish), say if each row is a true positive, a false positive, a true negative or a false negative.

conf_mat <- mydf %>%
    mutate(
        dog_conf = case_when(
            true_class == "dog" &  pred_class == "dog" ~ "TP",
            true_class == "dog" &  pred_class != "dog" ~ "FN",
            true_class != "dog" &  pred_class == "dog" ~ "FP",
            true_class != "dog" &  pred_class != "dog" ~ "TN"
        ),
        cat_conf = case_when(
            true_class == "cat" &  pred_class == "cat" ~ "TP",
            true_class == "cat" &  pred_class != "cat" ~ "FN",
            true_class != "cat" &  pred_class == "cat" ~ "FP",
            true_class != "cat" &  pred_class != "cat" ~ "TN"
        ),
        fish_conf = case_when(
            true_class == "fish" &  pred_class == "fish" ~ "TP",
            true_class == "fish" &  pred_class != "fish" ~ "FN",
            true_class != "fish" &  pred_class == "fish" ~ "FP",
            true_class != "fish" &  pred_class != "fish" ~ "TN"
        )
    )

However this code is very repetitive and bulky. I'm not sure how to streamline this. Does anyone have any suggestions? Thank you.

Here is one option with map where we loop over the unique elements of the dataset, create columns with transmute in the loop based on the conditions specified in OP's post and bind those columns with the original data

library(dplyr)
library(purrr)
library(stringr)

map_dfc(unique(unlist(mydf)), ~ 
      mydf %>% 
           transmute(!! str_c(.x, '_conf') := 
        case_when(true_class == .x &  pred_class == .x ~ "TP",
            true_class == .x &  pred_class != .x ~ "FN",
            true_class != .x &  pred_class == .x ~ "FP",
            true_class != .x &  pred_class != .x ~ "TN"
        ))) %>% 
   bind_cols(mydf, .)

-output

#     pred_class true_class dog_conf cat_conf fish_conf
#1         dog        cat       FP       FN        TN
#2         dog        cat       FP       FN        TN
#3        fish        dog       FN       TN        FP
#4         cat        cat       TN       TP        TN
#5         cat        cat       TN       TP        TN
#6         dog        dog       TP       TN        TN
#7        fish        dog       FN       TN        FP
#8         cat        cat       TN       TP        TN
#9         dog        dog       TP       TN        TN
#10       fish       fish       TN       TN        TP

Or using merge on a key val dataset

keydat <- data.frame(pred_class = c(TRUE, TRUE, FALSE, FALSE), 
   true_class = c(TRUE, FALSE, TRUE, FALSE), 
  conf = c("TP", "FN", "FP", "TN"))

un1 <- unique(unlist(mydf))
mydf[paste0(un1, "_conf")] <- lapply(un1, function(x)
             merge(mydf == x, keydat, all.x = TRUE)$conf)

In addition to the excellent answer by @akrun, if you wish to identify the status of each prediction (TP/TN/FP/FN) in order to calculate additional statistics/metrics, many of these can be provided by the caret package , eg

library(caret)
mydf <- structure(list(pred_class = c("dog", "dog", "fish", "cat", "cat", 
                                      "dog", "fish", "cat", "dog", "fish"), true_class = c("cat", "cat", 
                                                                                           "dog", "cat", "cat", "dog", "dog", "cat", "dog", "fish")), row.names = c(NA, 
                                                                                                                                                                    10L), class = "data.frame")

conf_matrix <- confusionMatrix(factor(mydf$pred_class),
                               reference = factor(mydf$true_class),
                               mode = "everything")
conf_matrix
#> Confusion Matrix and Statistics
#> 
#>           Reference
#> Prediction cat dog fish
#>       cat    3   0    0
#>       dog    2   2    0
#>      fish    0   2    1
#> 
#> Overall Statistics
#>                                          
#>                Accuracy : 0.6             
#>                  95% CI : (0.2624, 0.8784)
#>     No Information Rate : 0.5             
#>     P-Value [Acc > NIR] : 0.377           
#>                                          
#>                   Kappa : 0.3939          
#>                                          
#>  Mcnemar's Test P-Value : NA              
#>
#> Statistics by Class:
#>
#>                      Class: cat Class: dog Class: fish
#> Sensitivity              0.6000     0.5000      1.0000
#> Specificity              1.0000     0.6667      0.7778
#> Pos Pred Value           1.0000     0.5000      0.3333
#> Neg Pred Value           0.7143     0.6667      1.0000
#> Precision                1.0000     0.5000      0.3333
#> Recall                   0.6000     0.5000      1.0000
#> F1                       0.7500     0.5000      0.5000
#> Prevalence               0.5000     0.4000      0.1000
#> Detection Rate           0.3000     0.2000      0.1000
#> Detection Prevalence     0.3000     0.4000      0.3000
#> Balanced Accuracy        0.8000     0.5833      0.8889

Further explanation:

For a 2x2 table with notation

            Reference   
Predicted   Event   No Event
Event           A        B
No Event        C        D

When "A" = TP, "B" = FP, "C" = FN, "D" = TN, the formulas used by the package/function are:

  • Sensitivity = A/(A+C)
  • Specificity = D/(B+D)
  • Prevalence = (A+C)/(A+B+C+D)
  • PPV = (sensitivity * prevalence)/((sensitivity * prevalence) + ((1-specificity) * (1-prevalence)))
  • NPV = (specificity * (1-prevalence))/(((1-sensitivity) * prevalence) + ((specificity) * (1-prevalence))) Detection Rate = A/(A+B+C+D)
  • Detection Prevalence = (A+B)/(A+B+C+D)
  • Balanced Accuracy = (sensitivity+specificity)/2 Precision = A/(A+B) Recall = A/(A+C)
  • F1 = (1+beta^2) * precision * recall/((beta^2 * precision)+recall)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM