I am trying to make a confusion matrix table in R, using the following dataframe:
mydf <- structure(list(pred_class = c("dog", "dog", "fish", "cat", "cat",
"dog", "fish", "cat", "dog", "fish"), true_class = c("cat", "cat",
"dog", "cat", "cat", "dog", "dog", "cat", "dog", "fish")), row.names = c(NA,
10L), class = "data.frame")
pred_class true_class
1 dog cat
2 dog cat
3 fish dog
4 cat cat
5 cat cat
6 dog dog
I have produced code to do what I want - for every class (dog, cat or fish), say if each row is a true positive, a false positive, a true negative or a false negative.
conf_mat <- mydf %>%
mutate(
dog_conf = case_when(
true_class == "dog" & pred_class == "dog" ~ "TP",
true_class == "dog" & pred_class != "dog" ~ "FN",
true_class != "dog" & pred_class == "dog" ~ "FP",
true_class != "dog" & pred_class != "dog" ~ "TN"
),
cat_conf = case_when(
true_class == "cat" & pred_class == "cat" ~ "TP",
true_class == "cat" & pred_class != "cat" ~ "FN",
true_class != "cat" & pred_class == "cat" ~ "FP",
true_class != "cat" & pred_class != "cat" ~ "TN"
),
fish_conf = case_when(
true_class == "fish" & pred_class == "fish" ~ "TP",
true_class == "fish" & pred_class != "fish" ~ "FN",
true_class != "fish" & pred_class == "fish" ~ "FP",
true_class != "fish" & pred_class != "fish" ~ "TN"
)
)
However this code is very repetitive and bulky. I'm not sure how to streamline this. Does anyone have any suggestions? Thank you.
Here is one option with map
where we loop over the unique elements of the dataset, create columns with transmute
in the loop based on the conditions specified in OP's post and bind those columns with the original data
library(dplyr)
library(purrr)
library(stringr)
map_dfc(unique(unlist(mydf)), ~
mydf %>%
transmute(!! str_c(.x, '_conf') :=
case_when(true_class == .x & pred_class == .x ~ "TP",
true_class == .x & pred_class != .x ~ "FN",
true_class != .x & pred_class == .x ~ "FP",
true_class != .x & pred_class != .x ~ "TN"
))) %>%
bind_cols(mydf, .)
-output
# pred_class true_class dog_conf cat_conf fish_conf
#1 dog cat FP FN TN
#2 dog cat FP FN TN
#3 fish dog FN TN FP
#4 cat cat TN TP TN
#5 cat cat TN TP TN
#6 dog dog TP TN TN
#7 fish dog FN TN FP
#8 cat cat TN TP TN
#9 dog dog TP TN TN
#10 fish fish TN TN TP
Or using merge
on a key val dataset
keydat <- data.frame(pred_class = c(TRUE, TRUE, FALSE, FALSE),
true_class = c(TRUE, FALSE, TRUE, FALSE),
conf = c("TP", "FN", "FP", "TN"))
un1 <- unique(unlist(mydf))
mydf[paste0(un1, "_conf")] <- lapply(un1, function(x)
merge(mydf == x, keydat, all.x = TRUE)$conf)
In addition to the excellent answer by @akrun, if you wish to identify the status of each prediction (TP/TN/FP/FN) in order to calculate additional statistics/metrics, many of these can be provided by the caret package , eg
library(caret)
mydf <- structure(list(pred_class = c("dog", "dog", "fish", "cat", "cat",
"dog", "fish", "cat", "dog", "fish"), true_class = c("cat", "cat",
"dog", "cat", "cat", "dog", "dog", "cat", "dog", "fish")), row.names = c(NA,
10L), class = "data.frame")
conf_matrix <- confusionMatrix(factor(mydf$pred_class),
reference = factor(mydf$true_class),
mode = "everything")
conf_matrix
#> Confusion Matrix and Statistics
#>
#> Reference
#> Prediction cat dog fish
#> cat 3 0 0
#> dog 2 2 0
#> fish 0 2 1
#>
#> Overall Statistics
#>
#> Accuracy : 0.6
#> 95% CI : (0.2624, 0.8784)
#> No Information Rate : 0.5
#> P-Value [Acc > NIR] : 0.377
#>
#> Kappa : 0.3939
#>
#> Mcnemar's Test P-Value : NA
#>
#> Statistics by Class:
#>
#> Class: cat Class: dog Class: fish
#> Sensitivity 0.6000 0.5000 1.0000
#> Specificity 1.0000 0.6667 0.7778
#> Pos Pred Value 1.0000 0.5000 0.3333
#> Neg Pred Value 0.7143 0.6667 1.0000
#> Precision 1.0000 0.5000 0.3333
#> Recall 0.6000 0.5000 1.0000
#> F1 0.7500 0.5000 0.5000
#> Prevalence 0.5000 0.4000 0.1000
#> Detection Rate 0.3000 0.2000 0.1000
#> Detection Prevalence 0.3000 0.4000 0.3000
#> Balanced Accuracy 0.8000 0.5833 0.8889
Further explanation:
For a 2x2 table with notation
Reference
Predicted Event No Event
Event A B
No Event C D
When "A" = TP, "B" = FP, "C" = FN, "D" = TN, the formulas used by the package/function are:
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.